def test_lift_up_through_eltwise(self): graph = build_graph(nodes, [*connect('placeholder1', '0:mul'), *connect('placeholder2', '1:mul'), *connect('mul', 'reverse_channels'), *connect('reverse_channels', 'result')]) self.set_graph_attrs(graph, ['placeholder1', 'placeholder2']) node = Node(graph, 'mul') reverse_channels = Node(graph, 'reverse_channels') ReverseChannelsPropagationUp.lift_up_through_eltwise(node, reverse_channels) self.check_graph_attrs(graph, ['placeholder1', 'placeholder2'])
def test_lift_up_through(self): graph = build_graph(nodes2, [ *connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'), *connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'), *connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'), *connect('reverse_channels', 'result') ]) self.set_graph_attrs(graph, ['placeholder']) node = Node(graph, 'pad') reverse_channels = Node(graph, 'reverse_channels') ReverseChannelsPropagationUp.lift_up_through(node, reverse_channels) self.check_graph_attrs(graph, ['placeholder'])
def test_lift_up_through_transpose_negative_axis(self): graph = build_graph(nodes3, [ *connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'), *connect('transpose', 'reverse_channels_down'), *connect('reverse_channels_down', 'result') ]) graph_ref = build_graph(nodes3, [ *connect('placeholder', 'reverse_channels_down'), *connect('transpose_order', '1:transpose'), *connect('reverse_channels_down', 'transpose'), *connect('transpose', 'result') ]) self.set_graph_attrs(graph, ['placeholder']) node = Node(graph, 'transpose') reverse_channels = Node(graph, 'reverse_channels_down') reverse_channels.axis = int64_array(-3) keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose( node, reverse_channels) self.assertTrue(keep_moving_up is True) self.assertTrue(len(new_reverses) == 1) self.check_graph_attrs(graph, ['placeholder']) (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) reverse_channels = Node(graph, 'reverse_channels_down') self.assertTrue(reverse_channels.axis == 3) self.assertTrue(type(reverse_channels.axis) == np.ndarray)
def test_lift_up_through_pad2(self): graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'), *connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'), *connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'), *connect('reverse_channels:0', '0:result'), *connect('reverse_channels:0', '0:result2')]) self.set_graph_attrs(graph, ['placeholder']) node = Node(graph, 'pad') reverse_channels = Node(graph, 'reverse_channels') keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels) self.assertTrue(keep_moving_up is True) self.assertTrue(len(new_reverses) == 1) self.check_graph_attrs(graph, ['placeholder'])