#!/usr/pkg/bin/python3.12
#
# Copyright 2021, Julian Catchen <jcatchen@illinois.edu>
#
# This file is part of Stacks.
#
# Stacks is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Stacks is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Stacks.  If not, see <http://www.gnu.org/licenses/>.
#

import sys
import argparse
import os
import gzip
from datetime import datetime
from enum import Enum
from operator import itemgetter, attrgetter

version     = '2.68'
stacks_path = ""
bam_path    = ""
out_path    = ""

min_pct_id  = 0.6  # Minimum BLAST-style percent identity to keep an alignment.
min_aln_cov = 0.6  # Minimum fraction of the read participating in the alignment.
min_mapq    = 20   # Minimum mapping quality required to keep an alignment.
verbose     = False

def parse_cigar(c_str):
    cigar = []

    cigar_len = len(c_str)
    i         = 0
    while i < cigar_len:
        j = i
        while c_str[j].isdigit():
            j += 1
        cigar.append( (c_str[j], int(c_str[i:j])) )
        i = j + 1

    return cigar


class AlnType(Enum):
    unmapped      = 4
    primary       = 2
    secondary     = 256
    supplementary = 2048


class Aln:
    """
    Object to hold the alignment of one read to the reference.
    """
    def __init__(self, chr, bp, strand, alntype, mapq, cigar, edit_dist):
        self.chr       = chr
        self.bp        = bp
        self.type      = alntype
        self.strand    = strand  # Strand alignment is mapped to; 1 = positive, 0 = negative.
        self.mapq      = mapq    # Mapping quality score.
        self.edit_dist = 0       # Edit distance of the query to the reference.
        self.q_len     = 0       # Length of the query read.
        self.pct_id    = 0.0     # BLAST-style percent identity between the query and reference.
        self.aln_pct   = 0.0     # Percentage of the query length participating in the alignment.
        self.cigar_str = cigar
        self.cigar     = parse_cigar(cigar)

        #
        # Calculate alignment block length from the CIGAR string.
        #
        query_len     = 0.0
        aln_blk_len   = 0.0
        query_aln_len = 0.0
        ref_aln_len   = 0
        for cig_op, op_len in self.cigar:
            if cig_op == "H" or cig_op == "S":
                query_len     += op_len
            elif cig_op == "M":
                query_len     += op_len
                aln_blk_len   += op_len
                query_aln_len += op_len
                ref_aln_len   += op_len
            elif cig_op == "I":
                query_len     += op_len
                aln_blk_len   += op_len
                query_aln_len += op_len
            elif cig_op == "D":
                aln_blk_len   += op_len
                ref_aln_len   += op_len

        self.q_len   = query_len
        self.r_len   = ref_aln_len
        self.pct_id  = (aln_blk_len - float(edit_dist)) / aln_blk_len
        self.aln_pct = query_aln_len / query_len

class Locus:
    def __init__(self, id):
        self.id_old   = id
        self.id_new   = -1
        self.pri_aln  = None
        self.num_samples = 0
        self.contig      = ""
        self.sequence    = ""

def parse_command_line():
    global min_pct_id, min_aln_cov, min_mapq
    global stacks_path, bam_path, out_path
    global verbose, version

    desc = """Extracts the coordinates of the RAD loci from the given BAM file into a \
    'locus_coordinates.tsv' table, then rewrites the 'catalog.fa.gz' and \
    'catalog.calls' files so that they include the genomic coordinates given in the \
    input BAM file."""

    p = argparse.ArgumentParser(description=desc)

    #
    # Add options.
    #
    p.add_argument("-P","--in-path", type=str, dest="stacks_path", required=True, metavar="path",
                   help="Path to a directory containing Stacks ouput files.")
    p.add_argument("-B","--bam-path", type=str, dest="bam_path", required=True, metavar="path",
                   help="Path to a SAM or BAM file containing alignment of de novo catalog loci to a reference genome.")
    p.add_argument("-O","--out-path", type=str, dest="out_path", required=True, metavar="path",
                   help="Path to write the integrated ouput files.")
    p.add_argument("-q", "--min_mapq", type=int,
                   help="Minimum mapping quality as listed in the BAM file (default 20).")
    p.add_argument("-a", "--min_alncov", type=float,
                   help="Minimum fraction of the de novo catalog locus that must participate in the alignment (default 0.6).")
    p.add_argument("-p", "--min_pctid", type=float,
                   help="Minimum BLAST-style percent identity of the largest alignment fragment for a de novo catalog locus (default 0.6).")
    p.add_argument("--verbose", action="store_true", dest="verbose",
                   help="Provide verbose output.")
    p.add_argument("--version", action='version', version=version)

    #
    # Parse the command line
    #
    args = p.parse_args()

    if args.min_pctid != None:
        min_pct_id = args.min_pctid
        if min_pct_id > 1 and min_pct_id <= 100:
            min_pct_id = min_pct_id / 100
    if args.min_alncov != None:
        min_aln_cov = args.min_alncov
        if min_aln_cov > 1 and min_aln_cov <= 100:
            min_aln_cov = min_aln_cov / 100
    if args.min_mapq != None:
        min_mapq = args.min_mapq
    if args.stacks_path != None:
        stacks_path = args.stacks_path
    if args.bam_path != None:
        bam_path = args.bam_path
    if args.out_path != None:
        out_path = args.out_path
    if args.verbose != None:
        verbose = args.verbose

    if len(stacks_path) == 0 or os.path.exists(stacks_path) == False:
        print("Error: you must specify a valid path to the de novo Stacks output directory.", file=sys.stderr)
        p.print_help()
        sys.exit()

    if len(bam_path) == 0 or os.path.exists(bam_path) == False:
        print("Error: you must specify a valid path to SAM or BAM file containing alignmnets of de novo catalog loci to a reference genome.", file=sys.stderr)
        p.print_help()
        sys.exit()

    if len(out_path) == 0 or os.path.exists(out_path) == False:
        print("Error: you must specify a valid/existing path to a directory to write the integrated output files.", file=sys.stderr)
        p.print_help()
        sys.exit()

    if out_path[-1] != "/":
        out_path += "/"

    if stacks_path[-1] != "/":
        stacks_path += "/"


def find_max_locus_id(stacks_path):

    fh = gzip.open(stacks_path + "catalog.fa.gz", 'rt')

    max_locus_id = 0
    
    for line in fh:
        line = line.strip('\n')

        if line[0] != ">":
            continue

        parts    = line.split(' ')
        locus_id = int(parts[0][1:])

        if locus_id > max_locus_id:
            max_locus_id = locus_id
    fh.close()

    return max_locus_id


def parse_bam_file(path, loci, chrs):
    #
    # Pull a list of chromosomes out of the BAM header.
    #
    cmd = "samtools view -H " + path
    fh = os.popen(cmd, "r")

    for line in fh:
        line  = line.strip("\n")
        parts = line.split("\t")

        if parts[0] != "@SQ":
            continue

        chrs.append( (parts[1][3:], int(parts[2][3:])) )
    fh.close()

    #
    # Sort chromosomes by size.
    #
    chrs.sort(key=itemgetter(1), reverse=True)

    tot_loc = {}
    tot_aln = 0
    pri_cnt = 0
    sec_cnt = 0
    sup_cnt = 0
    unm_cnt = 0
    
    cmd = "samtools view " + path
    fh = os.popen(cmd, "r")

    for line in fh:
        line  = line.strip("\n")
        parts = line.split("\t")

        loc_id = int(parts[0])
        chr    = parts[2]
        bp     = int(parts[3])
        mapq   = int(parts[4])
        cigar  = parts[5]

        edit_dist = 0
        for part in parts:
            if part[0:5] == "NM:i:":
                edit_dist = int(part[5:])

        tot_aln += 1

        if loc_id not in tot_loc:
            tot_loc[loc_id] = 0

        #
        # Unmapped, primary, secondary, or supplementary alignment?
        #
        flag = int(parts[1])

        if ((flag & AlnType.secondary.value) == AlnType.secondary.value):
            atype = AlnType.secondary
            sec_cnt += 1
            tot_loc[loc_id] += 1
            continue
        elif ((flag & AlnType.supplementary.value) == AlnType.supplementary.value):
            atype = AlnType.supplementary
            sup_cnt += 1
            tot_loc[loc_id] += 1
            continue
        elif ((flag & AlnType.unmapped.value) == AlnType.unmapped.value):
            unm_cnt += 1
            continue
        else:
            atype = AlnType.primary
            pri_cnt += 1
            tot_loc[loc_id] += 1

        #
        # Mapped to the negative strand?
        #
        if ((flag & 16) == 16):
            strand = 0
        else:
            strand = 1

        aln = Aln(chr, bp, strand, atype, mapq, cigar, edit_dist)

        loc = Locus(loc_id)
        loc.pri_aln = aln
        if loc_id in loci:
            print("  Warning: locus {} has more than one primary alignment.".format(locus), file=sys.stderr)
        else:
            loci[loc_id] = loc
        
    fh.close()

    loc_unmapped = 0
    loc_onemap   = 0
    loc_multimap = 0
    for loc_id in tot_loc:
        if tot_loc[loc_id] == 0:
            loc_unmapped += 1
        elif tot_loc[loc_id] == 1:
            loc_onemap += 1
        else:
            loc_multimap += 1
            
    print("  Read {} total alignments: {} mapped; {} unmapped; {} secondary alignments; {} supplementary alignments.".format(
        tot_aln, pri_cnt, unm_cnt, sec_cnt, sup_cnt), file=sys.stderr)
    print("  {} total de novo catalog loci seen; {} had multiple alignments; {} had one alignment; {} were unmapped.".format(
        len(tot_loc), loc_multimap, loc_onemap, loc_unmapped), file=sys.stderr)


def filter_alns(loci):
    #
    # Filter the alignments by MAPQ, percent identity, and alignment coverage.
    #
    mapq_filt   = 0
    pctid_filt  = 0
    alncov_filt = 0
    loc_rem     = 0
    for loc_id in loci:

        if loci[loc_id].pri_aln.mapq < min_mapq:
            mapq_filt += 1
            loci[loc_id].pri_aln = None
        elif loci[loc_id].pri_aln.aln_pct < min_aln_cov:
            alncov_filt += 1
            loci[loc_id].pri_aln = None
        elif loci[loc_id].pri_aln.pct_id < min_pct_id:
            pctid_filt += 1
            loci[loc_id].pri_aln = None
        else:
            loc_rem += 1

    tot_filt = mapq_filt + pctid_filt + alncov_filt

    print("  {} loci filtered; {} due to mapping quality (mapq); {} due to alignment coverage; {} due to percent identity; {} loci remain.".format(
        tot_filt, mapq_filt, alncov_filt, pctid_filt, loc_rem), file=sys.stderr)


def check_locus_bounds(loci, chrs):
    #
    # Translate coordinates on the negative strand and check locus bounds.
    #
    chr_key = {}
    for chr, clen in chrs:
        chr_key[chr] = clen

    art_corrected = 0
    ref_too_small = 0
    
    for loc_id in loci:
        loc = loci[loc_id]

        if loc.pri_aln == None:
            continue

        #
        # Adjust the alignmnet position of negative strand alignments to the RAD cutsite -- the 3' end.
        #
        if loc.pri_aln.strand == 0:
            loc.pri_aln.bp += loc.pri_aln.r_len
            position = "{}:{}:-".format(loc.pri_aln.chr, int(loc.pri_aln.bp))
        else:
            position = "{}:{}:+".format(loc.pri_aln.chr, int(loc.pri_aln.bp))
        loc.position = position

        #
        # Check that locus does not reach boundaries of the reference scaffold/chromosome
        #
        if loc.pri_aln.strand == 0:
            if loc.pri_aln.bp - loc.pri_aln.q_len < 0:
                #
                # The aligned locus extends prior to the start of the reference scaffold. Without
                # correction, this could result in Stacks translating SNP coordinates into this
                # 'negative' space.
                #
                loc.pri_aln.bp = loc.pri_aln.bp + (loc.pri_aln.q_len - loc.pri_aln.bp)
                if loc.pri_aln.bp > chr_key[loc.pri_aln.chr]:
                    if verbose:
                        print("Locus {}, aligned to {} [CIGAR {}] extends past the starting bound; reference is too small to correct, removing alignment.".format(
                            loc.id_old, position, loc.pri_aln.cigar_str), file=sys.stderr)
                    loc.pri_aln = None
                    ref_too_small += 1
                    continue
                else:
                    art_corrected += 1
                    loc.position   = "{}:{}:-".format(loc.pri_aln.chr, int(loc.pri_aln.bp))
                    if verbose:
                        print("Locus {}, aligned to {} [CIGAR {}] extends past the starting bound; artifically adjusting to {} to prevent SNPs with negative basepair coordinates.".format(
                            loc.id_old, position, loc.pri_aln.cigar_str, loc.position), file=sys.stderr)
        else:
            if loc.pri_aln.bp + loc.pri_aln.q_len > chr_key[loc.pri_aln.chr]:
                #
                # The aligned locus extends out past the end of the reference scaffold. Without
                # correction, this could result in Stacks translating SNP coordinates into this
                # 'positive' space.
                #
                loc.pri_aln.bp = loc.pri_aln.bp - (loc.pri_aln.q_len - loc.pri_aln.r_len) + 1
                if loc.pri_aln.bp < 0:
                    if verbose:
                        print("Locus {}, aligned to {} [CIGAR {}] extends past the end bound; reference is too small to correct, removing alignment.".format(
                            loc.id_old, position, loc.pri_aln.cigar_str), file=sys.stderr)
                    loc.pri_aln = None
                    ref_too_small += 1
                    continue
                else:
                    art_corrected += 1
                    loc.position   = "{}:{}:+".format(loc.pri_aln.chr, int(loc.pri_aln.bp))
                    if verbose:
                        print("Locus {}, aligned to {} [CIGAR {}] extends past the end bound; artifically adjusting to {} to prevent SNPs with non-existant basepair coordinates.".format(
                            loc.id_old, position, loc.pri_aln.cigar_str, loc.position), file=sys.stderr)

    print("  Artificially altered the alignment positions of {} loci; removed {} loci that could not be altered properly.".format(art_corrected, ref_too_small), file=sys.stderr)


def write_catalog(loci, chrs, datestamp, new_locus_id, ordered_loci):
    #
    # Read in existing catalog sequences.
    #
    fh = gzip.open(stacks_path + "catalog.fa.gz", 'rt')

    eof  = False
    line = fh.readline()
    if len(line) == 0:
        return
    line = line.strip('\n ')

    while len(line) == 0 or line[0] == "#":
        line = fh.readline()
        if len(line) == 0:
            return
        line = line.strip('\n ')

    while eof == False:
        seq = ""

        if line[0] == ">":
            # Parse and store the comments from the ID line
            parts  = line.split(' ')
            loc_id = int(parts[0][1:])
            ns     = parts[1]
            contig = ""
            if len(parts) > 2:
                contig = parts[2]
            line   = ""

        while len(line) == 0 or line[0] != ">":
            line = fh.readline()
            if len(line) == 0:
                eof = True
                break
            line = line.strip('\n ')
            if len(line) == 0 or line[0] == "#":
                continue
            if line[0] != ">":
                seq += line
        #
        # Record the final sequence.
        #
        if loc_id in loci:
            loci[loc_id].num_samples = ns
            loci[loc_id].contig      = contig
            loci[loc_id].sequence    = seq

    fh.close()
    
    #
    # Sort the loci by chromosome, output from largest to smallest chromosome.
    #
    for loc_id in loci:
        loc = loci[loc_id]

        if loc.pri_aln == None:
            continue

        if loc.pri_aln.chr not in ordered_loci:
            ordered_loci[loc.pri_aln.chr] = []
        ordered_loci[loc.pri_aln.chr].append(loc)

    for chr in ordered_loci:
        ordered_loci[chr].sort(key=attrgetter("pri_aln.bp"))

    #
    # Assign a new locus ID ordered by alignment location.
    #
    new_id = new_locus_id
    for chr, chrlen in chrs:
        if chr not in ordered_loci:
            continue
        
        for loc in ordered_loci[chr]:

            if loc.pri_aln == None:
                continue

            loc.id_new = new_id
            loci[loc.id_old].id_new = new_id
            new_id += 1
    
    #
    # Re-write the catalog FASTA file. Also write a file of chromosomes used.
    #
    o_fa    = out_path + "catalog.fa.gz"
    out_fh  = gzip.open(o_fa, "wt")
    out_fh.write("# Generated by stacks-integrate-alignments, version {}; date {}\n# {}\n".format(version, datestamp, " ".join(sys.argv)))
    
    chrs_fh = open(out_path + "catalog.chrs.tsv", "w")
    chrs_fh.write("# Generated by stacks-integrate-alignments, version {}; date {}\n# {}\n".format(version, datestamp, " ".join(sys.argv)))
    chrs_fh.write("# Chrom\tLength\n")
    
    for chr, chrlen in chrs:
        if chr not in ordered_loci:
            continue

        chrs_fh.write("{}\t{}\n".format(chr, chrlen))
        
        for loc in ordered_loci[chr]:
            out_fh.write(">{} pos={} {} {}\n{}\n".format(loc.id_new, loc.position, loc.num_samples, loc.contig, loc.sequence))

    out_fh.close()
    chrs_fh.close()
    return o_fa


def write_locus_coords(ordered_loci, chrs, datestamp, out_path):
    #
    # Translate coordinates on the negative strand and write a table of old/new locus IDs.
    #
    o_coords = out_path + "locus_coordinates.tsv"

    out_fh = open(o_coords, "w")

    out_fh.write("# Generated by stacks-integrate-alignments, version {}; date {}\n# {}\n".format(version, datestamp, " ".join(sys.argv)))
    out_fh.write("# id_new\tid_old\taln_pos\n")

    for chr, chrlen in chrs:
        if chr not in ordered_loci:
            continue

        for loc in ordered_loci[chr]:
            out_fh.write("{}\t{}\t{}\n".format(loc.id_new, loc.id_old, loc.position))

    out_fh.close()
    return o_coords

    
def write_catalog_calls(loci, chrs, stacks_path, out_path, datestamp):
    #
    # Read in existing catalog sequences.
    #
    fh = gzip.open(stacks_path + "catalog.calls", 'rt')

    #
    # Write the header.
    #
    out_fh = gzip.open(out_path + "catalog.calls", 'wt')

    for line in fh:
        if line[0:10] == "##filedate":
            out_fh.write("##filedate={}\n".format(datestamp))
            continue
        elif line[0:8] == "##source":
            out_fh.write("##source=\"Stacks v{}; {}\"\n".format(version, " ".join(sys.argv)))
            continue
        elif line[0] == "#":
            out_fh.write(line)
            continue
        else:
            break

    out_fh.close()
    
    #
    # Open the output pipe to sort and compress.
    #
    o_vcf   = out_path + "catalog.calls"
    cmd     = "sort -k 1,1n -k 2,2n - | gzip >> " + o_vcf
    out_fh  = os.popen(cmd, mode="w")

    #
    # Write the last line read above
    #
    line   = line.strip('\n')
    parts  = line.split('\t')
    loc_id = int(parts[0])

    pre_loc = loc_id
    loc_cnt = 0
    col_cnt = 50

    if loc_id in loci and loci[loc_id].pri_aln != None:
        out_fh.write("{}\t{}\n".format(loci[loc_id].id_new, "\t".join(parts[1:])))

    #
    # Read from the input file, translate the locus ID, output to the pipe to sort and compress.
    #
    for line in fh:
        line   = line.strip('\n')
        parts  = line.split('\t')
        loc_id = int(parts[0])
        
        if loc_id not in loci:
            continue

        if loci[loc_id].pri_aln == None:
            continue

        if loc_id != pre_loc:
            pre_loc  = loc_id
            if loc_cnt % 1000 == 0:
                print(".", file=sys.stderr, end="", flush=True)
                if col_cnt % 100 == 0:
                    print("\n  ", file=sys.stderr, end="", flush=True)
                col_cnt += 1
            loc_cnt += 1

        out_fh.write("{}\t{}\n".format(loci[loc_id].id_new, "\t".join(parts[1:])))

    print("\n  Waiting for sort and gzip to finish...", file=sys.stderr)

    fh.close()
    out_fh.close()

    return o_vcf


##
##------------------------------------------------------------------------------------------------##
##

loci = {}
chrs = []

parse_command_line()

#
# Find the maximum locus ID in the existing de novo catalog, and set our starting ID number for the integrated loci.
#
print("Finding the highest current locus ID... ", file=sys.stderr, end="")
max_locus_id = find_max_locus_id(stacks_path)
new_locus_id = max_locus_id + 1
print(max_locus_id, file=sys.stderr)

print("\nExtracting locus coordinates...", file=sys.stderr)
parse_bam_file(bam_path, loci, chrs)

print("\nFiltering alignments...", file=sys.stderr)
filter_alns(loci)

check_locus_bounds(loci, chrs)

datestamp = datetime.now().strftime("%Y%m%d")

ordered_loci = {}

print("\nRewriting locus catalog sequences and information...", file=sys.stderr)
o_fa = write_catalog(loci, chrs, datestamp, new_locus_id, ordered_loci)
print("  Wrote " + o_fa, file=sys.stderr)

o_coords = write_locus_coords(ordered_loci, chrs, datestamp, out_path)
print("  Wrote " + o_coords, file=sys.stderr)

print("\nReading, translating, and writing the catalog calls", file=sys.stderr, end="", flush=True)
o_vcf = write_catalog_calls(loci, chrs, stacks_path, out_path, datestamp)
print("  Wrote " + o_vcf, file=sys.stderr)

pos      = sys.argv[0].rfind("/") + 1
basename = sys.argv[0][pos:]
print("\n{} is done.".format(basename))
