def infer(node): """ https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch """ input_shape = node.in_node(0).shape if input_shape is None: return if len(node.in_nodes()) != 4: return block_size = node.in_port(1).data.get_value() pads_begin = node.in_port(2).data.get_value() pads_end = node.in_port(3).data.get_value() if block_size is None or pads_begin is None or pads_end is None: return pads = pads_begin + input_shape + pads_end node.out_node().shape = int64_array([ input_shape[0] * np.prod(block_size), *[int(x) for x in (pads[1:] / block_size[1:])] ]) # block_shape, pads_begin, pads_end should be permuted during the NHWC->NCHW layout change PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'shape') PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape') PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:0', 'shape')
def infer(node): """ https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch """ input_shape = node.in_port(0).data.get_shape() node_name = node.soft_get('name', node.id) assert len(node.in_nodes()) == 4, 'Some inputs are not connected for the operation SpaceToBatch with name {}' \ ''.format(node_name) block_size = node.in_port(1).data.get_value() pads_begin = node.in_port(2).data.get_value() pads_end = node.in_port(3).data.get_value() assert block_size is not None and pads_begin is not None and pads_end is not None,\ 'Some inputs are not defined for SpaceToBatch operation with name {}'.format(node_name) pads = pads_begin + input_shape + pads_end if is_fully_defined(block_size): block_elements_count = np.prod(block_size) else: block_elements_count = dynamic_dimension node.out_port(0).data.set_shape([ input_shape[0] * block_elements_count, *[x for x in (pads[1:] // block_size[1:])] ]) # block_shape, pads_begin, pads_end should be permuted during the NHWC->NCHW layout change PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'shape') PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape') PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:0', 'shape')
def infer(node): input_shape = node.in_node(0).shape if input_shape is None: return if len(node.in_nodes()) != 4: return block_size = node.in_port(1).data.get_value() crops_begin = node.in_port(2).data.get_value() crops_end = node.in_port(3).data.get_value() if block_size is None or crops_begin is None or crops_end is None: return pads = block_size * input_shape sizes = pads[1:] - crops_begin[1:] - crops_end[1:] batch = int(input_shape[0] / (np.prod(block_size))) node.out_node().shape = int64_array([batch, *sizes]) # block_shape, crops_begin, crops_end values should be permuted during the NHWC->NCHW layout change PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'shape') PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape') PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:0', 'shape')
def infer(node): pad_node_name = node.soft_get('name', node.id) assert len(node.in_nodes()) in [3, 4], "The node {} must have 3 or 4 inputs".format(pad_node_name) input_shape = node.in_port(0).data.get_shape() pad_beg = node.in_port(1).data.get_value() pad_end = node.in_port(2).data.get_value() assert pad_beg is not None, 'The padding begin value is None for node {}'.format(pad_node_name) assert pad_end is not None, 'The padding end value is None for node {}'.format(pad_node_name) assert input_shape is not None, 'The input shape is None for node {}'.format(pad_node_name) assert len(input_shape) == len(pad_beg), \ 'Length of begin padding "{}" does not correspond to input tensor shape "{}" for node "{}".' \ ''.format(pad_beg, input_shape, pad_node_name) assert len(input_shape) == len(pad_end), \ 'Length of end padding "{}" does not correspond to input tensor shape "{}" for node "{}".' \ ''.format(pad_beg, input_shape, pad_node_name) assert not node.is_in_port_connected(3) or node.in_port(3).data.get_shape().size == 0, \ 'Optional 3rd input of Pad operation should be scalar, but has shape {} for node {}' \ ''.format(node.in_port(3).data.get_shape(), pad_node_name) node.out_port(0).data.set_shape(input_shape + pad_beg + pad_end) if node.in_port(0).data.get_value() is not None: pads = np.insert(pad_end, np.arange(len(pad_end)), pad_beg) pads = np.reshape(pads, (len(pad_end), 2)) pad_val = 0 if len(node.in_nodes()) == 4: pad_val = node.in_port(3).data.get_value() if node.in_port(3).data is not None else 0 node.out_port(0).data.set_value(np.pad(node.in_port(0).data.get_value(), pads, constant_values=pad_val, mode='constant')) # pad values should be permuted during the NHWC->NCHW layout change PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'shape') PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape')
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(type='StridedSlice'): StridedSliceNormalizer.normalize_strided_slice(graph, node) PermuteAttrs.create_permute_attrs( node, attrs=[ ('begin_mask', 'input:0'), # but indeed depends from slice_rank ('end_mask', 'input:0'), ('new_axis_mask', 'input:0'), ('shrink_axis_mask', 'input:0'), ('ellipsis_mask', 'input:0') ]) # StridedSliceNormalizer inserted nodes that changed original begin, end, and strides data nodes # Until now it was not possible to set correct permutations PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'slice', 'dim_size') PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:2', 'slice', 'dim_size') if node.is_in_port_connected(3): PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size')
def infer(node): input_shape = node.in_node(0).shape if input_shape is None: return if len(node.in_nodes()) != 4: return block_size = node.in_port(1).data.get_value() crops_begin = node.in_port(2).data.get_value() crops_end = node.in_port(3).data.get_value() if block_size is None or crops_begin is None or crops_end is None: return pads = block_size * input_shape sizes = pads[1:] - crops_begin[1:] - crops_end[1:] if is_fully_defined(block_size): block_elements_count = np.prod(block_size) else: block_elements_count = dynamic_dimension batch = input_shape[0] // block_elements_count node.out_port(0).data.set_shape([batch, *sizes]) # block_shape, crops_begin, crops_end values should be permuted during the NHWC->NCHW layout change PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'shape') PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape') PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:0', 'shape')
def infer(node: Node): node_name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() input_value = node.in_port(0).data.get_value() target_shape = node.in_port(1).data.get_value() assert target_shape is not None, 'Output shape is not defined for node "{}"'.format( node_name) assert node.has_and_set( 'mode'), 'Broadcasting mode is not defined for node "{}"'.format( node_name) PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape') if input_value is not None and not node.has_and_set( 'stop_value_propagation'): if node.mode == 'numpy': node.out_port(0).data.set_value( uni_directional_broadcasting(input_value, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_value( bi_directional_broadcasting(input_value, target_shape)) elif node.mode == 'explicit': axes_mapping = node.in_port(2).data.get_value() assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \ 'is not supported. Node: `{}`'.format(node_name) PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis') axes_mapping = node.in_port(2).data.get_value() node.out_port(0).data.set_value( explicit_broadcasting(input_value, target_shape, axes_mapping)) else: raise Error('The node "{}" has unsupported mode "{}"'.format( node_name, node.mode)) else: if node.mode == 'numpy': node.out_port(0).data.set_shape( uni_directional_shape_broadcasting(input_shape, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_shape( bi_directional_shape_broadcasting(input_shape, target_shape)) elif node.mode == 'explicit': axes_mapping = node.in_port(2).data.get_value() assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \ 'is not supported. Node: `{}`'.format(node_name) PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis') axes_mapping = node.in_port(2).data.get_value() new_shape, _ = explicit_shape_broadcasting( input_shape, target_shape, axes_mapping) node.out_port(0).data.set_shape(new_shape) else: raise Error('The node "{}" has unsupported mode "{}"'.format( node_name, node.mode))
def infer(node: Node): input_shape = node.in_port(0).data.get_shape() if input_shape is None: return input_value = node.in_port(0).data.get_value() axes = node.in_port(1).data.get_value() if input_value is not None and axes is not None: norm_value = np.linalg.norm(input_value, node.p, axes, keepdims=True) if node.eps_mode == 'add': norm_value = norm_value + node.eps elif node.eps_mode == 'max': norm_value = np.max(norm_value, node.eps) else: assert False, 'Unsupported "eps_mode" = {}'.format( node.eps_mode) node.out_port(0).data.set_value(input_value / norm_value) else: node.out_port(0).data.set_shape(input_shape) PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis')
def reduce_infer(node: Node): connected_in_ports = [ port for port in node.in_ports().values() if not port.disconnected() ] assert len(connected_in_ports) == 2, \ "{} node `{}` should have 2 input ports, where 0-input is data input and 1-input represent " \ "`reduction_indices`".format(node.op, node.id) axis = int64_array([1]) used_dims = np.zeros(3, dtype=np.bool) output_shape = np.array([1, 6, 5]) for dim in axis: used_dims[dim] = True output_shape[dim] = 1 # In case if keep dims == False, we should remove all 1 dims that was used in reduction if not node.keep_dims: output_shape = output_shape[np.invert(used_dims)] node.out_port(0).data.set_shape(output_shape) PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis')
def infer(node: Node): node_name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() input_value = node.in_port(0).data.get_value() target_shape = node.in_port(1).data.get_value() assert target_shape is not None, 'Output shape is not defined for node "{}"'.format( node_name) assert node.has_and_set( 'mode'), 'Broadcasting mode is not defined for node "{}"'.format( node_name) if node.mode == 'numpy': node.out_port(0).data.set_shape( uni_directional_shape_broadcasting(input_shape, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_shape( bi_directional_shape_broadcasting(input_shape, target_shape)) else: raise Error('The node "{}" has unsupported mode "{}"'.format( node_name, node.mode)) PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape') if input_value is not None and not node.has_and_set( 'stop_value_propagation'): if node.mode == 'numpy': node.out_port(0).data.set_value( uni_directional_broadcasting(input_value, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_value( bi_directional_broadcasting(input_value, target_shape))
def infer(node: Node): name = node.soft_get('name', node.id) connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()} assert len(connected_in_ports) == 3 and 0 in connected_in_ports and 1 in connected_in_ports and \ 2 in connected_in_ports, "Gather should have 3 connected input port, but it doesn't for " \ "node: `{}`. Ports: {}".format(name, connected_in_ports) data_shape = node.in_port(0).data.get_shape() assert data_shape is not None indices_shape = node.in_port(1).data.get_shape() assert indices_shape is not None axis = node.in_port(2).data.get_value() assert axis is not None axis = get_canonical_axis_index(data_shape, axis) # we import PermuteInputs locally because it uses Gather inside and we have recursive imports from mo.graph.perm_inputs import PermuteInputs PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis') data_value = node.in_port(0).data.get_value() indices_value = node.in_port(1).data.get_value() if data_value is not None and indices_value is not None: node.out_port(0).data.set_value(np.array(np.take(data_value, int64_array(indices_value), axis), dtype=data_value.dtype)) return shape = np.concatenate((data_shape[:axis], indices_shape)) if axis < len(data_shape) - 1: shape = np.concatenate((shape, data_shape[axis + 1:])) node.out_port(0).data.set_shape(int64_array(shape))
def infer(node): # order parameter calculation and checks in_ports = node.in_ports() connected_ports = [ port for port in in_ports.values() if not port.disconnected() ] input_shape = node.in_port(0).data.get_shape() if node.has_and_set('reverse_order'): assert len(connected_ports) == 1 and 0 in in_ports, \ 'Cannot infer `{}` due to both order and reverse_order was set'.format(node.soft_get('name')) order = np.arange(len(input_shape))[::-1] # Reverse order else: # we import PermuteInputs locally because it uses Transpose inside and we have recursive imports from mo.graph.perm_inputs import PermuteInputs assert len(connected_ports) == 2 and 0 in in_ports and 1 in in_ports, \ "{} node `{}` should have 2 input ports, where 0-input is a data input and 1-input represents " \ "Transpose `order`".format(node.op, node.id) order = node.in_port(1).data.get_value() assert order is not None, 'Cannot infer `{}` because order is None'.format( node.soft_get('name')) PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'order') # setting shape and value if applicable if node.in_port(0).data.get_value() is not None: node.out_port(0).data.set_value( np.transpose(node.in_port(0).data.get_value(), axes=order)) else: node.out_port(0).data.set_shape(input_shape[order])
def infer(node): name = node.soft_get('name', node.id) op = node.soft_get('op', None) assert op is not None and op in ['Split', 'AttributedSplit'], \ 'Unexpected `op`={} attribute for Split-like node {}'.format(op, name) num_in_ports = 1 if op == 'AttributedSplit' else 2 if op == 'Split' else None assert num_in_ports in [1, 2], \ 'SplitBase supports AttributedSplit with 1 input and Split with 2 inputs, but it is {} for {} node {}' \ ''.format(num_in_ports, op, name) connected_inputs = { idx: port for idx, port in node.in_ports().items() if not port.disconnected() } assert len(connected_inputs) == num_in_ports and all([i in connected_inputs for i in range(num_in_ports)]), \ "{} should have {} connected input ports, but it doesn't for node: `{}`. Ports: {}" \ "".format(op, num_in_ports, name, connected_inputs) input_shape = node.in_port(0).data.get_shape() assert input_shape is not None, 'Input shape is unknown for node {}'.format( name) assert node.has_valid( 'num_splits' ), 'Parameter `num_splits` is unknown for node {}'.format(name) num_splits = node.num_splits axis = node.in_port(1).data.get_value( ) if op == 'Split' else node.soft_get('axis', None) assert axis is not None, '{} `axis` is unknown for node {}'.format( op, name) assert axis.ndim == 0, '{} `axis` should be scalar, but it`s not for node {}'.format( op, name) assert input_shape[axis] % num_splits == 0, \ 'Input shape is not evenly divided by `num_splits` of {} node {}. `input_shape`={}, `axis`={}, ' \ '`num_splits`={}'.format(op, name, input_shape, axis, num_splits) out_shape = input_shape.copy() out_shape[axis] = np.int64(input_shape[axis] / num_splits) input_value = node.in_port(0).data.get_value() output_value = np.split(input_value.copy(), axis=axis, indices_or_sections=num_splits) \ if input_value is not None else None for idx, port in node.out_ports().items(): if idx in node.out_nodes(): port.data.set_shape(out_shape) if output_value is not None: port.data.set_value(output_value[idx]) if op == 'Split': PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis') elif op == 'AttributedSplit': PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def infer(node: Node): name = node.soft_get('name', node.id) connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()} assert len(connected_in_ports) == 3 and 0 in connected_in_ports and 1 in connected_in_ports and \ 2 in connected_in_ports, "Gather should have 3 connected input port, but it doesn't for " \ "node: `{}`. Ports: {}".format(name, connected_in_ports) data_shape = node.in_port(0).data.get_shape() assert data_shape is not None indices_shape = node.in_port(1).data.get_shape() assert indices_shape is not None axis = node.in_port(2).data.get_value() assert axis is not None, 'axis input is undefined' assert -len(data_shape) <= axis < len(data_shape), \ 'axis must be within interval [-data_rank, data_rank). Instead got axis = {}, data_rank = {} '.\ format(axis, len(data_shape)) batch_dims = node.batch_dims assert -len(indices_shape) <= batch_dims <= len(indices_shape), \ 'batch_dims must be within interval [-indices_rank, indices_rank]. Instead got batch_dims = {}, ' \ 'indices_rank = {} '.format(batch_dims, len(indices_shape)) # normalize to positive values axis = axis + len(data_shape) if axis < 0 else axis batch_dims = batch_dims + len(indices_shape) if batch_dims < 0 else batch_dims assert np.ma.allequal(data_shape[:batch_dims], indices_shape[:batch_dims]), \ 'data and indices inputs must have equal first dimensions until batch_dims' assert batch_dims <= axis, \ 'normalized batch_dims must be <= axis. Instead got batch_dims = {}, axis = {}'.format(axis, batch_dims) # we import PermuteInputs locally because it uses Gather inside and we have recursive imports from mo.graph.perm_inputs import PermuteInputs PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'axis') batch_dims_range = indices_shape[:batch_dims] out_shape = np.concatenate((data_shape[:axis], indices_shape[batch_dims:], data_shape[axis + 1:])) data_value = node.in_port(0).data.get_value() indices_value = node.in_port(1).data.get_value() if data_value is not None and indices_value is not None and is_fully_defined(indices_value): if batch_dims == 0: node.out_port(0).data.set_value(np.ma.take(data_value, indices_value, axis)) else: out_value = np.empty(out_shape) for batch_idx in np.ndindex(tuple(batch_dims_range)): out_value[batch_idx] = np.ma.take(data_value[batch_idx], indices_value[batch_idx], axis - batch_dims) node.out_port(0).data.set_value(out_value) else: node.out_port(0).data.set_shape(out_shape)
def tf_reshape_shape_infer(node): # TODO Make sure that all -1 are handled correctly # We cannot simply copy shape argument to the output, # because if -1 appears, it should be substituted by a real # value from input shape if input shape is completely defined. if node.in_node(0).shape is None: return None assert len(node.in_nodes( )) == 2, 'The Reshape operation {} must have 2 inputs'.format(node.name) input_shape = node.in_port(0).data.get_shape() reshape_output = node.in_port(1).data.get_value() if node.in_node(0).shape is None: return None total = 1 for index, i in enumerate(input_shape): total *= i res = 1 for index, x in enumerate(reshape_output): if x == 0: res *= input_shape[index] elif x != -1: res *= x new_dim = total // res output_shape = [] for index, x in enumerate(reshape_output): if x == 0 and node.has_and_set('special_zero'): output_shape.append(input_shape[index]) elif x == -1: output_shape.append(new_dim) else: output_shape.append(x) out_shape_total = 1 for index, i in enumerate(output_shape): assert i != -1 out_shape_total *= i if total != out_shape_total: raise Error( "Number of elements in input {} and output {} of reshape node {} mismatch" "".format(input_shape, output_shape, node.name)) PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape') output_shape = int64_array(output_shape) return output_shape
def infer(node: Node): name = node.soft_get('name', node.id) connected_inputs = { idx: port for idx, port in node.in_ports().items() if not port.disconnected() } assert len(connected_inputs) == 2 and all([i in connected_inputs for i in range(2)]), \ "Reshape should have 2 connected input ports, but it doesn't for node: `{}`. Ports: {}" \ "".format(name, connected_inputs) input_shape = node.in_port(0).data.get_shape() assert input_shape is not None new_shape = node.in_port(1).data.get_value() assert new_shape is not None, 'Dynamic Reshape second input is not supported. Node {}'.format( name) assert np.argwhere(new_shape == -1).size <= 1, \ 'Reshape second input should not have several `-1` values set. ' \ 'Node: {}, reshape second input value {}'.format(name, new_shape) num_of_input_elements = np.prod(input_shape) num_of_output_elements = 1 for index, x in enumerate(new_shape): if x == 0 and node.has_and_set('special_zero'): num_of_output_elements *= input_shape[index] elif x != -1: num_of_output_elements *= x undefined_dim = num_of_input_elements // num_of_output_elements output_shape = [] for index, x in enumerate(new_shape): if x == 0 and node.has_and_set('special_zero'): output_shape.append(input_shape[index]) elif x == -1: output_shape.append(undefined_dim) else: output_shape.append(x) assert np.prod(input_shape) == np.prod(output_shape), \ "Number of elements in input {} and output {} of reshape node {} mismatch" \ "".format(input_shape, output_shape, name) PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape') if node.in_port(0).data.get_value() is not None: node.out_port(0).data.set_value( node.in_port(0).data.get_value().reshape(output_shape)) else: node.out_port(0).data.set_shape(output_shape)
def reduce_infer(node: Node): connected_in_ports = [ port for port in node.in_ports().values() if not port.disconnected() ] assert len(connected_in_ports) == 2, \ "{} node `{}` should have 2 input ports, where 0-input is data input and 1-input represent " \ "`reduction_indices`".format(node.op, node.id) in_data = node.in_port(0).data in_shape = in_data.get_shape() axis = node.in_port(1).data.get_value() # If the axis is None then reduce over all the dimensions of the input tensor if axis.size == 1 and axis.item() is None: axis = int64_array(list(range(len(in_shape)))) node.in_port(1).data.set_value(axis) assert in_shape is not None, "Can not infer {} node `{}`: shape of 0-input unknown".format( node.op, node.id) axis = axis.copy() if axis.size == 1: axis = int64_array([axis.item()]) in_value = in_data.get_value() if in_value is not None: value = reduce_helper(reduce_map[node.op], in_value.copy(), axis=tuple(axis), keepdims=node.keep_dims) node.out_port(0).data.set_value(value) else: used_dims = np.zeros(len(in_shape), dtype=np.bool) output_shape = in_shape.copy() for dim in axis: used_dims[dim] = True output_shape[dim] = 1 # In case if keep dims == False, we should remove all 1 dims that was used in reduction if not node.keep_dims: output_shape = output_shape[np.invert(used_dims)] node.out_port(0).data.set_shape(output_shape) # if the operation changes the rank of the output tensor then it is necessary to insert Permute if the input is 4D # or 5D if not node.keep_dims: node['reinterp_shape'] = True PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis')
def infer(node: Node): # TODO Add necessary checks and asserts b_value = node.in_port(0).data.get_value() b_shape = node.in_port(1).data.get_value() assert b_shape is not None node.out_port(0).data.set_shape(b_shape) PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape') if b_value is not None and not node.has_and_set('stop_value_propagation'): new_value = np.broadcast_to(b_value, b_shape) node.out_port(0).data.set_value(new_value)
def infer(node: Node): real_squeeze_dims = int64_array([]) input_shape = node.in_port(0).data.get_shape() node_name = node.soft_get('name', node.id) if input_shape is None: raise Error( 'Input shape is not defined for node {}'.format(node_name)) output_shape = input_shape.copy() assert len(node.in_nodes( )) == 2, 'The Squeeze node {} must have 2 inputs'.format(node_name) # TODO remove the following 'if' statement when IE start support 0D tensors squeeze_dims = node.in_port(1).data.get_value() if squeeze_dims.ndim == 0: squeeze_dims = squeeze_dims.reshape([1]) for dim in squeeze_dims: if output_shape[dim] == 1 or output_shape[dim] is dynamic_dimension: real_squeeze_dims = np.ma.append( real_squeeze_dims, get_canonical_axis_index(output_shape, dim)) else: raise Error( 'Trying to squeeze dimension not equal to 1 for node "{}"'. format(node_name)) # if squeeze_dims empty then all 1s should be removed (tf specification of Squeeze op) if squeeze_dims.size == 0: for i in range(output_shape.size): if output_shape[i] == 1: real_squeeze_dims = np.ma.append( real_squeeze_dims, get_canonical_axis_index(output_shape, i)) assert is_fully_defined( real_squeeze_dims ), 'Squeeze dimension(s) is not defined for op "{}"'.format(node_name) output_shape = shape_delete(output_shape, real_squeeze_dims) node.out_port(0).data.set_shape(output_shape) # make dimensions positive to correctly translate from NHWC to NCHW layout if node.in_port(1).get_source().node.op == 'Const': node.in_port(1).data.set_value(real_squeeze_dims) if node.in_port(0).data.get_value() is not None: node.out_port(0).data.set_value( node.in_port(0).data.get_value().reshape(output_shape)) # the squeeze_dim attribute will be converted to the second input in the end of the Middle phase PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis')
def infer(node: None): name = node.soft_get('name', node.id) assert node.eps is not None, 'MVN required attribute `eps` unspecified for node {}'.format( name) assert node.eps_mode is not None, 'MVN required attribute `eps_mode` unspecified for node {}'.format( name) assert node.normalize_variance is not None, \ 'MVN required attribute `normalize_variance` unspecified for node {}'.format(name) PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis') copy_shape_infer(node)
def infer(node: Node): assert node.has_valid('output_type') node.out_port(0).data.set_shape(node.in_port(0).data.get_value()) # We need to keep data type in data nodes corresponding to min and max values, # as min and max value type should be the same as output_type attribute of RandomUniform # operation. 'correct_data_type' attribute prevents changes of the data node type when # ir data type is not equal to data node type. node.in_node(1)['correct_data_type'] = True node.in_node(2)['correct_data_type'] = True PermuteInputs().set_input_permutation(node.in_node(0), node, 'output:0', 'shape')
def infer_for_opset4(node: Node): assert len([p for p in node.in_ports().values() if not p.disconnected()]) in [3, 4], \ "Interpolate-4 node {} must have 3 or 4 inputs".format(node.soft_get(node.name, node.id)) assert node.has_valid('mode') assert node.has_valid('shape_calculation_mode') src_shape = node.in_port(0).data.get_shape() assert src_shape is not None input_rank = len(src_shape) pads_begin = correct_pad(node.soft_get('pads_begin', [0]), input_rank) pads_end = correct_pad(node.soft_get('pads_end', [0]), input_rank) node['pads_begin'] = pads_begin node['pads_end'] = pads_end if len(node.in_ports()) == 3: axes = list(range(0, input_rank)) else: axes = node.in_port(3).get_source().data.get_value() assert axes is not None, \ "Interpolate-4 node with name {} has None as 'axes' input".format(node.soft_get('name', node.id)) axes = int64_array(axes) output_shape = src_shape + pads_begin + pads_end if node.shape_calculation_mode == 'sizes': dst_shape = node.in_port(1).data.get_value() assert dst_shape is not None correct_scales_using_dst_shape(node, dst_shape, src_shape, axes) for i, axis in enumerate(axes): output_shape[axis] = dst_shape[i] else: scales = node.in_port(2).data.get_value() assert scales is not None for i, axis in enumerate(axes): if output_shape[axis] is not dynamic_dimension and scales[ i] is not dynamic_dimension: output_shape[axis] = math.floor(scales[i] * output_shape[axis] + 1.0e-5) else: output_shape[axis] = dynamic_dimension_value if node.is_in_port_connected(3): PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:0', 'axis') node.out_port(0).data.set_shape(output_shape)
def infer(node: Node): real_squeeze_dims = int64_array([]) input_shape = node.in_node().shape if input_shape is None: return output_shape = input_shape.copy() assert len(node.in_nodes() ) == 2, 'The Squeeze node {} must have 2 inputs'.format( node.soft_get('name')) # TODO remove the following 'if' statement when IE start support 0D tensors squeeze_dims = node.in_port(1).data.get_value() if squeeze_dims.ndim == 0: squeeze_dims = squeeze_dims.reshape([1]) for dim in squeeze_dims: if output_shape[dim] == 1: real_squeeze_dims = np.append( real_squeeze_dims, get_canonical_axis_index(output_shape, dim)) else: raise Error( 'Trying to squeeze dimension not equal to 1 for node "{}"'. format(node.soft_get('name'))) output_shape = np.delete(output_shape, real_squeeze_dims) node.out_node().shape = output_shape # make dimensions positive to correctly translate from NHWC to NCHW layout if node.in_port(1).get_source().node.op == 'Const': node.in_port(1).data.set_value(real_squeeze_dims) if node.in_port(0).data.get_value() is not None: node.out_port(0).data.set_value( node.in_port(0).data.get_value().reshape(output_shape)) # the squeeze_dim attribute will be converted to the second input in the end of the Middle phase PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis')
def infer(node): if len(node.in_nodes()) <= 1: raise Error( 'There is no input with unsqueeze dims for the node {}'.format( node.soft_get('name'))) unsqueeze_dims = node.in_port(1).data.get_value() if unsqueeze_dims is None: raise Error( 'The dimensions to unsqueeze are not defined for the node {}'. format(node.soft_get('name'))) unsqueeze_dims = int64_array(unsqueeze_dims) input_value = node.in_port(0).data.get_value() input_shape = node.in_port(0).data.get_shape() # TODO remove the following line when the Inference Engine plugins support 0D tensors if unsqueeze_dims.ndim == 0: unsqueeze_dims = int64_array([unsqueeze_dims.item()]) # make dimensions positive to correctly translate from NHWC to NCHW layout unsqueeze_dims = int64_array([ dim + len(node.in_port(0).data.get_shape()) + 1 if dim < 0 else dim for dim in unsqueeze_dims ]) if node.in_port(1).get_source().node.op == 'Const': node.in_port(1).data.set_value(unsqueeze_dims) output_shape = input_shape.copy() for dim in unsqueeze_dims: output_shape = shape_insert(output_shape, dim, 1) if input_value is not None and is_fully_defined(output_shape): node.out_port(0).data.set_value(input_value.reshape(output_shape)) else: node.out_port(0).data.set_shape(output_shape) PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis')
def infer(node: Node): name = node.soft_get('name', node.id) connected_in_ports = { idx: port for idx, port in node.in_ports().items() if not port.disconnected() } assert len(connected_in_ports) == 2 and 0 in connected_in_ports and 1 in connected_in_ports, \ "Tile should have 2 connected input port, but it doesn't for node: `{}`. Ports: {}" \ "".format(name, connected_in_ports) shape = node.in_port(0).data.get_shape() assert shape is not None, "Undefined input shape for Tile node '{}'.".format( name) tile_array = node.in_port(1).data.get_value() assert tile_array is not None, "Undefined `repeats` (1st port input value) of Tile node '{}'".format( name) # align ranks of the tile_array tensor and input shape node if shape.size < tile_array.size: shape = shape_insert(shape, 0, [1] * (tile_array.size - shape.size)) elif shape.size > tile_array.size: tile_array = shape_insert(tile_array, 0, [1] * (shape.size - tile_array.size)) input_value = node.in_port(0).data.get_value() if input_value is not None and is_fully_defined( shape) and is_fully_defined(tile_array): node.out_port(0).data.set_value( np.tile(input_value.reshape(shape), tile_array)) else: node.out_port(0).data.set_shape(shape * tile_array) PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'shape')
def insert_permute_inputs_before_dynamic_weights_subgraph( dynamic_subgraphs: Set[Node] = None): """ The function inserts permutations on input nodes in the weights subgraph :param dynamic_subgraphs: Set of Nodes belonging to weight path subgraphs :return: the list of Nodes which are inputs to weight path subgraphs """ dynamic_in_nodes = dict() for node in dynamic_subgraphs: node_type = node.soft_get('type') if node_type not in ['Const', 'Parameter', 'ShapeOf']: idx_lst = list() for idx in [ idx for idx, port in node.in_ports().items() if not port.disconnected() and port.get_source().node not in dynamic_subgraphs ]: PermuteInputs().set_input_permutation( node.in_node(idx), node, 'input:{}'.format(idx), 'transpose_nchw_to_nhwc') idx_lst.append(idx) if len(idx_lst): dynamic_in_nodes[node] = idx_lst return dynamic_in_nodes
def infer(node: Node): """ Deconvolution has an input argument that explicitly determines output shape, so in contrast to the forward Conv2d we shouldn't infer output shape. We just use this output shape as an input shape and pass it to our utilities that computes numeric values for padding. They also deliver output shape that is interpreted here as input shape for convolution. We need to check that the real input shape and shape inferred by those utility functions match. """ output_shape = np.array(node.in_node(2).value) batch = np.array(node.in_node(0).shape)[0] output_shape[0] = batch kernel_shape = node.in_node(1).shape node['kernel_shape'] = kernel_shape if output_shape is None or kernel_shape is None or node.spatial_dims is None or node.stride is None: return if not node.has_valid('kernel_spatial_idx'): node['kernel_spatial_idx'] = np.delete( [x for x in range(len(kernel_shape))], (node.input_feature_channel, node.output_feature_channel)) if not node.has_valid('dilation'): node['dilation'] = np.full([len(output_shape)], 1, dtype=np.int64) spatial_dims = node.spatial_dims output_spatial = np.array(output_shape[spatial_dims]) stride_spatial = np.array(node.stride[spatial_dims]) node['kernel_spatial'] = np.array( kernel_shape[node.kernel_spatial_idx]) node.pad_spatial_shape, input_spatial_for_check = tf_window_op_pad_infer( output_spatial, node.kernel_spatial, stride_spatial, node.auto_pad) assert all( input_spatial_for_check == node.in_node(0).shape[spatial_dims]) pad = np.zeros((len(output_shape), 2), dtype=np.int64) pad[spatial_dims] = node.pad_spatial_shape node.pad = pad node.output = output_shape[node.channel_dims][0] node.output_shape = output_shape node.out_node().shape = output_shape mark_input_bins(node, ['weights'], 1) assign_dims_to_weights(node.in_node(1), node.kernel_spatial_idx, node.input_feature_channel, node.output_feature_channel, len(kernel_shape)) # OK, now we are sure this is a supported Deconvolution layer node.type = 'Deconvolution' node.op = 'Deconv2D' # Add permute_attrs PermuteAttrs.create_permute_attrs( node, attrs=[ ('pad', 'input:0'), ('stride', 'input:0'), ('output_shape', 'input:0'), ('batch_dims', 'input:0'), ('channel_dims', 'input:0'), ('spatial_dims', 'input:0'), ('kernel_shape', 'input:1'), ('kernel_spatial_idx', 'input:1'), ('input_feature_channel', 'input:1'), ('output_feature_channel', 'input:1'), ]) PermuteAttrs.set_permutation( node.in_node(1), node, node.get_weights_permute if node.has_valid('get_weights_permute') else None) PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'shape') node['force_precision_in_ports'] = {2: 'int64'}
def infer(node: Node): name = node.soft_get('name', node.id) connected_inputs = { idx: port for idx, port in node.in_ports().items() if not port.disconnected() } assert len(connected_inputs) == 2 and all([i in connected_inputs for i in range(2)]), \ "Reshape should have 2 connected input ports, but it doesn't for node: `{}`. Ports: {}" \ "".format(name, connected_inputs) input_shape = node.in_port(0).data.get_shape() assert input_shape is not None new_shape = node.in_port(1).data.get_value() assert new_shape is not None, 'Dynamic Reshape second input is not supported. Node {}'.format( name) assert np.argwhere(new_shape == -1).size <= 1, \ 'Reshape second input should not have several `-1` values set. ' \ 'Node: {}, reshape second input value {}'.format(name, new_shape) num_of_input_elements = np.prod(input_shape) num_of_output_elements = 1 for index, x in enumerate(new_shape): if x is dynamic_dimension: num_of_output_elements = dynamic_dimension_value elif x == 0 and node.has_and_set('special_zero'): if input_shape[index] is not dynamic_dimension: num_of_output_elements *= input_shape[index] elif x != -1: num_of_output_elements *= x # input_shape = [dynamic, 5, 6], new_shape = [0, -1] => output_shape [dynamic, 30] # marker that no dynamic input dimensions or all of them are copied with "0" magic value all_dynamic_dimension_are_copied = True if not is_fully_defined(input_shape): for index, x in enumerate(input_shape): if x is dynamic_dimension: if index >= len(new_shape) or new_shape[index] != 0: all_dynamic_dimension_are_copied = False undefined_dim = dynamic_dimension if num_of_output_elements is not dynamic_dimension and all_dynamic_dimension_are_copied and \ is_fully_defined(new_shape): undefined_dim = num_of_input_elements // num_of_output_elements output_shape = [] for index, x in enumerate(new_shape): if x == 0 and node.has_and_set('special_zero'): output_shape.append(input_shape[index]) elif x == -1: output_shape.append(undefined_dim) else: output_shape.append(x) # even if the new_shape contains some dynamic values we can calculate the actual value by deducing it from the # input shape if it is static: input_shape = [5, 3, 8], new_shape = [4, d] => output_shape = [4, 30] if is_fully_defined(input_shape) and not is_fully_defined(new_shape): dynamic_indices = np.argwhere( [item is dynamic_dimension for item in new_shape]) num_of_output_elements = 1 if dynamic_indices.size == 1: for index, x in enumerate(new_shape): if x == 0 and node.has_and_set('special_zero'): num_of_output_elements *= input_shape[index] elif x is not dynamic_dimension and x != -1: num_of_output_elements *= x assert num_of_input_elements % num_of_output_elements == 0, \ 'Incorrect number of output elements deduced for node {}: '.format(name) output_shape[dynamic_indices[0] [0]] = num_of_input_elements // num_of_output_elements assert not is_fully_defined(input_shape) or not is_fully_defined(output_shape) or \ np.prod(input_shape) == np.prod(output_shape), \ "Number of elements in input {} and output {} of reshape node {} mismatch" \ "".format(input_shape, output_shape, name) PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape') if node.in_port(0).data.get_value() is not None and is_fully_defined( output_shape): node.out_port(0).data.set_value( node.in_port(0).data.get_value().reshape(output_shape)) else: node.out_port(0).data.set_shape(output_shape)
def infer(node): name = node.soft_get('name', node.id) op = node.soft_get('op', None) assert op is not None and op in ['VariadicSplit', 'AttributedVariadicSplit'], \ 'Unexpected `op`={} attribute for Split-like node {}'.format(op, name) num_in_ports = 1 if op == 'AttributedVariadicSplit' else 3 if op == 'VariadicSplit' else None assert num_in_ports in [1, 3], \ 'VariadicSplitBase supports AttributedVariadicSplit with 1 input and VariadicSplit with 3 inputs, ' \ 'but it is {} for {} node {}'.format(num_in_ports, op, name) connected_inputs = { idx: port for idx, port in node.in_ports().items() if not port.disconnected() } assert len(connected_inputs) == num_in_ports and all([i in connected_inputs for i in range(num_in_ports)]), \ "{} should have {} connected input ports, but it doesn't for node: `{}`. Ports: {}" \ "".format(op, num_in_ports, name, connected_inputs) input_shape = node.in_port(0).data.get_shape() assert input_shape is not None axis = node.in_port(1).data.get_value( ) if op == 'VariadicSplit' else node.soft_get('axis', None) assert axis is not None, '{} `axis` is unknown for node {}'.format( op, name) assert axis.ndim == 0, '{} `axis` should be scalar, but it`s not for node {}'.format( op, name) split_lengths = node.in_port(2).data.get_value( ) if op == 'VariadicSplit' else node.soft_get('split_lengths', None) assert split_lengths is not None, '{} `split_lengths` is unknown for node {}'.format( op, name) undefined_elements = np.argwhere(split_lengths == -1).flatten() assert undefined_elements.size <= 1, \ '{} split_lengths=`{}` is a list with output sizes, only one of which could be -1. Node: {}' \ ''.format(op, split_lengths, name) input_elements = input_shape[axis] assert undefined_elements.size != 0 or input_elements == np.sum(split_lengths), \ 'The sum of split_lengths=`{}` must match data.shape[axis]=`{}`. Node: {}' \ ''.format(split_lengths, input_elements, name) assert len(split_lengths) >= len([port for i, port in node.out_ports().items() if not port.disconnected()]), \ 'Number of split_lengths=`{}` is less than connected output ports. Node: {}'.format(split_lengths, name) # in split_lengths some value can be 0, in this case we will ignore it: # * remove according branch # * remove 0 from split_lengths for i in reversed(range(len(split_lengths))): if split_lengths[i] == 0: if node.out_port(i).disconnected(): size_splits = list(split_lengths) split_lengths = np.delete(int64_array(split_lengths), i) if op == 'VariadicSplit': node.in_port(2).data.set_value(split_lengths) else: node['split_lengths'] = split_lengths delete_out_port(i, node) else: log.error( "Zero dimension on {} branch after Split node {}". format(i, node.id)) return # shape propagation idxs, curr_pos = [], 0 for i, piece in enumerate(split_lengths): assert piece >= -1, 'VariadicSplit split_lengths=`{}` should be non-negative'.format( split_lengths) out_shape = input_shape.copy() split_length = piece if piece > -1 else input_elements - ( np.sum(split_lengths) + 1) out_shape[axis] = split_length curr_pos = curr_pos + split_length idxs.append(curr_pos) if not node.out_port(i).disconnected(): node.out_port(i).data.set_shape(out_shape) # value propagation input_value = node.in_port(0).data.get_value() if input_value is not None: split = np.split(input_value, idxs[:-1], axis) for i, port in node.out_ports().items(): if not port.disconnected(): port.data.set_value(split[i]) if op == 'VariadicSplit': PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis') elif op == 'AttributedVariadicSplit': PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def infer(node: Node): """ Infers shape of convolution node as it is done in ONNX. It is very similar to one that Caffe does, but slightly different. We made a complete fork of this function because they are supposed to be supported differently by different people. Args: node: graph convolution node """ input_shape = node.in_node(0).shape if input_shape is None: return # bias_term cannot be deduced earlier for frameworks that represent # convolution weights/biases as regular inputs; so the number of inputs # is being checked here and restore correct value for bias_term to # have the rest of the code unchanged. It will be used after we merge # several infer functions for convolution in different FWs to a single one. if not node.has_valid('bias_term'): node['bias_term'] = len(node.in_nodes()) == 3 weights_index = node.weights_index if node.has_valid( 'weights_index') else 1 # Reshape weights kernel to original shape # In case of caffe or MXNet framework, values for weights have no structured shape like OIHW # so we have to reshape weights to normal shape # For this case, Convolution node should have attribute reshape_kernel = True if node.has_valid('reshape_kernel') and node.reshape_kernel: if not (node.has_valid('output') and node.has_valid('channel_dims') and node.has_valid('group') and node.has_valid('kernel_spatial')): log.error( 'Cannot reshape kernel due to not all required attrs was set to {} node' .format(node.id)) return # layout for Convolution weights is OIHW kernel_shape = int64_array([ node.output, input_shape[node.channel_dims].item() / node.group, *[ node.kernel_spatial[i] for i in range(len(node.kernel_spatial)) ] ]) if node.type == 'Deconvolution': # layout for Deconvolution weights is IOHW kernel_shape[[0, 1]] = kernel_shape[[1, 0]] #node.input_feature_channel, node.output_feature_channel = node.output_feature_channel, node.input_feature_channel if np.prod(kernel_shape) != np.prod( node.in_node(weights_index).value.shape): log.error( "Size of weights {} does not match kernel shape: {}\n". format(np.prod(node.in_node(weights_index).value.shape), kernel_shape) + " Possible reason is wrong channel number in input shape\n" ) raise Error("Cannot reshape weights to kernel shape") node.in_node(weights_index).shape = np.array(kernel_shape) node.in_node(weights_index).value = np.reshape( node.in_node(weights_index).value, kernel_shape) node.reshape_kernel = False # Pass weights shape to node attribute kernel_shape kernel_shape = node.in_node(weights_index).shape node['kernel_shape'] = kernel_shape # Calculate kernel_spatial_idx and spatial_dims if it is not specified # It is necessary for ONNX dut to convolution can be 1D/2D/3D if not node.has_valid('kernel_spatial_idx'): node['kernel_spatial_idx'] = np.delete( [x for x in range(len(kernel_shape))], (node.input_feature_channel, node.output_feature_channel)) if not node.has_valid('spatial_dims'): node['spatial_dims'] = np.delete( [x for x in range(len(input_shape))], (node.channel_dims[0], node.batch_dims[0])) node['kernel_spatial'] = kernel_shape[node.kernel_spatial_idx] if not node.has_valid('output'): # restore the number of output feature maps from the second argument that is weights if node.type in [ 'Convolution', 'Deconvolution', 'DeformableConvolution', 'BinaryConvolution' ]: node['output'] = kernel_shape[node.output_feature_channel] else: raise Error( 'Convolution infer function was called for a node {} with unsupported type {}', node.soft_get('name'), node.type) # Set default values for dilation, strides and pads if not set if not node.has_valid('dilation'): node['dilation'] = np.full([len(input_shape)], 1, dtype=np.int64) if not node.has_valid('stride'): node['stride'] = np.full([len(input_shape)], 1, dtype=np.int64) if not node.has_valid('pad'): node['pad'] = int64_array([[0, 0]] * len(input_shape)) node['pad_spatial_shape'] = node.pad[node.spatial_dims] if not node.has_valid('output_padding'): node['output_padding'] = np.full([len(input_shape)], 0, dtype=np.int64) if node.has_valid('output_padding') and len(input_shape) > len( node['output_padding']): output_padding = np.zeros(len(input_shape), dtype=np.int64) for i in range(len(node['output_padding'])): output_padding[i] = node['output_padding'][i] node['output_padding'] = output_padding input_spatial_shape = input_shape[node.spatial_dims] stride_spatial_shape = node.stride[node.spatial_dims] kernel_extent = node.dilation[node.spatial_dims] * ( node.kernel_spatial - 1) + 1 # TensorFlow always has auto_pad attribute that can be either valid or same_upper # In ONNX auto_pad attribute is deprecated but appears in some models (could be valid, same_upper or same_lower) # Caffe do not use auto_pad attribute if node.has_valid( 'auto_pad' ) and node.auto_pad != 'explicit' and not node.has_valid( 'output_spatial_shape'): node['pad_spatial_shape'], node[ 'output_spatial_shape'] = tf_window_op_pad_infer( input_spatial_shape, kernel_extent, stride_spatial_shape, node.auto_pad, node.type == 'Deconvolution') pad = np.zeros((len(input_shape), 2), dtype=np.int64) pad[node.spatial_dims] = node.pad_spatial_shape node.pad = pad else: pad_spatial_shape = np.add.reduce(node.pad_spatial_shape, axis=1) if node.type in ('Convolution', 'BinaryConvolution'): float_spatial = Convolution.calc_convolution( input_spatial_shape, stride_spatial_shape, pad_spatial_shape, kernel_extent) node['output_spatial_shape'] = int64_array(float_spatial) elif node.type == 'Deconvolution': # In case of given output_spatial_shape we calculate pads spatial if node.has_valid('output_spatial_shape'): if node.has_valid('get_pad'): node['pad'] = node.get_pad(node, input_shape, kernel_shape) else: log.debug( 'Can\'t calculate paddings due to missing lambda get_pad in {} node' .format(node.id)) return else: output_padding = node.output_padding[ node.spatial_dims] if node.has_valid( 'output_padding') else None if output_padding is not None and any(output_padding): pad_spatial_shape -= output_padding for dim in range(len(pad_spatial_shape)): node.pad_spatial_shape[dim][ 1] -= pad_spatial_shape[dim] float_spatial = Convolution.calc_deconvolution( node, input_spatial_shape, pad_spatial_shape, kernel_extent) node['output_spatial_shape'] = int64_array(float_spatial) elif node.type == 'DeformableConvolution': # get the output spatial shape from the second input with offsets node['output_spatial_shape'] = int64_array( [node.in_node(1).shape[2:4]]) else: assert 'Unsupported layer type "{}"'.format(node.type) # For cases when group attribute wasn't set in extractor we should specify get_group attribute # this attribute should store lambda node: ... (check tf convolution extractor) if node.has_valid('get_group'): node['group'] = node.get_group(node) output_shape = np.full_like(input_shape, -1, dtype=np.int64) output_shape[node.batch_dims] = input_shape[node.batch_dims] # pylint: disable=unsupported-assignment-operation output_shape[node.spatial_dims] = node.output_spatial_shape # pylint: disable=unsupported-assignment-operation # For cases when output attribute wasn't set in extractor we should specify get_output_feature_dim attribute # this attribute should store lambda node: ... (check tf convolution extractor) if node.has_valid('get_output_feature_dim'): node['output'] = node.get_output_feature_dim(node) output_shape[node.channel_dims] = node.output # pylint: disable=unsupported-assignment-operation node['output_shape'] = output_shape for n in node.out_nodes(): node.out_node(n).shape = output_shape mark_input_bins( node, start_port=1 if node.type != 'DeformableConvolution' else 2) assign_dims_to_weights(node.in_node(weights_index), node.kernel_spatial_idx, node.input_feature_channel, node.output_feature_channel, len(kernel_shape)) PermuteAttrs.create_permute_attrs( node, attrs=[ ('pad', 'input:0'), ('stride', 'input:0'), ('dilation', 'input:0'), ('output_shape', 'input:0'), ('batch_dims', 'input:0'), ('channel_dims', 'input:0'), ('spatial_dims', 'input:0'), ('kernel_shape', 'input:{}'.format(weights_index)), ('kernel_spatial_idx', 'input:{}'.format(weights_index)), ('input_feature_channel', 'input:{}'.format(weights_index)), ('output_feature_channel', 'input:{}'.format(weights_index)), ]) PermuteAttrs.set_permutation( node.in_node(weights_index), node, node.soft_get('get_weights_permute', None)) PermuteInputs().set_input_permutation(node.in_node(weights_index), node, 'input:{}'.format(weights_index), 'transpose')