def replace_timeheightconv(self, graph: Graph, node: Node):
        req_time_offsets = node.soft_get('time_offsets')
        offsets = node.soft_get("offsets", [[]])
        all_time_offsets = list(set(offsets[:, 0]))
        all_time_offsets.sort()
        in_name = node.soft_get('name', node.id)
        rename_node(node, in_name + '/to_delete')

        # create memoryoffsets for context gathering
        # we need concat if time offsets more than 1
        concat = Concat(graph,
                        attrs={
                            'name': in_name + '/Concat',
                            'in_ports_count': len(all_time_offsets)
                        }).create_node()
        i = 0
        for t in all_time_offsets:
            # if time offset included in required_time_offsets we don't need default value
            has_default = t not in req_time_offsets
            memoff = MemoryOffset(graph,
                                  attrs={
                                      'name':
                                      in_name + '/MemoryOffset_' + str(i),
                                      't':
                                      t,
                                      'has_default':
                                      has_default,
                                      'splitted':
                                      False,
                                      'pair_name':
                                      in_name + '/MemoryOffset_pair_' + str(i)
                                  }).create_node()
            concat.in_port(i).connect(memoff.out_port(0))
            memoff.in_port(0).connect(node.in_port(0).get_source())
            i = i + 1

        stride = node.soft_get("height_subsample", 1)

        kernel = int64_array([0, 0])
        kernel[0] = len(set(offsets[:, 0]))
        kernel[1] = len(set(offsets[:, 1]))

        pad_h = int64_array([0, 0])
        pad_h[0] = -min(offsets[:, 1]) if min(offsets[:, 1]) < 0 else 0
        pad_h[1] = stride * node.height_out - (node.height_in -
                                               max([max(offsets[:, 1]), 0]))

        dilation_t = (max(offsets[:, 0]) - min(offsets[:, 0])) / (
            kernel[0] - 1) if kernel[0] > 1 else 1
        dilation_h = (max(offsets[:, 1]) - min(offsets[:, 1])) / (
            kernel[1] - 1) if kernel[0] > 1 else 1

        conv_attrs = {
            'name':
            in_name,
            'output':
            node['out_channels'],
            'height_in':
            node.height_in,
            'bias_term':
            None,
            'pad':
            int64_array([[0, 0], [0, 0], [0, 0], pad_h]),
            'pad_spatial_shape':
            int64_array([[0, 0], pad_h]),
            'dilation':
            int64_array([1, 1, dilation_t, dilation_h]),
            'kernel':
            int64_array(
                [node.out_channels, node.in_channels, kernel[0], kernel[1]]),
            'stride':
            int64_array([1, 1, 1, stride]),
            'kernel_spatial':
            kernel,
            'input_feature_channel':
            1,
            'output_feature_channel':
            0,
            'channel_dims':
            int64_array([1]),
            'spatial_dims':
            int64_array([2, 3]),
            'batch_dims':
            int64_array([0]),
            'kernel_spatial_idx':
            int64_array([2, 3]),
            'group':
            1,
            'reshape_kernel':
            True,
            'bias_addable':
            True,
        }
        conv = Convolution(graph, attrs=conv_attrs).create_node()
        conv.in_port(0).connect(concat.out_port(0))
        conv.in_port(1).connect(node.in_port(1).get_source())

        # change layout for weights from OHWI to OIHW
        # in future should be replaced by common Permute mechanics
        weights = conv.in_port(1).get_source().node.value
        weights = weights.reshape(
            int64_array([node.out_channels, -1, node.in_channels]))
        weights = weights.transpose(int64_array([0, 2, 1]))
        weights = weights.flatten()
        conv.in_port(1).get_source().node.value = weights

        conv.in_port(2).connect(node.in_port(2).get_source())
        node.out_port(0).get_connection().set_source(conv.out_port(0))
        graph.remove_node(node.id)
Exemple #2
0
    def replace_pattern(self, graph: Graph, match: dict):
        """
        Converts specific for NasNet topology subgraph Pad->StridedSlice->AvgPool to Conv->Crop->AvgPool
        """
        input = match['input']

        pad_node = match['pad_op']
        pad_node_name = pad_node.soft_get('name', pad_node.id)

        sslice_node = match['sslice']
        begin = []
        end = []
        stride = []
        for s in sslice_node.slices:
            begin.append(s.start)
            end.append(s.stop)
            stride.append(s.step)

        pads_begin = pad_node.in_port(1).data.get_value()
        pads_end = pad_node.in_port(2).data.get_value()
        if pads_begin is None or pads_end is None:
            log.error('Pad values for node "{}" are not constants'.format(
                pad_node_name))
            return

        if not np.array_equal(pads_begin, int64_array([0, 0, 0, 0])):
            log.error('Pad begin values doesn\'t match for node {}!'.format(
                pad_node_name))
            return

        if not np.array_equal(pads_end, int64_array([0, 1, 1, 0])):
            log.error('Pad end values doesn\'t match for node {}!'.format(
                pad_node_name))
            return

        if not np.array_equal(begin, int64_array([0, 1, 1, 0])):
            log.error("StridedSlice has wrong begin")
            return

        if not np.array_equal(sslice_node.end_mask, int64_array(
            [0, 0, 0, 0])) or not np.array_equal(sslice_node.begin_mask,
                                                 int64_array([0, 1, 1, 0])):
            log.error("StridedSlice has wrong masks")
            return

        # Pad -> Conv
        conv_name = graph.unique_id(pad_node.name + '/Conv_')
        conv_weights_name = graph.unique_id(pad_node.name + '/ConvW_')
        conv_weights = np.ones((input.shape[3], 1, 1, 1))
        output_shape = int64_array([
            input.shape[0], input.shape[1] + 1, input.shape[2] + 1,
            input.shape[3]
        ])

        conv_node = Convolution(
            graph,
            dict(
                name=conv_name,
                stride=int64_array([1, 1, 1, 1]),
                dilation=int64_array([1, 1, 1, 1]),
                group=input.shape[3],
                bias_addable=True,
                bias_term=False,
                spatial_dims=int64_array([1, 2]),
                kernel_spatial=int64_array([1, 1]),
                pad=int64_array([[0, 0], [0, 1], [0, 1], [0, 0]]),
                output_shape=output_shape,
                batch_dims=int64_array([0]),
                channel_dims=int64_array([3]),
                output=input.shape[3],
                input_feature_channel=1,
                output_feature_channel=0,
            )).create_node()

        weights_const_node = Const(
            graph,
            dict(name=conv_weights_name,
                 value=conv_weights,
                 shape=int64_array(conv_weights.shape))).create_node()

        # StridedSlice -> Crop
        crop_node = Crop(
            graph,
            dict(name=sslice_node.name + '/Crop_',
                 axis=int64_array([1, 2]),
                 dim=int64_array([output_shape[1] - 1, output_shape[2] - 1]),
                 offset=int64_array([1, 1]))).create_node()

        # Connect nodes
        pad_node.in_port(0).get_connection().set_destination(
            conv_node.in_port(0))
        weights_const_node.out_port(0).connect(conv_node.in_port(1))
        conv_node.out_port(0).connect(crop_node.in_port(0))
        sslice_node.out_port(0).get_connection().set_source(
            crop_node.out_port(0))

        conv_node.in_port(1).bin = 'weights'

        # Remove Pad and StridedSlice nodes from graph
        graph.remove_node(pad_node.id)
        graph.remove_node(sslice_node.id)