diff --git a/client.go b/client.go index 86e1802..a86889d 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package sshrpc import ( + "log" "net/rpc" "golang.org/x/crypto/ssh" @@ -12,6 +13,7 @@ type Client struct { Config *ssh.ClientConfig ChannelName string sshClient *ssh.Client + RPCServer *rpc.Server } // NewClient returns a new Client to handle RPC requests. @@ -24,7 +26,7 @@ func NewClient() *Client { }, } - return &Client{nil, config, DefaultRPCChannel, nil} + return &Client{nil, config, DefaultRPCChannel, nil, rpc.NewServer()} } @@ -37,37 +39,27 @@ func (c *Client) Connect(address string) { } c.sshClient = sshClient + c.openRPCServerChannel(c.ChannelName + "-reverse") + // Each ClientConn can support multiple channels - channel, err := c.openRPCChannel(c.ChannelName) + channel, err := openRPCClientChannel(c.sshClient.Conn, c.ChannelName) if err != nil { panic("Failed to create channel: " + err.Error()) } c.Client = rpc.NewClient(channel) - return } -// openRPCChannel opens an SSH RPC channel and makes an ssh subsystem request to trigger remote RPC server start -func (c *Client) openRPCChannel(channelName string) (ssh.Channel, error) { - channel, in, err := c.sshClient.OpenChannel(channelName, nil) +// openRPCClientChannel opens an SSH RPC channel and makes an ssh subsystem request to trigger remote RPC server start +func openRPCClientChannel(conn ssh.Conn, channelName string) (ssh.Channel, error) { + channel, in, err := conn.OpenChannel(channelName, nil) if err != nil { return nil, err } // SSH Documentation states that this go channel of requests needs to be serviced - go func(reqs <-chan *ssh.Request) error { - for msg := range reqs { - switch msg.Type { - - default: - if msg.WantReply { - msg.Reply(false, nil) - } - } - } - return nil - }(in) + go ssh.DiscardRequests(in) var msg struct { Subsystem string @@ -81,3 +73,63 @@ func (c *Client) openRPCChannel(channelName string) (ssh.Channel, error) { return channel, nil } + +func (c *Client) openRPCServerChannel(channelName string) error { + subChannel := c.sshClient.HandleChannelOpen(channelName) + go func(chans <-chan ssh.NewChannel) { + + for newChannel := range chans { + + /* Don't need to check, already know what channel type is coming in + // Check the type of channel + if t := newChannel.ChannelType(); t != channelName { + newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) + continue + } + */ + acceptRPCServerRequest(c.RPCServer, newChannel) + } + }(subChannel) + return nil +} + +func acceptRPCServerRequest(rpcServer *rpc.Server, newChannel ssh.NewChannel) { + channel, requests, err := newChannel.Accept() + if err != nil { + log.Printf("could not accept channel (%s)", err) + return + } + log.Printf("Accepted channel") + + // Channels can have out-of-band requests + go func(in <-chan *ssh.Request) { + for req := range in { + ok := false + switch req.Type { + + case "subsystem": + ok = true + log.Printf("subsystem '%s'", req.Payload) + switch string(req.Payload[4:]) { + //RPCSubsystem Request made indicates client desires RPC Server access + case RPCSubsystem: + go rpcServer.ServeConn(channel) + log.Printf("Started SSH RPC") + default: + log.Printf("Unknown subsystem: %s", req.Payload) + } + + } + if !ok { + log.Printf("declining %s request...", req.Type) + } + req.Reply(ok, nil) + } + + }(requests) +} + +// Wait allows clients also acting as an RPC server to detect when the ssh connection ends +func (c *Client) Wait() error { + return c.sshClient.Conn.Wait() +} diff --git a/server.go b/server.go index 7fc3d8c..299f1cd 100644 --- a/server.go +++ b/server.go @@ -20,6 +20,8 @@ type Server struct { *rpc.Server Config *ssh.ServerConfig ChannelName string + RPCClient *rpc.Client + sshConn ssh.Conn } // NewServer returns a new Server to handle incoming SSH and RPC requests. @@ -32,7 +34,7 @@ func NewServer() *Server { return nil, fmt.Errorf("password rejected for %q", c.User()) }, } - return &Server{rpc.NewServer(), c, DefaultRPCChannel} + return &Server{rpc.NewServer(), c, DefaultRPCChannel, nil, nil} } @@ -55,6 +57,7 @@ func (s *Server) StartServer(address string) { } // Before use, a handshake must be performed on the incoming net.Conn. sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, s.Config) + s.sshConn = sshConn if err != nil { log.Printf("failed to handshake (%s)", err) continue @@ -65,6 +68,7 @@ func (s *Server) StartServer(address string) { go s.handleRequests(reqs) // Accept all channels go s.handleChannels(chans) + } } @@ -107,6 +111,14 @@ func (s *Server) handleChannels(chans <-chan ssh.NewChannel) { case RPCSubsystem: go s.ServeConn(channel) log.Printf("Started SSH RPC") + // triggers reverse RPC connection as well + clientChannel, err := openRPCClientChannel(s.sshConn, s.ChannelName+"-reverse") + if err != nil { + log.Printf("Failed to create client channel: " + err.Error()) + continue + } + s.RPCClient = rpc.NewClient(clientChannel) + log.Printf("Started SSH RPC client") default: log.Printf("Unknown subsystem: %s", req.Payload) }