def backpropdata_infer(op: Node):
    op['new_input_shapes'] = list()
    for n in op.in_nodes():
        op.new_input_shapes.append(op.in_node(n).shape)
    assert len(op.new_input_shapes) == len(op.old_input_shapes)

    for i in range(len(op.new_input_shapes)):
        assert np.array_equal(op.new_input_shapes[i], op.old_input_shapes[i]), 'Something wrong happened while ' \
                                                    '{} shape infer!'.format(op.old_type)

    output_shape = op.ports[len(op.in_nodes())]
    # op.output_shape = output_shape
    op.out_node().shape = int64_array(output_shape)
    op.type = op.old_type
def backpropdata_infer(op: Node):
    op['new_input_shapes'] = list()
    for n in op.in_nodes():
        op.new_input_shapes.append(op.in_node(n).shape)
    assert len(op.new_input_shapes) == len(op.old_input_shapes)

    for i in range(len(op.new_input_shapes)):
        assert np.array_equal(op.new_input_shapes[i], op.old_input_shapes[i]), 'Something wrong happened while ' \
                                                    '{} shape infer with type {}!'.format(op.name, op.type)

    Extender.const_shape_infer(op)
Esempio n. 3
0
 def const_shape_infer(node: Node):
     i = len(node.in_nodes())
     for num in node.out_nodes():
         node.out_node(num).shape = int64_array(node.ports[i])
         i += 1