コード例 #1
0
    def test_lift_up_through_eltwise(self):
        graph = build_graph(nodes, [*connect('placeholder1', '0:mul'), *connect('placeholder2', '1:mul'),
                                    *connect('mul', 'reverse_channels'), *connect('reverse_channels', 'result')])
        self.set_graph_attrs(graph, ['placeholder1', 'placeholder2'])

        node = Node(graph, 'mul')
        reverse_channels = Node(graph, 'reverse_channels')

        ReverseChannelsPropagationUp.lift_up_through_eltwise(node, reverse_channels)
        self.check_graph_attrs(graph, ['placeholder1', 'placeholder2'])
コード例 #2
0
    def test_lift_up_through(self):
        graph = build_graph(nodes2, [
            *connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
            *connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
            *connect('pad_const_2', '2:pad'),
            *connect('pad', 'reverse_channels'),
            *connect('reverse_channels', 'result')
        ])
        self.set_graph_attrs(graph, ['placeholder'])

        node = Node(graph, 'pad')
        reverse_channels = Node(graph, 'reverse_channels')

        ReverseChannelsPropagationUp.lift_up_through(node, reverse_channels)
        self.check_graph_attrs(graph, ['placeholder'])
コード例 #3
0
    def test_lift_up_through_transpose_negative_axis(self):
        graph = build_graph(nodes3, [
            *connect('placeholder', '0:transpose'),
            *connect('transpose_order', '1:transpose'),
            *connect('transpose', 'reverse_channels_down'),
            *connect('reverse_channels_down', 'result')
        ])
        graph_ref = build_graph(nodes3, [
            *connect('placeholder', 'reverse_channels_down'),
            *connect('transpose_order', '1:transpose'),
            *connect('reverse_channels_down', 'transpose'),
            *connect('transpose', 'result')
        ])
        self.set_graph_attrs(graph, ['placeholder'])

        node = Node(graph, 'transpose')
        reverse_channels = Node(graph, 'reverse_channels_down')
        reverse_channels.axis = int64_array(-3)

        keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_transpose(
            node, reverse_channels)
        self.assertTrue(keep_moving_up is True)
        self.assertTrue(len(new_reverses) == 1)
        self.check_graph_attrs(graph, ['placeholder'])
        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)

        reverse_channels = Node(graph, 'reverse_channels_down')
        self.assertTrue(reverse_channels.axis == 3)
        self.assertTrue(type(reverse_channels.axis) == np.ndarray)
コード例 #4
0
    def test_lift_up_through_pad2(self):
        graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
                                     *connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
                                     *connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'),
                                     *connect('reverse_channels:0', '0:result'),  *connect('reverse_channels:0', '0:result2')])
        self.set_graph_attrs(graph, ['placeholder'])

        node = Node(graph, 'pad')
        reverse_channels = Node(graph, 'reverse_channels')

        keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
        self.assertTrue(keep_moving_up is True)
        self.assertTrue(len(new_reverses) == 1)
        self.check_graph_attrs(graph, ['placeholder'])