import unittest
import numpy

from harmony import random


class RandomTest(unittest.TestCase):
    def test_importance(self):
        probs = numpy.asarray([0.1, 0.5, 0.4])
        nsample = 3000
        tol = 0.03
        n1 = 0
        n2 = 0
        n3 = 0
        for i in range(nsample):
            x = random.get_importance(probs)[0]
            if x == 0:
                n1 += 1
            elif x == 1:
                n2 += 1
            elif x == 2:
                n3 += 1
            else:
                raise Exception("Bad output")

        p1 = float(n1)/nsample
        p2 = float(n2)/nsample
        p3 = float(n3)/nsample

        d1 = abs(p1 - probs[0])
        d2 = abs(p2 - probs[1])
        d3 = abs(p3 - probs[2])

        self.assertTrue(d1 < tol, "difference: {}".format(d1))
        self.assertTrue(d2 < tol, "difference: {}".format(d2))
        self.assertTrue(d3 < tol, "difference: {}".format(d3))

    def test_importance_gen(self):
        probs = numpy.asarray([0.2, 1.0, 0.8])
        nsample = 3000
        tol = 0.03
        n1 = 0
        n2 = 0
        n3 = 0
        for i in range(nsample):
            x = random.get_importance(probs)[0]
            if x == 0:
                n1 += 1
            elif x == 1:
                n2 += 1
            elif x == 2:
                n3 += 1
            else:
                raise Exception("Bad output")

        p1 = float(n1)/nsample
        p2 = float(n2)/nsample
        p3 = float(n3)/nsample

        d1 = abs(p1 - 0.5*probs[0])
        d2 = abs(p2 - 0.5*probs[1])
        d3 = abs(p3 - 0.5*probs[2])

        self.assertTrue(d1 < tol, "difference: {}".format(d1))
        self.assertTrue(d2 < tol, "difference: {}".format(d2))
        self.assertTrue(d3 < tol, "difference: {}".format(d3))


if __name__ == '__main__':
    unittest.main()
