Пример #1
0
def test_conv_flatten_deriv(n4_hw12_c3_5x5):
    """
    Test deriv of conv followed by flatten
    """
    cf = ConvParams(**n4_hw12_c3_5x5)

    axes_rsck = ng.make_axes([cf.ax_f[2], cf.ax_f[3], cf.ax_f[0], cf.ax_f[-1]])
    axes_rsck_prime = ng.make_axes([ng.make_axis(name=ax.name + 'p', length=ax.length)
                                    for ax in axes_rsck])
    axes_nmpqk = ng.make_axes([cf.ax_o[-1], cf.ax_o[1], cf.ax_o[2], cf.ax_o[3], cf.ax_o[0]])

    # broadcast input / filter axes
    input_var = ng.variable(cf.ax_i).named('input')
    input_val = np.ones(input_var.axes.lengths)

    filter_rsck_prime = ng.variable(axes_rsck_prime).named('filter')
    filter_var = filter_rsck_prime
    filter_rsck = ng.cast_axes(filter_rsck_prime, axes_rsck).named('frsck')
    filter_trsck = ng.expand_dims(filter_rsck, cf.ax_f[1], 0).named('ftrsck')
    filter_ctrsk = ng.axes_with_order(filter_trsck, axes=cf.ax_f).named('ctrsk')

    # convolution
    output_kmpqn = ng.convolution(cf.conv_params, input_var, filter_ctrsk, axes=cf.ax_o)
    output_nmpqk = ng.axes_with_order(output_kmpqn, axes=axes_nmpqk)

    # slice away the oD
    out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
    output_npqk = ng.tensor_slice(output_nmpqk, out_slicing)

    output = ng.flatten_at(output_npqk, idx=1)

    # cost and grad
    cost = ng.sum(output, out_axes=())

    filter_val = np.ones(filter_var.axes.lengths)

    with ExecutorFactory() as factory:

        conv_comp = factory.executor(output, filter_var, input_var)
        grad_filter_num_comp = factory.numeric_derivative(cost, filter_var, 1.0, input_var)
        grad_filter_sym_comp = factory.derivative(cost, filter_var, input_var)

        grad_input_num_comp = factory.numeric_derivative(cost, input_var, 1.0, filter_var)
        grad_input_sym_comp = factory.derivative(cost, input_var, filter_var)

        conv_val = conv_comp(filter_val, input_val)
        conv_val_num = np.empty_like(conv_val)
        conv_val_num.fill(np.prod(cf.ax_f.lengths[:-1]))
        ng.testing.assert_allclose(conv_val, conv_val_num)

        grad_filter_num_val = grad_filter_num_comp(filter_val, input_val)
        grad_filter_sym_val = grad_filter_sym_comp(filter_val, input_val)
        ng.testing.assert_allclose(grad_filter_num_val, grad_filter_sym_val)

        grad_input_num_val = grad_input_num_comp(input_val, filter_val)
        grad_input_sym_val = grad_input_sym_comp(input_val, filter_val)
        ng.testing.assert_allclose(grad_input_num_val, grad_input_sym_val)
Пример #2
0
def test_conv_flatten_deriv(transformer_factory):
    """
    Test deriv of conv followed by flatten
    """
    # set shape
    C, D, H, W, N = (3, 1, 28, 28, 8)
    C, T, R, S, K = (3, 1, 5, 5, 32)

    # i, f, o axes
    ax_i = ng.make_axes([ax.C, ax.D, ax.H, ax.W, ax.N])
    ax_f = ng.make_axes([ax.C, ax.T, ax.R, ax.S, ax.K])
    ax_o = ng.make_axes([
        ng.make_axis(32, roles=[ar.Channel]),
        ng.make_axis(1, roles=[ar.Depth]),
        ng.make_axis(24, roles=[ar.Height]),
        ng.make_axis(24, roles=[ar.Width]), ax.N
    ])
    ax_i.set_shape((C, D, H, W, N))
    ax_f.set_shape((C, T, R, S, K))
    params = dict(pad_d=0, pad_h=0, pad_w=0, str_d=1, str_h=1, str_w=1)
    axes_rsck = ng.make_axes([ax.R, ax.S, ax.C, ax.K])
    axes_rsck_prime = ng.make_axes(
        [ng.make_axis(l) for l in axes_rsck.lengths])

    # broadcast input / filter axes
    image = ng.constant(np.ones(ax_i.lengths), ax_i)
    filter = ng.variable(axes_rsck_prime, initial_value=np.ones((R, S, C, K)))
    filter_casted = ng.cast_axes(filter, axes_rsck)
    filter_casted = ng.expand_dims(filter_casted, ax.T, 0)
    filter_casted = ng.axes_with_order(filter_casted, axes=ax_f)

    # convolution
    output = ng.convolution(params, image, filter_casted, axes=ax_o)
    oC, oD, oH, oW, oN = output.axes
    output = ng.axes_with_order(output,
                                axes=ng.make_axes([oN, oD, oH, oW, oC]))

    # slice away the oD
    out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
    conv = ng.Slice(output, out_slicing)
    flatten = ng.flatten_at(conv, idx=1)

    # cost and grad
    cost = ng.sum(flatten, reduction_axes=flatten.axes)
    grad = ng.deriv(cost, filter)

    # compute
    conv_grad_comp = executor([conv, grad])
    conv_val, grad_val = conv_grad_comp()

    assert np.allclose(conv_val, np.zeros_like(conv_val) + 75.)
    assert np.allclose(grad_val, np.zeros_like(grad_val) + 4608.)
Пример #3
0
def reorder_spatial_axes(tensor):
    """
    Assumes we are getting a C, H, N, or C, H, W, N, or C, D, H, W, N
    """
    spatial_axes = tensor.axes.spatial_axes()
    batch_axes = tensor.axes.batch_axes()

    if len(spatial_axes) == 0 or len(spatial_axes) > 3:
        raise ValueError(
            'spatial ops can only operate on tensors with 1, 2, or 3 spatial axes.'
            'Found {}'.format(spatial_axes))

    if not batch_axes:
        raise ValueError('spatial ops require a batch axis')

    if not tensor.axes.channel_axis():
        c = ng.make_axis(length=1, name='C')
        tensor = ng.expand_dims(tensor, c, 0)
    channel_axes = ng.make_axes(tensor.axes.channel_axis())

    if len(spatial_axes) == 1:
        w = ng.make_axis(length=1, name=_WIDTH)
        tensor = ng.expand_dims(tensor, w, 0)
        spatial_axes = spatial_axes + w

    if len(spatial_axes) == 2:
        d = ng.make_axis(length=1, name=_DEPTH)
        tensor = ng.expand_dims(tensor, d, 0)
        spatial_axes = ng.make_axes([d]) + spatial_axes

    new_axes = channel_axes + spatial_axes + batch_axes
    return ng.axes_with_order(tensor, new_axes)
Пример #4
0
    def train_outputs(self, in_obj):
        """
        Arguments:
            in_obj (Tensor): object that provides the lookup indices
        """
        in_obj.axes.find_by_short_name('time')[0].add_role(ar.time)
        in_obj.axes.find_by_short_name('time')[0].is_recurrent = True
        in_obj = ng.axes_with_role_order(in_obj, self.role_order)
        in_obj = ng.flatten(in_obj)
        in_axes = in_obj.axes

        self.lut_v_axis = ng.make_axis(self.vocab_size).named('V')
        self.lut_f_axis = ng.make_axis(self.embed_dim).named('F')

        self.w_axes = ng.make_axes([self.lut_v_axis, self.lut_f_axis])
        self.lut_o_axes = in_axes + ng.make_axes([self.lut_f_axis])
        self.o_axes = ng.make_axes([self.lut_f_axis]) + in_axes[0].axes

        self.W = ng.variable(axes=self.w_axes,
                             initial_value=self.lut_init(
                                 self.w_axes, self.lut_v_axis,
                                 self.pad_idx)).named('W')

        lut_result = ng.lookuptable(self.W,
                                    in_obj,
                                    self.lut_o_axes,
                                    update=self.update,
                                    pad_idx=self.pad_idx)
        return ng.axes_with_order(ng.unflatten(lut_result), self.o_axes)
Пример #5
0
    def __call__(self, in_obj, **kwargs):
        """
        Arguments:
            in_obj (Tensor): object that provides the lookup indices
        """
        LABELS = {"weight": "weight", "bias": "bias"}

        in_obj = ng.axes_with_order(
            in_obj,
            ng.make_axes(
                [in_obj.axes.recurrent_axis(),
                 in_obj.axes.batch_axis()]))
        in_obj = ng.flatten(in_obj)
        in_axes = in_obj.axes

        # label lut_v_axis as shadow axis for initializers ... once #1158 is
        # in, shadow axis will do more than just determine fan in/out for
        # initializers.
        self.lut_v_axis = ng.make_axis(self.vocab_size).named('V')
        self.axes_map = shadow_axes_map([self.lut_v_axis])
        self.lut_v_axis = list(self.axes_map.values())[0]

        self.lut_f_axis = ng.make_axis(self.embed_dim).named('F')

        self.w_axes = ng.make_axes([self.lut_v_axis, self.lut_f_axis])
        self.lut_o_axes = in_axes | ng.make_axes([self.lut_f_axis])
        self.o_axes = ng.make_axes([self.lut_f_axis]) | in_axes[0].axes

        if not self.initialized:
            self.W = ng.variable(
                axes=self.w_axes,
                initial_value=self.lut_init(self.w_axes, self.lut_v_axis,
                                            self.pad_idx),
                metadata={
                    "label": LABELS["weight"]
                },
            ).named('LutW')

        lut_result = ng.lookuptable(self.W,
                                    in_obj,
                                    self.lut_o_axes,
                                    update=self.update,
                                    pad_idx=self.pad_idx)
        return ng.axes_with_order(
            ng.map_roles(ng.unflatten(lut_result), self.axes_map), self.o_axes)
Пример #6
0
    def run_inference(self, out_axes, init_states, **kwargs):
        if self.celltype == 'LSTM':
            init_states = [(state, ng.constant(0., state.axes))
                           for state in init_states]

        one_time_axis = ng.make_axis(1, name="REC")
        time_axis = out_axes.recurrent_axis()
        batch_axis = out_axes.batch_axis()
        feature_axis = (out_axes - [time_axis, batch_axis])[0]

        outputs = [ng.constant(0., [batch_axis, one_time_axis, feature_axis])]
        hidden_states = init_states

        for timestep in range(time_axis.length):
            in_obj = outputs[-1]

            # Compute the next hidden/cell states for the recurrent layers
            next_hidden_states = []
            for i, l in enumerate(self.layers[:-1]):
                if i < len(hidden_states):
                    init_state = hidden_states[i]
                else:
                    init_state = None

                if self.celltype == 'LSTM':
                    h, c = l(in_obj,
                             init_state=init_state,
                             return_cell_state=True)
                    in_obj = h

                    h = ng.slice_along_axis(h, one_time_axis, 0)
                    c = ng.slice_along_axis(c, one_time_axis, 0)
                    next_hidden_states.append((h, c))
                else:
                    h = l(in_obj, init_state=init_state)
                    in_obj = h

                    h = ng.slice_along_axis(h, one_time_axis, 0)
                    next_hidden_states.append((h, c))
            hidden_states = next_hidden_states

            # Compute the output of the affine layer
            in_obj = self.layers[-1](in_obj)
            outputs.append(in_obj)

        # Get rid of the initial 0 input
        outputs = outputs[1:]
        outputs = [
            ng.slice_along_axis(output, one_time_axis, 0) for output in outputs
        ]
        outputs = ng.stack(outputs, time_axis)
        outputs = ng.axes_with_order(outputs, out_axes)
        return outputs
Пример #7
0
    def test_dimshuffle_bprop(self, x, A, B):
        """
        dimshuffle a 2d array and make sure bprop works
        """
        # randomly initialize
        x_value = rng.uniform(-1, 1, x.axes)

        check_derivative(
            ng.axes_with_order(x, [B, A]),
            x, 0.001, x_value,
            atol=1e-3, rtol=1e-3
        )
Пример #8
0
    def NHWC2NCHW(self, c2_op, inputs):
        """ Returns data in NHWC format. """
        assert 1 == len(inputs)
        X = inputs[0]

        order = X.order if hasattr(X, 'order') else 'NHWC'
        if 'NHWC' != order:
            raise ValueError("NHWC2NCHW accepts only NHWC input format.")

        Y = ng.axes_with_order(
            X, axes=ng.make_axes([X.axes[0], X.axes[3], X.axes[1], X.axes[2]]))
        Y.order = 'NCHW'
        return Y
Пример #9
0
    def __call__(self, *args, **kwargs):

        output = super(Deepspeech, self).__call__(*args, **kwargs)

        # prepare activations/gradients for warp-ctc
        # TODO: This should be handled in a graph pass
        if self.to_ctc is True:
            warp_axes = ng.make_axes([output.axes.recurrent_axis(),
                                      output.axes.batch_axis()])
            warp_axes = warp_axes | output.axes.feature_axes()
            output = ng.axes_with_order(output, warp_axes)
            output = ng.ContiguousOp(output)

        return output
Пример #10
0
    def SparseSoftmaxCrossEntropyWithLogits(self, tf_node, inputs):
        """
        Computes softmax cross entropy. The inputs `logits` are unscaled log
        probabilities, and each row of `labels[i]` must be a valid distribution.
        Reference: https://goo.gl/z5T2my

        Arguments:
            tf_node: NodeDef object, the tensorflow node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the tensorflow node.

        Inputs to tf_node:
            logits, labels, name
        """

        # logits: (N1, Y1), labels: (N2,)
        logits, labels = inputs

        # check input dimension
        try:
            assert len(logits.axes) == 2
            assert len(labels.axes) == 1
            assert logits.axes[0].length == labels.axes[0].length
        except:
            raise NotImplementedError("logits' shape must be (Y, N), "
                                      "labels' shape must be (N,), "
                                      "other shapes not supported yet.")
        # get axis
        axis_y = logits.axes[1]

        # labels_one_hot: (Y2, N2)
        labels_one_hot = ng.one_hot(labels, axis=axis_y)

        # predicts: (N1, Y1)
        predicts = ng.softmax(logits, normalization_axes=axis_y)

        # dim-shuffle / cast to (Y1, N1)
        predicts_axes = ng.make_axes(
            [axis for axis in reversed(predicts.axes)])
        predicts = ng.axes_with_order(predicts, axes=predicts_axes)
        labels_one_hot = ng.cast_axes(labels_one_hot, predicts_axes)

        # cross_entropy: (N1,)
        cross_entropy = ng.cross_entropy_multi(predicts,
                                               labels_one_hot,
                                               out_axes=(logits.axes[0], ))

        return cross_entropy
Пример #11
0
    def test_dimshuffle_fprop(self, x, A, B):
        """
        dimshuffle a 2d array and make sure fprop works
        """
        # compute convolution with graph
        output = ng.axes_with_order(x, [B, A])

        assert output.axes == ng.make_axes([B, A])

        # randomly initialize
        x_value = rng.uniform(-1, 1, x.axes)

        with executor(output, x) as ex:
            result = ex(x_value)

        ng.testing.assert_allclose(result, x_value.T)
Пример #12
0
def test_dimshuffle_bprop(transformer_factory):
    """
    dimshuffle a 2d array and make sure bprop works
    """
    A = ng.make_axis(2)
    B = ng.make_axis(3)

    x = ng.placeholder(ng.make_axes([A, B]))

    # randomly initialize
    x_value = rng.uniform(-1, 1, x.axes)

    check_derivative(
        ng.axes_with_order(x, [B, A]),
        x, 0.001, x_value,
        atol=1e-3, rtol=1e-3
    )
Пример #13
0
def test_idempotent_axes_c():
    """
    Test test axes transformations with autodiff, case c, with broadcast,
    slice, cast and dim-shuffle
    """
    with ExecutorFactory() as ex:
        axes = ng.make_axes([ng.make_axis(3), ng.make_axis(1)])
        result_axes = [ng.make_axis(length=axis.length) for axis in axes]

        # variable
        w = ng.variable(axes, initial_value=np.ones((3, 1)))

        # broadcast l / r, introducing dummy length 1 axes
        l = ng.broadcast(w, axes)
        r = ng.broadcast(w, axes)

        # slice
        axes_slice = [slice(None, None, None), slice(None, None, None)]
        l_sliced = ng.tensor_slice(l, axes_slice)
        r_sliced = ng.tensor_slice(r, axes_slice)

        # cast r
        r_sliced_casted = ng.cast_axes(r_sliced, axes)

        # perform add
        result = ng.add(l_sliced, r_sliced_casted)

        # cast / dimshuffle
        result = ng.cast_axes(result, result_axes)
        result = ng.axes_with_order(result, result_axes)

        # cost and grad
        cost = ng.sum(result, reduction_axes=result.axes)
        grad = ng.deriv(cost, w)

        grad_comp = ex.executor(grad)
        cost_comp = ex.executor(cost)

        cost_comp_ng = cost_comp()
        grad_comp_ng = grad_comp()
        grad_comp_np = np.ones((3, 1)) * 2.
        assert cost_comp_ng == 6.0
        assert np.array_equal(grad_comp_ng, grad_comp_np)
Пример #14
0
def test_shuffled_deriv(transformer_factory):
    # This gets the axes of a delta in a generate_add_delta in a different order than the
    # value being updated
    C = ng.make_axis(length=3)
    T = ng.make_axis(length=1)
    R = ng.make_axis(length=5)
    S = ng.make_axis(length=5)

    axes = [R, S, C]
    v = ng.variable([ng.make_axis(_.length) for _ in axes])
    rsc = ng.cast_axes(v, axes)
    trsc = ng.expand_dims(rsc, T, 0)
    ctrs = ng.axes_with_order(trsc, axes=[C, T, R, S])
    cost = ng.sum(ctrs, out_axes=None)
    grad = ng.deriv(cost, v)

    with ExecutorFactory() as ex:
        d_fun = ex.executor(grad)
        d_fun()
Пример #15
0
def reorder_pos_axes(x, prefix=POS_AXIS_PREFIX):
    """
    Reorder x's axes to descending positional axes. E.g.
    x's axes: [POS_1, POS_2, POS_0] => [POS_2, POS_1, POS_0]

    Args:
        x: ngrpah op

    Returns:
        x reordered to descending positional axes.
    """
    # get axes names
    axes_names = [axis.name for axis in x.axes]
    num_axes = len(axes_names)

    # check axes names are valid
    for name in axes_names:
        if name[:len(prefix)] != prefix:
            raise ValueError("axis {} is not a valid positional axes, "
                             "to be valid, must have prefix {}".format(
                                 name, prefix))

    axes_positions = [int(name[len(prefix):]) for name in axes_names]
    if sorted(axes_positions) != list(range(num_axes)):
        raise ValueError("axes positions {} must be continuous integers "
                         "starting from 0")

    # special case, x is already in a good order
    if (axes_positions == reversed(list(range(num_axes)))):
        return x

    # get a position -> length map
    map_pos_length = dict()
    for pos, length in zip(axes_positions, x.axes.lengths):
        map_pos_length[pos] = length

    # get shape after reordering
    new_shapes = [
        map_pos_length[pos] for pos in reversed(list(range(num_axes)))
    ]

    return ng.axes_with_order(x, axes=make_pos_axes(new_shapes))
Пример #16
0
def test_shuffled_deriv():
    # This gets the axes of a delta in a generate_add_delta in a different order than the
    # value being updated
    ax = ng.make_name_scope("ax")
    ax.C = ng.make_axis(3)
    ax.T = ng.make_axis(1)
    ax.R = ng.make_axis(5)
    ax.S = ng.make_axis(5)

    axes = [ax.R, ax.S, ax.C]
    v = ng.variable([ng.make_axis(_.length) for _ in axes])
    rsc = ng.cast_axes(v, axes)
    trsc = ng.expand_dims(rsc, ax.T, 0)
    ctrs = ng.axes_with_order(trsc, axes=[ax.C, ax.T, ax.R, ax.S])
    cost = ng.sum(ctrs, out_axes=None)
    grad = ng.deriv(cost, v)

    ex = ExecutorFactory()
    d_fun = ex.executor(grad)
    d_fun()
Пример #17
0
def test_dimshuffle_fprop(transformer_factory):
    """
    dimshuffle a 2d array and make sure fprop works
    """
    A = ng.make_axis(2)
    B = ng.make_axis(3)

    x = ng.placeholder(ng.make_axes([A, B]))

    # compute convolution with graph
    output = ng.axes_with_order(x, [B, A])

    assert output.axes == ng.make_axes([B, A])

    # randomly initialize
    x_value = rng.uniform(-1, 1, x.axes)

    result = executor(output, x)(x_value)

    np.testing.assert_allclose(result, x_value.T)
Пример #18
0
def sparse_softmax_cross_entropy_with_logits(labels=None,
                                             logits=None,
                                             name=None):
    """
    Computes softmax cross entropy. The inputs `logits` are unscaled log
    probabilities, and each row of `labels[i]` must be a valid distribution.

    Args:
        labels: of axis (N,) for (POS_0,)
        logits: of axis (N, Y) for (POS_1, POS_0)
        name: name of the ngraph op
    """
    # Check input dimension
    #         (    N,     Y),         (    N)
    # logits: (pos_1, pos_0), labels: (pos_0)
    try:
        assert len(logits.axes) == 2
        assert len(labels.axes) == 1
        assert logits.axes[0].length == labels.axes[0].length
    except:
        raise NotImplementedError("logits' shape must be (N, Y), "
                                  "labels' shape must be (N,), "
                                  "other shapes not supported yet.")
    # get axis
    axis_n, axis_y = logits.axes

    # convert labels to one-hot labels
    labels = ng.cast_axes(labels, ng.make_axes(axis_n))
    labels = ng.one_hot(labels, axis=axis_y)
    labels = ng.axes_with_order(labels, axes=logits.axes)

    # predicts: (N, Y)
    predicts = ng.softmax(logits, normalization_axes=axis_y)

    # cross_entropy: (N)
    res = ng.cross_entropy_multi(predicts, labels, out_axes=(axis_n, ))
    return cast_to_pos_axes(res).named(name)
Пример #19
0
 def test_fail_on_missing_and_extra_axis(self, x, A, C):
     with pytest.raises(ValueError):
         ng.axes_with_order(x, [A, C])
Пример #20
0
    def MaxPool(self, tf_node, inputs):
        """
        Performs the max pooling on the input.

        Arguments:
            tf_node: NodeDef object, the tensorflow node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the tensorflow node.

        Inputs to tf_node:
            input

        TODO: assume default tensorflow layout NHWC, RSCK,
              need to support NCHW as well
              need to clean up / merge with conv2d

        Axes:
                      Tensorflow          Ngraph
            in       (N, H, W, C)     (C, D, H, W, N)
            out      (N, P, Q, K)     (K, M, P, Q, N)

        Notes on output shape:
            https://www.tensorflow.org/api_docs/python/nn.html#convolution
        """
        image = inputs[0]

        # TODO: currently NHWC only
        assert tf_node.attr['data_format'].s.decode("ascii") == "NHWC"

        # new axes
        C, D, H, W, K, M, P, Q = [ng.make_axis() for _ in range(8)]
        N = ng.make_axis(name='N')
        D.length, M.length = 1, 1  # only supports 2D conv for now

        # tf's input axes
        ax_i_tf = ng.make_axes([N, H, W, C])
        ax_i_tf.set_shape(image.axes.lengths)

        # ksize params
        tf_ksize = [int(s) for s in list(tf_node.attr['ksize'].list.i)]
        if len(tf_ksize) != 4:
            raise ValueError("Length of ksize my be 4.")
        if tf_ksize[0] != 1:
            raise NotImplementedError('Ksize on batch axis (N) must be 1.')
        if tf_ksize[3] != 1:
            raise NotImplementedError('Ksize on channel axis (C) must be 1.'
                                      'Cross map pooling to be implemented.')
        R_length, S_length = tf_ksize[1:3]
        T_length = J_length = 1

        # strides params
        tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)]
        if len(tf_strides) != 4:
            raise ValueError("Length of strides my be 4.")
        if tf_strides[0] != 1:
            raise NotImplementedError('Strides on batch axis (N) must be 1.')
        if tf_strides[3] != 1:
            raise NotImplementedError('Strides on channel axis (C) must be 1.')
        str_h, str_w = tf_strides[1], tf_strides[2]

        # padding params
        padding = tf_node.attr['padding'].s.decode("ascii")
        pad_t, pad_b, pad_l, pad_r = common_conv2d_pool_padding(
            image.axes.lengths, (R_length, S_length, C.length, C.length),
            tf_strides, padding)
        if pad_t != pad_b or pad_l != pad_r:
            raise NotImplementedError("Requires symmetric padding in ngraph:"
                                      "pad_t(%s) == pad_b(%s) and"
                                      "pad_l(%s) == pad_r(%s)" %
                                      (pad_t, pad_b, pad_l, pad_r))
        # pooling params
        params = dict(op='max',
                      pad_d=0, pad_h=pad_t, pad_w=pad_l, pad_c=0,
                      str_d=1, str_h=str_h, str_w=str_w, str_c=1,
                      J=J_length, T=T_length, R=R_length, S=S_length)

        # tf's output axes
        ax_o_tf = ng.make_axes([N, P, Q, K])
        ax_o_tf.set_shape(common_conv2d_pool_output_shape(image.axes.lengths,
                                                          (R_length, S_length,
                                                           C.length, C.length),
                                                          tf_strides, padding))

        # ngraph's i, f, o axes
        ax_i = ng.make_axes([C, D, H, W, N])
        ax_o = ng.make_axes([K, M, P, Q, N])

        # image NHWC -> CDHWN
        image = ng.cast_axes(image, ng.make_axes([N, H, W, C]))
        image = ng.expand_dims(image, D, 1)  # NHWC -> NDHWC
        image = ng.axes_with_order(image, ax_i)  # NDHWC -> CDHWN

        # pooling
        output = ng.pooling(params, image, axes=ax_o)

        # output KMPQN -> NPQK
        # KMPQN -> NMPQK
        output = ng.axes_with_order(output, ng.make_axes(
            [N, M, P, Q, K]))
        # NMPQK -> NPQK
        output = ng.tensor_slice(output, [slice(None), 0, slice(None),
                                          slice(None), slice(None)])

        return output
Пример #21
0
    def Conv2D(self, tf_node, inputs):
        """
        Computes a 2-D convolution given 4D input and filter tensors.

        Arguments:
            tf_node: NodeDef object, the tensorflow node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the tensorflow node.

        Inputs to tf_node:
            input, filter

        TODO: assume default tensorflow layout NHWC, RSCK,
              need to support NCHW as well
              need to clean up / merge with maxpool

        Axes:
                      Tensorflow          Ngraph
            in       (N, H, W, C)     (C, D, H, W, N)
            filter   (R, S, C, K)     (C, T, R, S, K)
            out      (N, P, Q, K)     (K, M, P, Q, N)

        Notes on output shape:
            https://www.tensorflow.org/api_docs/python/nn.html#convolution
        """
        image, weight = inputs

        # TODO: currently NHWC only
        if tf_node.attr['data_format'].s.decode("ascii") != "NHWC":
            raise NotImplementedError("Only supports NHWC import for now.")

        # check in_C == f_C
        if image.axes.lengths[3] != weight.axes.lengths[2]:
            raise ValueError("Image's C dimension (%s) must be equal to "
                             "filter's C dimension (%s)."
                             % (image.axes.lengths[3], weight.axes.lengths[2]))

        # strides params
        tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)]
        if len(tf_strides) != 4:
            raise ValueError("Length of strides my be 4.")
        if tf_strides[0] != 1:
            raise NotImplementedError('Strides on batch axis (N) must be 1.')
        if tf_strides[3] != 1:
            raise NotImplementedError('Strides on channel axis (C) must be 1.')
        str_h, str_w = tf_strides[1], tf_strides[2]

        # padding params
        padding = tf_node.attr['padding'].s.decode("ascii")
        pad_t, pad_b, pad_l, pad_r = common_conv2d_pool_padding(
            image.axes.lengths, weight.axes.lengths, tf_strides, padding)
        if pad_t != pad_b or pad_l != pad_r:
            raise NotImplementedError("Requires symmetric padding in ngraph:"
                                      "pad_t(%s) == pad_b(%s) and"
                                      "pad_l(%s) == pad_r(%s)" %
                                      (pad_t, pad_b, pad_l, pad_r))

        # conv params
        params = dict(pad_d=0, pad_h=pad_t, pad_w=pad_l,
                      str_d=1, str_h=str_h, str_w=str_w,
                      dil_d=1, dil_h=1, dil_w=1)

        # new axes
        C, D, H, W, T, R, S, K, M, P, Q = [ng.make_axis() for _ in range(11)]
        N = ng.make_axis(name='N')
        D.length, T.length, M.length = 1, 1, 1  # only supports 2D conv for now

        # tf's i, f, o axes
        ax_i_tf = ng.make_axes([N, H, W, C])
        ax_f_tf = ng.make_axes([R, S, C, K])
        ax_o_tf = ng.make_axes([N, P, Q, K])
        ax_i_tf.set_shape(image.axes.lengths)
        ax_f_tf.set_shape(weight.axes.lengths)
        ax_o_tf.set_shape(common_conv2d_pool_output_shape(image.axes.lengths,
                                                          weight.axes.lengths,
                                                          tf_strides, padding))

        # ngraph's i, f, o axes
        ax_i = ng.make_axes([C, D, H, W, N])
        ax_f = ng.make_axes([C, T, R, S, K])
        ax_o = ng.make_axes([K, M, P, Q, N])

        # image NHWC -> CDHWN
        image = ng.cast_axes(image, ng.make_axes([N, H, W, C]))
        image = ng.expand_dims(image, D, 1)  # NHWC -> NDHWC
        image = ng.axes_with_order(image, ax_i)  # NDHWC -> CDHWN

        # weights RSCK -> CTRSK
        weight = ng.cast_axes(weight, ng.make_axes([R, S, C, K]))
        weight = ng.expand_dims(weight, T, 0)  # RSCK -> TRSCK
        weight = ng.axes_with_order(weight, ax_f)  # TRSCK -> CTRSK

        # convolution
        output = ng.convolution(params, image, weight, axes=ax_o)

        # output KMPQN -> NPQK
        # KMPQN -> NMPQK
        output = ng.axes_with_order(output, ng.make_axes([N, M, P, Q, K]))
        # NMPQK -> NPQK
        output = ng.tensor_slice(output, [slice(None), 0, slice(None),
                                          slice(None), slice(None)])

        return output
    def __call__(self,
                 in_obj,
                 channel_axes="C",
                 spatial_axes=("D", "H", "W"),
                 **kwargs):
        """
        Arguments:
            in_obj (Op): Input op
            channel_axes (str): name of the expected channel axis type - defaults to "C"
            spatial_axes (tuple): names of expected depth, height and width axis types - defaults
                                  to "D", "H", and "W"
        """
        if isinstance(spatial_axes, dict):
            spatial_axes = tuple(
                spatial_axes.get(name, name) for name in ("D", "H", "W"))
        elif isinstance(spatial_axes, tuple):
            if len(spatial_axes) < 3:
                raise ValueError(
                    "spatial_axes must have length 3 (e.g. ('D', 'H', 'W'))")
            spatial_axes = tuple(
                name if name else default
                for name, default in zip(spatial_axes, ("D", "H", "W")))

        orig_axes = in_obj.axes
        in_obj = reorder_spatial_axes(in_obj, channel_axes, spatial_axes)
        channel_axes = in_obj.axes.get_by_names(channel_axes)
        spatial_axes = in_obj.axes.get_by_names(*spatial_axes)

        filter_axes = self._filter_axes(channel_axes, spatial_axes)

        # mark 'K' as a shadow axis for the initializers.
        axes_map = shadow_axes_map(filter_axes.find_by_name('K'))
        filter_axes = ng.make_axes([
            axis if axis.name != 'K' else list(axes_map.keys())[0]
            for axis in filter_axes
        ])

        if not self.initialized:
            if not self.weight_norm:
                self.W = ng.variable(axes=filter_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("W")
            else:
                self.v = ng.variable(axes=filter_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("v")
                out_axes = ng.make_axes(
                    [filter_axes.get_by_names("K__NG_SHADOW")])
                v_norm = ng.mean(ng.square(self.v), out_axes=out_axes)
                self.g = ng.variable(axes=out_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("g")
                self.W = self.g * self.v * ng.reciprocal(
                    ng.sqrt(v_norm + 1e-3))
        else:
            if filter_axes != self.W.axes:
                raise ValueError(
                    ("{layer_name} layer has already been initialized with an "
                     "input object which has resulted in filter axes: "
                     "{existing_filter_axes}. This new input object has axes: "
                     "{input_axes}, which implies the need for filter axes: "
                     "{new_filter_axes} which are different than the existing "
                     "filter axes.").format(
                         layer_name=self.name,
                         existing_filter_axes=self.W.axes,
                         input_axes=in_obj.axes,
                         new_filter_axes=filter_axes,
                     ))

        output = ng.map_roles(
            self._conv_op(in_obj, channel_axes, spatial_axes), axes_map)
        # Reorder the output to match the input order
        output_axis_order = ng.make_axes(
            [output.axes.find_by_name(ax.name)[0] for ax in orig_axes])
        # Remove introduced axes. If their length is > 1, then perhaps they should be kept
        slices = [
            0 if (ax not in orig_axes) and ax.length == 1 else slice(None)
            for ax in output.axes
        ]
        output = ng.tensor_slice(output, slices)
        # New axes with length > 1 may have been introduced. Add them to the end.
        output_axis_order = output_axis_order | output.axes
        return ng.axes_with_order(output, output_axis_order)
Пример #23
0
    def __call__(self,
                 H_pr,
                 h_ip,
                 states,
                 output=None,
                 reset_cells=True,
                 input_data=None):
        """
        Arguments:
        ----------
        H_pr : Encoding for question
        h_ip: Sliced input of paragraph encoding for a particular time step
        states: State of the LSTM cell
        output: previous hidden state
        input_data: the ArrayIterator object for training data (contains information of
                                                        length of each sentence)
        """
        # get recurrent axis for question
        rec_axis_pr = H_pr.axes.recurrent_axis()
        const_one = ng.constant(const=1, axes=[self.dummy_axis])
        # if first word in a paragraph is encountered, assign the previous LSTM
        # hidden state as zeros
        if output is None:
            h_r_old = ng.constant(axes=[self.F, self.N], const=0)
        else:
            h_r_old = ng.cast_axes(output, [self.F, self.N])

        # Compute attention vector
        sum_1 = ng.dot(self.W_q, H_pr)
        sum_1 = ng.cast_axes(sum_1,
                             [self.hidden_rows, self.hidden_cols_ques, self.N])
        int_sum1 = ng.dot(self.W_p, h_ip)
        int_sum2 = ng.dot(self.W_r, h_r_old)
        int_sum = int_sum1 + int_sum2 + self.b_p
        int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1)

        # making for the attention vector
        req_mask = ng.axes_with_order(
            ng.cast_axes(ng.dot(self.e_q2, input_data['question_len']),
                         [self.hidden_rows, self.N, self.hidden_cols_ques]),
            [self.hidden_rows, self.hidden_cols_ques, self.N])

        req_mask_2 = ng.axes_with_order(
            ng.cast_axes(ng.dot(const_one, input_data['question_len']),
                         [self.N, self.hidden_cols_ques]),
            [self.hidden_cols_ques, self.N])

        G_i_int = sum_1 + ng.multiply(
            req_mask,
            ng.axes_with_order(
                ng.dot(int_sum, self.e_q),
                [self.hidden_rows, self.hidden_cols_ques, self.N]))

        G_i = ng.tanh(G_i_int)
        # Attention Vector
        at_sum1 = ng.dot(self.w_lr, G_i)
        at = ng.softmax(at_sum1 + ng.log(req_mask_2))
        at_repeated = ng.cast_axes(
            ng.dot(self.e_q2, ng.ExpandDims(at, self.dummy_axis, 0)),
            [self.F, rec_axis_pr, self.N])

        # Stack the 2 vectors as per the equation in the paper
        z1 = h_ip
        z2 = ng.sum(ng.multiply(H_pr, at_repeated), rec_axis_pr)
        # represents the inp to lstm_cell
        # ng.concat_along_axis([z1,z2],self.F)
        inputs_lstm = ng.dot(self.ZX, z1) + ng.dot(self.ZY, z2)

        # LSTM cell computations (from LSTM brach in ngraph)
        if self.out_axes is None:
            self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis()
        if states is None:
            states = self.initialize_states(inputs_lstm.axes.batch_axis(),
                                            reset_cells=reset_cells)
        assert self.out_axes == states['h'].axes

        for gate in self._gate_names:
            transform = self.gate_transform[gate]
            gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate](
                states['h'])
            self.gate_output[gate] = ng.cast_role(transform(gate_input),
                                                  self.out_axes)

        states['c'] = (states['c'] * self.gate_output['f'] +
                       self.gate_output['i'] * self.gate_output['g'])
        states['h'] = self.gate_output['o'] * self.activation(states['c'])
        states['h'] = ng.cast_role(states['h'], self.out_axes)
        # return unrolled output and state of LSTM cell
        return ng.cast_axes(states['h'], axes=[self.F, self.N]), states
Пример #24
0
    def MaxPool(self, tf_node, inputs):
        """
        Performs the max pooling on the input.

        Arguments:
            tf_node: NodeDef object, the tensorflow node tso convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the tensorflow node.

        Inputs to tf_node:
            input

        TODO: assume default tensorflow layout NHWC, RSCK,
              need to support NCHW as well
              need to clean up / merge with conv2d

        Notes on output shape:
            https://www.tensorflow.org/api_docs/python/nn.html#convolution
        """
        image = inputs[0]

        # TODO: currently NHWC only
        assert tf_node.attr['data_format'].s.decode("ascii") == "NHWC"

        # set axes shape
        ax_N = ng.make_axis(batch=True)
        ax_C = ng.make_axis(roles=[ar.Channel])
        ax_D = ng.make_axis(roles=[ar.Depth])
        ax_H = ng.make_axis(roles=[ar.Height])
        ax_W = ng.make_axis(roles=[ar.Width])
        ng.make_axes([ax_N, ax_H, ax_W, ax_C]).set_shape(image.axes.lengths)
        ax_D.length = 1

        # ksize params
        tf_ksize = [int(s) for s in list(tf_node.attr['ksize'].list.i)]
        if len(tf_ksize) != 4:
            raise ValueError("Length of ksize my be 4.")
        if tf_ksize[0] != 1:
            raise NotImplementedError('Ksize on batch axis (N) must be 1.')
        if tf_ksize[3] != 1:
            raise NotImplementedError('Ksize on channel axis (C) must be 1.'
                                      'Cross map pooling to be implemented.')
        R, S = tf_ksize[1:3]
        T = J = 1

        # strides params
        tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)]
        if len(tf_strides) != 4:
            raise ValueError("Length of strides my be 4.")
        if tf_strides[0] != 1:
            raise NotImplementedError('Strides on batch axis (N) must be 1.')
        if tf_strides[3] != 1:
            raise NotImplementedError('Strides on channel axis (C) must be 1.')
        str_h, str_w = tf_strides[1], tf_strides[2]

        # padding params
        padding = tf_node.attr['padding'].s.decode("ascii")
        pad_t, pad_b, pad_l, pad_r = tf_conv2d_pool_padding(
            image.axes.lengths, (R, S, ax_C.length, ax_C.length), tf_strides,
            padding)
        if pad_t != pad_b or pad_l != pad_r:
            raise NotImplementedError("Requires symmetric padding in ngraph:"
                                      "pad_t(%s) == pad_b(%s) and"
                                      "pad_l(%s) == pad_r(%s)" %
                                      (pad_t, pad_b, pad_l, pad_r))

        # pooling params
        params = dict(op='max',
                      pad_d=0,
                      pad_h=pad_t,
                      pad_w=pad_l,
                      pad_c=0,
                      str_d=1,
                      str_h=str_h,
                      str_w=str_w,
                      str_c=1,
                      J=J,
                      T=T,
                      R=R,
                      S=S)

        # i, f, o axes
        ax_i = ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N])
        ax_o = ng.make_axes([
            spatial_axis(ax_i, J, params['pad_c'], params['str_c'],
                         ar.Channel),
            spatial_axis(ax_i, T, params['pad_d'], params['str_d'], ar.Depth),
            spatial_axis(ax_i, R, params['pad_h'], params['str_h'], ar.Height),
            spatial_axis(ax_i, S, params['pad_w'], params['str_w'], ar.Width),
            ax_N
        ])

        # broadcast input / filter axes
        image = ng.cast_axes(image, ng.make_axes([ax_N, ax_H, ax_W, ax_C]))
        image = ng.expand_dims(image, ax_D, 1)  # NHWC -> NDHWC
        image = ng.axes_with_order(image, axes=ax_i)  # NDHWC -> CDHWN

        # pooling
        output = ng.pooling(params, image, axes=ax_o)

        # cast back to NHWC
        oC, oD, oH, oW, oN = output.axes
        output = ng.broadcast(output, ng.make_axes([oN, oD, oH, oW, oC]))

        # slice away the oD
        out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
        output = ng.Slice(output, out_slicing)

        return output
Пример #25
0
    def Conv2D(self, tf_node, inputs):
        """
        Computes a 2-D convolution given 4D input and filter tensors.

        Arguments:
            tf_node: NodeDef object, the tensorflow node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the tensorflow node.

        Inputs to tf_node:
            input, filter

        TODO: assume default tensorflow layout NHWC, RSCK,
              need to support NCHW as well
              need to clean up / merge with maxpool

        Notes on output shape:
            https://www.tensorflow.org/api_docs/python/nn.html#convolution
        """
        image, weight = inputs

        # TODO: currently NHWC only
        assert tf_node.attr['data_format'].s.decode("ascii") == "NHWC"

        # set axes shape
        ax_N = ng.make_axis(batch=True)
        ax_C = ng.make_axis(roles=[ar.Channel])
        ax_D = ng.make_axis(roles=[ar.Depth])
        ax_H = ng.make_axis(roles=[ar.Height])
        ax_W = ng.make_axis(roles=[ar.Width])

        ax_T = ng.make_axis(roles=[ar.Depth])
        ax_R = ng.make_axis(roles=[ar.Height])
        ax_S = ng.make_axis(roles=[ar.Width])
        ax_K = ng.make_axis(roles=[ar.Channelout])

        ng.make_axes([ax_N, ax_H, ax_W, ax_C]).set_shape(image.axes.lengths)
        ng.make_axes([ax_R, ax_S, ax_C, ax_K]).set_shape(weight.axes.lengths)
        ax_D.length = 1
        ax_T.length = 1

        # strides params
        tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)]
        if len(tf_strides) != 4:
            raise ValueError("Length of strides my be 4.")
        if tf_strides[0] != 1:
            raise NotImplementedError('Strides on batch axis (N) must be 1.')
        if tf_strides[3] != 1:
            raise NotImplementedError('Strides on channel axis (C) must be 1.')
        str_h, str_w = tf_strides[1], tf_strides[2]

        # padding params
        padding = tf_node.attr['padding'].s.decode("ascii")
        pad_t, pad_b, pad_l, pad_r = tf_conv2d_pool_padding(
            image.axes.lengths, weight.axes.lengths, tf_strides, padding)
        if pad_t != pad_b or pad_l != pad_r:
            raise NotImplementedError("Requires symmetric padding in ngraph:"
                                      "pad_t(%s) == pad_b(%s) and"
                                      "pad_l(%s) == pad_r(%s)" %
                                      (pad_t, pad_b, pad_l, pad_r))

        # conv params
        params = dict(pad_d=0,
                      pad_h=pad_t,
                      pad_w=pad_l,
                      str_d=1,
                      str_h=str_h,
                      str_w=str_w)

        # i, f, o axes
        ax_i = ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N])
        ax_f = ng.make_axes([ax_C, ax_T, ax_R, ax_S, ax_K])
        ax_o = ng.make_axes([
            ng.make_axis(ax_K.length, name='C', roles=[ar.Channel]),
            spatial_axis(ax_i, ax_f, params['pad_d'], params['str_d'],
                         ar.Depth),
            spatial_axis(ax_i, ax_f, params['pad_h'], params['str_h'],
                         ar.Height),
            spatial_axis(ax_i, ax_f, params['pad_w'], params['str_w'],
                         ar.Width), ax_N
        ])

        # broadcast input / filter axes
        image = ng.cast_axes(image, ng.make_axes([ax_N, ax_H, ax_W, ax_C]))
        image = ng.expand_dims(image, ax_D, 1)  # NHWC -> NDHWC
        image = ng.axes_with_order(image, axes=ax_i)  # NDHWC -> CDHWN
        weight = ng.cast_axes(weight, ng.make_axes([ax_R, ax_S, ax_C, ax_K]))
        weight = ng.expand_dims(weight, ax_T, 0)  # RSCK -> TRSCK
        weight = ng.axes_with_order(weight, axes=ax_f)  # TRSCK -> CTRSK

        # convolution
        output = ng.convolution(params, image, weight, axes=ax_o)

        # cast back to NHWC
        oC, oD, oH, oW, oN = output.axes
        output = ng.broadcast(output, ng.make_axes([oN, oD, oH, oW, oC]))

        # slice away the oD
        out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
        output = ng.Slice(output, out_slicing)

        return output
Пример #26
0
 def test_fail_on_axis_reuse(self, x, A, B):
     with pytest.raises(ValueError):
         ng.axes_with_order(x, [A, B, B])
Пример #27
0
# Set up drop out layer
dropout_val = ng.slice_along_axis(inputs['dropout_val'], N, 0)
dropout_1 = Dropout_Modified(keep=dropout_val)
dropout_2 = Dropout_Modified(keep=dropout_val)
drop_pointer = ng.maximum(dropout_val, ng.constant(const=0.8, axes=[]))
dropout_3 = Dropout_Modified(keep=drop_pointer)
dropout_4 = Dropout_Modified(keep=drop_pointer)

# Constants required for masking
const_LSTM = ng.constant(axes=[F, dummy_axis], const=1)
const_loss = ng.constant(axes=[ax.Y, dummy_axis], const=1)
const_LSTM_embed = ng.constant(axes=[F_embed, dummy_axis], const=1)

# Create masks
reorder_para_mask = ng.axes_with_order(
    inputs['para_len'], axes=[
        dummy_axis, inputs['para_len'].axes[2], N])

reorder_ques_mask = ng.axes_with_order(
    inputs['question_len'], axes=[
        dummy_axis, inputs['question_len'].axes[2], N])

# Masks for question and para after encoding layer
mask_para = ng.dot(const_LSTM, reorder_para_mask)
mask_question = ng.dot(const_LSTM,
                       ng.cast_axes(reorder_ques_mask, [dummy_axis, REC, N]))

# Masks for question and para after embedding/LookupTable layer
mask_para_embed = ng.dot(const_LSTM_embed, reorder_para_mask)
mask_question_embed = ng.dot(
    const_LSTM_embed, ng.cast_axes(
Пример #28
0
 def test_fail_on_missing(self, x, B):
     with pytest.raises(ValueError):
         ng.axes_with_order(x, [B, B])
Пример #29
0
    def __call__(self,
                 H_concat,
                 states=None,
                 output=None,
                 reset_cells=True,
                 input_data=None):
        """
        Arguments:
        ----------
        H_concat: Concatenated forward and reverse unrolled outputs of the
                 `MatchLSTMCell_withAttention` cell
        states: previous LSTM state
        output: hidden state from previous timestep
        reset_cells: argument to reset a cell
        input_data: the ArrayIterator object for training data
                    (contains information of length of each sentence)

        """

        rec_axis_pr = H_concat.axes.recurrent_axis()
        const_one = ng.constant(const=1, axes=[self.dummy_axis])

        b_k_lists = []
        # rec_axis_hy=H_hy.axes.recurrent_axis()
        for i in range(0, 2):
            if output is None:
                h_k_old = ng.constant(axes=[self.F, self.N], const=0)
            else:
                h_k_old = ng.cast_axes(output, [self.F, self.N])

            sum_1 = ng.dot(
                self.V_answer,
                ng.cast_axes(H_concat,
                             [self.lstm_feature_new, rec_axis_pr, self.N]))
            sum_1 = ng.cast_axes(
                sum_1, [self.hidden_rows, self.hidden_cols_para, self.N])

            int_sum2 = ng.dot(self.W_a, h_k_old)
            int_sum = int_sum2  # +self.b_a
            int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1)

            # Following notations from the paper
            # Compute Attention Vector
            F_i_int = sum_1 + ng.axes_with_order(
                ng.dot(int_sum, self.e_q),
                [self.hidden_rows, self.hidden_cols_para, self.N])

            F_i = ng.tanh(F_i_int)  # Attention Vector

            b_k_sum1 = ng.dot(self.v_lr, F_i)
            # This masking with -inf for length of para>max_para ensures that
            # when we do softmax over these values we get a 0
            mask_loss_new = ng.log(ng.dot(const_one, input_data['para_len']))
            mask_loss_new = ng.axes_with_order(
                ng.cast_axes(mask_loss_new, [self.N, self.hidden_cols_para]),
                [self.hidden_cols_para, self.N])

            # Add mask to the required logits
            b_k = ng.softmax(b_k_sum1 + mask_loss_new)
            b_k_req = ng.softmax(b_k_sum1 + mask_loss_new)
            b_k_repeated = ng.cast_axes(
                ng.dot(self.e_q2, ng.ExpandDims(b_k, self.dummy_axis, 0)),
                [H_concat.axes[0], rec_axis_pr, self.N])

            inputs_lstm = ng.sum(ng.multiply(H_concat, b_k_repeated),
                                 rec_axis_pr)

            # LSTM Cell calculations
            if self.out_axes is None:
                self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis(
                )
            if states is None:
                states = self.initialize_states(inputs_lstm.axes.batch_axis(),
                                                reset_cells=reset_cells)
            assert self.out_axes == states['h'].axes

            for gate in self._gate_names:
                transform = self.gate_transform[gate]
                gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate](
                    states['h'])
                self.gate_output[gate] = ng.cast_role(transform(gate_input),
                                                      self.out_axes)

            states['c'] = (states['c'] * self.gate_output['f'] +
                           self.gate_output['i'] * self.gate_output['g'])
            states['h'] = self.gate_output['o'] * self.activation(states['c'])
            states['h'] = ng.cast_role(states['h'], self.out_axes)

            output = states['h']

            # append required outputs
            b_k_lists.append(b_k_req)

        return b_k_lists
Пример #30
0
    def Conv(self, c2_op, inputs):
        """
        Computes a 2-D convolution given 4D input and filter tensors.

        Arguments:
            c2_op: NodeDef object, the caffe2 node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the caffe2 node.

        Inputs to c2_op:
            input, wegiths, filter

        Supports caffe2's layout NHWC and NCHW as well.
        """
        X, W, bias = inputs

        order = [val.s for val in c2_op.arg if val.name == "order"]
        if 1 != len(order):
            raise ValueError("Multiple order values in convolution")
        order = order[0]

        if order not in ("NHWC", "NCHW"):
            raise NotImplementedError("Unsupported order in convolution: {}",
                                      order)

        # set input axes shape
        ax_N = ng.make_axis(name='N')
        ax_C = ng.make_axis()
        ax_D = ng.make_axis(length=1)
        ax_H = ng.make_axis()
        ax_W = ng.make_axis()

        # set kernel axes shape
        ax_kernel_D = ng.make_axis(length=1)
        ax_kernel_H = ng.make_axis()
        ax_kernel_W = ng.make_axis()
        ax_kernel_ofm = ng.make_axis()

        # create placeholders for output axes
        oC = ng.make_axis(name='C')
        oD = ng.make_axis(name='D', length=1)
        oH = ng.make_axis(name='H')
        oW = ng.make_axis(name='W')

        axes_order = {
            'NCHW': {
                'X': [ax_N, ax_C, ax_H, ax_W],
                'W': [ax_kernel_ofm, ax_C, ax_kernel_H, ax_kernel_W]
            },
            'NHWC': {
                'X': [ax_N, ax_H, ax_W, ax_C],
                'W': [ax_kernel_ofm, ax_kernel_H, ax_kernel_W, ax_C]
            },
        }

        ng.make_axes(axes_order[order]['X']).set_shape(X.axes.lengths)
        ng.make_axes(axes_order[order]['W']).set_shape(W.axes.lengths)

        if 1 != len(bias.axes):
            raise ValueError("Bias's must be 1D.")
        if ax_kernel_ofm.length != bias.axes.lengths[0]:
            raise ValueError(
                "Bias's length must equal to number of output feature maps.")

        # strides params
        stride_size = [int(val.i) for val in c2_op.arg if val.name == "stride"]
        if len(stride_size) != 1:
            raise ValueError("Stride size must be scalar value")
        str_h = str_w = stride_size[0]

        # padding params
        pad_t, pad_b, pad_l, pad_r = \
            _c2_padding(c2_op,
                        in_NHWC=[ax_N.length, ax_H.length, ax_W.length, ax_C.length],
                        kernel_HWIO=[ax_kernel_H.length, ax_kernel_W.length,
                                     ax_C.length, ax_kernel_ofm.length],
                        stride_NHWC=[1, str_h, str_w, 1])

        if pad_t != pad_b or pad_l != pad_r:
            raise NotImplementedError("Requires symmetric padding in ngraph:"
                                      "pad_t(%s) == pad_b(%s) and"
                                      "pad_l(%s) == pad_r(%s)" %
                                      (pad_t, pad_b, pad_l, pad_r))

        # conv params
        params = dict(pad_d=0,
                      pad_h=pad_t,
                      pad_w=pad_l,
                      str_d=1,
                      str_h=str_h,
                      str_w=str_w,
                      dil_d=1,
                      dil_h=1,
                      dil_w=1)

        # input, weight, output axes
        internal_ax_dict = {
            'X':
            ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N]),
            'W':
            ng.make_axes(
                [ax_C, ax_kernel_D, ax_kernel_H, ax_kernel_W, ax_kernel_ofm])
        }

        oC.length = ax_kernel_ofm.length
        oH.length = output_dim(ax_H.length, ax_kernel_H.length,
                               params['pad_h'], params['str_h'])
        oW.length = output_dim(ax_W.length, ax_kernel_W.length,
                               params['pad_w'], params['str_w'])
        internal_ax_dict['Y'] = ng.make_axes([oC, oD, oH, oW, ax_N])

        # broadcast input / filter axes
        # flow for NHWC order:                   |  flow for NCHW order:
        # input:                                 |  input:
        #   expand dims: NHWC -> NDHWC           |    expand dims: NCHW -> NDCHW
        #   reorder:     NDHWC -> CDHWN          |    reorder:     NDCHW -> CDHWN
        # weights:                               |  weights:
        #   expand dims: (ofm)HWC -> D(ofm)HWC   |    expand dims: (ofm)CHWC -> D(ofm)CHW
        #   reorder:     D(ofm)HWC -> CDHW(ofm)  |    reorder:     D(ofm)CHW -> CDHW(ofm)

        X = ng.cast_axes(X, ng.make_axes(axes_order[order]['X']))
        X = ng.expand_dims(X, ax_D, 1)
        X = ng.axes_with_order(X, axes=internal_ax_dict['X'])
        W = ng.cast_axes(W, ng.make_axes(axes_order[order]['W']))
        W = ng.expand_dims(W, ax_kernel_D, 0)
        W = ng.axes_with_order(W, axes=internal_ax_dict['W'])

        # convolution
        Y = ng.convolution(params, X, W, axes=internal_ax_dict['Y'])

        # cast back to proper format
        Y = ng.broadcast(Y, ng.make_axes([ax_N, oD, oH, oW, oC])) if "NHWC" == order \
            else ng.broadcast(Y, ng.make_axes([ax_N, oD, oC, oH, oW]))  # NCHW

        # slice away the oD
        out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
        Y = ng.tensor_slice(Y, out_slicing)

        def _conv_bias_add(c2_op, inputs):
            X, bias = inputs
            bias = ng.cast_axes(bias,
                                axes=ng.make_axes(
                                    [X.axes[1 if 'NCHW' == order else 3]]))
            Y = ng.Add(X, bias)
            return Y

        return _conv_bias_add(c2_op, [Y, bias])
Пример #31
0
    def Pool(self, c2_op, inputs):
        """
        Performs max or average pooling on the input.

        Arguments:
            c2_op: NodeDef object, the tensorflow node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the c2_op node.

        Inputs to c2_op:
            input
        """
        supported_pooling = {'MaxPool': 'max', 'AveragePool': 'avg'}

        image = inputs[0]

        # TODO: we assume NCHW, make some assert here?

        # set input axes shape
        ax_N = ng.make_axis(name='N')
        ax_C = ng.make_axis()
        ax_D = ng.make_axis(length=1)
        ax_H = ng.make_axis()
        ax_W = ng.make_axis()
        ng.make_axes([ax_N, ax_C, ax_H, ax_W]).set_shape(image.axes.lengths)

        # create placeholders for output axes
        oC = ng.make_axis(name='C')
        oD = ng.make_axis(length=1, name='D')
        oH = ng.make_axis(name='H')
        oW = ng.make_axis(name='W')

        # spatial kernel size
        kernel_size = [int(val.i) for val in c2_op.arg if val.name == "kernel"]
        if len(kernel_size) != 1:
            raise ValueError("Kernel size must be scalar value")
        # kernel is square
        kernel_h = kernel_w = kernel_size[0]
        kernel_d = kernel_c = 1

        # strides params
        stride_size = [int(val.i) for val in c2_op.arg if val.name == "stride"]
        if len(stride_size) != 1:
            raise ValueError("Stride size must be scalar value")
        stride_h = stride_w = stride_size[0]

        # padding params
        pad_t, pad_b, pad_l, pad_r = \
            _c2_padding(c2_op,
                        in_NHWC=[ax_N.length, ax_H.length, ax_W.length, ax_C.length],
                        kernel_HWIO=[kernel_h, kernel_w, ax_C.length, ax_C.length],
                        stride_NHWC=[1, stride_h, stride_w, 1])
        if pad_t != pad_b or pad_l != pad_r:
            raise NotImplementedError("Requires symmetric padding in ngraph:"
                                      "pad_t(%s) == pad_b(%s) and"
                                      "pad_l(%s) == pad_r(%s)" %
                                      (pad_t, pad_b, pad_l, pad_r))

        # pooling params
        params = dict(op=supported_pooling[c2_op.type],
                      pad_d=0,
                      pad_h=pad_t,
                      pad_w=pad_l,
                      pad_c=0,
                      str_d=1,
                      str_h=stride_h,
                      str_w=stride_w,
                      str_c=1,
                      J=kernel_c,
                      T=kernel_d,
                      R=kernel_h,
                      S=kernel_w)

        # i, o axes
        oC.length = output_dim(ax_C.length, kernel_c, params['pad_c'],
                               params['str_c'])
        oD.length = output_dim(ax_D.length, kernel_d, params['pad_d'],
                               params['str_d'])
        oH.length = output_dim(ax_H.length, kernel_h, params['pad_h'],
                               params['str_h'])
        oW.length = output_dim(ax_W.length, kernel_w, params['pad_w'],
                               params['str_w'])
        ax_i = ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N])
        ax_o = ng.make_axes([oC, oD, oH, oW, ax_N])

        # broadcast input / filter axes
        image = ng.cast_axes(image, ng.make_axes([ax_N, ax_C, ax_H, ax_W]))
        image = ng.expand_dims(image, ax_D, 1)  # NCHW -> NDCHW
        image = ng.axes_with_order(image, axes=ax_i)  # NDCHW -> CDHWN

        # pooling
        output = ng.pooling(params, image, axes=ax_o)

        # cast back to NDCHW
        output = ng.broadcast(output, ng.make_axes([ax_N, oD, oC, oH, oW]))

        # slice away the oD
        out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
        output = ng.tensor_slice(output, out_slicing)

        return output