def align_frame_time(graph: Graph, node: Node, frame_time_max): for inp in node.in_ports(): if node.in_port(inp).disconnected(): continue in_node = node.in_port(inp).get_source().node in_node_out_port = node.in_port(inp).get_source() in_port = node.in_port(inp) # Adding MemoryOffset for Const does not make sense if in_node.frame_time < frame_time_max and in_node.op != 'Const': # Change existing MemoryOffset to avoid adding new one if in_node.op == 'MemoryOffset': in_node.t = in_node.frame_time - frame_time_max in_node.frame_time = in_node.t else: mem_name = graph.unique_id("align_" + node.id) memory_align = MemoryOffset( graph, attrs={ 'id': mem_name, 'name': mem_name, 'pair_name': mem_name + "_pair", 't': in_node.frame_time - frame_time_max, 'splitted': False }).create_node() # add element_size for MemoryOffset after Parameter for infer if in_node.op == 'Parameter': memory_align['element_size'] = in_node.shape in_port.get_connection().set_source(memory_align.out_port(0)) memory_align.in_port(0).connect(in_node_out_port) memory_align['frame_time'] = memory_align.t # remove MemoryOffset with maximum delay elif in_node.frame_time == frame_time_max and in_node.op == 'MemoryOffset': in_node_out_port.get_connection().set_source( in_node.in_port(0).get_source()) graph.remove_node(in_node.id)
def find_and_replace_pattern(self, graph: Graph): for offset_node in graph.get_op_nodes(op='MemoryOffset', splitted=False): paired_node = MemoryOffset( graph, { 'name': offset_node.pair_name, 'splitted': True, 'pair_name': offset_node.id, 't': offset_node.t, 'has_default': offset_node.has_default }).create_node() offset_node['splitted'] = True offset_node.out_port(0).get_connection().set_source( paired_node.out_port(0)) res_node = Result(graph, { 'name': offset_node.id + "_output" }).create_node() offset_node.out_port(0).connect(res_node.in_port(0)) # If 'element_size' is previously copied from Parameter of from node with defined dim if offset_node.has_valid('element_size'): paired_node['element_size'] = offset_node['element_size'] # Copy shape from previous node. Typically (but not always) for TDNN blocks this is the case else: paired_node['element_size'] = offset_node.in_port( 0).data.get_shape()[1]
def replace_pattern(graph: Graph, match: dict): offset_node = match['mem_offset'] paired_node = MemoryOffset(graph, {'name': offset_node.pair_name, 'splitted': True, 'pair_name': offset_node.id, 't': offset_node.t, 'has_default': offset_node.has_default}).create_node() offset_node['splitted'] = True offset_node.out_port(0).get_connection().set_source(paired_node.out_port(0)) res_node = Result(graph, {'name': offset_node.id+"_output"}).create_node() offset_node.out_port(0).connect(res_node.in_port(0))
def extract(cls, node): pb = node.parameters mapping_rule = { 'pair_name': pb['pair_name'], 't': pb['t'], 'has_default': pb['has_default'], 'splitted': False, } if 'element_size' in pb: mapping_rule['element_size'] = pb['element_size'] MemoryOffset.update_node_stat(node, mapping_rule) return cls.enabled
def replace_tdnn(self, graph: Graph, tdnn_node: Node): tdnn_name = tdnn_node.soft_get('name', tdnn_node.id) concat_node = Concat(graph, {'axis': 1}).create_node() rename_nodes([(tdnn_node, tdnn_name + '/to_be_removed'), (concat_node, tdnn_name)]) for offset_ind, t in enumerate(tdnn_node['time_offsets']): concat_node.add_input_port(offset_ind) if t != 0: memory_name = tdnn_name + '/MemoryOffset/' + str(abs(t)) memoryoffset_node = MemoryOffset( graph, { 'name': memory_name, 't': t, 'pair_name': memory_name + '_out', 'has_default': False, 'splitted': False }).create_node() tdnn_node.in_port(0).get_source().connect( memoryoffset_node.in_port(0)) memoryoffset_node.out_port(0).connect( concat_node.in_port(offset_ind)) else: # 0 time delay is not allowed in IE, it's meaningless # if time offset is 0 then connect input of tdnncomponent directly to Concat without memoryoffset tdnn_node.in_port(0).get_source().connect( concat_node.in_port(offset_ind)) weights = tdnn_node['weights'] fc_inputs = {1: weights} bias_term = False if tdnn_node.has_valid('biases'): assert len(tdnn_node['biases']) == weights.shape[0] fc_inputs.update({2: tdnn_node['biases']}) bias_term = True fc_node = create_op_with_const_inputs( graph, FullyConnected, fc_inputs, { 'name': tdnn_name + '/FC', 'out-size': weights.shape[0], 'transpose_weights': True, 'bias_term': bias_term }) concat_node.out_port(0).connect(fc_node.in_port(0)) tdnn_node.in_port(0).disconnect() tdnn_node.out_port(0).get_connection().set_source(fc_node.out_port(0))
def split_offset(offset_node: Node): paired_node = MemoryOffset( offset_node.graph, { 'name': offset_node.pair_name, 'splitted': True, 'pair_name': offset_node.id, 'element_size': offset_node['element_size'], 't': offset_node.t, 'has_default': offset_node.has_default }).create_node() offset_node['splitted'] = True offset_node.out_port(0).get_connection().set_source( paired_node.out_port(0)) res_node = Result(offset_node.graph, { 'name': offset_node.id + '_output' }).create_node() offset_node.out_port(0).connect(res_node.in_port(0))
def replace_timeheightconv(self, graph: Graph, node: Node): req_time_offsets = node.soft_get('time_offsets') offsets = node.soft_get("offsets", [[]]) all_time_offsets = list(set(offsets[:, 0])) all_time_offsets.sort() in_name = node.soft_get('name', node.id) rename_node(node, in_name + '/to_delete') # create memoryoffsets for context gathering # we need concat if time offsets more than 1 concat = Concat(graph, attrs={ 'name': in_name + '/Concat', 'in_ports_count': len(all_time_offsets) }).create_node() i = 0 for t in all_time_offsets: # if time offset included in required_time_offsets we don't need default value has_default = t not in req_time_offsets memoff = MemoryOffset(graph, attrs={ 'name': in_name + '/MemoryOffset_' + str(i), 't': t, 'has_default': has_default, 'splitted': False, 'pair_name': in_name + '/MemoryOffset_pair_' + str(i) }).create_node() concat.in_port(i).connect(memoff.out_port(0)) memoff.in_port(0).connect(node.in_port(0).get_source()) i = i + 1 stride = node.soft_get("height_subsample", 1) kernel = int64_array([0, 0]) kernel[0] = len(set(offsets[:, 0])) kernel[1] = len(set(offsets[:, 1])) pad_h = int64_array([0, 0]) pad_h[0] = -min(offsets[:, 1]) if min(offsets[:, 1]) < 0 else 0 pad_h[1] = stride * node.height_out - (node.height_in - max([max(offsets[:, 1]), 0])) dilation_t = (max(offsets[:, 0]) - min(offsets[:, 0])) / ( kernel[0] - 1) if kernel[0] > 1 else 1 dilation_h = (max(offsets[:, 1]) - min(offsets[:, 1])) / ( kernel[1] - 1) if kernel[0] > 1 else 1 conv_attrs = { 'name': in_name, 'output': node['out_channels'], 'height_in': node.height_in, 'bias_term': None, 'pad': int64_array([[0, 0], [0, 0], [0, 0], pad_h]), 'pad_spatial_shape': int64_array([[0, 0], pad_h]), 'dilation': int64_array([1, 1, dilation_t, dilation_h]), 'kernel': int64_array( [node.out_channels, node.in_channels, kernel[0], kernel[1]]), 'stride': int64_array([1, 1, 1, stride]), 'kernel_spatial': kernel, 'input_feature_channel': 1, 'output_feature_channel': 0, 'channel_dims': int64_array([1]), 'spatial_dims': int64_array([2, 3]), 'batch_dims': int64_array([0]), 'kernel_spatial_idx': int64_array([2, 3]), 'group': 1, 'reshape_kernel': True, 'bias_addable': True, } conv = Convolution(graph, attrs=conv_attrs).create_node() conv.in_port(0).connect(concat.out_port(0)) conv.in_port(1).connect(node.in_port(1).get_source()) # change layout for weights from OHWI to OIHW # in future should be replaced by common Permute mechanics weights = conv.in_port(1).get_source().node.value weights = weights.reshape( int64_array([node.out_channels, -1, node.in_channels])) weights = weights.transpose(int64_array([0, 2, 1])) weights = weights.flatten() conv.in_port(1).get_source().node.value = weights conv.in_port(2).connect(node.in_port(2).get_source()) node.out_port(0).get_connection().set_source(conv.out_port(0)) graph.remove_node(node.id)