def replace_timeheightconv(self, graph: Graph, node: Node): req_time_offsets = node.soft_get('time_offsets') offsets = node.soft_get("offsets", [[]]) all_time_offsets = list(set(offsets[:, 0])) all_time_offsets.sort() in_name = node.soft_get('name', node.id) rename_node(node, in_name + '/to_delete') # create memoryoffsets for context gathering # we need concat if time offsets more than 1 concat = Concat(graph, attrs={ 'name': in_name + '/Concat', 'in_ports_count': len(all_time_offsets) }).create_node() i = 0 for t in all_time_offsets: # if time offset included in required_time_offsets we don't need default value has_default = t not in req_time_offsets memoff = MemoryOffset(graph, attrs={ 'name': in_name + '/MemoryOffset_' + str(i), 't': t, 'has_default': has_default, 'splitted': False, 'pair_name': in_name + '/MemoryOffset_pair_' + str(i) }).create_node() concat.in_port(i).connect(memoff.out_port(0)) memoff.in_port(0).connect(node.in_port(0).get_source()) i = i + 1 stride = node.soft_get("height_subsample", 1) kernel = int64_array([0, 0]) kernel[0] = len(set(offsets[:, 0])) kernel[1] = len(set(offsets[:, 1])) pad_h = int64_array([0, 0]) pad_h[0] = -min(offsets[:, 1]) if min(offsets[:, 1]) < 0 else 0 pad_h[1] = stride * node.height_out - (node.height_in - max([max(offsets[:, 1]), 0])) dilation_t = (max(offsets[:, 0]) - min(offsets[:, 0])) / ( kernel[0] - 1) if kernel[0] > 1 else 1 dilation_h = (max(offsets[:, 1]) - min(offsets[:, 1])) / ( kernel[1] - 1) if kernel[0] > 1 else 1 conv_attrs = { 'name': in_name, 'output': node['out_channels'], 'height_in': node.height_in, 'bias_term': None, 'pad': int64_array([[0, 0], [0, 0], [0, 0], pad_h]), 'pad_spatial_shape': int64_array([[0, 0], pad_h]), 'dilation': int64_array([1, 1, dilation_t, dilation_h]), 'kernel': int64_array( [node.out_channels, node.in_channels, kernel[0], kernel[1]]), 'stride': int64_array([1, 1, 1, stride]), 'kernel_spatial': kernel, 'input_feature_channel': 1, 'output_feature_channel': 0, 'channel_dims': int64_array([1]), 'spatial_dims': int64_array([2, 3]), 'batch_dims': int64_array([0]), 'kernel_spatial_idx': int64_array([2, 3]), 'group': 1, 'reshape_kernel': True, 'bias_addable': True, } conv = Convolution(graph, attrs=conv_attrs).create_node() conv.in_port(0).connect(concat.out_port(0)) conv.in_port(1).connect(node.in_port(1).get_source()) # change layout for weights from OHWI to OIHW # in future should be replaced by common Permute mechanics weights = conv.in_port(1).get_source().node.value weights = weights.reshape( int64_array([node.out_channels, -1, node.in_channels])) weights = weights.transpose(int64_array([0, 2, 1])) weights = weights.flatten() conv.in_port(1).get_source().node.value = weights conv.in_port(2).connect(node.in_port(2).get_source()) node.out_port(0).get_connection().set_source(conv.out_port(0)) graph.remove_node(node.id)
def replace_pattern(self, graph: Graph, match: dict): """ Converts specific for NasNet topology subgraph Pad->StridedSlice->AvgPool to Conv->Crop->AvgPool """ input = match['input'] pad_node = match['pad_op'] pad_node_name = pad_node.soft_get('name', pad_node.id) sslice_node = match['sslice'] begin = [] end = [] stride = [] for s in sslice_node.slices: begin.append(s.start) end.append(s.stop) stride.append(s.step) pads_begin = pad_node.in_port(1).data.get_value() pads_end = pad_node.in_port(2).data.get_value() if pads_begin is None or pads_end is None: log.error('Pad values for node "{}" are not constants'.format( pad_node_name)) return if not np.array_equal(pads_begin, int64_array([0, 0, 0, 0])): log.error('Pad begin values doesn\'t match for node {}!'.format( pad_node_name)) return if not np.array_equal(pads_end, int64_array([0, 1, 1, 0])): log.error('Pad end values doesn\'t match for node {}!'.format( pad_node_name)) return if not np.array_equal(begin, int64_array([0, 1, 1, 0])): log.error("StridedSlice has wrong begin") return if not np.array_equal(sslice_node.end_mask, int64_array( [0, 0, 0, 0])) or not np.array_equal(sslice_node.begin_mask, int64_array([0, 1, 1, 0])): log.error("StridedSlice has wrong masks") return # Pad -> Conv conv_name = graph.unique_id(pad_node.name + '/Conv_') conv_weights_name = graph.unique_id(pad_node.name + '/ConvW_') conv_weights = np.ones((input.shape[3], 1, 1, 1)) output_shape = int64_array([ input.shape[0], input.shape[1] + 1, input.shape[2] + 1, input.shape[3] ]) conv_node = Convolution( graph, dict( name=conv_name, stride=int64_array([1, 1, 1, 1]), dilation=int64_array([1, 1, 1, 1]), group=input.shape[3], bias_addable=True, bias_term=False, spatial_dims=int64_array([1, 2]), kernel_spatial=int64_array([1, 1]), pad=int64_array([[0, 0], [0, 1], [0, 1], [0, 0]]), output_shape=output_shape, batch_dims=int64_array([0]), channel_dims=int64_array([3]), output=input.shape[3], input_feature_channel=1, output_feature_channel=0, )).create_node() weights_const_node = Const( graph, dict(name=conv_weights_name, value=conv_weights, shape=int64_array(conv_weights.shape))).create_node() # StridedSlice -> Crop crop_node = Crop( graph, dict(name=sslice_node.name + '/Crop_', axis=int64_array([1, 2]), dim=int64_array([output_shape[1] - 1, output_shape[2] - 1]), offset=int64_array([1, 1]))).create_node() # Connect nodes pad_node.in_port(0).get_connection().set_destination( conv_node.in_port(0)) weights_const_node.out_port(0).connect(conv_node.in_port(1)) conv_node.out_port(0).connect(crop_node.in_port(0)) sslice_node.out_port(0).get_connection().set_source( crop_node.out_port(0)) conv_node.in_port(1).bin = 'weights' # Remove Pad and StridedSlice nodes from graph graph.remove_node(pad_node.id) graph.remove_node(sslice_node.id)