#include "dns.h"
#include "str.h"
#include "byte.h"
#include "uint16.h"
#include <unistd.h>
#include "error.h"
#include "case.h"
#include "base32.h"
#include "hexparse.h"
#include "strerr.h"

char unsigned binext[16];

static char *d = 0;
static char *d1 = 0;
static char *d2 = 0;

void txtparse(stralloc *sa)
{
  char ch;
  unsigned int i;
  unsigned int j;

  j = 0;
  i = 0;
  while (i < sa->len) {
    ch = sa->s[i++];
    if (ch == '\\') {
      if (i >= sa->len) break;
      ch = sa->s[i++];
      if ((ch >= '0') && (ch <= '7')) {
        ch -= '0';
        if ((i < sa->len) && (sa->s[i] >= '0') && (sa->s[i] <= '7')) {
          ch <<= 3;
          ch += sa->s[i++] - '0';
          if ((i < sa->len) && (sa->s[i] >= '0') && (sa->s[i] <= '7')) {
            ch <<= 3;
            ch += sa->s[i++] - '0';
          }
        }
      }
    }
    sa->s[j++] = ch;
  }
  sa->len = j;
}


int curveresolve_parse(stralloc *key, stralloc *ext, const char *s, const char sel){

    int i,j;
    char ch;
    int flagstart = 1;
    unsigned int u,uu;

    if (!stralloc_copys(key,"")) return -1;
    if (!stralloc_copys(ext,"")) return -1;

    i = -1;
    for(;;){
        ++i;
        ch = s[i];
        if (!ch) break;

        if (ch == '.'){
            flagstart = 1;
            continue;
        }

        if (flagstart){

            u = str_chr(s + i, '.');

            /* extension */
            if (key->len == 64 && ext->len == 0){
                if (!stralloc_copyb(ext, s + i, u)) return -1;
                uu = byte_chr(ext->s, ext->len, '/');

                if (ext->len > uu){
                    if (sel){
                        for(j = 0; j < ext->len - uu - 1; ++j ){
                            ch = ext->s[uu + 1 + j];
                            /* if (ch == sel) goto OK; */
                            if (!case_diffb(&ch,1,&sel)) goto OK; 
                        }
                        /* encryption is disabled for service */
                        key->len = 0;
                        ext->len = 0;
                     }
                }
                OK:
                ext->len = uu;
                ext->s[ext->len] = 0;
                if (ext->len != 32){
                    key->len = 0;
                    ext->len = 0;
                }
                ext->s[ext->len] = 0;
                if (!hexparse(binext,16,ext->s)){
                    key->len = 0;
                    ext->len = 0;
                }
            }

            /* public key */
            if (case_starts(s + i,"uz7")){
                if (u == 54){
                    if (base32_decodehex(key, (unsigned char *)(s+i+3), 51) == -1) return -1;
                }
            }

            flagstart = 0;
        }
    }

    if (key->len != 64 || ext->len != 32){
        key->len = 0;
        ext->len = 0;
        return 0;
    }
    return 1;
}

stralloc tmp = {0};
stralloc tmpout = {0};

int curveresolve_xparse(stralloc *key, stralloc *ext, const char *dd, const char sel){

    if (!stralloc_copys(&tmp,"")) return -1;
    if (!dns_domain_todot_cat(&tmp,dd)) return -1;
    txtparse(&tmp);
    if (!stralloc_0(&tmp)) return -1;
    return curveresolve_parse(key, ext, tmp.s, sel);
}

int curveresolve_packet2(stralloc *out, stralloc *key, stralloc *ext, char *buf,unsigned int len, const char sel){

  uint16 numanswers;
  uint16 numauthority;
  unsigned int pos;
  char data[12];
  char misc[20];
  uint16 datalen;
  unsigned int newpos;
  int flaguz7 = 0;

  if (!stralloc_copys(out,"")) return -1;

  pos = dns_packet_copy(buf,len,0,data,12); if (!pos) return -1;
  uint16_unpack_big(data + 6,&numanswers);
  uint16_unpack_big(data + 8,&numauthority);
  pos = dns_packet_skipname(buf,len,pos); if (!pos) return -1;
  pos += 4;

  while (numanswers--) {
    pos = dns_packet_skipname(buf,len,pos); if (!pos) return -1;
    pos = dns_packet_copy(buf,len,pos,misc,10); if (!pos) return -1;
    uint16_unpack_big(misc + 8,&datalen);
    newpos = pos + datalen;

    if (byte_equal(misc,2,DNS_T_NS)){
        if (byte_equal(misc + 2,2,DNS_C_IN)){
            pos = dns_packet_getname(buf,len,pos,&d); if (!pos) return -1;
            if (!flaguz7) flaguz7 = curveresolve_xparse(key, ext, d, sel);
            dns_domain_free(&d);
            if (flaguz7 == -1) return -1;
        }
    }
    else{
        pos = newpos;
    }
  }

  dns_sortip(out->s,out->len);
  return flaguz7;
}


int curveresolve_packet(stralloc *out, stralloc *key, stralloc *ext, char *buf,unsigned int len, const char sel, int *flagauthorityns){

  uint16 numanswers;
  uint16 numauthority;
  unsigned int pos;
  char data[12];
  char misc[20];
  uint16 datalen;
  unsigned int newpos;
  int flaguz7 = 0;

  *flagauthorityns = 0;
  if (!stralloc_copys(out,"")) return -1;

  pos = dns_packet_copy(buf,len,0,data,12); if (!pos) return -1;
  uint16_unpack_big(data + 6,&numanswers);
  uint16_unpack_big(data + 8,&numauthority);
  pos = dns_packet_getname(buf,len,pos,&d1); if (!pos) return -1;
  pos += 4;

  if (!flaguz7) flaguz7 = curveresolve_xparse(key, ext, d1, sel);

  while (numanswers--) {
    pos = dns_packet_skipname(buf,len,pos); if (!pos) return -1;
    pos = dns_packet_copy(buf,len,pos,misc,10); if (!pos) return -1;
    uint16_unpack_big(misc + 8,&datalen);
    newpos = pos + datalen;


    if (byte_equal(misc,2,DNS_T_CNAME) || byte_equal(misc,2,DNS_T_NS)){
        if (byte_equal(misc + 2,2,DNS_C_IN)){
            pos = dns_packet_getname(buf,len,pos,&d); if (!pos) return -1;
            if (!flaguz7) flaguz7 = curveresolve_xparse(key, ext, d, sel);
            dns_domain_free(&d);
            if (flaguz7 == -1) return -1;
        }
    }
    else if (byte_equal(misc,2,DNS_T_A)) {
        if (byte_equal(misc + 2,2,DNS_C_IN)){
            if (datalen == 4){
                pos = dns_packet_copy(buf,len,pos,misc,4); if (!pos) return -1;
                if (!stralloc_catb(out,misc,4)) return -1;
            }
        }
    }
    else{
        pos = newpos;
    }
  }

  while (numauthority--) {
    pos = dns_packet_getname(buf,len,pos,&d2); if (!pos) return -1;
    pos = dns_packet_copy(buf,len,pos,misc,10); if (!pos) return -1;
    uint16_unpack_big(misc + 8,&datalen);
    newpos = pos + datalen;

    if (byte_equal(misc,2, DNS_T_NS)){
        if (byte_equal(misc + 2,2,DNS_C_IN)){
            pos = dns_packet_getname(buf,len,pos,&d); if (!pos) return -1;
            if (dns_domain_equal(d1,d2)){
                *flagauthorityns = 1;
                if (!flaguz7) flaguz7 = curveresolve_xparse(key, ext, d, sel);
            }
            dns_domain_free(&d);
            if (flaguz7 == -1) return -1;
        }
    }
    else{
        pos = newpos;
    }
  }
  dns_domain_free(&d1);
  dns_domain_free(&d2);

  dns_sortip(out->s,out->len);
  return flaguz7;
}

static char *q = 0;
static char servers[64];
int curveresolve(stralloc *out, stralloc *key, stralloc *ext, stralloc *fqdn, const char sel){

    unsigned int i;
    char code;
    char ch;
    int r = 0;
    int flagauthorityns = 0;

    if (dns_resolvconfip(servers) == -1) return -1;

    if (!stralloc_copys(out,"")) return -1;
    code = 0;
    for (i = 0;i <= fqdn->len;++i) {
        if (i < fqdn->len)
            ch = fqdn->s[i];
        else
            ch = '.';

        if ((ch == '[') || (ch == ']')) continue;
        if (ch == '.') {
            if (!stralloc_append(out,&code)) return -1;
            code = 0;
            continue;
        }
        if ((ch >= '0') && (ch <= '9')) {
            code *= 10;
            code += ch - '0';
            continue;
        }

        if (!dns_domain_fromdot(&q,fqdn->s,fqdn->len)) return -1;
        if (dns_resolve(q,DNS_T_A) == -1){dns_domain_free(&q); return -1; }
        r = curveresolve_packet(out, key, ext, dns_resolve_tx.packet,dns_resolve_tx.packetlen, sel, &flagauthorityns);
        dns_transmit_free(&dns_resolve_tx);
        dns_domain_free(&q);
#ifdef EXTRAQUERYHACK
        if (r != 0) return r;
        if (!out->len) return r;
        if (flagauthorityns) return r;
        strerr_warn1("curveresolve.c: warning: extra query for NS record !!!",0);
        if (!dns_domain_fromdot(&q,fqdn->s,fqdn->len)) return -1;
        if (dns_resolve(q,DNS_T_NS) == -1){dns_domain_free(&q); return -1; }
        r = curveresolve_packet2(&tmp, key, ext, dns_resolve_tx.packet,dns_resolve_tx.packetlen, sel);
        dns_transmit_free(&dns_resolve_tx);
        dns_domain_free(&q);
#endif
        return r;

  }

  out->len &= ~3;
  return r;
}


static int doit(stralloc *work,const char *rule)
{
  char ch;
  unsigned int colon;
  unsigned int prefixlen;

  ch = *rule++;
  if ((ch != '?') && (ch != '=') && (ch != '*') && (ch != '-')) return 1;
  colon = str_chr(rule,':');
  if (!rule[colon]) return 1;

  if (work->len < colon) return 1;
  prefixlen = work->len - colon;
  if ((ch == '=') && prefixlen) return 1;
  if (case_diffb(rule,colon,work->s + prefixlen)) return 1;
  if (ch == '?') {
    if (byte_chr(work->s,prefixlen,'.') < prefixlen) return 1;
    if (byte_chr(work->s,prefixlen,'[') < prefixlen) return 1;
    if (byte_chr(work->s,prefixlen,']') < prefixlen) return 1;
  }

  work->len = prefixlen;
  if (ch == '-') work->len = 0;
  return stralloc_cats(work,rule + colon + 1);
}


static stralloc rules = {0};
static stralloc fqdn = {0};

int curveresolveq(stralloc *out, stralloc *key, stralloc *ext, stralloc *in, const char sel){

    unsigned int i;
    unsigned int j;
    unsigned int plus;
    unsigned int fqdnlen;
    int r=0;

    if (dns_resolvconfrewrite(&rules) == -1) return -1;
    if (!stralloc_copy(&fqdn,in)) return -1;
    if (!stralloc_readyplus(&fqdn, rules.len)) return -1;

    for (j = i = 0;j < rules.len;++j){
        if (!rules.s[j]) {
            if (!doit(&fqdn,rules.s + i)) return -1;
            i = j + 1;
        }
    }

    fqdnlen = fqdn.len;
    plus = byte_chr(fqdn.s,fqdnlen,'+');
    if (plus >= fqdnlen)
        return curveresolve(out, key, ext, &fqdn, sel);


    i = plus + 1;
    for (;;) {
        j = byte_chr(fqdn.s + i,fqdnlen - i,'+');
        byte_copy(fqdn.s + plus,j,fqdn.s + i);
        fqdn.len = plus + j;
        r = curveresolve(out, key, ext, &fqdn, sel);
        if (r  == -1) return -1;
        if (out->len) return r;
        i += j;
        if (i >= fqdnlen) return r;
        ++i;
    }

    return r;
}

