//
// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

package main

import (
	"bufio"
	"bytes"
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"encoding/pem"
	"fmt"
	"io"
	"net"
	"strconv"
	"strings"

	"golang.org/x/crypto/ssh"
)

func sshserver() {
	noPass := fmt.Errorf("no password auth")
	noKey := fmt.Errorf("unauthorized key")
	config := &ssh.ServerConfig{
		PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
			if c.User() == "anon" && string(pass) == "" {
				return nil, nil
			}
			return nil, noPass
		},

		PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
			dlog.Printf("checking pubkey")
			u := getuser(c.User())
			if u != nil && u.sshKey != nil {
				k1 := pubKey.Marshal()
				k2 := u.sshKey.Marshal()
				if bytes.Equal(k1, k2) {
					ilog.Printf("login for %s", c.User())
					return &ssh.Permissions{Extensions: map[string]string{
						"user-name": u.Name,
					}}, nil
				}
			}
			return nil, noKey
		},
		//NoClientAuth: true,
	}

	var privateKeyStr []byte
	getconfig("sshprivatekey", &privateKeyStr)
	servKey, err := ssh.ParsePrivateKey(privateKeyStr)
	if err != nil {
		elog.Fatalf("Failed to parse private key: %s", err)
	}

	config.AddHostKey(servKey)

	var addr string
	getconfig("sshlistenaddr", &addr)
	ilog.Printf("starting ssh server on %s", addr)
	listener, err := net.Listen("tcp", addr)
	if err != nil {
		elog.Fatalf("failed to listen for connection: %s", err)
	}
	for {
		netConn, err := listener.Accept()
		if err != nil {
			elog.Printf("failed to accept incoming connection: %s", err)
			break
		}
		go servessh(netConn, config)
	}
}

func servessh(netConn net.Conn, config *ssh.ServerConfig) {
	defer netConn.Close()

	conn, chans, reqs, err := ssh.NewServerConn(netConn, config)
	if err != nil {
		elog.Printf("failed to handshake %s: %s", netConn.RemoteAddr(), err)
		return
	}
	username := conn.User()
	ilog.Printf("logged in as user %s", username)

	// The incoming Request channel must be serviced.
	go ssh.DiscardRequests(reqs)

	// Service the incoming Channel channel.
	for newChannel := range chans {
		// We only deal with session types
		if newChannel.ChannelType() != "session" {
			elog.Printf("not dealing with %s", newChannel.ChannelType())
			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
			continue
		}
		channel, requests, err := newChannel.Accept()
		if err != nil {
			elog.Printf("Could not accept channel: %v", err)
			return
		}

		go servesshsession(username, channel, requests)
	}
}

func servesshsession(username string, channel ssh.Channel, requests <-chan *ssh.Request) {
	// We're looking for an exec command
	var cmdline string
	for req := range requests {
		req.Reply(req.Type == "exec", nil)
		if req.Type == "exec" {
			cmdline = string(req.Payload[4:])
			break
		}
	}

	defer channel.Close()

	args := strings.Split(cmdline, " ")
	var reponame string
	for i, a := range args {
		switch i {
		case 0:
			if a != "hg" {
				ilog.Printf("command is not hg")
				return
			}
		case 1:
			if a != "-R" {
				ilog.Printf("argument is not -R")
				return
			}
		case 2:
			reponame = a
		case 3:
			if a != "serve" {
				ilog.Printf("argument is not serve")
				return
			}
		case 4:
			if a != "--stdio" {
				ilog.Printf("argument is not --stdio")
				return
			}
		case 5:
			ilog.Printf("too many arguments: %s", cmdline)
			return
		}
	}
	ilog.Printf("connecting to %s", reponame)

	hgconn := getwireconn(reponame)
	if hgconn == nil {
		channel.Stderr().Write([]byte("not found\n"))
		return
	}
	defer putbackconn(hgconn)

	wpipe := bufio.NewWriter(hgconn.Wpipe)
	rpipe := bufio.NewReader(hgconn.Rpipe)

	rnet := bufio.NewReader(channel)
	wnet := bufio.NewWriter(channel)

	var copyargs func(int)
	copyargs = func(nargs int) {
		for i := 0; i < nargs; i++ {
			l, _ := rnet.ReadString('\n')
			args := strings.Split(l[:len(l)-1], " ")
			n, _ := strconv.Atoi(args[1])
			io.WriteString(wpipe, l)
			if args[0] == "*" {
				copyargs(n)
			} else {
				io.CopyN(wpipe, rnet, int64(n))
			}
		}
	}
	readonly := true
	if canwrite(username, reponame) {
		readonly = false
	}

	for {
		cmd, err := rnet.ReadString('\n')
		if err != nil {
			return
		}
		cmd = cmd[:len(cmd)-1]
		if readonly && cmd == "unbundle" {
			channel.Stderr().Write([]byte("no writing\n"))
			return
		}
		io.WriteString(wpipe, cmd)
		io.WriteString(wpipe, "\n")
		if cmd == "getbundle" || cmd == "unbundle" {
			wpipe.Flush()
			done := make(chan bool)
			go func() {
				io.Copy(channel, rpipe)
				done <- true
			}()
			io.Copy(hgconn.Wpipe, rnet)
			hgconn.Wpipe.Close()
			hgconn.Wpipe = nil
			<-done
			if cmd == "unbundle" {
				ilog.Printf("push to %s", reponame)
				recache(reponame)
			}
			return

		}
		nargs := 1
		switch cmd {
		case "hello":
			nargs = 0
		case "branchmap":
			nargs = 0
		case "heads":
			nargs = 0
		case "batch":
			nargs = 2
		case "between":
		case "protocaps":
		case "listkeys":
		case "known":
		case "lookup":
		default:
			elog.Printf("unknown cmd")
			return
		}
		copyargs(nargs)
		wpipe.Flush()
		datalen, _ := readamt(rpipe)
		fmt.Fprintf(wnet, "%d\n", datalen)
		io.CopyN(wnet, rpipe, int64(datalen))
		wnet.Flush()
	}
}

func encodekey(i interface{}) (string, error) {
	var b pem.Block
	var err error
	switch k := i.(type) {
	case *rsa.PrivateKey:
		b.Type = "RSA PRIVATE KEY"
		b.Bytes = x509.MarshalPKCS1PrivateKey(k)
	case *rsa.PublicKey:
		b.Type = "PUBLIC KEY"
		b.Bytes, err = x509.MarshalPKIXPublicKey(k)
	default:
		err = fmt.Errorf("unknown key type: %s", k)
	}
	if err != nil {
		return "", err
	}
	return string(pem.EncodeToMemory(&b)), nil
}

func sshgenhostkey() {
	k, err := rsa.GenerateKey(rand.Reader, 3072)
	if err != nil {
		panic(err)
	}
	pubkey, err := encodekey(&k.PublicKey)
	if err != nil {
		panic(err)
	}
	seckey, err := encodekey(k)
	if err != nil {
		panic(err)
	}
	setconfig("sshprivatekey", seckey)
	setconfig("sshpublickey", pubkey)
}
