def align_frame_time(graph: Graph, node: Node, frame_time_max):
    for inp in node.in_ports():
        if node.in_port(inp).disconnected():
            continue
        in_node = node.in_port(inp).get_source().node
        in_node_out_port = node.in_port(inp).get_source()
        in_port = node.in_port(inp)
        # Adding MemoryOffset for Const does not make sense
        if in_node.frame_time < frame_time_max and in_node.op != 'Const':
            # Change existing MemoryOffset to avoid adding new one
            if in_node.op == 'MemoryOffset':
                in_node.t = in_node.frame_time - frame_time_max
                in_node.frame_time = in_node.t
            else:
                mem_name = graph.unique_id("align_" + node.id)
                memory_align = MemoryOffset(
                    graph,
                    attrs={
                        'id': mem_name,
                        'name': mem_name,
                        'pair_name': mem_name + "_pair",
                        't': in_node.frame_time - frame_time_max,
                        'splitted': False
                    }).create_node()
                # add element_size for MemoryOffset after Parameter for infer
                if in_node.op == 'Parameter':
                    memory_align['element_size'] = in_node.shape
                in_port.get_connection().set_source(memory_align.out_port(0))
                memory_align.in_port(0).connect(in_node_out_port)
                memory_align['frame_time'] = memory_align.t
        # remove MemoryOffset with maximum delay
        elif in_node.frame_time == frame_time_max and in_node.op == 'MemoryOffset':
            in_node_out_port.get_connection().set_source(
                in_node.in_port(0).get_source())
            graph.remove_node(in_node.id)
Exemple #2
0
    def find_and_replace_pattern(self, graph: Graph):
        for offset_node in graph.get_op_nodes(op='MemoryOffset',
                                              splitted=False):
            paired_node = MemoryOffset(
                graph, {
                    'name': offset_node.pair_name,
                    'splitted': True,
                    'pair_name': offset_node.id,
                    't': offset_node.t,
                    'has_default': offset_node.has_default
                }).create_node()
            offset_node['splitted'] = True
            offset_node.out_port(0).get_connection().set_source(
                paired_node.out_port(0))
            res_node = Result(graph, {
                'name': offset_node.id + "_output"
            }).create_node()
            offset_node.out_port(0).connect(res_node.in_port(0))

            # If 'element_size' is previously copied from Parameter of from node with defined dim
            if offset_node.has_valid('element_size'):
                paired_node['element_size'] = offset_node['element_size']
            # Copy shape from previous node. Typically (but not always) for TDNN blocks this is the case
            else:
                paired_node['element_size'] = offset_node.in_port(
                    0).data.get_shape()[1]
Exemple #3
0
 def replace_pattern(graph: Graph, match: dict):
     offset_node = match['mem_offset']
     paired_node = MemoryOffset(graph, {'name': offset_node.pair_name, 'splitted': True, 'pair_name': offset_node.id,
                                        't': offset_node.t, 'has_default': offset_node.has_default}).create_node()
     offset_node['splitted'] = True
     offset_node.out_port(0).get_connection().set_source(paired_node.out_port(0))
     res_node = Result(graph, {'name': offset_node.id+"_output"}).create_node()
     offset_node.out_port(0).connect(res_node.in_port(0))
Exemple #4
0
    def extract(cls, node):
        pb = node.parameters
        mapping_rule = {
            'pair_name': pb['pair_name'],
            't': pb['t'],
            'has_default': pb['has_default'],
            'splitted': False,
        }
        if 'element_size' in pb:
            mapping_rule['element_size'] = pb['element_size']

        MemoryOffset.update_node_stat(node, mapping_rule)
        return cls.enabled
Exemple #5
0
    def replace_tdnn(self, graph: Graph, tdnn_node: Node):
        tdnn_name = tdnn_node.soft_get('name', tdnn_node.id)

        concat_node = Concat(graph, {'axis': 1}).create_node()
        rename_nodes([(tdnn_node, tdnn_name + '/to_be_removed'),
                      (concat_node, tdnn_name)])

        for offset_ind, t in enumerate(tdnn_node['time_offsets']):
            concat_node.add_input_port(offset_ind)
            if t != 0:
                memory_name = tdnn_name + '/MemoryOffset/' + str(abs(t))
                memoryoffset_node = MemoryOffset(
                    graph, {
                        'name': memory_name,
                        't': t,
                        'pair_name': memory_name + '_out',
                        'has_default': False,
                        'splitted': False
                    }).create_node()

                tdnn_node.in_port(0).get_source().connect(
                    memoryoffset_node.in_port(0))
                memoryoffset_node.out_port(0).connect(
                    concat_node.in_port(offset_ind))
            else:
                # 0 time delay is not allowed in IE, it's meaningless
                # if time offset is 0 then connect input of tdnncomponent directly to Concat without memoryoffset
                tdnn_node.in_port(0).get_source().connect(
                    concat_node.in_port(offset_ind))

        weights = tdnn_node['weights']
        fc_inputs = {1: weights}

        bias_term = False
        if tdnn_node.has_valid('biases'):
            assert len(tdnn_node['biases']) == weights.shape[0]
            fc_inputs.update({2: tdnn_node['biases']})
            bias_term = True

        fc_node = create_op_with_const_inputs(
            graph, FullyConnected, fc_inputs, {
                'name': tdnn_name + '/FC',
                'out-size': weights.shape[0],
                'transpose_weights': True,
                'bias_term': bias_term
            })

        concat_node.out_port(0).connect(fc_node.in_port(0))
        tdnn_node.in_port(0).disconnect()
        tdnn_node.out_port(0).get_connection().set_source(fc_node.out_port(0))
Exemple #6
0
 def split_offset(offset_node: Node):
     paired_node = MemoryOffset(
         offset_node.graph, {
             'name': offset_node.pair_name,
             'splitted': True,
             'pair_name': offset_node.id,
             'element_size': offset_node['element_size'],
             't': offset_node.t,
             'has_default': offset_node.has_default
         }).create_node()
     offset_node['splitted'] = True
     offset_node.out_port(0).get_connection().set_source(
         paired_node.out_port(0))
     res_node = Result(offset_node.graph, {
         'name': offset_node.id + '_output'
     }).create_node()
     offset_node.out_port(0).connect(res_node.in_port(0))
    def replace_timeheightconv(self, graph: Graph, node: Node):
        req_time_offsets = node.soft_get('time_offsets')
        offsets = node.soft_get("offsets", [[]])
        all_time_offsets = list(set(offsets[:, 0]))
        all_time_offsets.sort()
        in_name = node.soft_get('name', node.id)
        rename_node(node, in_name + '/to_delete')

        # create memoryoffsets for context gathering
        # we need concat if time offsets more than 1
        concat = Concat(graph,
                        attrs={
                            'name': in_name + '/Concat',
                            'in_ports_count': len(all_time_offsets)
                        }).create_node()
        i = 0
        for t in all_time_offsets:
            # if time offset included in required_time_offsets we don't need default value
            has_default = t not in req_time_offsets
            memoff = MemoryOffset(graph,
                                  attrs={
                                      'name':
                                      in_name + '/MemoryOffset_' + str(i),
                                      't':
                                      t,
                                      'has_default':
                                      has_default,
                                      'splitted':
                                      False,
                                      'pair_name':
                                      in_name + '/MemoryOffset_pair_' + str(i)
                                  }).create_node()
            concat.in_port(i).connect(memoff.out_port(0))
            memoff.in_port(0).connect(node.in_port(0).get_source())
            i = i + 1

        stride = node.soft_get("height_subsample", 1)

        kernel = int64_array([0, 0])
        kernel[0] = len(set(offsets[:, 0]))
        kernel[1] = len(set(offsets[:, 1]))

        pad_h = int64_array([0, 0])
        pad_h[0] = -min(offsets[:, 1]) if min(offsets[:, 1]) < 0 else 0
        pad_h[1] = stride * node.height_out - (node.height_in -
                                               max([max(offsets[:, 1]), 0]))

        dilation_t = (max(offsets[:, 0]) - min(offsets[:, 0])) / (
            kernel[0] - 1) if kernel[0] > 1 else 1
        dilation_h = (max(offsets[:, 1]) - min(offsets[:, 1])) / (
            kernel[1] - 1) if kernel[0] > 1 else 1

        conv_attrs = {
            'name':
            in_name,
            'output':
            node['out_channels'],
            'height_in':
            node.height_in,
            'bias_term':
            None,
            'pad':
            int64_array([[0, 0], [0, 0], [0, 0], pad_h]),
            'pad_spatial_shape':
            int64_array([[0, 0], pad_h]),
            'dilation':
            int64_array([1, 1, dilation_t, dilation_h]),
            'kernel':
            int64_array(
                [node.out_channels, node.in_channels, kernel[0], kernel[1]]),
            'stride':
            int64_array([1, 1, 1, stride]),
            'kernel_spatial':
            kernel,
            'input_feature_channel':
            1,
            'output_feature_channel':
            0,
            'channel_dims':
            int64_array([1]),
            'spatial_dims':
            int64_array([2, 3]),
            'batch_dims':
            int64_array([0]),
            'kernel_spatial_idx':
            int64_array([2, 3]),
            'group':
            1,
            'reshape_kernel':
            True,
            'bias_addable':
            True,
        }
        conv = Convolution(graph, attrs=conv_attrs).create_node()
        conv.in_port(0).connect(concat.out_port(0))
        conv.in_port(1).connect(node.in_port(1).get_source())

        # change layout for weights from OHWI to OIHW
        # in future should be replaced by common Permute mechanics
        weights = conv.in_port(1).get_source().node.value
        weights = weights.reshape(
            int64_array([node.out_channels, -1, node.in_channels]))
        weights = weights.transpose(int64_array([0, 2, 1]))
        weights = weights.flatten()
        conv.in_port(1).get_source().node.value = weights

        conv.in_port(2).connect(node.in_port(2).get_source())
        node.out_port(0).get_connection().set_source(conv.out_port(0))
        graph.remove_node(node.id)