Exemplo n.º 1
0
    def pool(cls, node, pool_type=None, **kwargs):
        all_nodes = kwargs['all_nodes']
        G = kwargs['G']
        valid_name = kwargs['valid_name']
        inputs = [all_nodes[inp] for inp in node.input]
        x = inputs[0]
        x_shape = x[2].shape
        x_feature_shape = x_shape[2::]
        in_c = x_shape[1]

        kernel_shape = node.attrs["kernel_shape"]
        spatial_size = len(kernel_shape)
        x_rank = spatial_size + 2
        if spatial_size != 2:
            raise ValueError(valid_name + " with {}D input".format(x_rank))

        h = x_shape[2]
        w = x_shape[3]

        strides = node.attrs.get("strides", [1] * spatial_size)
        stride_is_one = all(stride == 1 for stride in strides)
        dilations = node.attrs.get("dilations", [1] * spatial_size)
        if any(dilation > 1 for dilation in dilations):
            raise ValueError(valid_name + " with dilation not supported")
        # ceil_mode = bool(node.attrs.get("ceil_mode", 0))
        pad_dim = cls.calc_pad_dim(node, spatial_size)
        # Note: This needs to check dilation if it is added
        filter_matches_input = (all(
            k_dim >= (x_dim + pad) for k_dim, x_dim, pad in zip(
                kernel_shape, x_feature_shape, [pad_dim.h, pad_dim.w])))

        if filter_matches_input and stride_is_one:
            params = GlobalPoolParameters(valid_name,
                                          pool_type=pool_type,
                                          axis=[1, 2],
                                          keep_dims=True,
                                          in_dims_hint=[['c', 'h', 'w']],
                                          out_dims_hint=[['c', 'h', 'w']])
        else:
            params = PoolingParameters(
                valid_name,
                filt=PoolFilterDim(kernel_shape[0], kernel_shape[1]),
                stride=StrideDim(strides[0], strides[1]),
                padding=pad_dim,
                pool_type=pool_type,
                in_dims_hint=[['c', 'h', 'w']],
                out_dims_hint=[['c', 'h', 'w']])

        in_dim = Dim.named_ordered(c=in_c, h=h, w=w)
        out_dims = params.get_output_size([in_dim])
        pout_dims = ProvisionalDim([x_shape[0]] + out_dims[0].shape)
        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))
        all_nodes[node.output[0]] = (params, 0, pout_dims)
        return params
Exemplo n.º 2
0
    def pool2d(cls, node, pool_type=None, **kwargs):
        all_nodes = kwargs['all_nodes']
        G = kwargs['G']
        opts = kwargs['opts']
        node_opts = node.get_options(Pool2DOptions)

        inputs = [all_nodes[inp] for inp in node.input]
        x = inputs[0]
        x = cls.remove_known_batch_dimension(G, x, node)
        x_shape = x[2].shape
        in_c = x_shape[1]

        in_b, h, w, in_c = tuple(x_shape)

        filt_h = node_opts.FilterHeight()
        filt_w = node_opts.FilterWidth()
        stride_h = node_opts.StrideH()
        stride_w = node_opts.StrideW()

        pad = cls.get_tf_padding(node_opts.Padding())

        filter_matches_input = h == filt_h and w == filt_w
        stride_is_one = stride_h == 1 and stride_w == 1

        if filter_matches_input and stride_is_one:
            params = GlobalPoolParameters(node.name,
                                          pool_type=pool_type,
                                          axis=[0, 1],
                                          keep_dims=True,
                                          in_dims_hint=[['h', 'w', 'c']],
                                          out_dims_hint=[['h', 'w', 'c']])
        else:
            params = PoolingParameters(node.name,
                                       filt=PoolFilterDim(filt_h, filt_w),
                                       stride=StrideDim(stride_h, stride_w),
                                       padding=pad,
                                       pool_type=pool_type,
                                       in_dims_hint=[['h', 'w', 'c']],
                                       out_dims_hint=[['h', 'w', 'c']])

        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization(
                node.input, node.output)

        in_dim = Dim.named_ordered(h=h, w=w, c=in_c)
        out_dims = params.get_output_size([in_dim])
        pout_dims = ProvisionalDim([in_b] + out_dims[0].shape)
        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))
        params = cls.fuse_activation(node_opts, node.name, params, **kwargs)
        all_nodes[node.output[0]] = (params, 0, pout_dims)
        return params
Exemplo n.º 3
0
    def _common(cls, node, **kwargs):
        node_opts = node.get_options(ReducerOptions)
        all_nodes = kwargs['all_nodes']
        G = kwargs['G']
        reduce_type = kwargs['reduce_type']
        opts = kwargs['opts']

        inputs = [all_nodes[inp] for inp in node.input]
        x = inputs[0]

        x_shape = x[2].shape

        axes = cls._verify_constant(inputs[1])
        node.input[1].used = True

        if len(axes.shape) == 0:
            axes = list([int(axes)])
        else:
            axes = sorted(list(axes))

        # convert all negative axis to their true value
        axes = sorted(
            [elem if elem >= 0 else len(x_shape) + elem for elem in axes])

        if not BackendHandler.remove_unspecified_dim(axes):
            params = NoOPParameters(node.name)
            pout_shape = x_shape.copy()
        else:
            pout_shape = [
                1 if idx in axes and dim is not None else dim
                for idx, dim in enumerate(x_shape)
            ]
            # subtract 1 from axis for all None's preceeding it and remove
            # axes that are not defined
            axes = [
                ax - sum([1 if dim is None else 0 for dim in x_shape[:ax:]])
                for ax in axes if x_shape[ax] is not None
            ]
            params = GlobalPoolParameters(node.name,
                                          pool_type=reduce_type,
                                          axis=tuple(axes),
                                          keep_dims=node_opts.KeepDims())
            # the reduced axes are set to 1 in the output shape

        if opts.get('load_quantization'):
            G.quantization[NodeId(
                params)] = BackendHandler.load_tf_quantization([node.input[0]],
                                                               node.output)

        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))

        all_nodes[node.output[0]] = (params, 0, ProvisionalDim(pout_shape))
        return params
Exemplo n.º 4
0
    def _common(cls, node, **kwargs):
        all_nodes = kwargs['all_nodes']
        G = kwargs['G']
        valid_name = kwargs['valid_name']
        reduce_type = kwargs['reduce_type']
        inputs = [all_nodes[inp] for inp in node.input]

        x = inputs[0]
        x_shape = x[2].shape
        x_rank = len(x_shape)

        axes = node.attrs['axes']

        # convert all negative axis to their true value
        axes = set([elem if elem >= 0 else x_rank + elem for elem in axes])
        assert all(axis >= 0 and axis < x_rank
                   for axis in axes), "axis out of bounds"
        keep_dims = node.attrs.get('keepdims', 1)

        stripped_axes = [axis for axis in axes if x_shape[axis] is not None]

        if not stripped_axes:
            params = NoOPParameters(valid_name)
            pout_shape = x_shape.copy()
        else:
            if keep_dims:
                pout_shape = [
                    dim if idx not in axes else 1
                    for idx, dim in enumerate(x_shape)
                ]
            else:
                pout_shape = [
                    dim for idx, dim in enumerate(x_shape) if idx not in axes
                ]
                if all(dim is None for dim in pout_shape):
                    pout_shape.append(1)

            # subtract 1 from axis for all None's preceeding it and remove
            # axes that are not defined
            axes = [
                ax - sum([1 if dim is None else 0 for dim in x_shape[:ax:]])
                for ax in stripped_axes
            ]
            params = GlobalPoolParameters(valid_name,
                                          pool_type=reduce_type,
                                          axis=tuple(axes),
                                          keep_dims=keep_dims)

        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))

        all_nodes[node.output[0]] = (params, 0, ProvisionalDim(pout_shape))
        return params