def deconvolution_infer(node: Node): dims = int64_array(node.in_node(0).shape) dilations = int64_array(node.dilations) strides = int64_array(node.strides) input_n = dims[0] kernel_shape = int64_array(node.kernel) kdims = np.where(dilations != 0, (kernel_shape - 1) * dilations + 1, kernel_shape) oc = node.output if node.has_valid('auto_pad') and node.auto_pad in [ 'valid', 'same_upper', 'same_lower' ]: auto_pad = node.auto_pad if auto_pad == 'valid': od_temp = (dims[2::] - 1) * strides + kdims else: od_temp = dims[2::] * strides else: od_temp = strides * (dims[2::] - 1) + kdims - node.pads_begin - node.pads_end out_shape = [input_n, oc] for d in od_temp: out_shape.append(np.int64(d)) node['output_shape'] = int64_array(out_shape) for n in node.out_nodes(): node.out_node(n).shape = node['output_shape']
def backpropdata_infer(op: Node): op['new_input_shapes'] = list() for n in op.in_nodes(): op.new_input_shapes.append(op.in_node(n).shape) assert len(op.new_input_shapes) == len(op.old_input_shapes) for i in range(len(op.new_input_shapes)): assert np.array_equal(op.new_input_shapes[i], op.old_input_shapes[i]), 'Something wrong happened while ' \ '{} shape infer!'.format(op.old_type) output_shape = op.ports[len(op.in_nodes())] # op.output_shape = output_shape op.out_node().shape = int64_array(output_shape) op.type = op.old_type
def const_shape_infer(node: Node): i = len(node.in_nodes()) for num in node.out_nodes(): node.out_node(num).shape = int64_array(node.ports[i]) i += 1