Exemple #1
0
    def _pooling_op(self, cntk_op, inputs):
        """
        Computes the pooling of a tensor.

        Arguments:
            cntk_op: CNTK function to be imported.
            inputs: List of inputs to this node.

        Returns:
            A ngraph Op.
        """
        inputs = self._expand_input_axes(inputs[0])
        C, D, H, W, N = inputs.axes

        M, P, Q = self._make_out_axes(cntk_op.shape)

        if cntk_op.attributes['poolingType'] == 0:
            pool_type = 'max'
        else:
            pool_type = 'avg'

        strides = self._make_strides(cntk_op.attributes['strides'])
        kernel = self._make_kernel(cntk_op.attributes['poolingWindowShape'])
        pad = self._make_padding('pool', cntk_op.attributes['autoPadding'],
                                 (D.length, H.length, W.length, C.length),
                                 kernel,
                                 (M.length, P.length, Q.length, C.length),
                                 strides)

        params = dict(op=pool_type,
                      pad_d=pad[0],
                      pad_h=pad[1],
                      pad_w=pad[2],
                      pad_c=pad[3],
                      str_d=strides[0],
                      str_h=strides[1],
                      str_w=strides[2],
                      str_c=strides[3],
                      T=kernel[0],
                      R=kernel[1],
                      S=kernel[2],
                      J=kernel[3])

        pool = ng.pooling(params, inputs, [C, M, P, Q, N])
        return squeeze_axes([pool])[0]
Exemple #2
0
    def _convolution_op(self, cntk_op, inputs):
        """
        Computes the convolution of a tensor with operand.

        Arguments:
            cntk_op: CNTK operation to be imported.
            inputs: List of inputs to this node.

        Returns:
            A ngraph Op.
        """
        filters, inputs = inputs

        inputs = self._expand_input_axes(inputs)
        C, D, H, W, N = inputs.axes

        filters = self._expand_filters_axes(filters, C)
        _, T, M1, M2, O = filters.axes

        M, oH, oW = self._make_out_axes(cntk_op.shape)
        out_axes = [O, M, oH, oW, N]

        strides = self._make_strides(cntk_op.attributes['strides'])
        pad = self._make_padding('conv', cntk_op.attributes['autoPadding'],
                                 (D.length, H.length, W.length, C.length),
                                 (T.length, M1.length, M2.length, C.length),
                                 (M.length, oH.length, oW.length, O.length),
                                 strides)

        params = dict(pad_d=pad[0],
                      pad_h=pad[1],
                      pad_w=pad[2],
                      pad_c=pad[3],
                      str_d=strides[0],
                      str_h=strides[1],
                      str_w=strides[2],
                      str_c=strides[3],
                      dil_d=1,
                      dil_h=1,
                      dil_w=1)

        conv = ng.convolution(params, inputs, filters, out_axes)
        return squeeze_axes([conv])[0]
Exemple #3
0
    def CrossEntropyWithSoftmax(self, cntk_op, inputs):
        """
        Computes the softmax cross entropy between the inputs[0] and inputs[1].

        Arguments:
            cntk_op: CNTK operation to be imported.
            inputs: List of inputs to this node.

        Returns:
            A ngraph Op.
        """
        cast_0, cast_1 = squeeze_axes(inputs)

        if cast_0.axes.lengths != cast_1.axes.lengths:
            cast_0 = ng.Transpose(cast_0)
        assert cast_0.axes.lengths == cast_1.axes.lengths

        cast_0 = ng.cast_axes(cast_0, axes=cast_1.axes)
        loss = ng.cross_entropy_multi(ng.softmax(cast_0), cast_1)

        return ng.mean(loss, out_axes=()).named(cntk_op.uid)
Exemple #4
0
    def _cast_for_binary_op(self, inputs):
        """
        Cast axes for input with more axes by matching
        its axes with second input's axes.

        Arguments:
            inputs: List of inputs to be casted.

        Returns:
            Casted inputs.
        """
        assert len(inputs) == 2

        cast_0, cast_1 = squeeze_axes(inputs)

        if len(cast_0.axes) >= len(cast_1.axes):
            axes = self._match_axes(cast_0.axes, cast_1.axes)
            cast_0 = ng.cast_axes(cast_0, axes)
        else:
            axes = self._match_axes(cast_1.axes, cast_0.axes)
            cast_1 = ng.cast_axes(cast_1, axes)

        return cast_0, cast_1