From b4b856a17ad868b9cf6d72dc9f2d7e32023234ad Mon Sep 17 00:00:00 2001 From: Justin Date: Sat, 30 May 2015 13:25:21 -0400 Subject: [PATCH] updated to use easyssh library --- client.go | 10 ++-- doc.go | 1 + example_server_test.go | 34 +++++++++++ example_test.go | 33 +---------- server.go | 128 +++++++++++++---------------------------- server_test.go | 4 +- 6 files changed, 85 insertions(+), 125 deletions(-) create mode 100644 example_server_test.go diff --git a/client.go b/client.go index a86889d..5981d8e 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,8 @@ import ( "log" "net/rpc" + "dev.justinjudd.org/justin/easyssh" + "golang.org/x/crypto/ssh" ) @@ -12,7 +14,7 @@ type Client struct { *rpc.Client Config *ssh.ClientConfig ChannelName string - sshClient *ssh.Client + sshClient *easyssh.Client RPCServer *rpc.Server } @@ -33,7 +35,7 @@ func NewClient() *Client { // Connect starts a client connection to the given SSH/RPC server. func (c *Client) Connect(address string) { - sshClient, err := ssh.Dial("tcp", address, c.Config) + sshClient, err := easyssh.Dial("tcp", address, c.Config) if err != nil { panic("Failed to dial: " + err.Error()) } @@ -66,7 +68,7 @@ func openRPCClientChannel(conn ssh.Conn, channelName string) (ssh.Channel, error } msg.Subsystem = RPCSubsystem - ok, err := channel.SendRequest("subsystem", true, ssh.Marshal(&msg)) + ok, err := channel.SendRequest(easyssh.SubsystemRequest, true, ssh.Marshal(&msg)) if err == nil && !ok { return nil, err } @@ -107,7 +109,7 @@ func acceptRPCServerRequest(rpcServer *rpc.Server, newChannel ssh.NewChannel) { ok := false switch req.Type { - case "subsystem": + case easyssh.SubsystemRequest: ok = true log.Printf("subsystem '%s'", req.Payload) switch string(req.Payload[4:]) { diff --git a/doc.go b/doc.go index b8e1fb0..7e8d734 100644 --- a/doc.go +++ b/doc.go @@ -1,4 +1,5 @@ /* Package sshrpc provides RPC access over SSH. It uses the built-in RPC(net/rpc) library, so no RPC methods need to be rewritten. +Can be used as a replacement transport for RPC. */ package sshrpc diff --git a/example_server_test.go b/example_server_test.go new file mode 100644 index 0000000..c643975 --- /dev/null +++ b/example_server_test.go @@ -0,0 +1,34 @@ +package sshrpc + +import ( + "fmt" + "io/ioutil" + "log" + + "golang.org/x/crypto/ssh" +) + +type ExampleServer struct{} + +func (s *ExampleServer) Hello(name *string, out *string) error { + *out = fmt.Sprintf("Hello %s", *name) + return nil +} + +func ExampleServer_StartServer() { + + s := NewServer() + privateBytes, err := ioutil.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key (./id_rsa)") + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key") + } + + s.Config.AddHostKey(private) + s.Register(new(ExampleServer)) + s.StartServer("localhost:2022") +} diff --git a/example_test.go b/example_test.go index a744180..dcdfc6b 100644 --- a/example_test.go +++ b/example_test.go @@ -1,37 +1,6 @@ package sshrpc -import ( - "fmt" - "io/ioutil" - "log" - - "golang.org/x/crypto/ssh" -) - -type ExampleServer struct{} - -func (s *ExampleServer) Hello(name *string, out *string) error { - *out = fmt.Sprintf("Hello %s", *name) - return nil -} - -func ExampleServer_StartServer() { - - s := NewServer() - privateBytes, err := ioutil.ReadFile("id_rsa") - if err != nil { - log.Fatal("Failed to load private key (./id_rsa)") - } - - private, err := ssh.ParsePrivateKey(privateBytes) - if err != nil { - log.Fatal("Failed to parse private key") - } - - s.Config.AddHostKey(private) - s.Register(new(ExampleServer)) - s.StartServer("localhost:2022") -} +import "fmt" func ExampleClient_Connect() { diff --git a/server.go b/server.go index 0da7a75..4441861 100644 --- a/server.go +++ b/server.go @@ -3,9 +3,10 @@ package sshrpc import ( "fmt" "log" - "net" "net/rpc" + "dev.justinjudd.org/justin/easyssh" + "golang.org/x/crypto/ssh" ) @@ -44,97 +45,50 @@ func NewServer() *Server { // StartServer starts the server listening for requests func (s *Server) StartServer(address string) { - // Once a ServerConfig has been configured, connections can be accepted. - listener, err := net.Listen("tcp", address) - if err != nil { - log.Fatal("failed to listen on ", address) - } + server := easyssh.Server{Addr: address, Config: s.Config} + handler := easyssh.NewSSHConnHandler() + channelsMux := easyssh.NewChannelsMux() + channelsMux.HandleChannel(s.ChannelName, s) + handler.MultipleChannelsHandler = channelsMux + server.Handler = handler + server.ListenAndServe() - // Accept all connections - log.Print("listening on ", address) - for { - tcpConn, err := listener.Accept() - if err != nil { - log.Printf("failed to accept incoming connection (%s)", err) - continue - } - // Before use, a handshake must be performed on the incoming net.Conn. - sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, s.Config) - - if err != nil { - log.Printf("failed to handshake (%s)", err) - continue - } - - log.Printf("new ssh connection from %s (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion()) - // Print incoming out-of-band Requests - go s.handleRequests(reqs) - // Accept all channels - go s.handleChannels(chans, sshConn) - - } } -// handleRequests handles global out-of-band SSH Requests -func (s *Server) handleRequests(reqs <-chan *ssh.Request) { - for req := range reqs { - log.Printf("recieved out-of-band request: %+v", req) - } -} +func (s *Server) HandleChannel(newChannel ssh.NewChannel, channel ssh.Channel, reqs <-chan *ssh.Request, sshConn ssh.Conn) { + go func(in <-chan *ssh.Request) { + for req := range in { + ok := false + switch req.Type { -// handleChannels handels SSH Channel requests and their local out-of-band SSH Requests -func (s *Server) handleChannels(chans <-chan ssh.NewChannel, sshConn ssh.Conn) { - // Service the incoming Channel channel. - for newChannel := range chans { - - log.Printf("Received channel: %v", newChannel.ChannelType()) - // Check the type of channel - if t := newChannel.ChannelType(); t != s.ChannelName { - newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) - continue - } - channel, requests, err := newChannel.Accept() - if err != nil { - log.Printf("could not accept channel (%s)", err) - continue - } - log.Printf("Accepted channel") - - // Channels can have out-of-band requests - go func(in <-chan *ssh.Request) { - for req := range in { - ok := false - switch req.Type { - - case "subsystem": - ok = true - log.Printf("subsystem '%s'", req.Payload[4:]) - switch string(req.Payload[4:]) { - //RPCSubsystem Request made indicates client desires RPC Server access - case RPCSubsystem: - go s.ServeConn(channel) - log.Printf("Started SSH RPC") - // triggers reverse RPC connection as well - clientChannel, err := openRPCClientChannel(sshConn, s.ChannelName+"-reverse") - if err != nil { - log.Printf("Failed to create client channel: " + err.Error()) - continue - } - rpcClient := rpc.NewClient(clientChannel) - if s.CallbackFunc != nil { - s.CallbackFunc(rpcClient, sshConn) - } - log.Printf("Started SSH RPC client") - default: - log.Printf("Unknown subsystem: %s", req.Payload) + case easyssh.SubsystemRequest: + ok = true + log.Printf("subsystem '%s'", req.Payload[4:]) + switch string(req.Payload[4:]) { + //RPCSubsystem Request made indicates client desires RPC Server access + case RPCSubsystem: + go s.ServeConn(channel) + log.Printf("Started SSH RPC") + // triggers reverse RPC connection as well + clientChannel, err := openRPCClientChannel(sshConn, s.ChannelName+"-reverse") + if err != nil { + log.Printf("Failed to create client channel: " + err.Error()) + continue } + rpcClient := rpc.NewClient(clientChannel) + if s.CallbackFunc != nil { + s.CallbackFunc(rpcClient, sshConn) + } + log.Printf("Started SSH RPC client") + default: + log.Printf("Unknown subsystem: %s", req.Payload) + } - } - if !ok { - log.Printf("declining %s request...", req.Type) - } - req.Reply(ok, nil) } - }(requests) - } + if !ok { + log.Printf("declining %s request...", req.Type) + } + req.Reply(ok, nil) + } + }(reqs) } diff --git a/server_test.go b/server_test.go index be78f11..1c3e5d1 100644 --- a/server_test.go +++ b/server_test.go @@ -66,7 +66,7 @@ func (s *AdvancedServer) SortInts(req *AdvancedType, out *AdvancedType) error { func TestAdvancedServer(t *testing.T) { s := NewServer() - s.Subsystem = "Advanced" + s.ChannelName = "Advanced" private, err := ssh.ParsePrivateKey(testdata.ServerRSAKey) if err != nil { @@ -84,7 +84,7 @@ func TestAdvancedServer(t *testing.T) { t.Log("preparing client") client := NewClient() - client.Subsystem = "Advanced" + client.ChannelName = "Advanced" client.Connect("localhost:3022") defer client.Close() var reply AdvancedType