def test_size(self):
     topo_test = topos.TOPOS({}, {})
     topo_test._topo_sort = {
         0: {'a', 'b'},
         1: {'c', 'd', 'g1'},
         2: {'e', 'f', 'g1', 'g2'},
         3: {'g1', 'g2'}
     }
     self.assertEqual(topo_test.size, 4)
 def test_get_ops(self):
     topo_test = topos.TOPOS({}, {})
     topo_test._topo_sort = {
         0: {'a', 'b'},
         1: {'c', 'd', 'g1'},
         2: {'e', 'f', 'g1', 'g2'},
         3: {'g1', 'g2'}
     }
     ret = topo_test.get_ops(2)
     self.assertEqual(ret, {'e', 'f', 'g1', 'g2'})
 def test_reindex(self):
     topo_test = topos.TOPOS({}, {})
     topo_test._topo_sort = {
         0: {'a', 'b'},
         1: {},
         2: {'e', 'f', 'g1', 'g2'},
         3: {},
         4: {'a'}
     }
     topo_test._reindex()
     expected = {0: {'a', 'b'}, 1: {'e', 'f', 'g1', 'g2'}, 2: {'a'}}
     self.assertDictEqual(expected, topo_test._topo_sort)
 def test_build(self, build_dep, clean_bw, clean_update, reindex,
                build_order, tps):
     grad_ops = {6, 7}
     topsosort_return = [{1, 2, 3}, {4, 5}, {6, 7, 8, 9}, {10, 11, 12, 13}]
     tps.return_value = iter(topsosort_return)
     topo_test = topos.TOPOS({}, grad_ops)
     topo_test.build()
     # tps.assert_called_once_with(build_dep.return_value)
     self.assertTrue(clean_bw.called)
     self.assertTrue(clean_update.called)
     self.assertTrue(reindex.called)
     self.assertTrue(build_order.called)
     self.assertEqual(topo_test._bw_starting_order, 2)
 def test_get_order(self):
     topo_test = topos.TOPOS({}, {})
     topo_test._orders = {
         'a': 0,
         'b': 0,
         'c': 1,
         'd': 1,
         'e': 2,
         'f': 2,
         'g1': 3,
         'g2': 3
     }
     self.assertEqual(topo_test.get_order('a'), 0)
     self.assertEqual(topo_test.get_order('f'), 2)
     self.assertEqual(topo_test.get_order('g1'), 3)
     self.assertEqual(topo_test.get_order('asdf'), -1)
    def test_build_dependency_dict(self, intersec, get_gen, get_cons):
        ops = [
            mock.Mock(name=('op%s' % x),
                      control_inputs=set(),
                      inputs=set(),
                      outputs=set()) for x in range(5)
        ]
        tensors = [
            mock.Mock(name=('ts%s' % x),
                      consuming_ops=set(),
                      generating_ops=set()) for x in range(4)
        ]
        # Build mock graph
        seed_ops = {ops[0]}
        ops[0].outputs = {tensors[0], tensors[1]}
        tensors[0].consuming_ops = [ops[1], ops[2]]
        tensors[0].generating_ops = [ops[0]]
        tensors[1].consuming_ops = [ops[3]]
        tensors[1].generating_ops = [ops[0]]

        ops[1].inputs = {tensors[0]}
        ops[1].outputs = {tensors[2]}
        ops[2].inputs = {tensors[0]}
        ops[2].outputs = {tensors[3]}
        tensors[2].generating_ops = [ops[1]]
        tensors[3].generating_ops = [ops[2]]
        tensors[2].consuming_ops = [ops[4]]
        tensors[3].consuming_ops = [ops[4]]
        ops[3].inputs = {tensors[1]}

        ops[4].inputs = {tensors[2], tensors[3]}

        grad_ops = {}

        get_cons.side_effect = lambda x: x.consuming_ops
        get_gen.side_effect = lambda x: x.generating_ops
        intersec.return_value = ops
        topo_test = topos.TOPOS(seed_ops, grad_ops)
        ret = topo_test._build_dependency_dict()
        expected_dict = {
            ops[0]: set(),
            ops[1]: {ops[0]},
            ops[2]: {ops[0]},
            ops[3]: {ops[0]},
            ops[4]: {ops[1], ops[2]}
        }
        self.assertDictEqual(expected_dict, ret)
 def test_clean_bw_ops(self):
     grad_ops = {'g1', 'g2'}
     topo_test = topos.TOPOS({}, grad_ops)
     topo_test._topo_sort = {
         0: {'a', 'b'},
         1: {'c', 'd', 'g1'},
         2: {'e', 'f', 'g1', 'g2'},
         3: {'g1', 'g2'}
     }
     topo_test._clean_bw_ops()
     expected_val = {
         0: {'a', 'b'},
         1: {'c', 'd'},
         2: {'e', 'f'},
         3: {'g1', 'g2'}
     }
     self.assertDictEqual(topo_test._topo_sort, expected_val)
 def test_clean_update_ops(self, fwd_walk):
     grad_ops = {'g1', 'g2'}
     topo_test = topos.TOPOS({}, grad_ops)
     topo_test._topo_sort = {
         0: {'a', 'b'},
         1: {'c', 'd', 'g1'},
         2: {'e', 'f', 'g1', 'g2'},
         3: {'g1', 'g2'}
     }
     fwd_walk.side_effect = lambda x, inclusive=False: {'b', 'c', 'f'}
     topo_test._clean_update_ops()
     expected = {
         0: {'a'},
         1: {'d', 'g1'},
         2: {'e', 'g1', 'g2'},
         3: {'g1', 'g2'}
     }
     self.assertDictEqual(expected, topo_test._topo_sort)
 def test_build_order_dict(self):
     grad_ops = {'g1', 'g2'}
     topo_test = topos.TOPOS({}, grad_ops)
     topo_test._topo_sort = {
         0: {'a', 'b'},
         1: {'c', 'd'},
         2: {'e', 'f'},
         3: {'g1', 'g2'}
     }
     topo_test._build_order_dict()
     expected_val = {
         'a': 0,
         'b': 0,
         'c': 1,
         'd': 1,
         'e': 2,
         'f': 2,
         'g1': 3,
         'g2': 3
     }
     self.assertDictEqual(expected_val, topo_test._orders)
    def run(self, graph=None):
        """Edit the graph by adding swapin and swapout ops.

        Swapin and swapout ops are in the host.

        The graph is modified in-place.

        Return:
          a set of added ops.
        """
        if graph:
            self._graph = graph

        if self._n_tensors == 0:
            self._log_info("LMS is disabled and will not modify the model.")
            return  # turn off LMS
        elif self._n_tensors < 0:
            self._n_tensors = 0  # swap all tensors (default)

        if not self._graph:
            raise ValueError('The dataflow graph is required but has not been'
                             ' provided.')

        self._log_info("Editing model for LMS")
        self._print_configuration()
        start_time = time.time()

        self._build_gradient_ops()
        seed_ops = self._get_seed_ops()

        self._log_info(
            "Starting ops: {}".format(
                [(op.name, op.type) for op in seed_ops]), 1)

        reachable_ops = set()
        for seed_op in seed_ops:
            reachable_ops |= set(self._get_forward_walk_ops(seed_op))

        for op in reachable_ops:
            if 'lms/swap' in op.name:
                self._log_info('This model has already been updated with LMS '
                               'swap operations. LMS will not re-process it.')
                return
        # exclusive ops
        self._excl_ops = self._filter_scopes_and_types(reachable_ops,
                                                       self._excl_scopes,
                                                       self._excl_types)
        # inclusive ops
        self._incl_ops = self._filter_scopes_and_types(reachable_ops,
                                                       self._incl_scopes,
                                                       self._incl_types)

        reachable_ops -= self._grad_ops

        # build a topological sort
        self._topo_sort = topos.TOPOS(seed_ops, self._grad_ops)
        self._topo_sort.build()
        for i in range(0, self._topo_sort.size):
            self._log_info("[{}]: {}".format(
                i, [op.name for op in self._topo_sort.get_ops(i)]), 1)

        self._do_action(seed_ops)

        # check the validation of the new model
        new_reachable_ops = set()
        for seed_op in seed_ops:
            new_reachable_ops |= set(ge.get_forward_walk_ops(seed_op))
        new_reachable_ops -= self._grad_ops
        if (new_reachable_ops >= reachable_ops):
            self._log_info("Edited model is valid and logically equivalent to the original one")
            self._log_info("Added {} ops into the model".format(len(new_reachable_ops - reachable_ops)))
        else:
            self._log_info("Edited model is invalid. Running this may produce unexpected result")

        self._log_info("Editing model for LMS, took: {} ms".format(
            (time.time()-start_time)*1000))
        self._log_info(
            "{} tensors will be swapped out(in) to(from) the host".format(
                self._incpu_count))
        return (new_reachable_ops - reachable_ops)
 def test_bw_starting_order(self):
     topo_test = topos.TOPOS({}, {})
     topo_test._bw_starting_order = 100
     self.assertEqual(topo_test.bw_starting_order, 100)