updated to use easyssh library
This commit is contained in:
parent
b1cb51d887
commit
b4b856a17a
10
client.go
10
client.go
@ -4,6 +4,8 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/rpc"
|
"net/rpc"
|
||||||
|
|
||||||
|
"dev.justinjudd.org/justin/easyssh"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -12,7 +14,7 @@ type Client struct {
|
|||||||
*rpc.Client
|
*rpc.Client
|
||||||
Config *ssh.ClientConfig
|
Config *ssh.ClientConfig
|
||||||
ChannelName string
|
ChannelName string
|
||||||
sshClient *ssh.Client
|
sshClient *easyssh.Client
|
||||||
RPCServer *rpc.Server
|
RPCServer *rpc.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,7 +35,7 @@ func NewClient() *Client {
|
|||||||
// Connect starts a client connection to the given SSH/RPC server.
|
// Connect starts a client connection to the given SSH/RPC server.
|
||||||
func (c *Client) Connect(address string) {
|
func (c *Client) Connect(address string) {
|
||||||
|
|
||||||
sshClient, err := ssh.Dial("tcp", address, c.Config)
|
sshClient, err := easyssh.Dial("tcp", address, c.Config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("Failed to dial: " + err.Error())
|
panic("Failed to dial: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -66,7 +68,7 @@ func openRPCClientChannel(conn ssh.Conn, channelName string) (ssh.Channel, error
|
|||||||
}
|
}
|
||||||
msg.Subsystem = RPCSubsystem
|
msg.Subsystem = RPCSubsystem
|
||||||
|
|
||||||
ok, err := channel.SendRequest("subsystem", true, ssh.Marshal(&msg))
|
ok, err := channel.SendRequest(easyssh.SubsystemRequest, true, ssh.Marshal(&msg))
|
||||||
if err == nil && !ok {
|
if err == nil && !ok {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -107,7 +109,7 @@ func acceptRPCServerRequest(rpcServer *rpc.Server, newChannel ssh.NewChannel) {
|
|||||||
ok := false
|
ok := false
|
||||||
switch req.Type {
|
switch req.Type {
|
||||||
|
|
||||||
case "subsystem":
|
case easyssh.SubsystemRequest:
|
||||||
ok = true
|
ok = true
|
||||||
log.Printf("subsystem '%s'", req.Payload)
|
log.Printf("subsystem '%s'", req.Payload)
|
||||||
switch string(req.Payload[4:]) {
|
switch string(req.Payload[4:]) {
|
||||||
|
1
doc.go
1
doc.go
@ -1,4 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Package sshrpc provides RPC access over SSH. It uses the built-in RPC(net/rpc) library, so no RPC methods need to be rewritten.
|
Package sshrpc provides RPC access over SSH. It uses the built-in RPC(net/rpc) library, so no RPC methods need to be rewritten.
|
||||||
|
Can be used as a replacement transport for RPC.
|
||||||
*/
|
*/
|
||||||
package sshrpc
|
package sshrpc
|
||||||
|
34
example_server_test.go
Normal file
34
example_server_test.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package sshrpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ExampleServer struct{}
|
||||||
|
|
||||||
|
func (s *ExampleServer) Hello(name *string, out *string) error {
|
||||||
|
*out = fmt.Sprintf("Hello %s", *name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleServer_StartServer() {
|
||||||
|
|
||||||
|
s := NewServer()
|
||||||
|
privateBytes, err := ioutil.ReadFile("id_rsa")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Failed to load private key (./id_rsa)")
|
||||||
|
}
|
||||||
|
|
||||||
|
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Failed to parse private key")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Config.AddHostKey(private)
|
||||||
|
s.Register(new(ExampleServer))
|
||||||
|
s.StartServer("localhost:2022")
|
||||||
|
}
|
@ -1,37 +1,6 @@
|
|||||||
package sshrpc
|
package sshrpc
|
||||||
|
|
||||||
import (
|
import "fmt"
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ExampleServer struct{}
|
|
||||||
|
|
||||||
func (s *ExampleServer) Hello(name *string, out *string) error {
|
|
||||||
*out = fmt.Sprintf("Hello %s", *name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExampleServer_StartServer() {
|
|
||||||
|
|
||||||
s := NewServer()
|
|
||||||
privateBytes, err := ioutil.ReadFile("id_rsa")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Failed to load private key (./id_rsa)")
|
|
||||||
}
|
|
||||||
|
|
||||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Failed to parse private key")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.Config.AddHostKey(private)
|
|
||||||
s.Register(new(ExampleServer))
|
|
||||||
s.StartServer("localhost:2022")
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExampleClient_Connect() {
|
func ExampleClient_Connect() {
|
||||||
|
|
||||||
|
128
server.go
128
server.go
@ -3,9 +3,10 @@ package sshrpc
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"net/rpc"
|
"net/rpc"
|
||||||
|
|
||||||
|
"dev.justinjudd.org/justin/easyssh"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -44,97 +45,50 @@ func NewServer() *Server {
|
|||||||
// StartServer starts the server listening for requests
|
// StartServer starts the server listening for requests
|
||||||
func (s *Server) StartServer(address string) {
|
func (s *Server) StartServer(address string) {
|
||||||
|
|
||||||
// Once a ServerConfig has been configured, connections can be accepted.
|
server := easyssh.Server{Addr: address, Config: s.Config}
|
||||||
listener, err := net.Listen("tcp", address)
|
handler := easyssh.NewSSHConnHandler()
|
||||||
if err != nil {
|
channelsMux := easyssh.NewChannelsMux()
|
||||||
log.Fatal("failed to listen on ", address)
|
channelsMux.HandleChannel(s.ChannelName, s)
|
||||||
}
|
handler.MultipleChannelsHandler = channelsMux
|
||||||
|
server.Handler = handler
|
||||||
|
server.ListenAndServe()
|
||||||
|
|
||||||
// Accept all connections
|
|
||||||
log.Print("listening on ", address)
|
|
||||||
for {
|
|
||||||
tcpConn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to accept incoming connection (%s)", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Before use, a handshake must be performed on the incoming net.Conn.
|
|
||||||
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, s.Config)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to handshake (%s)", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("new ssh connection from %s (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion())
|
|
||||||
// Print incoming out-of-band Requests
|
|
||||||
go s.handleRequests(reqs)
|
|
||||||
// Accept all channels
|
|
||||||
go s.handleChannels(chans, sshConn)
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleRequests handles global out-of-band SSH Requests
|
func (s *Server) HandleChannel(newChannel ssh.NewChannel, channel ssh.Channel, reqs <-chan *ssh.Request, sshConn ssh.Conn) {
|
||||||
func (s *Server) handleRequests(reqs <-chan *ssh.Request) {
|
go func(in <-chan *ssh.Request) {
|
||||||
for req := range reqs {
|
for req := range in {
|
||||||
log.Printf("recieved out-of-band request: %+v", req)
|
ok := false
|
||||||
}
|
switch req.Type {
|
||||||
}
|
|
||||||
|
|
||||||
// handleChannels handels SSH Channel requests and their local out-of-band SSH Requests
|
case easyssh.SubsystemRequest:
|
||||||
func (s *Server) handleChannels(chans <-chan ssh.NewChannel, sshConn ssh.Conn) {
|
ok = true
|
||||||
// Service the incoming Channel channel.
|
log.Printf("subsystem '%s'", req.Payload[4:])
|
||||||
for newChannel := range chans {
|
switch string(req.Payload[4:]) {
|
||||||
|
//RPCSubsystem Request made indicates client desires RPC Server access
|
||||||
log.Printf("Received channel: %v", newChannel.ChannelType())
|
case RPCSubsystem:
|
||||||
// Check the type of channel
|
go s.ServeConn(channel)
|
||||||
if t := newChannel.ChannelType(); t != s.ChannelName {
|
log.Printf("Started SSH RPC")
|
||||||
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
|
// triggers reverse RPC connection as well
|
||||||
continue
|
clientChannel, err := openRPCClientChannel(sshConn, s.ChannelName+"-reverse")
|
||||||
}
|
if err != nil {
|
||||||
channel, requests, err := newChannel.Accept()
|
log.Printf("Failed to create client channel: " + err.Error())
|
||||||
if err != nil {
|
continue
|
||||||
log.Printf("could not accept channel (%s)", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
log.Printf("Accepted channel")
|
|
||||||
|
|
||||||
// Channels can have out-of-band requests
|
|
||||||
go func(in <-chan *ssh.Request) {
|
|
||||||
for req := range in {
|
|
||||||
ok := false
|
|
||||||
switch req.Type {
|
|
||||||
|
|
||||||
case "subsystem":
|
|
||||||
ok = true
|
|
||||||
log.Printf("subsystem '%s'", req.Payload[4:])
|
|
||||||
switch string(req.Payload[4:]) {
|
|
||||||
//RPCSubsystem Request made indicates client desires RPC Server access
|
|
||||||
case RPCSubsystem:
|
|
||||||
go s.ServeConn(channel)
|
|
||||||
log.Printf("Started SSH RPC")
|
|
||||||
// triggers reverse RPC connection as well
|
|
||||||
clientChannel, err := openRPCClientChannel(sshConn, s.ChannelName+"-reverse")
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to create client channel: " + err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
rpcClient := rpc.NewClient(clientChannel)
|
|
||||||
if s.CallbackFunc != nil {
|
|
||||||
s.CallbackFunc(rpcClient, sshConn)
|
|
||||||
}
|
|
||||||
log.Printf("Started SSH RPC client")
|
|
||||||
default:
|
|
||||||
log.Printf("Unknown subsystem: %s", req.Payload)
|
|
||||||
}
|
}
|
||||||
|
rpcClient := rpc.NewClient(clientChannel)
|
||||||
|
if s.CallbackFunc != nil {
|
||||||
|
s.CallbackFunc(rpcClient, sshConn)
|
||||||
|
}
|
||||||
|
log.Printf("Started SSH RPC client")
|
||||||
|
default:
|
||||||
|
log.Printf("Unknown subsystem: %s", req.Payload)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
log.Printf("declining %s request...", req.Type)
|
|
||||||
}
|
|
||||||
req.Reply(ok, nil)
|
|
||||||
}
|
}
|
||||||
}(requests)
|
if !ok {
|
||||||
}
|
log.Printf("declining %s request...", req.Type)
|
||||||
|
}
|
||||||
|
req.Reply(ok, nil)
|
||||||
|
}
|
||||||
|
}(reqs)
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ func (s *AdvancedServer) SortInts(req *AdvancedType, out *AdvancedType) error {
|
|||||||
|
|
||||||
func TestAdvancedServer(t *testing.T) {
|
func TestAdvancedServer(t *testing.T) {
|
||||||
s := NewServer()
|
s := NewServer()
|
||||||
s.Subsystem = "Advanced"
|
s.ChannelName = "Advanced"
|
||||||
|
|
||||||
private, err := ssh.ParsePrivateKey(testdata.ServerRSAKey)
|
private, err := ssh.ParsePrivateKey(testdata.ServerRSAKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -84,7 +84,7 @@ func TestAdvancedServer(t *testing.T) {
|
|||||||
|
|
||||||
t.Log("preparing client")
|
t.Log("preparing client")
|
||||||
client := NewClient()
|
client := NewClient()
|
||||||
client.Subsystem = "Advanced"
|
client.ChannelName = "Advanced"
|
||||||
client.Connect("localhost:3022")
|
client.Connect("localhost:3022")
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
var reply AdvancedType
|
var reply AdvancedType
|
||||||
|
Loading…
Reference in New Issue
Block a user