Updated to use custom channel for RPC connection instead of session

This commit is contained in:
Justin 2015-04-05 06:41:07 -04:00
parent 732e1bd69c
commit 24010b4a95
2 changed files with 62 additions and 46 deletions

View File

@ -1,37 +1,17 @@
package sshrpc package sshrpc
import ( import (
"fmt"
"net/rpc" "net/rpc"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type sshrpcSession struct {
*ssh.Session
}
func (s sshrpcSession) Read(p []byte) (n int, err error) {
pipe, err := s.StdoutPipe()
if err != nil {
return 0, err
}
return pipe.Read(p)
}
func (s sshrpcSession) Write(p []byte) (n int, err error) {
pipe, err := s.StdinPipe()
if err != nil {
return 0, err
}
return pipe.Write(p)
}
// Client represents an RPC client using an SSH backed connection. // Client represents an RPC client using an SSH backed connection.
type Client struct { type Client struct {
*rpc.Client *rpc.Client
Config *ssh.ClientConfig Config *ssh.ClientConfig
Subsystem string ChannelName string
sshClient *ssh.Client
} }
// NewClient returns a new Client to handle RPC requests. // NewClient returns a new Client to handle RPC requests.
@ -44,7 +24,7 @@ func NewClient() *Client {
}, },
} }
return &Client{nil, config, "sshrpc"} return &Client{nil, config, DefaultRPCChannel, nil}
} }
@ -55,22 +35,49 @@ func (c *Client) Connect(address string) {
if err != nil { if err != nil {
panic("Failed to dial: " + err.Error()) panic("Failed to dial: " + err.Error())
} }
c.sshClient = sshClient
// Each ClientConn can support multiple interactive sessions, // Each ClientConn can support multiple channels
// represented by a Session. channel, err := c.openRPCChannel(c.ChannelName)
sshSession, err := sshClient.NewSession()
if err != nil { if err != nil {
panic("Failed to create session: " + err.Error()) panic("Failed to create channel: " + err.Error())
}
//defer sshSession.Close()
err = sshSession.RequestSubsystem(c.Subsystem)
if err != nil {
fmt.Println("Unable to start subsystem:", err.Error())
} }
session := sshrpcSession{sshSession} c.Client = rpc.NewClient(channel)
c.Client = rpc.NewClient(session)
return return
} }
// openRPCChannel opens an SSH RPC channel and makes an ssh subsystem request to trigger remote RPC server start
func (c *Client) openRPCChannel(channelName string) (ssh.Channel, error) {
channel, in, err := c.sshClient.OpenChannel(channelName, nil)
if err != nil {
return nil, err
}
// SSH Documentation states that this go channel of requests needs to be serviced
go func(reqs <-chan *ssh.Request) error {
for msg := range reqs {
switch msg.Type {
default:
if msg.WantReply {
msg.Reply(false, nil)
}
}
}
return nil
}(in)
var msg struct {
Subsystem string
}
msg.Subsystem = RPCSubsystem
ok, err := channel.SendRequest("subsystem", true, ssh.Marshal(&msg))
if err == nil && !ok {
return nil, err
}
return channel, nil
}

View File

@ -9,11 +9,17 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// DefaultRPCChannel is the Channel Name that will be used to carry the RPC traffic, can be changed in Serverr and Client
const DefaultRPCChannel = "RPCChannel"
// RPCSubsystem is the subsystem that will be used to trigger RPC endpoint creation
const RPCSubsystem = "RPCSubsystem"
// Server represents an SSH Server that spins up RPC servers when requested. // Server represents an SSH Server that spins up RPC servers when requested.
type Server struct { type Server struct {
*rpc.Server *rpc.Server
Config *ssh.ServerConfig Config *ssh.ServerConfig
Subsystem string ChannelName string
} }
// NewServer returns a new Server to handle incoming SSH and RPC requests. // NewServer returns a new Server to handle incoming SSH and RPC requests.
@ -26,7 +32,7 @@ func NewServer() *Server {
return nil, fmt.Errorf("password rejected for %q", c.User()) return nil, fmt.Errorf("password rejected for %q", c.User())
}, },
} }
return &Server{rpc.NewServer(), c, "sshrpc"} return &Server{rpc.NewServer(), c, DefaultRPCChannel}
} }
@ -62,20 +68,21 @@ func (s *Server) StartServer(address string) {
} }
} }
// handleRequests handles global out-of-band SSH Requests
func (s *Server) handleRequests(reqs <-chan *ssh.Request) { func (s *Server) handleRequests(reqs <-chan *ssh.Request) {
for req := range reqs { for req := range reqs {
log.Printf("recieved out-of-band request: %+v", req) log.Printf("recieved out-of-band request: %+v", req)
} }
} }
// handleChannels handels SSH Channel requests and their local out-of-band SSH Requests
func (s *Server) handleChannels(chans <-chan ssh.NewChannel) { func (s *Server) handleChannels(chans <-chan ssh.NewChannel) {
// Service the incoming Channel channel. // Service the incoming Channel channel.
for newChannel := range chans { for newChannel := range chans {
// Channels have a type, depending on the application level
// protocol intended. In the case of a shell, the type is log.Printf("Received channel: %v", newChannel.ChannelType())
// "session" and ServerShell may be used to present a simple // Check the type of channel
// terminal interface. if t := newChannel.ChannelType(); t != s.ChannelName {
if t := newChannel.ChannelType(); t != "session" {
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
continue continue
} }
@ -84,8 +91,9 @@ func (s *Server) handleChannels(chans <-chan ssh.NewChannel) {
log.Printf("could not accept channel (%s)", err) log.Printf("could not accept channel (%s)", err)
continue continue
} }
log.Printf("Accepted channel")
// Sessions have out-of-band requests such as "shell", "pty-req" and "env" // Channels can have out-of-band requests
go func(in <-chan *ssh.Request) { go func(in <-chan *ssh.Request) {
for req := range in { for req := range in {
ok := false ok := false
@ -95,7 +103,8 @@ func (s *Server) handleChannels(chans <-chan ssh.NewChannel) {
ok = true ok = true
log.Printf("subsystem '%s'", req.Payload) log.Printf("subsystem '%s'", req.Payload)
switch string(req.Payload[4:]) { switch string(req.Payload[4:]) {
case s.Subsystem: //RPCSubsystem Request made indicates client desires RPC Server access
case RPCSubsystem:
go s.ServeConn(channel) go s.ServeConn(channel)
log.Printf("Started SSH RPC") log.Printf("Started SSH RPC")
default: default: