Ejemplo n.º 1
0
    def ReduceElements(self, cntk_op, inputs):
        """
        Returns a reduction operation (max, min, mean, sum, prod) or a calculation which matches
        CNTK's LogSum reduction (`reduce_log_sum_exp` function).

        Arguments:
            cntk_op: CNTK operation to be imported.
            inputs: List of inputs to this node.

        Returns:
            A ngraph Op.
        """
        assert len(inputs) == 1

        reduction_op_name = cntk_op.attributes.get('reductionOpName')
        # CNTK API defines a reductionKeepDimensions flag, but we currently don't use it
        # keep_dimensions = cntk_op.attributes.get('reductionKeepDimensions', False)

        cntk_op_attribute_axes = []
        if cntk_op.attributes.get('axisVec'):
            cntk_op_attribute_axes.extend(cntk_op.attributes.get('axisVec'))
        elif cntk_op.attributes.get('axis'):
            cntk_op_attribute_axes.append(cntk_op.attributes.get('axis'))

        # CNTK axes are numbered in reverse order: the last axis is labeled 0, the previous 1, etc.
        reduction_axes_indexes = [len(inputs[0].axes) - 1 - i
                                  for (_, _, i) in cntk_op_attribute_axes]
        reduction_ng_axes_list = [axis for (i, axis) in enumerate(inputs[0].axes)
                                  if i in reduction_axes_indexes]
        reduction_ng_axes = ng.Axes(axes=reduction_ng_axes_list)

        if reduction_op_name == 'Max':
            return ng.max(inputs[0], reduction_axes=reduction_ng_axes).named(cntk_op.uid)

        if reduction_op_name == 'Min':
            return ng.min(inputs[0], reduction_axes=reduction_ng_axes).named(cntk_op.uid)

        if reduction_op_name == 'Mean':
            return ng.mean(inputs[0], reduction_axes=reduction_ng_axes).named(cntk_op.uid)

        if reduction_op_name == 'Sum':
            return ng.sum(inputs[0], reduction_axes=reduction_ng_axes).named(cntk_op.uid)

        if reduction_op_name == 'Prod':
            return ng.prod(inputs[0], reduction_axes=reduction_ng_axes).named(cntk_op.uid)

        if reduction_op_name == 'LogSum':
            return ng.log(ng.sum(ng.exp(inputs[0]), reduction_axes=reduction_ng_axes))\
                .named(cntk_op.uid)

        raise NotImplementedError('CNTKImporter: ReduceElements does not support operation %s',
                                  reduction_op_name)
Ejemplo n.º 2
0
def test_prod_deriv(
        prod_deriv_arrays):  # Argon Transformer error - TODO triage
    """
    Test reduce product's gradient
    """
    def power_set(lst):
        """
        power_set([0, 1, 2]) is:
        [[], [0], [1], [0, 1], [2], [0, 2], [1, 2], [0, 1, 2]]
        """
        result = [[]]
        for x in lst:
            result.extend([subset + [x] for subset in result])
        return result

    def get_all_reduction_axes(axes):
        """
        Get all possible reduction axes
        """
        ndim = len(axes.lengths)
        if ndim == 0:
            return axes
        else:
            results = []
            all_indices = power_set(range(ndim))
            for indices in all_indices:
                if not indices:
                    results.append(ng.make_axes([]))
                else:
                    results.append(
                        ng.make_axes([axes[index] for index in indices]))
            return results

    def shape_to_axes(shape):
        """
        Convert shape to axes
        """
        if not shape:
            return ng.make_axes()
        axes = ng.make_axes([ng.make_axis(length=s) for s in shape])
        return axes

    x_val = prod_deriv_arrays
    axes = shape_to_axes(x_val.shape)
    all_reduction_axes = get_all_reduction_axes(axes)
    for reduction_axes in all_reduction_axes:
        x = ng.placeholder(axes=axes)
        x_prod = ng.prod(x, reduction_axes)
        check_derivative(x_prod, x, 0.001, x_val, atol=1e-3, rtol=1e-3)
Ejemplo n.º 3
0
def test_prod_constant(transformer_factory):
    """
    Test reduce product of constants
    """
    A0 = ng.make_axis(length=2)
    A1 = ng.make_axis(length=3)
    A2 = ng.make_axis(length=4)

    # ngrpah ops
    const_3d = ng.broadcast(ng.constant(2., axes=[]), axes=[A0, A1, A2])
    prod_0 = ng.prod(const_3d, reduction_axes=[A0])
    prod_1 = ng.prod(const_3d, reduction_axes=[A1])
    prod_2 = ng.prod(const_3d, reduction_axes=[A2])
    prod_0_1 = ng.prod(const_3d, reduction_axes=[A0, A1])
    prod_0_2 = ng.prod(const_3d, reduction_axes=[A0, A2])
    prod_1_2 = ng.prod(const_3d, reduction_axes=[A1, A2])
    prod_0_1_2 = ng.prod(const_3d, reduction_axes=[A0, A1, A2])

    # numpy results
    np_const_3d = np.ones((2, 3, 4)) * 2.
    res_0_np = np.prod(np_const_3d, axis=(0))
    res_1_np = np.prod(np_const_3d, axis=(1))
    res_2_np = np.prod(np_const_3d, axis=(2))
    res_0_1_np = np.prod(np_const_3d, axis=(0, 1))
    res_0_2_np = np.prod(np_const_3d, axis=(0, 2))
    res_1_2_np = np.prod(np_const_3d, axis=(1, 2))
    res_0_1_2_np = np.prod(np_const_3d, axis=(0, 1, 2))

    # define comp
    with ExecutorFactory() as ex:
        comps = ex.executor(
            [prod_0, prod_1, prod_2, prod_0_1, prod_0_2, prod_1_2, prod_0_1_2])

        res_0_ng, res_1_ng, res_2_ng, res_0_1_ng, res_0_2_ng, res_1_2_ng, res_0_1_2_ng = comps(
        )

    np.testing.assert_allclose(res_0_np, res_0_ng)
    np.testing.assert_allclose(res_1_np, res_1_ng)
    np.testing.assert_allclose(res_2_np, res_2_ng)
    np.testing.assert_allclose(res_0_1_np, res_0_1_ng)
    np.testing.assert_allclose(res_0_2_np, res_0_2_ng)
    np.testing.assert_allclose(res_1_2_np, res_1_2_ng)
    np.testing.assert_allclose(res_0_1_2_np, res_0_1_2_ng)
Ejemplo n.º 4
0
def test_prod_constant(prod_constant):
    """
    Test reduce product of constants
    """
    np_axis, ng_axis, axes_values = prod_constant

    # ngrpah op
    const_3d = ng.broadcast(ng.constant(2., axes=[]), axes=axes_values)
    prod = ng.prod(const_3d, reduction_axes=ng_axis)

    # numpy results
    np_const_3d = np.ones((2, 3, 4)) * 2.

    res_np = np.prod(np_const_3d, axis=np_axis)

    # define comp
    with ExecutorFactory() as ex:
        comps = ex.executor(prod)
        res_ng = comps()

    np.testing.assert_allclose(res_np, res_ng)
Ejemplo n.º 5
0
def test_prod_deriv(transformer_factory):
    """
    Test reduce product's gradient
    """
    def power_set(lst):
        """
        power_set([0, 1, 2]) is:
        [[], [0], [1], [0, 1], [2], [0, 2], [1, 2], [0, 1, 2]]
        """
        result = [[]]
        for x in lst:
            result.extend([subset + [x] for subset in result])
        return result

    def get_all_reduction_axes(axes):
        """
        Get all possible reduction axes
        """
        ndim = len(axes.lengths)
        if ndim == 0:
            return axes
        else:
            results = []
            all_indices = power_set(range(ndim))
            for indices in all_indices:
                if not indices:
                    results.append(ng.make_axes([]))
                else:
                    results.append(
                        ng.make_axes([axes[index] for index in indices]))
            return results

    def shape_to_axes(shape):
        """
        Convert shape to axes
        """
        if not shape:
            return ng.make_axes()
        axes = ng.make_axes([ng.make_axis(length=s) for s in shape])
        return axes

    # test cases
    test_cases = [
        np.array([[[1., 2., 3.], [4., 5., 0.], [0., 6., 0.]],
                  [[1., 2., 3.], [4., 5., 6.], [7., 8., 0.]]]),
        np.array([[1., 2., 3.], [4., 5., 0.], [0., 6., 0.]]),
        np.array([1., 2., 3.]),
        np.array([0., 2., 3.]),
        np.array([0., 0., 3.]),
        np.array([0., 0., 0.]),
        np.array([0.]),
        np.array([2.]),
        np.array(0.),
        np.array(2.),
    ]

    for x_val in test_cases:
        axes = shape_to_axes(x_val.shape)
        all_reduction_axes = get_all_reduction_axes(axes)
        for reduction_axes in all_reduction_axes:
            x = ng.placeholder(axes=axes)
            x_prod = ng.prod(x, reduction_axes)
            check_derivative(x_prod, x, 0.001, x_val, atol=1e-3, rtol=1e-3)
Ejemplo n.º 6
0
 def __call__(self, iteration):
     masked_gamma = (iteration >= self.schedule) * self.gamma
     masked_holes = (iteration < self.schedule)
     return self.base_lr * ng.prod(masked_gamma + masked_holes, out_axes=())