updated to use easyssh library
This commit is contained in:
parent
b1cb51d887
commit
b4b856a17a
10
client.go
10
client.go
@ -4,6 +4,8 @@ import (
|
||||
"log"
|
||||
"net/rpc"
|
||||
|
||||
"dev.justinjudd.org/justin/easyssh"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
@ -12,7 +14,7 @@ type Client struct {
|
||||
*rpc.Client
|
||||
Config *ssh.ClientConfig
|
||||
ChannelName string
|
||||
sshClient *ssh.Client
|
||||
sshClient *easyssh.Client
|
||||
RPCServer *rpc.Server
|
||||
}
|
||||
|
||||
@ -33,7 +35,7 @@ func NewClient() *Client {
|
||||
// Connect starts a client connection to the given SSH/RPC server.
|
||||
func (c *Client) Connect(address string) {
|
||||
|
||||
sshClient, err := ssh.Dial("tcp", address, c.Config)
|
||||
sshClient, err := easyssh.Dial("tcp", address, c.Config)
|
||||
if err != nil {
|
||||
panic("Failed to dial: " + err.Error())
|
||||
}
|
||||
@ -66,7 +68,7 @@ func openRPCClientChannel(conn ssh.Conn, channelName string) (ssh.Channel, error
|
||||
}
|
||||
msg.Subsystem = RPCSubsystem
|
||||
|
||||
ok, err := channel.SendRequest("subsystem", true, ssh.Marshal(&msg))
|
||||
ok, err := channel.SendRequest(easyssh.SubsystemRequest, true, ssh.Marshal(&msg))
|
||||
if err == nil && !ok {
|
||||
return nil, err
|
||||
}
|
||||
@ -107,7 +109,7 @@ func acceptRPCServerRequest(rpcServer *rpc.Server, newChannel ssh.NewChannel) {
|
||||
ok := false
|
||||
switch req.Type {
|
||||
|
||||
case "subsystem":
|
||||
case easyssh.SubsystemRequest:
|
||||
ok = true
|
||||
log.Printf("subsystem '%s'", req.Payload)
|
||||
switch string(req.Payload[4:]) {
|
||||
|
1
doc.go
1
doc.go
@ -1,4 +1,5 @@
|
||||
/*
|
||||
Package sshrpc provides RPC access over SSH. It uses the built-in RPC(net/rpc) library, so no RPC methods need to be rewritten.
|
||||
Can be used as a replacement transport for RPC.
|
||||
*/
|
||||
package sshrpc
|
||||
|
34
example_server_test.go
Normal file
34
example_server_test.go
Normal file
@ -0,0 +1,34 @@
|
||||
package sshrpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type ExampleServer struct{}
|
||||
|
||||
func (s *ExampleServer) Hello(name *string, out *string) error {
|
||||
*out = fmt.Sprintf("Hello %s", *name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExampleServer_StartServer() {
|
||||
|
||||
s := NewServer()
|
||||
privateBytes, err := ioutil.ReadFile("id_rsa")
|
||||
if err != nil {
|
||||
log.Fatal("Failed to load private key (./id_rsa)")
|
||||
}
|
||||
|
||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to parse private key")
|
||||
}
|
||||
|
||||
s.Config.AddHostKey(private)
|
||||
s.Register(new(ExampleServer))
|
||||
s.StartServer("localhost:2022")
|
||||
}
|
@ -1,37 +1,6 @@
|
||||
package sshrpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type ExampleServer struct{}
|
||||
|
||||
func (s *ExampleServer) Hello(name *string, out *string) error {
|
||||
*out = fmt.Sprintf("Hello %s", *name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExampleServer_StartServer() {
|
||||
|
||||
s := NewServer()
|
||||
privateBytes, err := ioutil.ReadFile("id_rsa")
|
||||
if err != nil {
|
||||
log.Fatal("Failed to load private key (./id_rsa)")
|
||||
}
|
||||
|
||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to parse private key")
|
||||
}
|
||||
|
||||
s.Config.AddHostKey(private)
|
||||
s.Register(new(ExampleServer))
|
||||
s.StartServer("localhost:2022")
|
||||
}
|
||||
import "fmt"
|
||||
|
||||
func ExampleClient_Connect() {
|
||||
|
||||
|
128
server.go
128
server.go
@ -3,9 +3,10 @@ package sshrpc
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/rpc"
|
||||
|
||||
"dev.justinjudd.org/justin/easyssh"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
@ -44,97 +45,50 @@ func NewServer() *Server {
|
||||
// StartServer starts the server listening for requests
|
||||
func (s *Server) StartServer(address string) {
|
||||
|
||||
// Once a ServerConfig has been configured, connections can be accepted.
|
||||
listener, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
log.Fatal("failed to listen on ", address)
|
||||
}
|
||||
server := easyssh.Server{Addr: address, Config: s.Config}
|
||||
handler := easyssh.NewSSHConnHandler()
|
||||
channelsMux := easyssh.NewChannelsMux()
|
||||
channelsMux.HandleChannel(s.ChannelName, s)
|
||||
handler.MultipleChannelsHandler = channelsMux
|
||||
server.Handler = handler
|
||||
server.ListenAndServe()
|
||||
|
||||
// Accept all connections
|
||||
log.Print("listening on ", address)
|
||||
for {
|
||||
tcpConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("failed to accept incoming connection (%s)", err)
|
||||
continue
|
||||
}
|
||||
// Before use, a handshake must be performed on the incoming net.Conn.
|
||||
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, s.Config)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("failed to handshake (%s)", err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("new ssh connection from %s (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion())
|
||||
// Print incoming out-of-band Requests
|
||||
go s.handleRequests(reqs)
|
||||
// Accept all channels
|
||||
go s.handleChannels(chans, sshConn)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequests handles global out-of-band SSH Requests
|
||||
func (s *Server) handleRequests(reqs <-chan *ssh.Request) {
|
||||
for req := range reqs {
|
||||
log.Printf("recieved out-of-band request: %+v", req)
|
||||
}
|
||||
}
|
||||
func (s *Server) HandleChannel(newChannel ssh.NewChannel, channel ssh.Channel, reqs <-chan *ssh.Request, sshConn ssh.Conn) {
|
||||
go func(in <-chan *ssh.Request) {
|
||||
for req := range in {
|
||||
ok := false
|
||||
switch req.Type {
|
||||
|
||||
// handleChannels handels SSH Channel requests and their local out-of-band SSH Requests
|
||||
func (s *Server) handleChannels(chans <-chan ssh.NewChannel, sshConn ssh.Conn) {
|
||||
// Service the incoming Channel channel.
|
||||
for newChannel := range chans {
|
||||
|
||||
log.Printf("Received channel: %v", newChannel.ChannelType())
|
||||
// Check the type of channel
|
||||
if t := newChannel.ChannelType(); t != s.ChannelName {
|
||||
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
|
||||
continue
|
||||
}
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
log.Printf("could not accept channel (%s)", err)
|
||||
continue
|
||||
}
|
||||
log.Printf("Accepted channel")
|
||||
|
||||
// Channels can have out-of-band requests
|
||||
go func(in <-chan *ssh.Request) {
|
||||
for req := range in {
|
||||
ok := false
|
||||
switch req.Type {
|
||||
|
||||
case "subsystem":
|
||||
ok = true
|
||||
log.Printf("subsystem '%s'", req.Payload[4:])
|
||||
switch string(req.Payload[4:]) {
|
||||
//RPCSubsystem Request made indicates client desires RPC Server access
|
||||
case RPCSubsystem:
|
||||
go s.ServeConn(channel)
|
||||
log.Printf("Started SSH RPC")
|
||||
// triggers reverse RPC connection as well
|
||||
clientChannel, err := openRPCClientChannel(sshConn, s.ChannelName+"-reverse")
|
||||
if err != nil {
|
||||
log.Printf("Failed to create client channel: " + err.Error())
|
||||
continue
|
||||
}
|
||||
rpcClient := rpc.NewClient(clientChannel)
|
||||
if s.CallbackFunc != nil {
|
||||
s.CallbackFunc(rpcClient, sshConn)
|
||||
}
|
||||
log.Printf("Started SSH RPC client")
|
||||
default:
|
||||
log.Printf("Unknown subsystem: %s", req.Payload)
|
||||
case easyssh.SubsystemRequest:
|
||||
ok = true
|
||||
log.Printf("subsystem '%s'", req.Payload[4:])
|
||||
switch string(req.Payload[4:]) {
|
||||
//RPCSubsystem Request made indicates client desires RPC Server access
|
||||
case RPCSubsystem:
|
||||
go s.ServeConn(channel)
|
||||
log.Printf("Started SSH RPC")
|
||||
// triggers reverse RPC connection as well
|
||||
clientChannel, err := openRPCClientChannel(sshConn, s.ChannelName+"-reverse")
|
||||
if err != nil {
|
||||
log.Printf("Failed to create client channel: " + err.Error())
|
||||
continue
|
||||
}
|
||||
rpcClient := rpc.NewClient(clientChannel)
|
||||
if s.CallbackFunc != nil {
|
||||
s.CallbackFunc(rpcClient, sshConn)
|
||||
}
|
||||
log.Printf("Started SSH RPC client")
|
||||
default:
|
||||
log.Printf("Unknown subsystem: %s", req.Payload)
|
||||
}
|
||||
|
||||
}
|
||||
if !ok {
|
||||
log.Printf("declining %s request...", req.Type)
|
||||
}
|
||||
req.Reply(ok, nil)
|
||||
}
|
||||
}(requests)
|
||||
}
|
||||
if !ok {
|
||||
log.Printf("declining %s request...", req.Type)
|
||||
}
|
||||
req.Reply(ok, nil)
|
||||
}
|
||||
}(reqs)
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ func (s *AdvancedServer) SortInts(req *AdvancedType, out *AdvancedType) error {
|
||||
|
||||
func TestAdvancedServer(t *testing.T) {
|
||||
s := NewServer()
|
||||
s.Subsystem = "Advanced"
|
||||
s.ChannelName = "Advanced"
|
||||
|
||||
private, err := ssh.ParsePrivateKey(testdata.ServerRSAKey)
|
||||
if err != nil {
|
||||
@ -84,7 +84,7 @@ func TestAdvancedServer(t *testing.T) {
|
||||
|
||||
t.Log("preparing client")
|
||||
client := NewClient()
|
||||
client.Subsystem = "Advanced"
|
||||
client.ChannelName = "Advanced"
|
||||
client.Connect("localhost:3022")
|
||||
defer client.Close()
|
||||
var reply AdvancedType
|
||||
|
Loading…
Reference in New Issue
Block a user