import unittest

from omega.system import Molecule
from omega.coord.bond import get_bond_list
from omega.coord.fragment import get_fragments, get_fragment_bonds


class FragmentTest(unittest.TestCase):
    def test_bond_list1(self):
        mymol = Molecule()
        mymol.add((0.0, 0.0, 0.0), name="H")
        mymol.add((0.0, 0.0, 0.8), name="H")
        mymol.add((0.0, 0.8, 0.0), name="H")
        mymol.add((0.0, 0.8, 0.8), name="H")

        blist = get_bond_list(mymol)
        self.assertTrue(len(blist) == 4)
        for b in blist:
            bt = b.bond_type()
            self.assertTrue(bt == "H-H")

    def test_fragment_list1(self):
        mymol = Molecule()
        mymol.add((0.0, 0.0, 0.0), name="H")
        mymol.add((0.0, 0.0, 0.8), name="H")
        mymol.add((0.0, 0.8, 0.0), name="H")
        mymol.add((0.0, 0.8, 0.8), name="H")

        blist = get_bond_list(mymol)
        fragments = get_fragments(mymol, blist)
        self.assertTrue(len(fragments) == 1)
        self.assertTrue(len(fragments[0]) == 4)

    def test_fragment_list2(self):
        mymol = Molecule()
        mymol.add((0.0, 0.0, 0.0), name="H")
        mymol.add((0.0, 0.0, 0.8), name="H")
        mymol.add((0.0, 1.8, 0.0), name="H")
        mymol.add((0.0, 1.8, 0.8), name="H")

        blist = get_bond_list(mymol)
        fragments = get_fragments(mymol, blist)
        self.assertTrue(len(fragments) == 2)
        self.assertTrue(len(fragments[0]) == 2)
        self.assertTrue(len(fragments[1]) == 2)

    def test_fragment_list3(self):
        mymol = Molecule()
        mymol.add((0.0, 0.0, 0.0), name="H")
        mymol.add((0.0, 0.0, 1.8), name="H")
        mymol.add((0.0, 0.8, 0.0), name="H")
        mymol.add((0.0, 0.8, 1.8), name="H")

        blist = get_bond_list(mymol)
        fragments = get_fragments(mymol, blist)
        self.assertTrue(len(fragments) == 2)
        self.assertTrue(len(fragments[0]) == 2)
        self.assertTrue(len(fragments[1]) == 2)

    def test_fblist(self):
        mymol = Molecule()
        mymol.add((0.0, 0.0, 0.0), name="H")
        mymol.add((0.0, 0.0, 0.8), name="H")
        mymol.add((0.0, 1.8, 0.0), name="H")
        mymol.add((0.0, 1.8, 0.8), name="H")

        blist = get_bond_list(mymol)
        fragments = get_fragments(mymol, blist)
        self.assertTrue(len(fragments) == 2)
        self.assertTrue(len(fragments[0]) == 2)
        self.assertTrue(len(fragments[1]) == 2)
        blist2 = get_fragment_bonds(mymol, fragments)
        self.assertTrue(len(blist2) == 1)


if __name__ == '__main__':
    unittest.main()
