added support for 2-way RPC sessions. closes #1

This commit is contained in:
Justin 2015-04-06 07:08:26 -04:00
parent 24010b4a95
commit 5dfe2ab054
2 changed files with 83 additions and 19 deletions

View File

@ -1,6 +1,7 @@
package sshrpc package sshrpc
import ( import (
"log"
"net/rpc" "net/rpc"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -12,6 +13,7 @@ type Client struct {
Config *ssh.ClientConfig Config *ssh.ClientConfig
ChannelName string ChannelName string
sshClient *ssh.Client sshClient *ssh.Client
RPCServer *rpc.Server
} }
// NewClient returns a new Client to handle RPC requests. // NewClient returns a new Client to handle RPC requests.
@ -24,7 +26,7 @@ func NewClient() *Client {
}, },
} }
return &Client{nil, config, DefaultRPCChannel, nil} return &Client{nil, config, DefaultRPCChannel, nil, rpc.NewServer()}
} }
@ -37,37 +39,27 @@ func (c *Client) Connect(address string) {
} }
c.sshClient = sshClient c.sshClient = sshClient
c.openRPCServerChannel(c.ChannelName + "-reverse")
// Each ClientConn can support multiple channels // Each ClientConn can support multiple channels
channel, err := c.openRPCChannel(c.ChannelName) channel, err := openRPCClientChannel(c.sshClient.Conn, c.ChannelName)
if err != nil { if err != nil {
panic("Failed to create channel: " + err.Error()) panic("Failed to create channel: " + err.Error())
} }
c.Client = rpc.NewClient(channel) c.Client = rpc.NewClient(channel)
return
} }
// openRPCChannel opens an SSH RPC channel and makes an ssh subsystem request to trigger remote RPC server start // openRPCClientChannel opens an SSH RPC channel and makes an ssh subsystem request to trigger remote RPC server start
func (c *Client) openRPCChannel(channelName string) (ssh.Channel, error) { func openRPCClientChannel(conn ssh.Conn, channelName string) (ssh.Channel, error) {
channel, in, err := c.sshClient.OpenChannel(channelName, nil) channel, in, err := conn.OpenChannel(channelName, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// SSH Documentation states that this go channel of requests needs to be serviced // SSH Documentation states that this go channel of requests needs to be serviced
go func(reqs <-chan *ssh.Request) error { go ssh.DiscardRequests(in)
for msg := range reqs {
switch msg.Type {
default:
if msg.WantReply {
msg.Reply(false, nil)
}
}
}
return nil
}(in)
var msg struct { var msg struct {
Subsystem string Subsystem string
@ -81,3 +73,63 @@ func (c *Client) openRPCChannel(channelName string) (ssh.Channel, error) {
return channel, nil return channel, nil
} }
func (c *Client) openRPCServerChannel(channelName string) error {
subChannel := c.sshClient.HandleChannelOpen(channelName)
go func(chans <-chan ssh.NewChannel) {
for newChannel := range chans {
/* Don't need to check, already know what channel type is coming in
// Check the type of channel
if t := newChannel.ChannelType(); t != channelName {
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
continue
}
*/
acceptRPCServerRequest(c.RPCServer, newChannel)
}
}(subChannel)
return nil
}
func acceptRPCServerRequest(rpcServer *rpc.Server, newChannel ssh.NewChannel) {
channel, requests, err := newChannel.Accept()
if err != nil {
log.Printf("could not accept channel (%s)", err)
return
}
log.Printf("Accepted channel")
// Channels can have out-of-band requests
go func(in <-chan *ssh.Request) {
for req := range in {
ok := false
switch req.Type {
case "subsystem":
ok = true
log.Printf("subsystem '%s'", req.Payload)
switch string(req.Payload[4:]) {
//RPCSubsystem Request made indicates client desires RPC Server access
case RPCSubsystem:
go rpcServer.ServeConn(channel)
log.Printf("Started SSH RPC")
default:
log.Printf("Unknown subsystem: %s", req.Payload)
}
}
if !ok {
log.Printf("declining %s request...", req.Type)
}
req.Reply(ok, nil)
}
}(requests)
}
// Wait allows clients also acting as an RPC server to detect when the ssh connection ends
func (c *Client) Wait() error {
return c.sshClient.Conn.Wait()
}

View File

@ -20,6 +20,8 @@ type Server struct {
*rpc.Server *rpc.Server
Config *ssh.ServerConfig Config *ssh.ServerConfig
ChannelName string ChannelName string
RPCClient *rpc.Client
sshConn ssh.Conn
} }
// NewServer returns a new Server to handle incoming SSH and RPC requests. // NewServer returns a new Server to handle incoming SSH and RPC requests.
@ -32,7 +34,7 @@ func NewServer() *Server {
return nil, fmt.Errorf("password rejected for %q", c.User()) return nil, fmt.Errorf("password rejected for %q", c.User())
}, },
} }
return &Server{rpc.NewServer(), c, DefaultRPCChannel} return &Server{rpc.NewServer(), c, DefaultRPCChannel, nil, nil}
} }
@ -55,6 +57,7 @@ func (s *Server) StartServer(address string) {
} }
// Before use, a handshake must be performed on the incoming net.Conn. // Before use, a handshake must be performed on the incoming net.Conn.
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, s.Config) sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, s.Config)
s.sshConn = sshConn
if err != nil { if err != nil {
log.Printf("failed to handshake (%s)", err) log.Printf("failed to handshake (%s)", err)
continue continue
@ -65,6 +68,7 @@ func (s *Server) StartServer(address string) {
go s.handleRequests(reqs) go s.handleRequests(reqs)
// Accept all channels // Accept all channels
go s.handleChannels(chans) go s.handleChannels(chans)
} }
} }
@ -107,6 +111,14 @@ func (s *Server) handleChannels(chans <-chan ssh.NewChannel) {
case RPCSubsystem: case RPCSubsystem:
go s.ServeConn(channel) go s.ServeConn(channel)
log.Printf("Started SSH RPC") log.Printf("Started SSH RPC")
// triggers reverse RPC connection as well
clientChannel, err := openRPCClientChannel(s.sshConn, s.ChannelName+"-reverse")
if err != nil {
log.Printf("Failed to create client channel: " + err.Error())
continue
}
s.RPCClient = rpc.NewClient(clientChannel)
log.Printf("Started SSH RPC client")
default: default:
log.Printf("Unknown subsystem: %s", req.Payload) log.Printf("Unknown subsystem: %s", req.Payload)
} }