#!/usr/bin/python
# you need to install
# (i) python-gtk2- (python-gtk2-dev also?)
# (ii) python-matplotlib

import os,sqlite3,sys,time
import optparse

import gtk

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_gtkagg import FigureCanvasGTKAgg as FigureCanvas
from matplotlib.backends.backend_gtkagg import NavigationToolbar2GTKAgg as NavigationToolbar

def Ws(s):
    sys.stdout.write(s)
def Es(s):
    sys.stderr.write(s)

def sql_execute(co, sql):
    t0 = time.time()
    Ws("[SQL] %s\n" % sql)
    res = co.execute(sql)
    t1 = time.time()
    Ws("took %f sec\n" % (t1 - t0))
    return res

class dr_database:
    """
    dr_database facilitates reading dr database
    """
    def get_finish_t(self, co):
        for finish_t, in sql_execute(co, "select max(end_t) from nodes"):
            pass
        return finish_t

    def get_start_clock(self, co):
        for start_clock, in sql_execute(co, "select start_clock from basics"):
            pass
        return start_clock

    def get_pp_events(self, co):
        """
        generate a list of (t,r,e,c,cc,wc)s,
        """
        ready_plus = """
        select 
        first_ready_t                                        as t,
        0.0                                                  as running,
        t_ready_end        /(last_start_t-first_ready_t+0.0) as end,
        t_ready_create     /(last_start_t-first_ready_t+0.0) as creat,
        t_ready_create_cont/(last_start_t-first_ready_t+0.0) as create_cont,
        t_ready_wait_cont  /(last_start_t-first_ready_t+0.0) as wait_cont
        from nodes
        where cur_node_count = 1
        """
        ready_minus = """
        select 
        last_start_t                                          as t,
        0.0                                                   as running,
        -t_ready_end        /(last_start_t-first_ready_t+0.0) as end,
        -t_ready_create     /(last_start_t-first_ready_t+0.0) as creat,
        -t_ready_create_cont/(last_start_t-first_ready_t+0.0) as create_cont,
        -t_ready_wait_cont  /(last_start_t-first_ready_t+0.0) as wait_cont
        from nodes
        where cur_node_count = 1
        """
        running_plus = """
        select 
        start_t                 as t,
        t_1/(end_t-start_t+0.0) as running,
        0.0                     as end,
        0.0                     as creat,
        0.0                     as create_cont,
        0.0                     as wait_cont
        from nodes
        where cur_node_count = 1
        """
        running_minus = """
        select 
        end_t                    as t,
        -t_1/(end_t-start_t+0.0) as running,
        0.0                      as end,
        0.0                      as creat,
        0.0                      as create_cont,
        0.0                      as wait_cont
        from nodes
        where cur_node_count = 1
        """
        cmd_all = ("""
        select * from 
        (%s union all %s union all %s union all %s) 
        order by t""" 
                   % (ready_plus, ready_minus, 
                      running_plus, running_minus))
        return sql_execute(co, cmd_all)

class drview:
    def __init__(self, opts, args):
        self.opts = opts
        self.args = args
        self.dr_sqlite_co = None
        self.search_paths = [ "." ]
        # widgets set in setup()
        self.win = None
        self.fig = None
        self.parallelism_ax = None
        self.table_widget = None
        self.scrolled_text = None
        self.textview = None
        self.n_lines = 0
        self.line_clicked = 0

    def open_sqlite_file(self):
        co = sqlite3.connect(self.args[0])
        self.dr_sqlite_co = co
        return co

    def close_sqlite_file(self):
        self.dr_sqlite_co.close()

    def setup(self):
        self.search_paths.extend(self.opts.search_paths)
        self.dr_sqlite_co = self.open_sqlite_file()
        # main window
        win = gtk.Window()
        win.move(100, 100)
        win.set_size_request(500,500)
        win.set_title("%s" % self.__class__.__name__)
        # win.set_default_size(1800, 800)
        win.connect("destroy", self.destroy)
        self.win = win
        # +----------+----------+----------+
        # | graph    | table    | code     |
        # |          |          |          |
        # |          |          |          |
        # +----------+          |          |
        # |navi bar  |          |          |
        # +----------+----------+----------+
        # hpane
        hpane = gtk.HPaned()
        #hpane.set_size_request(500,500)
        win.add(hpane)
        # vbox
        vbox = gtk.VBox()
        hpane.pack1(vbox, resize=1, shrink=1)
        # fig
        fig = plt.figure() # figsize=(5,4)
        self.fig = fig
        # graph
        parallelism_ax = fig.add_subplot(1, 1, 1)
        parallelism_ax.grid(1)
        self.parallelism_ax = parallelism_ax
        self.load_parallelism_data()
        canvas = FigureCanvas(fig)
        #canvas.set_size_request(500,500)
        #canvas.show()
        vbox.pack_start(canvas)
        #vbox.set_size_request(500,500)
        # toolbar
        toolbar = NavigationToolbar(canvas, win)
        #toolbar.show()
        vbox.pack_start(toolbar, expand=False, fill=False)
        # hpane of table and code
        hpane_right = gtk.HPaned()
        hpane_right.show()
        hpane.pack2(hpane_right, resize=1, shrink=1)
        # scroll window for the table
        scrolled_table = gtk.ScrolledWindow()
        hpane_right.pack1(scrolled_table, resize=1, shrink=1)
        # table inside scroll window
        table_widget = self.make_table_1()
        #table_widget = self.make_table_2()
        #table_widget.show()
        self.table_widget = table_widget
        self.load_table_data_1(None)
        #self.load_table_data_2(dr_sqlite_file, bt_sqlite_file, None)
        table_widget.connect("row-activated", self.on_row_activated_1)
        #table_widget.connect("row-activated", self.on_row_activated_2)
        scrolled_table.add_with_viewport(table_widget)
        # textview
        scrolled_text = gtk.ScrolledWindow()
        #scrolled_text.show()
        # hpane_right.add2(scrolled_text)
        hpane_right.pack2(scrolled_text, resize=1, shrink=1)
        textview = gtk.TextView()
        #textview.set_size_request(500,500)
        textbuffer = textview.get_buffer()
        textbuffer.set_text("")
        scrolled_text.add_with_viewport(textview)
        self.scrolled_text = scrolled_text
        self.textview = textview
        self.n_lines = 0
        self.line_clicked = 0
        # event handlers
        self.set_event_handlers()
        # show all
        win.show_all()

    def add_search_path(self, p):
        if os.path.exists(p):
            self.search_paths.append(p)
        else:
            Es("warning: search path %s does not exist (ignored)\n" % p)

    def search_file(self, filename):
        for d in self.search_paths:
            if filename[:1] == "/":
                p = filename
            else:
                p = "%s/%s" % (d, filename)
            if os.path.exists(p):
                return p
        return None

    def on_row_activated_1(self, table_widget, path, view_column):
        """
        callback when the table pane is clicked.
        open the corresponding file and put a 
        red mark at the line
        """
        model = table_widget.get_model()
        row, = path
        col = view_column.get_title() # start,end,t,count
        if col == "start":
            clicked_column = 0
        else:
            clicked_column = 1
        colors = [ "cyan", "pink" ]
        locs = [ model[row][0], model[row][1] ]
        files_and_lines = [ loc.rsplit(":", 1) for loc in locs ]
        [ file_clicked,line_clicked ] = files_and_lines[clicked_column]
        [ the_other_file,the_other_line ] = files_and_lines[1-clicked_column]
        line_clicked = int(line_clicked)
        the_other_line = int(the_other_line)
        file_found = self.search_file(file_clicked)
        if file_found is None:
            content = "<%s not found>" % file_clicked
            n_lines = 0
        else:
            fp = open(file_found)
            content = fp.read()
            n_lines = len(content.split("\n"))
            fp.close()
        textview = self.textview
        textbuffer = textview.get_buffer()
        textbuffer.set_text(content)
        # color begin and/or end 
        if file_found:
            if the_other_file == file_clicked \
               and the_other_line == line_clicked:
                it0 = textbuffer.get_iter_at_line(line_clicked-1)
                it1 = textbuffer.get_iter_at_line(line_clicked)
                tag = textbuffer.create_tag()
                tag.set_property("background", "magenta")
                textbuffer.apply_tag(tag, it0, it1)
            else:
                it0 = textbuffer.get_iter_at_line(line_clicked-1)
                it1 = textbuffer.get_iter_at_line(line_clicked)
                tag = textbuffer.create_tag()
                tag.set_property("background", colors[clicked_column])
                textbuffer.apply_tag(tag, it0, it1)
                if the_other_file == file_clicked:
                    it0 = textbuffer.get_iter_at_line(the_other_line-1)
                    it1 = textbuffer.get_iter_at_line(the_other_line)
                    tag = textbuffer.create_tag()
                    tag.set_property("background", colors[1-clicked_column])
                    textbuffer.apply_tag(tag, it0, it1)
            self.n_lines = n_lines
            self.line_clicked = line_clicked
            print self.line_clicked, self.n_lines
            self.scroll_text(None, None)
        self.win.queue_draw()

    def scroll_text(self, widget, event, data=None):
        c = self.line_clicked
        n = self.n_lines
        if n > 0:
            adj = self.scrolled_text.get_vadjustment()
            l,u = adj.get_lower(), adj.get_upper()
            v = (u * c + l * (n - c))/n
            print l,u,v
            adj.set_value(v - adj.get_page_size() * 0.5)

    def on_row_activated_2(self, table_widget, path, view_column):
        """
        callback when the table pane is clicked.
        open the corresponding file and put a 
        red mark at the line
        """
        model = table_widget.get_model()
        row, = path
        loc = model[row][0]
        [ file,line ] = loc.rsplit(":", 1)
        line = int(line)
        if os.path.exists(file):
            fp = open(file)
            content0 = fp.read()
            fp.close()
        else:
            content0 = "<%s not found>" % file
        textview = self.textview
        textbuffer = textview.get_buffer()
        textbuffer.set_text(content0)
        it0 = textbuffer.get_iter_at_line(line)
        it1 = textbuffer.get_iter_at_line_offset(line, 1)
        tag = textbuffer.create_tag()
        tag.set_property("background", "red")
        textbuffer.apply_tag(tag, it0, it1)
        self.win.queue_draw()

    def on_xylim_changed(self, ax):
        """
        called when a rectangular region is selected.
        matpltolib redraws the profiler pane.
        the table pane is recomputed to reflect the
        selected region
        """
        xstart,ystart,xdelta,ydelta = ax.viewLim.bounds
        self.load_table_data_1((xstart,xstart+xdelta))
        #bt_sqlite_file = self.bt_sqlite_file
        #self.load_table_data_2(dr_sqlite_file, bt_sqlite_file, (xstart,xstart+xdelta))
        self.table_widget.queue_draw()
        
    def set_event_handlers(self):
        conn = self.parallelism_ax.callbacks.connect
        conn("xlim_changed", self.on_xylim_changed)
        conn("ylim_changed", self.on_xylim_changed)
        self.textview.connect("size-allocate", self.scroll_text)

    def destroy(self, x):
        gtk.main_quit(x)

    def make_table_1(self):
        """
        make the widget for the table pane
        """
        table = gtk.ListStore(str, str, long, long)
        table_widget = gtk.TreeView(table)
        #table_widget.set_size_request(500,500)
        view_columns = []
        for i,label in enumerate([ "start", "end", "t", "count" ]):
            vc = gtk.TreeViewColumn(label, gtk.CellRendererText(), 
                                    text=i)
            table_widget.append_column(vc)
        return table_widget

    def make_table_2(self):
        """
        make the widget for the table pane
        """
        table = gtk.ListStore(str, str, long) # file:line_no fun, count
        table_widget = gtk.TreeView(table)
        view_columns = []
        for i,label in enumerate([ "point", "count" ]):
            vc = gtk.TreeViewColumn(label, gtk.CellRendererText(), 
                                    text=i)
            table_widget.append_column(vc)
        return table_widget

    def load_table_data_1(self, xlim):
        """
        load data from sqlite_file.
        """
        co = self.dr_sqlite_co
        db = dr_database()
        if xlim is None:
            xlim = (0, db.get_finish_t(co))
        # gtk.ListStore(str, str, long, long)
        table_store = self.table_widget.get_model()
        sql = ("""select 
        s1.file,start_line,s2.file,end_line,
        sum(end_t - start_t) as t,
        count(*) as count
        from nodes,strings s1,strings s2 
        on s1.idx = start_file and s2.idx = end_file 
        where cur_node_count = 1 and start_t <= %s and end_t >= %s
        group by start_file,start_line,end_file,end_line 
        order by t desc""" % (xlim[1], xlim[0]))
        table_store.clear()
        for start_f,start_l,end_f,end_l,t,count in sql_execute(co, sql):
            start_loc = "%s:%d" % (start_f,start_l)
            end_loc = "%s:%d" % (end_f,end_l)
            table_store.append((start_loc, end_loc, t, count))

    def load_table_data_2(self, dr_sqlite_file, bt_sqlite_file, xlim):
        """
        load data from sqlite_file.
        """
        co_bt = sqlite3.connect(bt_sqlite_file)
        co_dr = sqlite3.connect(dr_sqlite_file)
        dr = dr_database()
        # bt = bt_database()
        if xlim is None:
            xlim = (0, dr.get_finish_t(co_dr))
        start_clock = dr.get_start_clock(co_dr)
        # gtk.ListStore(str, long)
        table_store = self.table_widget.get_model()
        sql = ("""select filename,line_no,fun,count(*) from 
        samples natural join frames natural join locs 
        where depth = 2 
        and   idepth = 0
        and   %d <= tsc 
        and   tsc < %d
        group by addr
        """ % (start_clock + xlim[0], start_clock + xlim[1]))
        table_store.clear()
        for filename,line_no,fun,count in co_bt.execute(sql):
            loc = "%s:%s" % (filename,line_no)
            table_store.append((loc, fun, count))
        co_dr.close()
        co_bt.close()

    def load_parallelism_data(self):
        co = self.dr_sqlite_co
        db = dr_database()
        finish_t = db.get_finish_t(co)
        target_samples = 10000
        interval = (0, finish_t / (target_samples - 1))
        # values between (t_last,t)
        t_last = 0
        y_last = [ 0.0, 0.0, 0.0, 0.0, 0.0 ]
        # integral values between (interval[0],t)
        Y_interval = [ 0.0, 0.0, 0.0, 0.0, 0.0 ]
        # FIX: add "other"
        Y_labels = [ "running", "end", "create", "create cont", "wait cont" ]
        Y = [ [], [], [], [], [] ]
        T = []
        n = 0                   # samples already taken
        for t,r,e,c,cc,wc in db.get_pp_events(co):
            y = (r,e,c,cc,wc)
            for i in range(len(y)):
                Y_interval[i] += y_last[i] * (t - t_last)
                y_last[i] += y[i]
            if t >= interval[1]:
                T.append(interval[0])
                T.append(t)
                for i in range(len(y)):
                    assert (t - interval[0]), (interval, t)
                    v = Y_interval[i] / (t - interval[0])
                    Y[i].append(v)
                    Y[i].append(v)
                n += 1
                if n >= target_samples - 1:
                    interval = (t, finish_t + 1)
                else:
                    interval = (t, t + max(1, (finish_t - t) / (target_samples - 1 - n)))
                Y_interval = [ 0.0, 0.0, 0.0, 0.0, 0.0 ]
            t_last = t
        max_y = 1.1 * max((sum(ys) for ys in zip(*Y)))
        ax = self.parallelism_ax
        ax.set_ylim((0, max_y))
        ax.set_xlim((0, max(T) * 1.1))
        ax.set_color_cycle(["red", "lightgreen", "blue", "magenta", "cyan" ])
        sp = ax.stackplot(T, Y, linewidth=0)
        # here is how to add legends 
        # http://stackoverflow.com/questions/20336881/matplotlib-stackplot-legend-error
        proxy_rects = [plt.Rectangle((0,0),1,1,
                                     fc=p.get_facecolor()[0]) for p in sp ]
        ax.legend(proxy_rects, Y_labels)
        #red_patch = mpatches.Patch(color='red', label='The red data')
        #self.parallelism_ax.legend(handles=[red_patch])
        #plt.legend(handles=[red_patch])

def parse_cmdline(cmdline_args):
    parser = optparse.OptionParser()
    parser.add_option("-p", "--path",
                      action="append", 
                      dest="search_paths", default=[ "." ],
                      help="add search path")
    (options, args) = parser.parse_args(cmdline_args)
    return (options, args)

def main():
    opt,args = parse_cmdline(sys.argv[1:])
    drv = drview(opt, args)
    drv.setup()
    gtk.main()
    return drv

pv = main()

