Ejemplo n.º 1
0
    def test_interpolate4_using_scales_without_axes(self, pads_begin, pads_end, input_shape, output_shape, sizes,
                                                   scales):
        graph = build_graph(nodes_attrs=graph_node_attrs_without_axes,
                            edges=graph_edges_without_axes,
                            update_attributes={
                                'input_data': {'shape': input_shape},
                                'sizes': {'shape': int64_array(sizes).shape, 'value': int64_array(sizes)},
                                'sizes_data': {'shape': int64_array(sizes).shape, 'value': int64_array(sizes)},
                                'scales': {'shape': np.array(scales).shape, 'value': np.array(scales)},
                                'scales_data': {'shape': np.array(scales).shape, 'value': np.array(scales)},
                                'interpolate': {'pads_begin': int64_array(pads_begin),
                                                'pads_end': int64_array(pads_end),
                                                'shape_calculation_mode': 'scales'}
                            })

        node = Node(graph, 'interpolate')
        tested_class = Interpolate(graph=graph, attrs=node.attrs())
        tested_class.infer(node)

        msg = "Interpolate-4 infer failed for case: sizes={}, scales={}, pads_begin={}, pads_end={}," \
              " expected_shape={}, actual_shape={}"

        self.assertTrue(np.array_equal(graph.node['interpolate_data']['shape'], int64_array(output_shape)),
                        msg.format(sizes, scales, pads_begin, pads_end, output_shape,
                                   graph.node['interpolate_data']['shape']))
 def get_new_cell(multilayer_cell: Node, number: int):
     cell_class = Op.get_op_class_by_name(multilayer_cell.op)
     new_cell = lambda graph, attrs: cell_class(graph, attrs)
     attrs = multilayer_cell.attrs().copy()
     new_attrs = {
         'num_layers': 1,
         'multilayers': False,
         'name': multilayer_cell.name + '/LayerSplittedLSTM/{}'.format(number),
     }
     attrs.update(new_attrs)
     return new_cell(multilayer_cell.graph, attrs)
Ejemplo n.º 3
0
    def get_new_cell(bidirectional_cell: Node, direction: str):
        assert direction in ['forward', 'reverse']

        cell_class = Op.get_op_class_by_name(bidirectional_cell.op)
        new_cell = lambda graph, attrs: cell_class(graph, attrs)
        attrs = bidirectional_cell.attrs().copy()
        new_attrs = {
            'direction': direction,
            'name': bidirectional_cell.name + '/Split/' + direction,
        }
        attrs.update(new_attrs)
        return new_cell(bidirectional_cell.graph, attrs)
Ejemplo n.º 4
0
    def get_new_cell(bidirectional_cell: Node, direction: str):
        assert direction in ['forward', 'reverse']

        cell_class = Op.get_op_class_by_name(bidirectional_cell.op)
        new_cell = lambda graph, attrs: cell_class(graph, attrs)
        attrs = bidirectional_cell.attrs().copy()
        new_attrs = {
            'direction': direction,
            'name': bidirectional_cell.name + '/Split/' + direction,
        }
        attrs.update(new_attrs)
        # split bidirectional activations
        assert 'activations' in attrs
        if attrs['activations'] is not None and len(attrs['activations']) > 1:
            assert len(attrs['activations']) == 2, 'Bidirectional RNN should have 2 activations'
            activations = attrs['activations']
            attrs['activations'] = [activations[0 if direction == 'forward' else 1]]
        return new_cell(bidirectional_cell.graph, attrs)
Ejemplo n.º 5
0
    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
        input_node = node.in_node(0)
        port = graph.get_edge_data(input_node.id, node.id)[0]['out']
        input_reshape_node = Reshape(
            graph, {
                'name': '/Reshape/' + node.name,
                'axis': 1,
                'infer': Reshape.kaldi_infer
            }).create_node([(input_node, port)])

        convolution_node = Convolution(graph, node.attrs()).create_node(
            [input_reshape_node])

        output_reshape_node = Reshape(
            graph, {
                'name': node.name + '/Reshape/',
                'axis': 1,
                'infer': Reshape.kaldi_infer
            }).create_node([convolution_node])

        return [output_reshape_node.id]
Ejemplo n.º 6
0
    def test_value_propagation(self, a_shape, a_value, b_shape, b_value,
                               elem_type):
        graph = build_graph(nodes_attrs=graph_nodes_attrs,
                            edges=graph_edges,
                            update_attributes={
                                'A': {
                                    'shape': int64_array(a_shape),
                                    'value': a_value.astype(elem_type)
                                },
                                'A_data': {
                                    'shape': int64_array(a_shape),
                                    'value': a_value.astype(elem_type)
                                },
                                'B': {
                                    'shape': int64_array(b_shape),
                                    'value': b_value.astype(elem_type)
                                },
                                'B_data': {
                                    'shape': int64_array(b_shape),
                                    'value': b_value.astype(elem_type)
                                },
                            })
        node = Node(graph, 'div')
        node['infer'] = Div(graph, node.attrs()).create_node().infer
        node.infer(node)
        node_data = node.out_port(0).get_destination().data.get_value()

        def func_for_ref():
            if np.issubdtype(elem_type, np.integer):
                return lambda a, b: a // b
            else:
                return lambda a, b: a / b

        ref_data = func_for_ref()(a_value, b_value)
        node_data_shape = node_data.shape
        ref_data_shape = ref_data.shape
        msg = "Value propagation for 'div' node is not correct."
        self.assertTrue(
            node_data_shape == ref_data_shape
            and np.all(node_data == ref_data), msg)
def is_output_data_in_correct_layout(node: Node, port_ind: int):
    assert node.soft_get('kind') == 'op', 'The function work with operation nodes only'
    return 'correct_out_data_layout' in node.attrs() and port_ind in node.attrs()['correct_out_data_layout']