//
// Copyright (c) 2018 Ted Unangst <tedu@tedunangst.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

package main

import (
	"bytes"
	"compress/zlib"
	"encoding/binary"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"strconv"
	"strings"

	"github.com/gorilla/mux"
	"humungus.tedunangst.com/r/gerc"
)

// functions to talk to an hg wire server
// mostly to support pull
func readamt(rpipe io.Reader) (int, error) {
	var onebyte [1]byte
	lenbuf := make([]byte, 0, 10)
	for {
		amt, err := rpipe.Read(onebyte[:])
		if err != nil {
			return 0, err
		}
		if amt == 0 || onebyte[0] == '\n' {
			break
		}
		lenbuf = append(lenbuf, onebyte[0])
	}
	return strconv.Atoi(string(lenbuf))
}

var badbundle = fmt.Errorf("bad bundle")

func copybundle(w io.Writer, r io.Reader) error {
	prevzero := false
	for {
		var chunklen int32
		err := binary.Read(r, binary.BigEndian, &chunklen)
		if err != nil {
			elog.Println("error reading bundle len", err)
			return badbundle
		}
		binary.Write(w, binary.BigEndian, chunklen)
		if chunklen == 0 {
			// it seems bundle ends with two zero chunks
			if prevzero {
				break
			}
			prevzero = true
			continue
		}
		prevzero = false
		chunklen -= 4
		if chunklen > 0 {
			clen := int64(chunklen)
			_, err := io.CopyN(w, r, clen)
			if err != nil {
				return badbundle
			}
		}
	}
	return nil
}

func pullhandler(w http.ResponseWriter, r *http.Request, reponame string, args url.Values) {
	ilog.Printf("pulling from %s", reponame)
	remote := r.Header.Get("X-Forwarded-For")
	if toobusy(reponame, remote) {
		http.Error(w, "chill the fuck out", http.StatusTooManyRequests)
		return
	}

	conn := getwireconn(reponame)
	if conn == nil {
		http.NotFound(w, r)
		return
	}
	defer putbackconn(conn)

	if len(args) != 2 {
		elog.Printf("cmd with missing arg")
		http.NotFound(w, r)
		return
	}

	wpipe := conn.Wpipe
	io.WriteString(wpipe, "getbundle\n")
	fmt.Fprintf(wpipe, "* %d\n", len(args))
	for k, v0 := range args {
		v := v0[0]
		fmt.Fprintf(wpipe, "%s %d\n%s", k, len(v), v)
	}
	rpipe := conn.Rpipe

	var bigbuf bytes.Buffer
	zwriter := zlib.NewWriter(&bigbuf)
	defer zwriter.Close()
	err := copybundle(zwriter, rpipe)
	if err != nil {
		elog.Printf("problem copying chunk: %s", err)
		// it's broken, we've lost sync
		conn.Wpipe.Close()
		conn.Wpipe = nil
		return
	}
	zwriter.Flush()

	pulldata := bigbuf.Bytes()

	w.Header().Set("Content-Type", "application/mercurial-0.1")
	w.Write(pulldata)
}

func hgpassthru(w http.ResponseWriter, r *http.Request) {
	reponame := mux.Vars(r)["reponame"]

	r.ParseForm()
	cmd := r.FormValue("cmd")
	xarg := r.Header.Get("X-Hgarg-1") + r.Header.Get("X-Hgarg-2") + r.Header.Get("X-Hgarg-3")
	args, err := url.ParseQuery(xarg)
	if err != nil {
		elog.Printf("failure to parse x-arg: %s: %s", xarg, err)
		http.NotFound(w, r)
		return
	}

	repo := getgerc(reponame)
	if repo == nil {
		http.NotFound(w, r)
		return
	}
	defer putgerc(repo)

	ilog.Printf("request for %s: '%s' '%s' %d", reponame, cmd, xarg, len(args))

	switch cmd {
	case "getbundle":
		pullhandler(w, r, reponame, args)
		return
	case "capabilities":
		data := []byte("batch branchmap changegroupsubset getbundle known lookup protocaps pushkey streamreqs=generaldelta,revlogv1,sparserevlog unbundle=HG10GZ,HG10BZ,HG10UN unbundlehash httpheader=1024")
		datalen := len(data)
		w.Header().Set("Content-Type", "application/mercurial-0.1")
		w.Header().Set("Content-Length", strconv.Itoa(datalen))
		w.Write(data)
		return
	case "batch":
		if len(args) != 1 {
			elog.Printf("bad batch cmd")
			http.NotFound(w, r)
			return
		}
		cmds := args.Get("cmds")
		changes, err := repo.GetChanges(gerc.ChangesArgs{Revisions: "tip"})
		if err == nil {
			head := changes[0]
			headnode := fmt.Sprintf("%x", head.NodeID[:20])
			expect1 := fmt.Sprintf("heads ;known nodes=%s", headnode)
			expect2 := fmt.Sprintf("heads ;known nodes=")
			if cmds == expect1 || cmds == expect2 {
				data := []byte(headnode + "\n;")
				if cmds == expect1 {
					data = append(data, byte('1'))
				}
				datalen := len(data)
				w.Header().Set("Content-Type", "application/mercurial-0.1")
				w.Header().Set("Content-Length", strconv.Itoa(datalen))
				w.Write(data)
				return
			}
		}

		subcmds := strings.Split(cmds, ";")
		for _, c := range subcmds {
			if c == "heads " {
			} else if strings.HasPrefix(c, "known nodes=") {
			} else {
				elog.Printf("bad batch cmd %s", c)
				http.NotFound(w, r)
				return
			}
		}
	case "listkeys":
		if len(args) != 1 {
			elog.Printf("cmd with missing arg")
			http.NotFound(w, r)
			return
		}
		if args.Get("namespace") == "phases" {
			data := []byte("publishing\tTrue")
			datalen := len(data)
			w.Header().Set("Content-Type", "application/mercurial-0.1")
			w.Header().Set("Content-Length", strconv.Itoa(datalen))
			w.Write(data)
			return
		}
		if args.Get("namespace") == "bookmarks" {
			data := []byte("")
			datalen := len(data)
			w.Header().Set("Content-Type", "application/mercurial-0.1")
			w.Header().Set("Content-Length", strconv.Itoa(datalen))
			w.Write(data)
			return
		}
	case "hello", "branchmap", "heads":
		if len(args) != 0 {
			elog.Printf("bad zero arg cmd")
			http.NotFound(w, r)
			return
		}
	case "known", "lookup":
		if len(args) != 1 {
			elog.Printf("cmd with missing arg")
			http.NotFound(w, r)
			return
		}
	case "unbundle":
		ilog.Printf("push attempt")
		http.Error(w, "push denied", 403)
		return
	default:
		ilog.Printf("did not know that command")
		http.NotFound(w, r)
		return
	}

	conn := getwireconn(reponame)
	if conn == nil {
		http.NotFound(w, r)
		return
	}
	defer putbackconn(conn)

	wpipe := conn.Wpipe
	io.WriteString(wpipe, cmd)
	io.WriteString(wpipe, "\n")
	if len(xarg) > 0 {
		switch cmd {
		case "batch":
			io.WriteString(wpipe, "* 0\n")
		case "known":
			io.WriteString(wpipe, "* 0\n")
		}
		for k, v0 := range args {
			v := v0[0]
			fmt.Fprintf(wpipe, "%s %d\n%s", k, len(v), v)
		}
	}
	rpipe := conn.Rpipe
	datalen, _ := readamt(rpipe)
	data := make([]byte, datalen)
	io.ReadFull(rpipe, data)
	ilog.Printf("reply: %q", data)
	w.Header().Set("Content-Type", "application/mercurial-0.1")
	w.Header().Set("Content-Length", strconv.Itoa(datalen))
	w.Write(data)
}
