diff --git a/server.go b/server.go index 299f1cd..4cf0e5c 100644 --- a/server.go +++ b/server.go @@ -15,13 +15,16 @@ const DefaultRPCChannel = "RPCChannel" // RPCSubsystem is the subsystem that will be used to trigger RPC endpoint creation const RPCSubsystem = "RPCSubsystem" +// CallbackFunc to be called when reverse RPC client is created +type CallbackFunc func(RPCClient *rpc.Client) + // Server represents an SSH Server that spins up RPC servers when requested. type Server struct { *rpc.Server Config *ssh.ServerConfig ChannelName string - RPCClient *rpc.Client - sshConn ssh.Conn + + CallbackFunc } // NewServer returns a new Server to handle incoming SSH and RPC requests. @@ -34,7 +37,7 @@ func NewServer() *Server { return nil, fmt.Errorf("password rejected for %q", c.User()) }, } - return &Server{rpc.NewServer(), c, DefaultRPCChannel, nil, nil} + return &Server{Server: rpc.NewServer(), Config: c, ChannelName: DefaultRPCChannel} } @@ -57,7 +60,7 @@ func (s *Server) StartServer(address string) { } // Before use, a handshake must be performed on the incoming net.Conn. sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, s.Config) - s.sshConn = sshConn + if err != nil { log.Printf("failed to handshake (%s)", err) continue @@ -67,7 +70,7 @@ func (s *Server) StartServer(address string) { // Print incoming out-of-band Requests go s.handleRequests(reqs) // Accept all channels - go s.handleChannels(chans) + go s.handleChannels(chans, sshConn) } } @@ -80,7 +83,7 @@ func (s *Server) handleRequests(reqs <-chan *ssh.Request) { } // handleChannels handels SSH Channel requests and their local out-of-band SSH Requests -func (s *Server) handleChannels(chans <-chan ssh.NewChannel) { +func (s *Server) handleChannels(chans <-chan ssh.NewChannel, sshConn ssh.Conn) { // Service the incoming Channel channel. for newChannel := range chans { @@ -105,19 +108,22 @@ func (s *Server) handleChannels(chans <-chan ssh.NewChannel) { case "subsystem": ok = true - log.Printf("subsystem '%s'", req.Payload) + log.Printf("subsystem '%s'", req.Payload[4:]) switch string(req.Payload[4:]) { //RPCSubsystem Request made indicates client desires RPC Server access case RPCSubsystem: go s.ServeConn(channel) log.Printf("Started SSH RPC") // triggers reverse RPC connection as well - clientChannel, err := openRPCClientChannel(s.sshConn, s.ChannelName+"-reverse") + clientChannel, err := openRPCClientChannel(sshConn, s.ChannelName+"-reverse") if err != nil { log.Printf("Failed to create client channel: " + err.Error()) continue } - s.RPCClient = rpc.NewClient(clientChannel) + rpcClient := rpc.NewClient(clientChannel) + if s.CallbackFunc != nil { + s.CallbackFunc(rpcClient) + } log.Printf("Started SSH RPC client") default: log.Printf("Unknown subsystem: %s", req.Payload)