//
// Copyright (c) 2023 Ted Unangst <tedu@tedunangst.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

package main

import (
	"bytes"
	"database/sql"
	"encoding/gob"
	notrand "math/rand"
	"strings"
	"sync"
	"time"

	"humungus.tedunangst.com/r/webs/gate"
)

type Delivery struct {
	ID    int64
	When  time.Time
	Tries int64
	Rcpt  string
	Msgs  []MsgPair
}

type MsgPair struct {
	From string
	Data []byte
}

const maxQueueLen = 40
const maxTries = 20

func unpackMsgs(data []byte) ([]MsgPair, error) {
	var msgs []MsgPair
	dec := gob.NewDecoder(bytes.NewReader(data))
	err := dec.Decode(&msgs)
	return msgs, err
}

func repackMsgs(msgs []MsgPair) ([]byte, error) {
	var buf bytes.Buffer
	enc := gob.NewEncoder(&buf)
	err := enc.Encode(msgs)
	return buf.Bytes(), err
}

func tryitagain(delivery Delivery) {
	if develMode {
		return
	}
	delivery.Tries += 1
	var drift time.Duration
	if delivery.Tries <= 3 { // 5, 10, 15 minutes
		drift = time.Duration(delivery.Tries*5) * time.Minute
	} else if delivery.Tries <= 6 { // 1, 2, 3 hours
		drift = time.Duration(delivery.Tries-3) * time.Hour
	} else if delivery.Tries <= 9 { // 12 hours
		drift = time.Duration(12) * time.Hour
	} else if delivery.Tries <= maxTries { // 7 days
		drift = time.Duration(12) * time.Hour
		if l := len(delivery.Msgs); l > maxQueueLen {
			delivery.Msgs = delivery.Msgs[l-maxQueueLen:]
		}
	} else {
		ilog.Printf("delivery has perma failed: %s", delivery.Rcpt)
		dbClearFollows(delivery.Rcpt)
		return
	}
	drift += time.Duration(notrand.Int63n(int64(drift / 10)))
	when := time.Now().Add(drift)
	data, err := repackMsgs(delivery.Msgs)
	if err != nil {
		elog.Printf("error encoding delivery: %s", err)
		return
	}
	_, err = stmtAddDelivery.Exec(when.UTC().Format(dbtimeformat), delivery.Tries, delivery.Rcpt, data)
	if err != nil {
		elog.Printf("error saving delivery: %s", err)
	}
	select {
	case pokechan <- 0:
	default:
	}
}

func abortdelivery(err error) int64 {
	str := err.Error()
	if strings.Contains(str, "no such host") {
		return maxTries - 2
	}
	return 0
}

func semiacceptable(err error) bool {
	str := err.Error()
	if strings.Contains(str, "http post status: 400") {
		return true
	}
	if strings.Contains(str, "http post status: 422") {
		return true
	}
	return false
}

var dqmtx sync.Mutex

func delinquent(from string, rcpt string, msg []byte) bool {
	dqmtx.Lock()
	defer dqmtx.Unlock()
	row := stmtDelinquentCheck.QueryRow(rcpt)
	var deliveryid int64
	var data []byte
	err := row.Scan(&deliveryid, &data)
	if err == sql.ErrNoRows {
		return false
	}
	if err != nil {
		elog.Printf("error scanning deliquent check: %s", err)
		return true
	}
	msgs, err := unpackMsgs(data)
	if err != nil {
		elog.Printf("error unpacking messages: %s", err)
	}
	msgs = append(msgs, MsgPair{From: from, Data: msg})
	data, err = repackMsgs(msgs)
	if err != nil {
		elog.Printf("error repacking messages: %s", err)
	}
	_, err = stmtDelinquentUpdate.Exec(data, deliveryid)
	if err != nil {
		elog.Printf("error updating deliquent: %s", err)
		return true
	}
	return true
}

func deliver(from string, rcpt string, msg []byte) {
	if delinquent(from, rcpt, msg) {
		return
	}
	d := Delivery{
		Rcpt: rcpt,
		Msgs: []MsgPair{{From: from, Data: msg}},
	}
	doDelivery(d)
}

var garage = gate.NewLimiter(40)

func doDelivery(delivery Delivery) {
	if !enableFedi {
		return
	}
	requestWG.Add(1)
	defer requestWG.Done()
	rcpt := delivery.Rcpt
	garage.StartKey(rcpt)
	defer garage.FinishKey(rcpt)

	var inbox string
	// already did the box indirection
	if rcpt[0] == '%' {
		inbox = rcpt[1:]
	} else {
		inbox = findInbox(rcpt, false)
		if inbox == "" {
			ilog.Printf("no inbox for delivery to %s", rcpt)
			return
		}
	}
	for i, msg := range delivery.Msgs {
		if i > 0 {
			time.Sleep(2 * time.Second)
		}
		keyname, seckey := getseckey(msg.From)
		if keyname == "" {
			elog.Printf("lost key for delivery")
			continue
		}
		err := PostMsg(keyname, seckey, inbox, msg.Data)
		if err != nil {
			ilog.Printf("failed to post json to %s: %s", inbox, err)
			if t := abortdelivery(err); t > delivery.Tries {
				delivery.Tries = t
			}
			if semiacceptable(err) {
				continue
			}
			delivery.Msgs = delivery.Msgs[i:]
			tryitagain(delivery)
			return
		}
	}
}

var pokechan = make(chan int, 1)

func getdeliveries() []Delivery {
	rows, err := stmtGetDeliveries.Query()
	if err != nil {
		elog.Printf("wat?")
		time.Sleep(1 * time.Minute)
		return nil
	}
	defer rows.Close()
	var deliveries []Delivery
	for rows.Next() {
		var d Delivery
		var dt string
		err := rows.Scan(&d.ID, &dt)
		if err != nil {
			elog.Printf("error scanning delivery: %s", err)
			continue
		}
		d.When, _ = time.Parse(dbtimeformat, dt)
		deliveries = append(deliveries, d)
	}
	return deliveries
}

func extractdelivery(d *Delivery) error {
	dqmtx.Lock()
	defer dqmtx.Unlock()
	row := stmtLoadDelivery.QueryRow(d.ID)
	var data []byte
	err := row.Scan(&d.Tries, &d.Rcpt, &data)
	if err != nil {
		return err
	}
	_, err = stmtZapDelivery.Exec(d.ID)
	if err != nil {
		return err
	}
	d.Msgs, err = unpackMsgs(data)
	return err
}

func deliveryManager() {
	sleeper := time.NewTimer(5 * time.Second)
	for {
		select {
		case <-pokechan:
			if !sleeper.Stop() {
				<-sleeper.C
			}
			time.Sleep(5 * time.Second)
		case <-sleeper.C:
		}

		deliveries := getdeliveries()

		now := time.Now()
		nexttime := now.Add(24 * time.Hour)
		for _, d := range deliveries {
			if d.When.Before(now) {
				err := extractdelivery(&d)
				if err != nil {
					elog.Printf("error extracting delivery: %s", err)
					continue
				}
				ilog.Printf("redelivering %s try %d", d.Rcpt, d.Tries)
				doDelivery(d)
			} else if d.When.Before(nexttime) {
				nexttime = d.When
			}
		}
		now = time.Now()
		dur := 5 * time.Second
		if now.Before(nexttime) {
			dur += nexttime.Sub(now).Round(time.Second)
		}
		sleeper.Reset(dur)
	}
}
