#!/usr/local/bin/python2.7
"""
    Generates an HTML report from a pf log file. At the moment this is little
    more than a capability demonstration - in time, though, it could grow into
    a really useful tool.

    This program is part of PyOpenBSD. It uses Cubictemp 0.4 or later
    (required) and PyGDChart2 (optional), both of which can be found at the
    Nullcube website (http://www.nullcube.com).
"""
import sys, optparse, os.path
import gdchart, openbsd, cubictemp
VERSION = "0.2"

class SkipReport(Exception): pass

class Counter(dict):
    def add(self, item):
        count = self.get(item, 0)
        self[item] = count + 1

    def counts(self):
        """ 
            Returns a descending list of keys sorted by values. 
        """
        result = map(None, self.values(), self.keys())
        result.sort()
        result.reverse()
        return result


class Reporter:
    """ 
        This class does two things:

            - It reads the pflog file, passing each packet to the .gather
              method of each report.
            - It creates the final HTML file that constitutes the "index" to
              the report.
    """
    Template = """
        <html>
            <style type="text/css" media=screen>
                <!--
                    body, table, input { 
                        font-family:        verdana,arial,helvetica,sans-serif; 
                        font-size:          large; 
                    }
                    th { 
                        background-color:   #9aabb7; 
                        font-weight:        bold; 
                    }
                    .highlight { 
                        background-color:   #9aabb7; 
                        font-weight:        bold; 
                    }
                    table {
                        width:              50%;
                    }
                -->
            </style>
            <h1> @!title!@ </h1>
            
            <table border="1" cellpadding="5" cellspacing="0">
                <tr>
                    <td class="highlight"> Start: </td>
                    <td> @!first!@ </td>
                </tr>
                <tr>
                    <td class="highlight"> End: </td>
                    <td> @!last!@ </td>
                </tr>
                <tr>
                    <td class="highlight"> Total Packets: </td>
                    <td> @!count!@       </td>
                </tr>
            </table>

            <p/>
            <!--(for i in reports)-->
                @!i!@
                <p/>
            <!--(end)-->
            <p/>
        </html>
    """
    TIMEFMT = "%b %d %H:%M:%S"
    def __init__(self):
        self.reports = []
        self.first, self.last = None, None
        self.count = 0

    def addReport(self, report, *args, **kwargs):
        try:
            self.reports.append(report(*args, **kwargs))
        except SkipReport:
            pass

    def process(self, logfile):
        feed = openbsd.pcap.Offline(logfile)
        feed.loop(-1, self, interpret=1)
        feed.close()

    def render(self, directory, title):
        if not os.path.isdir(directory):
            os.mkdir(directory)
        t = cubictemp.Temp(
                                self.Template,
                                reports=self.reports,
                                title=title,
                                count=self.count,
                                first=self.first.strftime(self.TIMEFMT),
                                last=self.last.strftime(self.TIMEFMT)
                            )
        f = open("%s/index.html"%directory, "w+")
        f.write(repr(t))

    def __call__(self, packet, tstamp, length):
        if not self.first:
            self.first = tstamp
        self.last = tstamp
        self.count += 1
        for i in self.reports:
            i.gather(packet, tstamp, length)


class _ReportBase:
    """
        Base class from which all reports are derived.
    """
    _cubictemp_unescaped = 1
    def __init__(self, directory):
        self.directory = directory

    def gather(self, packet, tstamp, length):
        """
            This method is called for every packet in the logfiles.
        """
        raise NotImplementedError

    def __str__(self):
        """
            This method is called after all packets have been passed to the
            report. It should perform any finalisation required (closing files,
            etc.), and should return None, or an HTML snippet to be embedded in
            the report index page.
        """
        raise NotImplementedError


class OverviewReport(_ReportBase):
    """
        High-level graphical overview of blocked traffic.
    """
    def __init__(self, *args, **kwargs):
        try:
            import gdchart
        except ImportError:
            raise SkipReport
        _ReportBase.__init__(self, *args, **kwargs)
        self.data = []
        self.labels = []

        self.currentday = 0
        self.counter = 0

    def gather(self, packet, tstamp, length):
        if not self.currentday:
            self.currentday = tstamp.date()
            self.counter = 1
        else:
            if tstamp.date() != self.currentday:
                delta = tstamp.date() - self.currentday
                for i in range(delta.days - 1):
                    self.data.append(0)
                    self.labels.append(" ")
                self.data.append(self.counter)
                self.labels.append(self.currentday.strftime("%d/%m"))
                self.counter = 0
                self.currentday = tstamp.date()
            else:
                self.counter += 1

    def __str__(self):
        self.data.append(self.counter)
        self.labels.append(self.currentday.strftime("%d/%m"))
        g = gdchart.Line()
        g.bg_color = "white"
        g.width = 600
        g.height = 200
        g.setData(self.data)
        g.setLabels(self.labels)
        g.plot_color = "red"
        g.draw(os.path.join(self.directory, "overview.png"))
        return "<h2> Traffic Graph </h2><p><img src=\"overview.png\">"


class _TopAddressBase(_ReportBase):
    """
        Base class for top N address reports.
    """
    Template = """
        <h2> Top @!n!@ %s Addresses </h2>

        <table border="1" cellpadding="5" cellspacing="0">
            <th></th>
            <th>%s IP</th>
            <th>Occurences</th>
            <!--(for i in range(len(values)))-->
                <tr>
                    <td>
                        <b>@!i + 1!@</b>
                    </td>
                    <td>
                        @!values[i][1]!@
                    </td>
                    <td>
                        @!values[i][0]!@
                    </td>
                </tr>
            <!--(end)-->
        </table>
    """
    def __init__(self, directory, n):
        _ReportBase.__init__(self, directory)
        self.n = n
        self.c = Counter()

    def __str__(self):
        top = self.c.counts()[:self.n]
        t = cubictemp.Temp(self.Template, values=top, n=self.n)
        return repr(t)


class TopSources(_TopAddressBase):
    """
        Top N source addresses.
    """
    Type = "Source"
    Template = _TopAddressBase.Template%(Type, Type)
    def gather(self, packet, tstamp, length):
        try:
            self.c.add(packet["ip"].src)
        except KeyError:
            pass


class TopDestinations(_TopAddressBase):
    """
        Top N destination addresses.
    """
    Type = "Destination"
    Template = _TopAddressBase.Template%(Type, Type)
    def gather(self, packet, tstamp, length):
        try:
            self.c.add(packet["ip"].dst)
        except KeyError:
            pass


class _TopPortBase(_ReportBase):
    """
        Base class for top N port reports.
    """
    Template = """
        <h2> Top @!n!@ %s Ports </h2>

        <table border="1" cellpadding="5" cellspacing="0">
            <th></th>
            <th>%s Port</th>
            <th>Occurences</th>
            <!--(for i in range(len(values)))-->
                <tr>
                    <td>
                        <b>@!i + 1!@</b>
                    </td>
                    <td>
                        @!values[i][1]!@
                    </td>
                    <td>
                        @!values[i][0]!@
                    </td>
                </tr>
            <!--(end)-->
        </table>
    """
    def __init__(self, directory, n):
        _ReportBase.__init__(self, directory)
        self.n = n
        self.c = Counter()

    def __str__(self):
        top = self.c.counts()[:self.n]
        t = cubictemp.Temp(self.Template, values=top, n=self.n)
        return repr(t)


class TopSourcePorts(_TopPortBase):
    """
        Top N source ports.
    """
    Type = "Source"
    Template = _TopPortBase.Template%(Type, Type)
    def gather(self, packet, tstamp, length):
        try:
            self.c.add(packet["tcp"].srcPort)
        except KeyError:
            try:
                self.c.add(packet["udp"].srcPort)
            except KeyError:
                pass


class TopDestinationPorts(_TopPortBase):
    """
        Top N destination ports.
    """
    Type = "Destination"
    Template = _TopPortBase.Template%(Type, Type)
    def gather(self, packet, tstamp, length):
        try:
            self.c.add(packet["tcp"].dstPort)
        except KeyError:
            try:
                self.c.add(packet["udp"].dstPort)
            except KeyError:
                pass


class ProtocolAnalysis(_ReportBase):
    def __init__(self, directory):
        try:
            import gdchart
        except ImportError:
            raise SkipReport
        _ReportBase.__init__(self, directory)
        self.c = Counter()

    def gather(self, packet, tstamp, length):
        if packet.has_key("udp"):
            self.c.add("UDP")
        elif packet.has_key("tcp"):
            self.c.add("TCP")
        elif packet.has_key("icmp"):
            self.c.add("ICMP")
        elif packet.has_key("igmp"):
            self.c.add("IGMP")

    def __str__(self):
        counts = self.c.counts()
        data = [i[0] for i in counts]
        labels = ["%s - %s"%(i[1], i[0]) for i in counts]
        g = gdchart.Pie()
        g.bg_color = "white"
        g.width = 400
        g.height = 400
        g.setData(*data)
        g.setLabels(labels)
        g.color = ["red", "green", "yellow", "blue"]
        g.draw(os.path.join(self.directory, "protocols.png"))
        return "<h2> Protocol Breakdown </h2><p><img src=\"protocols.png\">"
            

def main():
    parser = optparse.OptionParser(usage="%prog [options] destdir logfile1 logfile2 ...",
                                    version="pfreport v%s"%VERSION)
    parser.add_option("-t", None, dest="title", help="Report Title", default="PF Log Report")
    reports = optparse.OptionGroup(parser, "Reports",
                        "The following options define which reports are included in the output.")
    reports.add_option("-g",    dest="overview", action="store_true",
                                help="Do not show graphical traffic overview.")
    reports.add_option("-p",    dest="protocol", action="store_true",
                                help="Do not show protocol breakdown.")
    reports.add_option("-s",    dest="sources", metavar="N", type="int",
                                default=20,
                                help="Show top n source IPs. Set to 0 to turn off. (Default: 10)")
    reports.add_option("-d",    dest="destinations", metavar="N", type="int",
                                default=20,
                                help="Show top n destination IPs. Set to 0 to turn off. (Default: 10)")
    reports.add_option("-r",    dest="sourceports", metavar="N", type="int",
                                default=20,
                                help="Show top n source ports. Set to 0 to turn off. (Default: 10)")
    reports.add_option("-e",    dest="destinationports", metavar="N", type="int",
                                default=20,
                                help="Show top n destination ports. Set to 0 to turn off. (Default: 10)")
    parser.add_option_group(reports)

    options, args = parser.parse_args()
    if len(args) < 2:
        parser.error("Not enough arguments.")

    directory = args[0]
    logfiles = args[1:]

    r = Reporter()
    if not options.overview:
        r.addReport(OverviewReport, directory)
    if not options.protocol:
        r.addReport(ProtocolAnalysis, directory)
    if options.sources:
        r.addReport(TopSources, directory, options.sources)
    if options.destinations:
        r.addReport(TopDestinations, directory, options.destinations)
    if options.sourceports:
        r.addReport(TopSourcePorts, directory, options.sourceports)
    if options.destinationports:
        r.addReport(TopDestinationPorts, directory, options.destinationports)

    for i in logfiles:
        r.process(i)
    r.render(directory, options.title)


if __name__ == "__main__":
    main()
