diff --git a/client.go b/client.go index 4c494cf..86e1802 100644 --- a/client.go +++ b/client.go @@ -1,37 +1,17 @@ package sshrpc import ( - "fmt" "net/rpc" "golang.org/x/crypto/ssh" ) -type sshrpcSession struct { - *ssh.Session -} - -func (s sshrpcSession) Read(p []byte) (n int, err error) { - pipe, err := s.StdoutPipe() - if err != nil { - return 0, err - } - return pipe.Read(p) -} - -func (s sshrpcSession) Write(p []byte) (n int, err error) { - pipe, err := s.StdinPipe() - if err != nil { - return 0, err - } - return pipe.Write(p) -} - // Client represents an RPC client using an SSH backed connection. type Client struct { *rpc.Client - Config *ssh.ClientConfig - Subsystem string + Config *ssh.ClientConfig + ChannelName string + sshClient *ssh.Client } // NewClient returns a new Client to handle RPC requests. @@ -44,7 +24,7 @@ func NewClient() *Client { }, } - return &Client{nil, config, "sshrpc"} + return &Client{nil, config, DefaultRPCChannel, nil} } @@ -55,22 +35,49 @@ func (c *Client) Connect(address string) { if err != nil { panic("Failed to dial: " + err.Error()) } + c.sshClient = sshClient - // Each ClientConn can support multiple interactive sessions, - // represented by a Session. - sshSession, err := sshClient.NewSession() + // Each ClientConn can support multiple channels + channel, err := c.openRPCChannel(c.ChannelName) if err != nil { - panic("Failed to create session: " + err.Error()) - } - //defer sshSession.Close() - - err = sshSession.RequestSubsystem(c.Subsystem) - if err != nil { - fmt.Println("Unable to start subsystem:", err.Error()) + panic("Failed to create channel: " + err.Error()) } - session := sshrpcSession{sshSession} - c.Client = rpc.NewClient(session) + c.Client = rpc.NewClient(channel) return } + +// openRPCChannel opens an SSH RPC channel and makes an ssh subsystem request to trigger remote RPC server start +func (c *Client) openRPCChannel(channelName string) (ssh.Channel, error) { + channel, in, err := c.sshClient.OpenChannel(channelName, nil) + if err != nil { + return nil, err + } + + // SSH Documentation states that this go channel of requests needs to be serviced + go func(reqs <-chan *ssh.Request) error { + for msg := range reqs { + switch msg.Type { + + default: + if msg.WantReply { + msg.Reply(false, nil) + } + } + } + return nil + }(in) + + var msg struct { + Subsystem string + } + msg.Subsystem = RPCSubsystem + + ok, err := channel.SendRequest("subsystem", true, ssh.Marshal(&msg)) + if err == nil && !ok { + return nil, err + } + + return channel, nil +} diff --git a/server.go b/server.go index dc3980b..7fc3d8c 100644 --- a/server.go +++ b/server.go @@ -9,11 +9,17 @@ import ( "golang.org/x/crypto/ssh" ) +// DefaultRPCChannel is the Channel Name that will be used to carry the RPC traffic, can be changed in Serverr and Client +const DefaultRPCChannel = "RPCChannel" + +// RPCSubsystem is the subsystem that will be used to trigger RPC endpoint creation +const RPCSubsystem = "RPCSubsystem" + // Server represents an SSH Server that spins up RPC servers when requested. type Server struct { *rpc.Server - Config *ssh.ServerConfig - Subsystem string + Config *ssh.ServerConfig + ChannelName string } // NewServer returns a new Server to handle incoming SSH and RPC requests. @@ -26,7 +32,7 @@ func NewServer() *Server { return nil, fmt.Errorf("password rejected for %q", c.User()) }, } - return &Server{rpc.NewServer(), c, "sshrpc"} + return &Server{rpc.NewServer(), c, DefaultRPCChannel} } @@ -62,20 +68,21 @@ func (s *Server) StartServer(address string) { } } +// handleRequests handles global out-of-band SSH Requests func (s *Server) handleRequests(reqs <-chan *ssh.Request) { for req := range reqs { log.Printf("recieved out-of-band request: %+v", req) } } +// handleChannels handels SSH Channel requests and their local out-of-band SSH Requests func (s *Server) handleChannels(chans <-chan ssh.NewChannel) { // Service the incoming Channel channel. for newChannel := range chans { - // Channels have a type, depending on the application level - // protocol intended. In the case of a shell, the type is - // "session" and ServerShell may be used to present a simple - // terminal interface. - if t := newChannel.ChannelType(); t != "session" { + + log.Printf("Received channel: %v", newChannel.ChannelType()) + // Check the type of channel + if t := newChannel.ChannelType(); t != s.ChannelName { newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) continue } @@ -84,8 +91,9 @@ func (s *Server) handleChannels(chans <-chan ssh.NewChannel) { log.Printf("could not accept channel (%s)", err) continue } + log.Printf("Accepted channel") - // Sessions have out-of-band requests such as "shell", "pty-req" and "env" + // Channels can have out-of-band requests go func(in <-chan *ssh.Request) { for req := range in { ok := false @@ -95,7 +103,8 @@ func (s *Server) handleChannels(chans <-chan ssh.NewChannel) { ok = true log.Printf("subsystem '%s'", req.Payload) switch string(req.Payload[4:]) { - case s.Subsystem: + //RPCSubsystem Request made indicates client desires RPC Server access + case RPCSubsystem: go s.ServeConn(channel) log.Printf("Started SSH RPC") default: