added support for 2-way RPC sessions. closes #1
This commit is contained in:
parent
24010b4a95
commit
5dfe2ab054
88
client.go
88
client.go
@ -1,6 +1,7 @@
|
||||
package sshrpc
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/rpc"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
@ -12,6 +13,7 @@ type Client struct {
|
||||
Config *ssh.ClientConfig
|
||||
ChannelName string
|
||||
sshClient *ssh.Client
|
||||
RPCServer *rpc.Server
|
||||
}
|
||||
|
||||
// NewClient returns a new Client to handle RPC requests.
|
||||
@ -24,7 +26,7 @@ func NewClient() *Client {
|
||||
},
|
||||
}
|
||||
|
||||
return &Client{nil, config, DefaultRPCChannel, nil}
|
||||
return &Client{nil, config, DefaultRPCChannel, nil, rpc.NewServer()}
|
||||
|
||||
}
|
||||
|
||||
@ -37,37 +39,27 @@ func (c *Client) Connect(address string) {
|
||||
}
|
||||
c.sshClient = sshClient
|
||||
|
||||
c.openRPCServerChannel(c.ChannelName + "-reverse")
|
||||
|
||||
// Each ClientConn can support multiple channels
|
||||
channel, err := c.openRPCChannel(c.ChannelName)
|
||||
channel, err := openRPCClientChannel(c.sshClient.Conn, c.ChannelName)
|
||||
if err != nil {
|
||||
panic("Failed to create channel: " + err.Error())
|
||||
}
|
||||
|
||||
c.Client = rpc.NewClient(channel)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// openRPCChannel opens an SSH RPC channel and makes an ssh subsystem request to trigger remote RPC server start
|
||||
func (c *Client) openRPCChannel(channelName string) (ssh.Channel, error) {
|
||||
channel, in, err := c.sshClient.OpenChannel(channelName, nil)
|
||||
// openRPCClientChannel opens an SSH RPC channel and makes an ssh subsystem request to trigger remote RPC server start
|
||||
func openRPCClientChannel(conn ssh.Conn, channelName string) (ssh.Channel, error) {
|
||||
channel, in, err := conn.OpenChannel(channelName, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// SSH Documentation states that this go channel of requests needs to be serviced
|
||||
go func(reqs <-chan *ssh.Request) error {
|
||||
for msg := range reqs {
|
||||
switch msg.Type {
|
||||
|
||||
default:
|
||||
if msg.WantReply {
|
||||
msg.Reply(false, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}(in)
|
||||
go ssh.DiscardRequests(in)
|
||||
|
||||
var msg struct {
|
||||
Subsystem string
|
||||
@ -81,3 +73,63 @@ func (c *Client) openRPCChannel(channelName string) (ssh.Channel, error) {
|
||||
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func (c *Client) openRPCServerChannel(channelName string) error {
|
||||
subChannel := c.sshClient.HandleChannelOpen(channelName)
|
||||
go func(chans <-chan ssh.NewChannel) {
|
||||
|
||||
for newChannel := range chans {
|
||||
|
||||
/* Don't need to check, already know what channel type is coming in
|
||||
// Check the type of channel
|
||||
if t := newChannel.ChannelType(); t != channelName {
|
||||
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
|
||||
continue
|
||||
}
|
||||
*/
|
||||
acceptRPCServerRequest(c.RPCServer, newChannel)
|
||||
}
|
||||
}(subChannel)
|
||||
return nil
|
||||
}
|
||||
|
||||
func acceptRPCServerRequest(rpcServer *rpc.Server, newChannel ssh.NewChannel) {
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
log.Printf("could not accept channel (%s)", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Accepted channel")
|
||||
|
||||
// Channels can have out-of-band requests
|
||||
go func(in <-chan *ssh.Request) {
|
||||
for req := range in {
|
||||
ok := false
|
||||
switch req.Type {
|
||||
|
||||
case "subsystem":
|
||||
ok = true
|
||||
log.Printf("subsystem '%s'", req.Payload)
|
||||
switch string(req.Payload[4:]) {
|
||||
//RPCSubsystem Request made indicates client desires RPC Server access
|
||||
case RPCSubsystem:
|
||||
go rpcServer.ServeConn(channel)
|
||||
log.Printf("Started SSH RPC")
|
||||
default:
|
||||
log.Printf("Unknown subsystem: %s", req.Payload)
|
||||
}
|
||||
|
||||
}
|
||||
if !ok {
|
||||
log.Printf("declining %s request...", req.Type)
|
||||
}
|
||||
req.Reply(ok, nil)
|
||||
}
|
||||
|
||||
}(requests)
|
||||
}
|
||||
|
||||
// Wait allows clients also acting as an RPC server to detect when the ssh connection ends
|
||||
func (c *Client) Wait() error {
|
||||
return c.sshClient.Conn.Wait()
|
||||
}
|
||||
|
14
server.go
14
server.go
@ -20,6 +20,8 @@ type Server struct {
|
||||
*rpc.Server
|
||||
Config *ssh.ServerConfig
|
||||
ChannelName string
|
||||
RPCClient *rpc.Client
|
||||
sshConn ssh.Conn
|
||||
}
|
||||
|
||||
// NewServer returns a new Server to handle incoming SSH and RPC requests.
|
||||
@ -32,7 +34,7 @@ func NewServer() *Server {
|
||||
return nil, fmt.Errorf("password rejected for %q", c.User())
|
||||
},
|
||||
}
|
||||
return &Server{rpc.NewServer(), c, DefaultRPCChannel}
|
||||
return &Server{rpc.NewServer(), c, DefaultRPCChannel, nil, nil}
|
||||
|
||||
}
|
||||
|
||||
@ -55,6 +57,7 @@ func (s *Server) StartServer(address string) {
|
||||
}
|
||||
// Before use, a handshake must be performed on the incoming net.Conn.
|
||||
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, s.Config)
|
||||
s.sshConn = sshConn
|
||||
if err != nil {
|
||||
log.Printf("failed to handshake (%s)", err)
|
||||
continue
|
||||
@ -65,6 +68,7 @@ func (s *Server) StartServer(address string) {
|
||||
go s.handleRequests(reqs)
|
||||
// Accept all channels
|
||||
go s.handleChannels(chans)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,6 +111,14 @@ func (s *Server) handleChannels(chans <-chan ssh.NewChannel) {
|
||||
case RPCSubsystem:
|
||||
go s.ServeConn(channel)
|
||||
log.Printf("Started SSH RPC")
|
||||
// triggers reverse RPC connection as well
|
||||
clientChannel, err := openRPCClientChannel(s.sshConn, s.ChannelName+"-reverse")
|
||||
if err != nil {
|
||||
log.Printf("Failed to create client channel: " + err.Error())
|
||||
continue
|
||||
}
|
||||
s.RPCClient = rpc.NewClient(clientChannel)
|
||||
log.Printf("Started SSH RPC client")
|
||||
default:
|
||||
log.Printf("Unknown subsystem: %s", req.Payload)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user