Esempio n. 1
0
    def test_fuse_swapin_ops(self, identity, ctrl_dep, connect, sgv):
        lms_test = lms.LMS({'s1'}, graph=mock.Mock(), lb=5, ub=500)
        lms_test._topo_sort = mock.Mock()
        # Mock get_order to return the "order" value from the mock op
        lms_test._topo_sort.get_order = lambda x: x.order
        lms_test._topo_sort.size = 200
        swap_in = mock.Mock(name='swapin_ts', op=mock.Mock(name='swapin_op'))
        identity.return_value = swap_in
        sgv.return_value.input_index.return_value = 123
        src_op = mock.Mock(name='src_op')
        swapout_op = mock.Mock(name='swapout_op')
        bw_fr_ops = [mock.Mock(name='op1', order=15),
                     mock.Mock(name='op2', order=-1),
                     mock.Mock(name='op3', order=10),
                     mock.Mock(name='op4-earliest', order=5),
                     mock.Mock(name='op5', order=8),
                     mock.Mock(name='op6', order=-2)]
        ts0 = mock.Mock(name='ts0')
        earliest_op = bw_fr_ops[3]

        ret = lms_test._fuse_swapin_ops(src_op, swapout_op, set(bw_fr_ops),
                                        ts0)
        self.assertEqual(ret, {bw_fr_ops[1], bw_fr_ops[5]})
        connect_calls = [mock.call(swapout_op, swap_in.op),
                         mock.call(swap_in.op, bw_fr_ops[0], remap_inputs=True,
                                   idx=123),
                         mock.call(swap_in.op, bw_fr_ops[2], remap_inputs=True,
                                   idx=123),
                         mock.call(swap_in.op, bw_fr_ops[3], remap_inputs=True,
                                   idx=123),
                         mock.call(swap_in.op, bw_fr_ops[4], remap_inputs=True,
                                   idx=123)]
        connect.assert_has_calls(connect_calls, any_order=True)
        self.assertEqual(lms_test._excl_ops, {swap_in.op})
        ctrl_dep.assert_called_once_with(src_op, earliest_op, swap_in.op)
Esempio n. 2
0
    def test_filter_scopes_and_types(self, get_name_scope_ops):
        op1 = mock.Mock(type='a')
        op2 = mock.Mock(type='b')
        op3 = mock.Mock(type='a')
        op4 = mock.Mock(type='d')
        op5 = mock.Mock(type='c')

        within_ops = {op1, op2, op3, op4, op5}
        lms_test = lms.LMS({'s1'})
        get_name_scope_ops.return_value = [op4]
        ret = lms_test._filter_scopes_and_types(within_ops, {'s1', 's2'},
                                                {'a', 'c'})
        get_name_scope_ops.assert_has_calls([mock.call(within_ops, 's1'),
                                             mock.call(within_ops, 's2')],
                                            any_order=True)
        assertCountEqual(self, ret, {op1, op3, op4, op5})

        # Test no ops found for scope (test more than 1 scope)
        get_name_scope_ops.reset_mock()
        get_name_scope_ops.return_value = None
        get_name_scope_ops.side_effect = [[op4], []]
        self.assertRaisesRegex(ValueError,
                               'No operations were found with scope',
                               lms_test._filter_scopes_and_types,
                               within_ops, {'s1', 's2'}, {})
        # Test no ops found for type (test more than 1 type)
        self.assertRaisesRegex(ValueError,
                               'No operations were found with types: ',
                               lms_test._filter_scopes_and_types,
                               within_ops, {}, {'a', 'c', 'z'})
        # Test atomic types. The method should not throw errors
        # if an atomic type operation is not found.
        input_types = {'a', 'c'} | lms.lms.ATOMIC_TYPES
        ret = lms_test._filter_scopes_and_types(within_ops, {}, input_types)
        assertCountEqual(self, ret, {op1, op3, op5})
Esempio n. 3
0
    def test_get_forward_walk_ops(self, fwd_walk):
        lms_test = lms.LMS({'s1'})
        ops_dict = {mock.sentinel.op1: [mock.sentinel.op1, mock.sentinel.op10],
                    mock.sentinel.op2: [mock.sentinel.op2, mock.sentinel.op20],
                    mock.sentinel.op3: [mock.sentinel.op3, mock.sentinel.op30]}
        lms_test._ops_dict = ops_dict

        # Test op in ops_dict
        ret = lms_test._get_forward_walk_ops(mock.sentinel.op1)
        self.assertEqual(ret, [mock.sentinel.op1, mock.sentinel.op10])
        # Test inclusive=False
        ret = lms_test._get_forward_walk_ops(mock.sentinel.op1,
                                             inclusive=False)
        self.assertEqual(ret, [mock.sentinel.op10])

        # Test op not in ops_dict
        fwd_walk.return_value = [mock.sentinel.op, mock.sentinel.ret]
        ret = lms_test._get_forward_walk_ops(mock.sentinel.op)
        self.assertEqual(ret, [mock.sentinel.op, mock.sentinel.ret])
        # Test inclusive=False
        lms_test._ops_dict = ops_dict
        fwd_walk.return_value = [mock.sentinel.op, mock.sentinel.ret]
        ret = lms_test._get_forward_walk_ops(mock.sentinel.op,
                                             inclusive=False)
        self.assertEqual(ret, [mock.sentinel.ret])
Esempio n. 4
0
    def test_connect_ops(self, sgv, connect):
        graph = mock.Mock()
        lms_test = lms.LMS({'s1'}, graph=graph)
        src_op = mock.Mock()
        dest_op = mock.Mock()
        dcf = mock.Mock()
        sgv.side_effect = ['a', 'b']
        lms_test._connect_ops(src_op, dest_op, remap_inputs=False,
                              remap_outputs=False, disconnect_first=dcf)
        sgv_calls = [mock.call(src_op, graph=graph),
                     mock.call(dest_op, graph=graph)]
        sgv.assert_has_calls(sgv_calls)
        connect.assert_called_once_with('a', 'b', dcf)

        # Test remaps
        connect.reset_mock()
        sgv.reset_mock()
        src_sgv = mock.Mock()
        dest_sgv = mock.Mock()
        sgv.side_effect = [src_sgv, dest_sgv]
        lms_test._connect_ops(src_op, dest_op, remap_inputs=True,
                              remap_outputs=True, idx=2, disconnect_first=dcf)
        src_sgv.remap_outputs.assert_called_once_with([2])
        dest_sgv.remap_inputs.assert_called_once_with([2])
        connect.assert_called_once_with(src_sgv.remap_outputs.return_value,
                                        dest_sgv.remap_inputs.return_value,
                                        dcf)
Esempio n. 5
0
    def test_do_action(self, get_cons, swap):
        lms_test = lms.LMS({'s1'})
        src_ops = [mock.Mock() for x in range(2)]
        grad_op1 = mock.Mock()
        lms_test._grad_ops = {grad_op1}
        for op in src_ops:
            op.outputs = ['o' for x in range(2)]
        dup_op = mock.Mock()
        get_cons_side_effect = [[mock.Mock() for x in range(3)] + [dup_op],
                                [mock.Mock() for x in range(2)] + [dup_op],
                                [mock.Mock() for x in range(2)],
                                [mock.Mock() for x in range(3)]]
        # Set outputs of first level consuming ops
        for op_list in get_cons_side_effect:
            for op in op_list:
                op.outputs = ['a']
        for x in range(11):
            get_cons_side_effect.append([grad_op1])

        get_cons.side_effect = get_cons_side_effect
        # Now add gradient ops for the remaining calls to get_consuming_ops
        # Set all consuming ops to have a gradient as
        lms_test._do_action(src_ops)
        # There should be 13 calls to _insert_swap_nodes.  The original 2
        # nodes, the 10 from the side_effect ranges above and ONE call on the
        # dup_op
        self.assertEqual(swap.call_count, 13)

        self.assertEqual(get_cons.call_count, 15)

        # Test when NOT swapping all possible tensors
        swap.reset_mock()
        get_cons.reset_mock()
        lms_test = lms.LMS(optimizer_scopes={'s1'}, n_tensors=7)

        def fake_swap(op):
            lms_test._incpu_count += len(op.outputs)

        swap.side_effect = fake_swap
        lms_test._grad_ops = {grad_op1}
        get_cons.side_effect = get_cons_side_effect
        lms_test._do_action(src_ops)

        # There should only be 5 calls to insert swap nodes.  The first
        # seed ops will swap 2 tensors each, and 3 other ops will each swap 1
        self.assertEqual(swap.call_count, 5)
Esempio n. 6
0
    def test_run(self, grad, seed, fwd_walk, tf_fwd_walk, filter, build, action):
        # Test mainline through
        seed_ops = [mock.Mock() for x in range(5)]
        grad_ops = [mock.MagicMock() for x in range(6)]
        fwd_walk.return_value = [mock.MagicMock() for x in range(3)] + grad_ops
        seed.return_value = seed_ops
        lms_test = lms.LMS({'s1'}, graph=mock.Mock())

        def fake_build_gradient_ops():
            lms_test._grad_ops = set(grad_ops)

        grad.side_effect = fake_build_gradient_ops
        lms_test.topos = mock.Mock()
        lms_test.run()
        self.assertTrue(grad.called)
        self.assertTrue(seed.called)
        self.assertTrue(fwd_walk.call_count, len(seed_ops))
        reachable = set(fwd_walk.return_value) - set(grad_ops)
        filter.assert_has_calls([mock.call(reachable, mock.ANY, mock.ANY),
                                 mock.call(reachable, mock.ANY, mock.ANY)])
        self.assertTrue(build.called)
        action.assert_called_once_with(seed_ops)

        # Test passing a graph in run and verify it overwrites a graph passed
        # on the constructor
        new_graph = mock.Mock()
        lms_test.run(graph=new_graph)
        self.assertEqual(lms_test._graph, new_graph)

        # Test no graph
        grad.reset_mock()
        lms_test = lms.LMS({'s1'})
        self.assertRaises(ValueError, lms_test.run)

        # Test n_tensors = 0
        lms_test = lms.LMS({'s1'}, n_tensors=0)
        lms_test.run(new_graph)
        self.assertFalse(grad.called)

        # Test n_tensors = -1
        action.reset_mock()
        lms_test = lms.LMS({'s1'}, n_tensors=-1)
        lms_test.run(new_graph)
        self.assertTrue(action.called)
        # n_tensors gets set to when passed as -1
        self.assertEqual(lms_test._n_tensors, 0)
Esempio n. 7
0
 def test_get_branch_ops(self):
     lms_test = lms.LMS({'s1'})
     lms_test._topo_sort = mock.Mock()
     lms_test._topo_sort.size = 6
     threshold = 3
     within_ops = {'f1', 'f2', 'f3', 'f4', 'f5', 'f6'}
     common_orders = {'f1': 1, 'f2': 2, 'f3': 3, 'f4': 4, 'f5': 5, 'f6': 6}
     lms_test._topo_sort.get_order = lambda x: common_orders.get(x, -1)
     ret = lms_test._get_branch_ops(within_ops, threshold)
     self.assertEqual(ret, {'f5', 'f6'})
Esempio n. 8
0
    def test_build_gradient_ops(self, filter_ops, make_list):
        graph = mock.Mock()
        filter_ops.side_effect = [['a', 'b', 'c'], ['d']]
        lms_test = lms.LMS({'s1', 's2'}, graph=graph)
        lms_test._build_gradient_ops()
        self.assertEqual(lms_test._grad_ops, {'a', 'b', 'c', 'd'})
        self.assertEqual(2, make_list.call_count)
        self.assertEqual(2, filter_ops.call_count)

        # Test ops found for 's1' but not 's2'
        filter_ops.reset_mock()
        make_list.reset_mock()

        def filter_fake(x, y):
            if 's1' in y:
                return ['a', 'b', 'c']
            return set()
        filter_ops.side_effect = filter_fake
        lms_test = lms.LMS({'s1', 's2'}, graph=graph)
        self.assertRaisesRegex(ValueError, 'optimizer scope s2',
                               lms_test._build_gradient_ops)
Esempio n. 9
0
    def test_find_new_src_op(self, get_cons):
        ts = mock.Mock()
        lms_test = lms.LMS({'s1'})
        lms_test._topo_sort = ts
        ts.bw_starting_order = 5
        ts.get_order.side_effect = lambda x: x.order
        get_cons.side_effect = lambda x: x.consumers

        op_w_order = mock.Mock(name='op_w_order', order=45, outputs=[])
        frontier1 = mock.Mock(name='fwd1', order=-1, outputs=[])
        frontier2_out = mock.Mock(name='fwd2_out', consumers=[op_w_order])
        frontier2 = mock.Mock(name='fwd2', order=-1, outputs=[frontier2_out])
        orig_src_out = mock.Mock(name='orig_out', consumers=[frontier1,
                                                             frontier2])
        original_op = mock.Mock(name='original_op', outputs=[orig_src_out])
        new_src_ops = lms_test._find_new_src_op(original_op)
        self.assertEqual(new_src_ops, {frontier2})
Esempio n. 10
0
 def test_do_direct_order(self, get_fwd_walk):
     lms_test = lms.LMS({'s1'})
     lms_test._topo_sort = mock.Mock()
     # Mock get_order to return the "order" value from the mock op
     lms_test._topo_sort.get_order.side_effect = lambda x: x.order
     expected_ret = mock.Mock(name='expected_op')
     expected_ret.name = "expected_op_name"
     get_ops = mock.Mock(name='get_ops')
     gor = [[]]
     get_ops.side_effect = [[1, 2, 3], [4, 5], [6, expected_ret]]
     lms_test._topo_sort.get_ops = get_ops
     fw_op = mock.Mock(name='fwd_op', order=5)
     src_op = mock.Mock(name='src_op', order=50)
     get_fwd_walk.side_effect = lambda x: ({'a', src_op}
                                           if x is expected_ret
                                           else {'b', 'c'})
     ret = lms_test._do_direct_order(fw_op, src_op, 3, 100)
     self.assertEqual(ret, (expected_ret, 44))
Esempio n. 11
0
 def test_add_swapout(self, identity, sgv, connect_ops):
     graph = mock.Mock()
     lms_modifier = lms.LMS(graph=graph,
                            optimizer_scopes={'s1'})
     src_op = mock.Mock()
     ts0 = mock.Mock()
     swap_out = mock.Mock()
     sgv_ret = mock.Mock()
     sgv_ret.output_index.return_value = 123
     sgv.return_value = sgv_ret
     identity.return_value = swap_out
     ret = lms_modifier._add_swapout(src_op, ts0)
     identity.assert_called_once_with(ts0, name='lms/swapout')
     sgv.assert_called_once_with(src_op, graph=graph)
     connect_ops.assert_called_once_with(src_op, swap_out.op,
                                         remap_outputs=True, idx=123)
     self.assertEqual(ret, swap_out.op)
     self.assertEqual(lms_modifier._excl_ops, {swap_out.op})
Esempio n. 12
0
    def test_do_chain_rule(self, direct_order, cons_ops):
        lms_test = lms.LMS({'s1'})
        lms_test._topo_sort = mock.Mock()

        # Test calling _do_direct_order when the bw_op is close to the
        # boundary
        lms_test._topo_sort.get_order.side_effect = [1, 5]
        lms_test._topo_sort.bw_starting_order = 10
        fwd_op = mock.Mock(name='fwdop')
        bw_op = mock.Mock(name='bwop')
        lms_test._do_chain_rule(fwd_op, bw_op, 4, 10)
        direct_order.assert_called_once_with(fwd_op, bw_op, 4, 10)

        # Test going through one layer
        lms_test._topo_sort = mock.Mock()
        # Mock get_order to return the "order" value from the mock op
        lms_test._topo_sort.get_order = lambda x: x.order
        # Mock get_consuming_ops to return the mock outputs' consumers
        cons_ops.side_effect = lambda x: x.consumers
        grad_op = mock.Mock(name='grad', order=7)
        grad_op.name = 'gradop_name'
        lms_test._grad_ops = {grad_op}
        layer2 = [mock.Mock(name='l2a',
                            outputs=[mock.Mock(name='secondToLast',
                                               outputs=['a'],
                                               consumers=[grad_op])]),
                  mock.Mock(name='l2b', outputs=[])]
        layer2[0].name = 'l2a_name'
        layer2[1].name = 'l2b_name'
        layer2b = mock.Mock(outputs=[])
        layer1 = [mock.Mock(name='l1a', outputs=[mock.Mock(consumers=layer2)]),
                  mock.Mock(name='l1b',
                            outputs=[mock.Mock(consumers=[layer2b])])]
        fwd_op = mock.Mock(name='fwdop', order=5,
                           outputs=[mock.Mock(name='fwdop_out1',
                                              consumers=layer1),
                                    mock.Mock(name='fwdop_out2',
                                              consumers=[], outputs=[])])
        bw_op = mock.Mock(name='bwop', order=50)
        lms_test._topo_sort.bw_starting_order = 1
        ret = lms_test._do_chain_rule(fwd_op, bw_op, 1, 10)
        self.assertEqual(ret, (grad_op, 7))
Esempio n. 13
0
 def test_add_swapin(self, identity, sgv, connect_ops):
     graph = mock.Mock()
     lms_modifier = lms.LMS({'s1'}, graph=graph)
     lms_modifier._topo_sort = mock.Mock()
     dest_op = mock.Mock()
     swapout_op = mock.Mock()
     ts0 = mock.Mock()
     swapin = mock.Mock()
     sgv_ret = mock.Mock()
     sgv_ret.input_index.return_value = 123
     sgv.return_value = sgv_ret
     identity.return_value = swapin
     ret = lms_modifier._add_swapin(swapout_op, dest_op, ts0)
     identity.assert_called_once_with(ts0, name='lms/swapin')
     sgv.assert_called_once_with(dest_op, graph=graph)
     connect_calls = [mock.call(swapout_op, swapin.op),
                      mock.call(swapin.op, dest_op,
                                remap_inputs=True, idx=123)]
     connect_ops.assert_has_calls(connect_calls)
     self.assertEqual(ret, swapin.op)
     self.assertEqual(lms_modifier._excl_ops, {swapin.op})
Esempio n. 14
0
    def test_add_control_dependency(self, do_chain, do_direct, add_ctrl_input):
        # Test when lb is reset and chain rule
        lms_test = lms.LMS({'s1'}, ctrld_strategy="chain_rule", lb=10, ub=20)

        lms_test._topo_sort = mock.Mock()
        lms_test._topo_sort.get_order.side_effect = [24, 15]
        fw_op = mock.sentinel.fw_op
        bw_op = mock.sentinel.orig_bw_op
        swapin_op = mock.Mock()
        do_chain.return_value = [mock.sentinel.ctl_op, 123]
        lms_test._add_control_dependency(fw_op, bw_op, swapin_op)
        do_chain.assert_called_once_with(fw_op, bw_op, 1, 20)
        add_ctrl_input.assert_called_once_with(swapin_op,
                                               mock.sentinel.ctl_op)

        # Test chain rule
        lms_test._topo_sort.get_order.side_effect = [26, 15]
        do_chain.reset_mock()
        add_ctrl_input.reset_mock()
        lms_test._add_control_dependency(fw_op, bw_op, swapin_op)
        do_chain.assert_called_once_with(fw_op, bw_op,
                                         10, 20)
        add_ctrl_input.assert_called_once_with(swapin_op,
                                               mock.sentinel.ctl_op)

        # Test with direct_order
        do_chain.reset_mock()
        add_ctrl_input.reset_mock()
        lms_test = lms.LMS({'s1'}, ctrld_strategy="direct_order", lb=10, ub=20)

        lms_test._topo_sort = mock.Mock()
        lms_test._topo_sort.get_order.side_effect = [26, 15]
        do_chain.return_value = None
        do_direct.return_value = [mock.sentinel.ctl_op, 567]
        lms_test._add_control_dependency(fw_op, bw_op, swapin_op)
        add_ctrl_input.assert_called_once_with(swapin_op,
                                               mock.sentinel.ctl_op)
        do_direct.assert_called_once_with(fw_op, mock.sentinel.orig_bw_op,
                                          10, 20)

        # Test direct order when fw_op is a gradient op
        do_direct.reset_mock()
        do_chain.reset_mock()
        add_ctrl_input.reset_mock()
        lms_test = lms.LMS({'s1'}, ctrld_strategy="chain_rule", lb=10, ub=20)
        lms_test._grad_ops = {fw_op}
        lms_test._topo_sort = mock.Mock()
        lms_test._topo_sort.get_order.side_effect = [26, 15]
        do_chain.return_value = None
        do_direct.return_value = [mock.sentinel.ctl_op, 567]
        lms_test._add_control_dependency(fw_op, bw_op, swapin_op)
        add_ctrl_input.assert_called_once_with(swapin_op,
                                               mock.sentinel.ctl_op)
        # Note we expect _do_direct_order to be called even though the
        # control dependency strategy was set to "chain_rule" because fw_op is
        # a gradient op.
        do_direct.assert_called_once_with(fw_op, mock.sentinel.orig_bw_op,
                                          10, 20)
        self.assertEqual(do_chain.call_count, 0)

        # Test direct order when fw_op is a gradient op
        do_direct.reset_mock()
        do_chain.reset_mock()
        add_ctrl_input.reset_mock()
        lms_test = lms.LMS({'s1'}, ctrld_strategy="chain_rule", lb=10, ub=20)
        lms_test._grad_ops = {fw_op}
        lms_test._topo_sort = mock.Mock()
        lms_test._topo_sort.get_order.side_effect = [26, 15]
        do_chain.return_value = None
        do_direct.return_value = [mock.sentinel.ctl_op, 567]
        lms_test._add_control_dependency(fw_op, bw_op, swapin_op)
        add_ctrl_input.assert_called_once_with(swapin_op,
                                               mock.sentinel.ctl_op)
        # Note we expect _do_direct_order to be called even though the
        # control dependency strategy was set to "chain_rule" because fw_op is
        # a gradient op.
        do_direct.assert_called_once_with(fw_op, mock.sentinel.orig_bw_op,
                                          10, 20)
        self.assertEqual(do_chain.call_count, 0)
Esempio n. 15
0
    def test_get_seed_ops(self, get_consuming, get_fwd_walk, make_list,
                          filter_ops):
        graph = mock.Mock()
        filter_ops.return_value = [mock.Mock()]

        filter_ops.return_value = [mock.Mock()]
        # Test with starting scope
        lms_test = lms.LMS({'s1'}, graph=graph, starting_scope='sc')
        ret = lms_test._get_seed_ops()
        filter_ops.assert_called_once_with(make_list.return_value, "^sc")
        make_list.assert_called_once_with(graph)
        self.assertEqual(ret, filter_ops.return_value)

        # Test with starting op names
        filter_ops.reset_mock()
        make_list.reset_mock()
        name_ops = [mock.Mock(name='op1'), mock.Mock(name='op2')]
        filter_ops.side_effect = [[name_ops[0]], [name_ops[1]]]
        lms_test = lms.LMS({'s1'}, graph=graph, starting_op_names={'a', 'b'})
        ret = lms_test._get_seed_ops()
        filter_ops.assert_has_calls([mock.call(make_list.return_value, "^a$"),
                                     mock.call(make_list.return_value, "^b$")],
                                    any_order=True)
        make_list.assert_called_once_with(graph)
        assertCountEqual(self, ret, name_ops)

        # Test building seed ops with graph traversal
        # setup graph operations and their outputs
        graph_ops = [mock.Mock() for x in range(6)]
        for op in graph_ops:
            op.outputs = ['mock_output']
        graph.get_operations.return_value = graph_ops

        # Test when both starting scope and starting op names are passed
        filter_ops.reset_mock()
        make_list.reset_mock()
        name_ops = [mock.Mock(name='op1'), mock.Mock(name='op2'),
                    mock.Mock(name='op3')]
        filter_ops.side_effect = [[name_ops[0]], [name_ops[1]], [name_ops[2]]]
        lms_test = lms.LMS({'s1'}, graph=graph, starting_op_names={'a', 'b'},
                           starting_scope='sc')
        ret = lms_test._get_seed_ops()
        filter_ops.assert_has_calls([mock.call(make_list.return_value, "^a$"),
                                     mock.call(make_list.return_value, "^b$"),
                                     mock.call(make_list.return_value, "^sc")],
                                    any_order=True)
        make_list.assert_called_once_with(graph)
        assertCountEqual(self, ret, name_ops)

        # Test no ops for scope
        filter_ops.side_effect = [[]]
        lms_test = lms.LMS({'s1'}, graph=graph, starting_scope='sc')
        self.assertRaisesRegex(ValueError, 'starting scope sc',
                               lms_test._get_seed_ops)

        # Test no ops for names
        # make filter_ops return an op for op 'b' but not op 'a'
        filter_ops.side_effect = lambda x, y: {"^b$": ['abc']}.get(y, [])
        lms_test = lms.LMS({'s1'}, graph=graph, starting_op_names={'a', 'b'})
        self.assertRaisesRegex(ValueError, 'No starting operation was found '
                               'with name a', lms_test._get_seed_ops)

        # Consuming ops for the graph operations, make most of them have a
        # gradient op consuming.
        grad_ops = {mock.sentinel.gradient1, mock.sentinel.gradient2}
        non_grad_ops = graph_ops
        get_consuming.side_effect = [[mock.sentinel.gradient1],
                                     [mock.sentinel.gradient1],
                                     [mock.sentinel.gradient2],
                                     ['nothing'],
                                     [mock.sentinel.gradient2],
                                     ['nothing2']]
        get_fwd_walk.side_effect = [[],
                                    [graph_ops[0], graph_ops[2], graph_ops[4]],
                                    [graph_ops[0], graph_ops[1], graph_ops[4]],
                                    [graph_ops[1], graph_ops[2]]]
        lms_test = lms.LMS({'s1'}, graph=graph)
        lms_test._grad_ops = grad_ops
        ret = lms_test._get_seed_ops()
        self.assertEqual(get_consuming.call_count, 6)
        walk_calls = [mock.call(graph_ops[0], within_ops=non_grad_ops, inclusive=False),
                      mock.call(graph_ops[1], within_ops=non_grad_ops, inclusive=False),
                      mock.call(graph_ops[2], within_ops=non_grad_ops, inclusive=False),
                      mock.call(graph_ops[4], within_ops=non_grad_ops, inclusive=False)]
        get_fwd_walk.assert_has_calls(walk_calls, any_order=True)
        # There will be two seed ops. However, since sets are used in
        # _get_seed_ops and we are mocking get_forward_walk_ops, we can't
        # check for specific mock operations.
        self.assertEqual(len(ret), 2)
Esempio n. 16
0
    def test_insert_swap_nodes(self, consuming_ops, fwd_walk_ops, swapout,
                               swapin, ctrldep, fuse_swapins, find_new_src):
        graph = mock.Mock()
        lms_test = lms.LMS({'s1'}, graph=graph, fuse_swapins=False)
        # Test op is excluded.
        # _insert_swap_nodes should return before accessing methods on
        # src_op which would blow up the test
        lms_test._excl_ops = {'a'}
        lms_test._insert_swap_nodes('a')
        lms_test._excl_ops = {}

        # Test included ops but op is not included.
        # _insert_swap_nodes should return before accessing methods on
        # src_op which would blow up the test
        inc_op = mock.Mock()
        lms_test._incl_ops = {inc_op}
        lms_test._insert_swap_nodes('a')

        # Test included ops and op is included.
        # _insert_swap_nodes should start iterating on the src outputs
        inc_op = mock.Mock()
        inc_op.outputs = ['a']
        lms_test._incl_ops = {inc_op}
        lms_test._grad_ops = {'b', 'c'}
        consuming_ops.return_value = {'b'}
        fwd_walk_ops.return_value = {}
        lms_test._insert_swap_nodes(inc_op)
        consuming_ops.assert_called_once_with('a')
        fwd_walk_ops.assert_called_once_with('b', inclusive=False)
        lms_test._incl_ops = {}
        consuming_ops.reset_mock()
        fwd_walk_ops.reset_mock()

        # Test creating swap out and swap in nodes
        src_op = mock.Mock()
        src_op.outputs = ['a', 'z']
        lms_test._grad_ops = {'b', 'c', 'g1', 'g2'}
        consuming_ops.side_effect = [{'b', 'f2', 'g2'}, {'c', 'g1', 'f3'}]
        swapout.side_effect = ['swapout_op1', 'swapout_op2']

        swap_in_node_map = {'b': 'swapin_op1',
                            'g2': 'swapin_op2',
                            'c': 'swapin_op3',
                            'g1': 'swapin_op4'}

        def swap_in_fake(source_op, destination_op, tensor):
            return swap_in_node_map[destination_op]

        swapin.side_effect = swap_in_fake
        fwd_walk_ops.return_value = {'d'}
        lms_test._topo_sort = mock.Mock()
        lms_test._topo_sort.get_order.side_effect = lambda x: 0
        lms_test._insert_swap_nodes(src_op)
        consuming_ops.assert_has_calls([mock.call('a'), mock.call('z')])
        fwd_calls = [mock.call('b', inclusive=False),
                     mock.call('g2', inclusive=False)]
        fwd_walk_ops.assert_has_calls(fwd_calls, any_order=True)
        swapout_calls = [mock.call(src_op, 'a'),
                         mock.call(src_op, 'z')]
        swapout.assert_has_calls(swapout_calls)
        swapin_calls = [mock.call('swapout_op1', 'b', 'a'),
                        mock.call('swapout_op1', 'g2', 'a'),
                        mock.call('swapout_op2', 'c', 'z'),
                        mock.call('swapout_op2', 'g1', 'z')]
        swapin.assert_has_calls(swapin_calls, any_order=True)
        ctrldep_calls = [mock.call(src_op, 'b', 'swapin_op1'),
                         mock.call(src_op, 'g2', 'swapin_op2'),
                         mock.call(src_op, 'c', 'swapin_op3'),
                         mock.call(src_op, 'g1', 'swapin_op4')]
        ctrldep.assert_has_calls(ctrldep_calls, any_order=True)
        self.assertEqual(lms_test._incpu_count, 2)

        # Test calling _find_new_src_op
        lms_test = lms.LMS({'s1'}, graph=graph, fuse_swapins=False)
        src_op = mock.Mock()
        src_op.outputs = ['a', 'z']
        lms_test._grad_ops = {'b', 'c', 'g1', 'g2'}
        consuming_ops.side_effect = [{'b', 'f2', 'g2'}, {'c', 'g1', 'f3'}]
        swapout.side_effect = ['swapout_op1', 'swapout_op2']
        swapin.side_effect = ['swapin_op1', 'swapin_op4']
        fwd_walk_ops.return_value = {'d'}
        lms_test._topo_sort = mock.Mock()

        op_orders = {'b': 1, 'g2': -1, 'c': -1, 'g1': 1}
        lms_test._topo_sort.get_order.side_effect = (lambda x:
                                                     op_orders.get(x, 1))
        # New src ops for recursion
        find_new_src.side_effect = [{'z1'}, {'z2'}]
        # Put ops 'z1' and 'z2' into excluded ops so the recursive call to
        # _insert_swap_nodes stops early.
        lms_test._excl_ops = {'z1', 'z2'}
        lms_test._insert_swap_nodes(src_op)
        consuming_ops.assert_has_calls([mock.call('a'), mock.call('z')])
        fwd_calls = [mock.call('b', inclusive=False),
                     mock.call('g2', inclusive=False)]
        fwd_walk_ops.assert_has_calls(fwd_calls, any_order=True)
        swapout_calls = [mock.call(src_op, 'a'),
                         mock.call(src_op, 'z')]
        swapout.assert_has_calls(swapout_calls)
        swapin_calls = [mock.call('swapout_op1', 'b', 'a'),
                        mock.call('swapout_op2', 'g1', 'z')]
        swapin.assert_has_calls(swapin_calls, any_order=True)
        ctrldep_calls = [mock.call(src_op, 'b', 'swapin_op1'),
                         mock.call(src_op, 'g1', 'swapin_op4')]
        ctrldep.assert_has_calls(ctrldep_calls, any_order=True)
        find_new_src.assert_has_calls([mock.call('g2'),
                                       mock.call('c')], any_order=True)
        self.assertEqual(lms_test._incpu_count, 2)

        # Test calling fuse_swapins
        consuming_ops.reset_mock()
        swapout.reset_mock()
        swapin.reset_mock()
        fwd_walk_ops.reset_mock()
        graph = mock.Mock()
        lms_test = lms.LMS({'s1'}, graph=graph, fuse_swapins=True)
        lms_test._topo_sort = mock.Mock()
        lms_test._topo_sort.get_order.side_effect = lambda x: 0
        lms_test._grad_ops = {'b', 'c', 'g1', 'g2'}
        consuming_ops.side_effect = [{'b', 'f2', 'g2'}, {'c', 'g1', 'f3'}]
        swapout.side_effect = ['swapout_op1', 'swapout_op2']
        swapin.side_effect = ['swapin_op1', 'swapin_op2', 'swapin_op3',
                              'swapin_op4']
        fwd_walk_ops.return_value = {'d'}
        fuse_swapins.return_value = ['fuse1', 'fuse2']
        lms_test._insert_swap_nodes(src_op)
        fuse_calls = [mock.call(src_op, "swapout_op1", {'b', 'g2'}, 'a'),
                      mock.call(src_op, "swapout_op2", {'c', 'g1'}, 'z')]
        fuse_swapins.assert_has_calls(fuse_calls)

        # Test stop swapping out once max number of tensors to swap is hit
        consuming_ops.reset_mock()
        lms_test = lms.LMS({'s1'}, graph=graph, fuse_swapins=False)
        lms_test._n_tensors = 10
        lms_test._incpu_count = 10
        lms_test._insert_swap_nodes(src_op)
        self.assertFalse(consuming_ops.called)