def deconvolution_infer(node: Node):
    dims = int64_array(node.in_node(0).shape)
    dilations = int64_array(node.dilations)
    strides = int64_array(node.strides)
    input_n = dims[0]
    kernel_shape = int64_array(node.kernel)
    kdims = np.where(dilations != 0, (kernel_shape - 1) * dilations + 1,
                     kernel_shape)
    oc = node.output

    if node.has_valid('auto_pad') and node.auto_pad in [
            'valid', 'same_upper', 'same_lower'
    ]:
        auto_pad = node.auto_pad
        if auto_pad == 'valid':
            od_temp = (dims[2::] - 1) * strides + kdims
        else:
            od_temp = dims[2::] * strides
    else:
        od_temp = strides * (dims[2::] -
                             1) + kdims - node.pads_begin - node.pads_end

    out_shape = [input_n, oc]
    for d in od_temp:
        out_shape.append(np.int64(d))

    node['output_shape'] = int64_array(out_shape)
    for n in node.out_nodes():
        node.out_node(n).shape = node['output_shape']
def backpropdata_infer(op: Node):
    op['new_input_shapes'] = list()
    for n in op.in_nodes():
        op.new_input_shapes.append(op.in_node(n).shape)
    assert len(op.new_input_shapes) == len(op.old_input_shapes)

    for i in range(len(op.new_input_shapes)):
        assert np.array_equal(op.new_input_shapes[i], op.old_input_shapes[i]), 'Something wrong happened while ' \
                                                    '{} shape infer!'.format(op.old_type)

    output_shape = op.ports[len(op.in_nodes())]
    # op.output_shape = output_shape
    op.out_node().shape = int64_array(output_shape)
    op.type = op.old_type
Example #3
0
 def const_shape_infer(node: Node):
     i = len(node.in_nodes())
     for num in node.out_nodes():
         node.out_node(num).shape = int64_array(node.ports[i])
         i += 1