(* SwiftSurf
 * Sebastien Ailleret *)

open Unix
open Pervasives

open Activebuffer
open Conf
open Types
open Utils

(**************************************************************)
(* understand *)
(**************************************************************)

(* connect to the given inet addr *)
let connect_to_inet inet conn =
  conn.server <- Unix.socket Unix.PF_INET Unix.SOCK_STREAM 0;
  set_nonblock conn.server;
  try
    Unix.connect conn.server inet;
    conn.state <- ALIVE
  with Unix_error (EINPROGRESS, "connect", "") ->
    conn.state <- CONNECTING

(* connect to a host with port *)
let my_connect conn (host, port) =
  (try
    let ip_addr = Dns.my_gethostbyname conn host in
    connect_to_inet (ADDR_INET (ip_addr, port)) conn
  with
  | Not_found ->
      raise (Unix_error (EUNKNOWNERR 0, "gethostbyname", ""))
  | Dns.Dns_queue ->
      conn.state_req <- WAIT_DNS;
      conn.state <- DNS)

(* log this information *)
let make_log conn url =
  match !Types.stats with
  | 0 -> ()
  | 1 ->
      print_string url;
      flush stdout;
  | _ ->
      let t = Unix.localtime (Unix.time ()) in
      let s2_of_int i =
        (string_of_int (i / 10)) ^ (string_of_int (i mod 10)) in
      Printf.printf "%d-%d-%d, %d:%s:%s (%s) : %s"
        t.Unix.tm_mday (t.Unix.tm_mon + 1) (1900 + t.Unix.tm_year)
        t.Unix.tm_hour (s2_of_int t.Unix.tm_min)
        (s2_of_int t.Unix.tm_sec) conn.from url;
      flush stdout

(* make sure this method is allowed *)
let ok_meth conn =
  match conn.meth with
  | "CONNECT" ->
      (conn.prof.allCONNECT
       || conn.prof.canCONNECT
       && ((snd conn.host) == 443))
  | x -> conn.prof.methods = [] || List.mem x conn.prof.methods

(* end of headers, verify authentication *)
let goto_content conn =
  conn.state_req <-
    if conn.len_post > 0 then CONTENT else END;
  try
    conn.prof <- Conf.get_profile conn.auth;
    if conn.prof.req_in then
      (print_string (String.sub conn.read_req.buffer 0
                       conn.read_req.pos_fin);
       flush stdout);
    if conn.prof.req_1 then
      (print_string (Activebuffer.contents conn.write_req);
       flush stdout);
    match ok_url conn.url conn.prof with
    | None ->
      if ok_meth conn then
        (make_log conn (Printf.sprintf "ok %s%s\n" conn.proto_str conn.url);
         match conn.need_proxy with
         | None -> my_connect conn conn.host
         | Some (p, _) -> connect_to_inet p conn)
      else
        (* bad protocol *)
        (make_log conn "forbidden method\n";
         finish conn secure_forbid)
    | Some reason ->
      (* forbidden url *)
      (make_log conn (Printf.sprintf "no %s%s\n" conn.proto_str conn.url);
       finish conn reason)
  with Auth_failed ->
    (make_log conn "authentication failed\n";
     finish conn need_auth)

(* is this header a proxy authentication *)
let verif_auth conn deb len =
  try
    let str = "proxy-authorization: " in
    let len2 = String.length str in
    let rec aux i =
      if i == len2
      then true
      else str.[i] == (Char.lowercase conn.read_req.buffer.[deb+i])
          && aux (i+1) in
    if aux 0
    then
      (conn.auth <- decode64 (String.sub conn.read_req.buffer
                                (deb+len2+6) (len-len2-8));
       false)
    else true
  with _ -> false

(* manage a get or a post *)
let gere_std cmd conn deb_url =
  let fin_url = String.index_from cmd deb_url ' ' in
  if (String.sub cmd deb_url 7) = "http://" then
    (* this line looks good *)
    (conn.url <- String.sub cmd (deb_url+7) (fin_url - deb_url - 7);
     (* conn.proto_str <- "http://"; *) (* default value *)
     conn.state_req <- HEADERS;
     let fin_host = String.index conn.url '/' in
     let total_host = String.sub conn.url 0 fin_host in
     conn.host <-
       (try
         let deb_port = 1 + String.index total_host ':' in
         let str_port = String.sub total_host deb_port 
             ((String.length total_host) - deb_port) in
         let host = String.sub total_host 0 (deb_port - 1) in
         host, int_of_string str_port
       with Not_found ->
         (* no port given, use 80 *)
         total_host, 80);
     conn.need_proxy <- need_proxy (fst conn.host);
     (match conn.need_proxy with
     | None ->
         Activebuffer.add_substring conn.write_req cmd 0 deb_url;
         Activebuffer.add_substring conn.write_req conn.url fin_host
           ((String.length conn.url) - fin_host);
         Activebuffer.add_substring conn.write_req
           cmd fin_url ((String.length cmd) - fin_url);
     | Some (_, None) ->
         Activebuffer.add_string conn.write_req cmd
     | Some (_, Some c) ->
         Activebuffer.add_string conn.write_req cmd;
         Activebuffer.add_string conn.write_req c);
    Activebuffer.add_string conn.write_req "Connection: close\r\n";
    Activebuffer.add_string conn.write_req !Conf.add_req_headers)
  else (* not an http url *)
    (* cannot understand anything *)
    (make_log conn "Invalid request\n";
     finish conn invalid_req)

(* manage a connect (usually for https) *)
let gere_connect cmd conn deb_url =
  let fin_url = String.index_from cmd deb_url ' ' in
  conn.url <- String.sub cmd deb_url (fin_url - deb_url);
  conn.proto_str <- "https:";
  conn.host <-
    (try
      let deb_port = 1 + String.index_from cmd deb_url ':' in
      let port = int_of_string (String.sub cmd deb_port (fin_url-deb_port)) in
      String.sub cmd deb_url (deb_port-1-deb_url), port
    with _ -> String.sub cmd deb_url fin_url, 443);
  conn.need_proxy <- need_proxy (fst conn.host);
  match conn.need_proxy with
  | None ->
      let ans_200 = "HTTP/1.0 200 Connection established\r\n\r\n" in
      conn.len_post <- max_int; (* no limit for the input *)
      conn.state_req <- HEADERSCONNECT;
      conn.state_ans <- CONTENT;
      Activebuffer.add_string conn.write_ans ans_200
  | Some (_, None) ->
      Activebuffer.add_string conn.write_req cmd;
      conn.state_req <- HEADERS
  | Some (_, Some c) ->
      Activebuffer.add_string conn.write_req cmd;
      Activebuffer.add_string conn.write_req c;
      conn.state_req <- HEADERS

(* Try to understand the command line *)
let gere_cmdline cmd conn =
  try
    let deb_url = String.index cmd ' ' in
    let mthd = String.sub cmd 0 deb_url in
    (match mthd with
    | "CONNECT" ->
        conn.meth <- "CONNECT";
        gere_connect cmd conn (deb_url+1)
    | x ->
        conn.meth <- x;
        gere_std cmd conn (deb_url+1))
  with
  | Not_found
  | Failure "int_of_string" ->
      (* cannot understand anything *)
      make_log conn "Invalid request\n";
      finish conn invalid_req

(* Try to understand what we received *)
let rec compute_read conn =
  match conn.state_req with
  | CMD_LINE ->
      (try
        let pos = index conn.read_req '\n' in
        let cmd = String.sub conn.read_req.buffer 0 (pos+1) in
        conn.read_req.pos_deb <- pos+1;
        gere_cmdline cmd conn;
        match conn.state_req with
        | HEADERS | HEADERSCONNECT ->
            compute_read conn
        | _ -> ()
      with
      | Unix_error (_, "connect", _) ->
          finish conn host_unreach
      | Not_found ->
        (* the command line is not finished *)
          ()
      | e -> print_string (Printexc.to_string e); flush stdout;
          assert false)
  | WAIT_DNS ->
      (try
        let ip_addr =
          match conn.state with
          | DNSDONE ip -> ip
          | _ -> assert false in
        connect_to_inet (ADDR_INET (ip_addr, snd conn.host)) conn;
        conn.state_req <-
          if conn.len_post > 0 then CONTENT else END;
        compute_read conn
      with
      | Unix_error (_, "connect", "") ->
          finish conn host_unreach
      | e -> print_string (Printexc.to_string e); flush stdout;
          assert false)
  | HEADERS ->
      (try
        let pos = index conn.read_req '\n' in
        let len = pos + 1 - conn.read_req.pos_deb in
        let header = String.sub conn.read_req.buffer
            conn.read_req.pos_deb len in
        let deb = conn.read_req.pos_deb in
        conn.read_req.pos_deb <- pos+1;
        if len <= 2 then
          (* last line of headers *)
          (goto_content conn;
           Activebuffer.add_string conn.write_req header)
        else
          (if verif_auth conn deb len && ok_req_header header then
            ((
             try
               if String.lowercase (String.sub header 0 15) = "content-length:"
               then
                 (let pos = ref 16 in
                 while header.[!pos] >= '0' && header.[!pos] <= '9' do
                   conn.len_post <-
                     10 * conn.len_post + Char.code (header.[!pos])
                       - Char.code '0';
                   incr pos;
                 done);
             with _ -> ());  (* header too short... *)
             Activebuffer.add_string conn.write_req header));
        if conn.state_req != WAIT_DNS
        then compute_read conn
      with
      | Unix_error (_, "connect", _) ->
          finish conn host_unreach
      | Not_found ->
        (* this line of headers is not finished *)
          ()
      | e -> print_string (Printexc.to_string e); flush stdout;
          assert false)
  | HEADERSCONNECT -> (* forget them *)
      (try
        let pos = index conn.read_req '\n' in
        let len = pos + 1 - conn.read_req.pos_deb in
        let deb = conn.read_req.pos_deb in
        conn.read_req.pos_deb <- pos+1;
        if len <= 2 then
          (* last line of headers *)
          goto_content conn
        else
          ignore (verif_auth conn deb len);
        if conn.state_req != WAIT_DNS
        then compute_read conn
      with
      | Unix_error (_, "connect", _) ->
          finish conn host_unreach
      | Not_found ->
        (* this line of headers is not finished *)
          ()
      | e -> print_string (Printexc.to_string e); flush stdout;
          assert false)
  | CONTENT ->
      let ab = conn.read_req in
      let lab = Activebuffer.length ab in
      let len =
        if lab < conn.len_post
        then
          (conn.len_post <- conn.len_post - lab; lab)
        else
          (conn.state_req <- END; conn.len_post) in
      if conn.prof.req_2 then
        (output stdout ab.buffer ab.pos_deb len; flush stdout);
      Activebuffer.add_subbuffer conn.write_req ab 0 len;
      ab.pos_deb <- ab.pos_deb + len;
  | END -> ()

(**************************************************************)
(* read and write *)
(**************************************************************)

(* Update every connexions *)
let gere_conns time active_read active_write conns =
  let rec gere_aux conns = function
    | [] -> conns
    | conn::l when conn.timeout <= time -> (* timeout => close and forget *)
        (* it is easier to manage timeout here than in before select *)
        close_connexion conn;
        gere_aux conns l
    | conn::l ->
        match conn.state with
        | DNS ->
            (* is there sth to read *)
            if List.mem conn.client active_read
                && not (manage_read conn) then
              (* connexion closed by client *)
              (close_connexion conn;
               gere_aux conns l)
            else gere_aux (conn::conns) l
        | DNSDONE _ ->
            compute_read conn;
            (* is there sth to read *)
            if List.mem conn.client active_read
                && not (manage_read conn) then
              (* connexion closed by client *)
              (close_connexion conn;
               gere_aux conns l)
            else gere_aux (conn::conns) l
        | CONNECTING ->
            let connecting () =
              if List.mem conn.server active_write then
                (* We're connected *)
                (conn.state <- ALIVE;
                 manage_write conn conns l)
              else
                (* not connected yet *)
                gere_aux (conn::conns) l in
            if List.mem conn.client active_read then
              if manage_read conn then
                (* Nothing more to read now *)
                (if conn.read_req.pos_fin <> conn.read_req.pos_deb
                 then compute_read conn;
                 (* see if we can write *)
                 connecting ())
              else
                (* Connexion closed by client *)
                (close_connexion conn;
                 gere_aux conns l)
            else
              (* see if we can write something *)
              connecting ()
        | STARTING
        | ALIVE ->
            if List.mem conn.client active_read then
              if manage_read conn then
                (* Nothing more to read now *)
                (if conn.read_req.pos_fin <> conn.read_req.pos_deb
                 then compute_read conn;
                 (* try to write what must be *)
                 if conn.state == ALIVE then
                   manage_write conn conns l
                 else
                   gere_aux (conn::conns) l)
              else
                (* Connexion closed by client *)
                (close_connexion conn;
                 gere_aux conns l)
            else
              (* try to write what must be *)
              if conn.state == ALIVE then
                manage_write conn conns l
              else
                gere_aux (conn::conns) l
        | FINISHING ->
            gere_aux (conn::conns) l

  (* write what can be and is ready *)
  and manage_write conn conns l =
    (try
      let len = length conn.write_req in
      let str, pos = buffer conn.write_req in
      let nb = Unix.write conn.server str pos len in
      if nb > 0 then
        (if conn.prof.req_out then
           print_string (String.sub str pos nb);
         sub conn.write_req nb (len - nb);
         conn.size_req <- min (conn.size_req + nb) buf_size;
         conn.timeout <- time +. !Types.timeout);
      gere_aux (conn::conns) l
    with
    | Unix_error (EAGAIN, "write", _) ->
        (* Cannot write any more now *)
        gere_aux (conn::conns) l
    | Unix_error _  ->
        (* connexion closed by server *)
        finish conn "";
        gere_aux (conn::conns) l);

  (* read all available data
   * return true if the connexion is still alive
   *        false if the connexion is closed *)
  and manage_read conn =
    try
      conn.size_req == 0 || (* just read if allowed *)
      let str, pos = before_read conn.read_req conn.size_req in
      match Unix.read conn.client str pos conn.size_req with
      | 0 -> (* connexion closed *)
          false
      | nb ->
          if conn.prof.req_in then
            print_string (String.sub str pos nb);
          after_read conn.read_req nb;
          conn.timeout <- time +. !Types.timeout;
          conn.size_req <- conn.size_req - nb;
          true
    with
    | Unix_error (EAGAIN, "read", "") -> (* Nothing more to read now *)
        true
    | Unix_error _ -> (* connexion closed *)
        false
  in

  gere_aux [] conns
