#!/usr/bin/env python2.7

# obtained from :
# https://github.com/bxlab/bx-python/blob/master/lib/bx/intervals/operations/quicksect.py
# Latest commit f3419d7 on Jun 18 
# modified by Matt Halvorsen (mhalvors1@gmail.com)
# changelog :
# - took out reliance on Interval() class
# - made library entirely self-contained. No other part of bx-python needed
# - changed name of variable 'other' to 'infodict' to more explicitly 
#   communicate that variable is supposed to be a dictionary that stores
#   information regarding the IntervalNode in question.
# - set variable 'infodict' to dict() by default in all new instances of IntervalTree()
# - added import statement for builtin module 'sys' : needed for user input regarding
#   which tests to run
# - added code for brief unit test of IntervalTree and IntervalNode classes

"""
Intersects ... faster.  Suports GenomicInterval datatype and multiple
chromosomes.
"""
from __future__ import print_function

import math
import random
import sys

try:
    from time import process_time
except ImportError:
    # For compatibility with Python < 3.3
    from time import clock as process_time


class IntervalTree(object):
    def __init__(self):
        self.chroms = {}

    def insert(self, chrom, start, end, linenum=0, infodict={}):
        if chrom in self.chroms:
            self.chroms[chrom] = self.chroms[chrom].insert(start, end, linenum, infodict)
        else:
            self.chroms[chrom] = IntervalNode(start, end, linenum, infodict)

    def intersect(self, chrom, start, end, report_func):
        if chrom in self.chroms:
            self.chroms[chrom].intersect(start, end, report_func)

    def traverse(self, func):
        for item in self.chroms.values():
            item.traverse(func)


class IntervalNode(object):
    def __init__(self, start, end, linenum=0, infodict={}):
        # Python lacks the binomial distribution, so we convert a
        # uniform into a binomial because it naturally scales with
        # tree size.  Also, python's uniform is perfect since the
        # upper limit is not inclusive, which gives us undefined here.
        self.priority = math.ceil((-1.0 / math.log(.5)) * math.log(-1.0 / (random.uniform(0, 1) - 1)))
        self.start = start
        self.end = end
        self.maxend = self.end
        self.minend = self.end
        self.left = None
        self.right = None
        self.linenum = linenum
        self.infodict = infodict

    def insert(self, start, end, linenum=0, infodict={}):
        root = self
        if start > self.start:
            # insert to right tree
            if self.right:
                self.right = self.right.insert(start, end, linenum, infodict)
            else:
                self.right = IntervalNode(start, end, linenum, infodict)
            # rebalance tree
            if self.priority < self.right.priority:
                root = self.rotateleft()
        else:
            # insert to left tree
            if self.left:
                self.left = self.left.insert(start, end, linenum, infodict)
            else:
                self.left = IntervalNode(start, end, linenum, infodict)
            # rebalance tree
            if self.priority < self.left.priority:
                root = self.rotateright()
        if root.right and root.left:
            root.maxend = max(root.end, root.right.maxend, root.left.maxend)
            root.minend = min(root.end, root.right.minend, root.left.minend)
        elif root.right:
            root.maxend = max(root.end, root.right.maxend)
            root.minend = min(root.end, root.right.minend)
        elif root.left:
            root.maxend = max(root.end, root.left.maxend)
            root.minend = min(root.end, root.left.minend)
        return root

    def rotateright(self):
        root = self.left
        self.left = self.left.right
        root.right = self
        if self.right and self.left:
            self.maxend = max(self.end, self.right.maxend, self.left.maxend)
            self.minend = min(self.end, self.right.minend, self.left.minend)
        elif self.right:
            self.maxend = max(self.end, self.right.maxend)
            self.minend = min(self.end, self.right.minend)
        elif self.left:
            self.maxend = max(self.end, self.left.maxend)
            self.minend = min(self.end, self.left.minend)
        return root

    def rotateleft(self):
        root = self.right
        self.right = self.right.left
        root.left = self
        if self.right and self.left:
            self.maxend = max(self.end, self.right.maxend, self.left.maxend)
            self.minend = min(self.end, self.right.minend, self.left.minend)
        elif self.right:
            self.maxend = max(self.end, self.right.maxend)
            self.minend = min(self.end, self.right.minend)
        elif self.left:
            self.maxend = max(self.end, self.left.maxend)
            self.minend = min(self.end, self.left.minend)
        return root

    def intersect(self, start, end, report_func):
        if start < self.end and end > self.start:
            report_func(self)
        if self.left and start < self.left.maxend:
            self.left.intersect(start, end, report_func)
        if self.right and end > self.start:
            self.right.intersect(start, end, report_func)

    def traverse(self, func):
        if self.left:
            self.left.traverse(func)
        func(self)
        if self.right:
            self.right.traverse(func)


def main():

    def testfunc_error():
        print("quicksect.py <test_intersect> " + \
              "<test_intervalnode_intervaltree>")
        sys.exit(1)

    try:
        testfunc = sys.argv[1]
        testfuncs = ["test_intersect", "test_intervalnode_intervaltree"]
        assert testfunc in testfuncs
    except:
        testfunc_error()

    if testfunc == "test_intersect":
        test = None
        intlist = []
        for _ in range(20000):
            start = random.randint(0, 1000000)
            end = start + random.randint(1, 1000)
            if test:
                test = test.insert(start, end)
            else:
                test = IntervalNode(start, end)
            intlist.append((start, end))
        starttime = process_time()
        for x in range(5000):
            start = random.randint(0, 10000000)
            end = start + random.randint(1, 1000)
            result = []
            test.intersect(start, end, lambda x: result.append(x.linenum))
        print("%f for tree method" % (process_time() - starttime))
        starttime = process_time()
        for _ in range(5000):
            start = random.randint(0, 10000000)
            end = start + random.randint(1, 1000)
            bad_sect(intlist, start, end)
        print("%f for linear (bad) method" % (process_time() - starttime))

    elif testfunc == "test_intervalnode_intervaltree":
        x=IntervalTree()
        x.insert("21", 1, 10, infodict={"GENE":"ABC1"})
        x.insert("22", 1, 10, infodict={"GENE":"ABC2"})
        x.insert("22", 5, 10, infodict={"GENE":"ABC3"})
        x.insert("22", 1, 5, infodict={"GENE":"ABC4"})
        x.insert("22", 1, 6, infodict={"GENE":"ABC5"})
        result_list=[]
        x.intersect("22", 5, 15, lambda x: result_list.append(x.infodict["GENE"]))
        result_list.sort()
        result_list_expected = ["ABC2","ABC3","ABC5"]
        print("gene coordinates : ")
        print("ABC1 21:1-10")
        print("ABC2 22:1-10")
        print("ABC3 22:5-10")
        print("ABC4 22:1-5")
        print("ABC5 22:1-6")
        print("search interval : 22:5-15")
        print("Expected gene overlaps : " + ",".join(result_list_expected))
        print("Observed gene overlaps : " + ",".join(result_list))
        if result_list == result_list_expected:
            print("test_intervalnode_intervaltree results : PASS")
        else:
            print("test_intervalnode_intervaltree results : FAIL")
    
    else:
        testfunc_error()        

    return

def test_func(node):
    print("[%d, %d), %d" % (node.start, node.end, node.maxend))


def bad_sect(lst, int_start, int_end):
    intersection = []
    for start, end in lst:
        if int_start < end and int_end > start:
            intersection.append((start, end))
    return intersection


if __name__ == "__main__":
    main()
