#!/usr/bin/env python

from __future__ import division

import sys
import os
    
def getFile(dir):
    return os.path.join(dir,'run.stats')

def getLastRecord(dir):
    f = open(getFile(dir))
    try:
        f.seek(-1024, 2)
    except IOError:
        pass # at beginning?
    for ln in f.read(1024).split('\n')[::-1]:
        ln = ln.strip()
        if ln.startswith('(') and ln.endswith(')'):
            if '(' in ln[1:]:
                print >>sys.stderr, 'WARNING: corrupted line in file, out of disk space?'
                ln = ln[ln.index('(',1):]
            return eval(ln)
    raise IOError


class LazyEvalList:
    def __init__(self, lines):
        self.lines = lines

    def __getitem__(self, index):
        item = self.lines[index]
        if isinstance(item,str):
            item = self.lines[index] = eval(item)
        return item

    def __len__(self):
        return len(self.lines)


def getMatchedRecord(data,reference,key):
    refKey = key(reference)
    lo = 1 # header
    hi = len(data)-1
    while lo<hi:
        mid = (lo+hi)//2
        if key(data[mid])<=refKey:
            lo = mid + 1
        else:
            hi = mid
    return data[lo]


def stripCommonPathPrefix(table, col):
    paths = map(os.path.normpath, [row[col] for row in table])
    pathElts = [p.split('/') for p in paths]
    zipped = zip(*pathElts)
    idx = 0
    for idx,elts in enumerate(zipped):
        if len(set(elts))>1:
            break
    paths = ['/'.join(elts[idx:]) for elts in pathElts]
    for i,row in enumerate(table):
        table[i] = row[:col] + (paths[i],) + row[col+1:]


def getKeyIndex(keyName,labels):
    def normalizedKey(key):
        if key.endswith("(#)") or key.endswith("(%)") or key.endswith("(s)"):
            key = key[:-3]
        return key.lower()

    keyIndex = None
    for i,title in enumerate(labels):
        if normalizedKey(title)==normalizedKey(keyName):
            keyIndex = i
            break
    else:
        raise ValueError,'invalid keyName to sort by: %s'%`keyName`
    return keyIndex


def sortTable(table, labels, keyName, ascending=False):   
    indices = range(len(table))
    keyIndex = getKeyIndex(keyName,labels)
    indices.sort(key = lambda n: table[n][keyIndex])
    if not ascending:
        indices.reverse()
    table[:] = [table[n] for n in indices]


def printTable(table):
    def strOrNone(ob):
        if ob is None:
            return ''
        elif isinstance(ob,float):
            return '%.2f'%ob
        else:
            return str(ob)
    def printRow(row):
        if row is None:
            print header
        else:
            out.write('|')
            for j,elt in enumerate(row):
                if j:
                    out.write(' %*s |'%(widths[j],elt))
                else:
                    out.write(' %-*s |'%(widths[j],elt))
            out.write('\n')
    maxLen = max([len(r) for r in table if r])
    for i,row in enumerate(table):
        if row:
            table[i] = row + (None,)*(maxLen-len(row))
    table = [row and map(strOrNone,row) or None for row in table]
    tableLens = [map(len,row) for row in table if row]
    from pprint import pprint
    widths = map(max, zip(*tableLens))
    out = sys.stdout
    header = '-'*(sum(widths) + maxLen*3 + 1)
    map(printRow, table)
        

def main(args):
    from optparse import OptionParser
    op = OptionParser(usage="usage: %prog [options] directories",
                      epilog="""\
LEGEND
------
Instrs:  Number of executed instructions
Time:    Total wall time (s)
ICov:    Instruction coverage in the LLVM bitcode (%)
BCov:    Branch coverage in the LLVM bitcode (%)
ICount:  Total static instructions in the LLVM bitcode
Solver:  Time spent in the constraint solver (%)
States:  Number of currently active states
Mem:     Megabytes of memory currently used
Queries: Number of queries issued to STP
AvgQC:   Average number of query constructs per query
Tcex:    Time spent in the counterexample caching code (%)
Tfork:   Time spent forking (%)""")

    op.add_option('', '--print-more', dest='printMore',
                  action='store_true', default=False,
                  help='Print extra information (needed when monitoring an ongoing run).')
    op.add_option('', '--print-all', dest='printAll',
                  action='store_true', default=False,
                  help='Print all available information.')
    op.add_option('','--sort-by', dest='sortBy',
                  help='key value to sort by, e.g. --sort-by=Instrs')
    op.add_option('','--ascending', dest='ascending',
                  action='store_true', default=False,
                  help='sort in ascending order (default is descending)')
    op.add_option('','--compare-by', dest='compBy',
                  help="key value on which to compare runs to the reference one (which is the first one).  E.g., --compare-by=Instrs shows how each run compares to the reference run after executing the same number of instructions as the reference run.  If a run hasn't executed as many instructions as the reference one, we simply print the statistics at the end of that run.")
    opts,dirs = op.parse_args()
    
    if not dirs:
        op.error("no directories given.")

    actualDirs = []
    for dir in dirs:
        if os.path.exists(os.path.join(dir,'info')):
            actualDirs.append(dir)
        else:
            for root,dirs,_ in os.walk(dir):
                for d in dirs:
                    p = os.path.join(root,d)
                    if os.path.exists(os.path.join(p,'info')):
                        actualDirs.append(p)
    dirs = actualDirs
    
    summary = []
    
    if (opts.printAll):
        labels = ('Path','Instrs','Time(s)','ICov(%)','BCov(%)','ICount','Solver(%)', 'States', 'Mem(MB)', 'Queries', 'AvgQC', 'Tcex(%)', 'Tfork(%)')
    elif (opts.printMore):
        labels = ('Path','Instrs','Time(s)','ICov(%)','BCov(%)','ICount','Solver(%)', 'States', 'Mem(MB)')
    else:
        labels = ('Path','Instrs','Time(s)','ICov(%)','BCov(%)','ICount','Solver(%)')


    def addRecord(Path,rec):
        (I,BFull,BPart,BTot,T,St,Mem,QTot,QCon,NObjs,Treal,SCov,SUnc,QT,Ts,Tcex,Tf) = rec

        # special case for straight-line code: report 100% branch coverage
        if BTot == 0:
            BFull = BTot = 1

        Mem=Mem/1024./1024.
        AvgQC = int(QCon/max(1,QTot))
        if (opts.printAll):
            table.append((Path, I, Treal, 100.*SCov/(SCov+SUnc), 100.*(2*BFull+BPart)/(2.*BTot),
                          SCov+SUnc, 100.*Ts/Treal, St, Mem, QTot, AvgQC, 100.*Tcex/Treal, 100.*Tf/Treal))
        elif (opts.printMore):
            table.append((Path, I, Treal, 100.*SCov/(SCov+SUnc), 100.*(2*BFull+BPart)/(2.*BTot),
                          SCov+SUnc, 100.*Ts/Treal, St, Mem))
        else:
            table.append((Path, I, Treal, 100.*SCov/(SCov+SUnc), 100.*(2*BFull+BPart)/(2.*BTot),
                          SCov+SUnc, 100.*Ts/Treal))
        
    def addRow(Path,data):
        data = tuple(data[:17]) + (None,)*(17-len(data))
        addRecord(Path,data)
        if not summary:
            summary[:] = list(data)
        else:
            summary[:] = [(a+b) for a,b in zip(summary,data)]

    datas = [(dir,LazyEvalList(list(open(getFile(dir))))) for dir in dirs]
    reference = datas[0][1][-1]

    table = []    

    for dir,data in datas:
        try:
            if opts.compBy:
                addRow(dir,getMatchedRecord(data,reference,lambda f: f[getKeyIndex(opts.compBy,labels)-1]))
            else: 
                addRow(dir, data[len(data)-1])  #getLastRecord(dir))
        except IOError:
            print 'Unable to open: ',dir
            continue

    stripCommonPathPrefix(table, 0)
    if opts.sortBy:
        sortTable(table, labels, opts.sortBy, opts.ascending)
    if not table:
        sys.exit(1)
    elif len(table)>1:
        table.append(None)
        addRecord('Total (%d)'%(len(table)-1,),summary)
    table[0:0] = [None,labels,None]
    table.append(None)
    printTable(table)
        
        
if __name__=='__main__':
    main(sys.argv)
