Ejemplo n.º 1
0
def test_softmax(transformer_factory, input_tensor):
    """TODO."""
    p_x = input_tensor
    N = p_x.axes.batch_axes()[0]
    W = p_x.axes.sample_axes()[0]
    # set up some distributions
    u = rng.uniform(0, 1, p_x.axes)
    u = u / sum(u, 0).reshape(1, N.length)

    # Put them in pre-softmax form
    x = np.log(u) + rng.uniform(-5000, 5000, ng.make_axes([N])).reshape(
        1, N.length)

    with ExecutorFactory() as ex:
        smax_w_fun = ex.executor(
            ng.softmax(p_x, normalization_axes=ng.make_axes([W])), p_x)
        smax_fun = ex.executor(ng.softmax(p_x), p_x)

        s = smax_w_fun(x)
        ng.testing.assert_allclose(s, u, atol=1e-6, rtol=1e-3)

        x = rng.uniform(-5000, 5000, p_x.axes)
        u = np_softmax(x, 0)
        s = smax_w_fun(x)
        ng.testing.assert_allclose(s, u, atol=1e-6, rtol=1e-3)

        # Test with softmax_axis default
        s = smax_fun(x)
        ng.testing.assert_allclose(s, u, atol=1e-6, rtol=1e-3)
Ejemplo n.º 2
0
def test_softmax(transformer_factory):
    """TODO."""
    N = ng.make_axis(name='N', batch=True)
    W = ng.make_axis(name='W')

    W.length = 128
    N.length = 10
    axes = ng.make_axes([W, N])

    # set up some distributions
    u = rng.uniform(0, 1, ng.make_axes([W, N]))
    u = u / sum(u, 0).reshape(1, N.length)

    # Put them in pre-softmax form
    x = np.log(u) + rng.uniform(-5000, 5000,
                                ng.make_axes([N])).reshape(1, N.length)
    p_x = ng.placeholder(axes)

    ex = ExecutorFactory()
    smax_w_fun = ex.executor(ng.softmax(p_x, softmax_axes=ng.make_axes([W])), p_x)
    smax_fun = ex.executor(ng.softmax(p_x), p_x)

    s = smax_w_fun(x)
    np.testing.assert_allclose(s, u, atol=1e-6, rtol=1e-3)

    x = rng.uniform(-5000, 5000, ng.make_axes([W, N]))
    u = np_softmax(x, 0)
    s = smax_w_fun(x)
    np.testing.assert_allclose(s, u, atol=1e-6, rtol=1e-3)

    # Test with softmax_axis default
    s = smax_fun(x)
    np.testing.assert_allclose(s, u, atol=1e-6, rtol=1e-3)
Ejemplo n.º 3
0
    def __call__(self, inputs):
        query = ng.cast_axes(inputs['query'], [self.batch_axis, self.sentence_rec_axis])

        # Query embedding [batch, sentence_axis, F]
        q_emb = self.LUT_A(query)

        # Multiply by position encoding and sum
        u_0 = ng.sum(q_emb * self.pos_enc, reduction_axes=[self.sentence_rec_axis])  # [batch, F]

        # Start a list of the internal states of the model.
        # Will be appended to after each memory hop
        u = [u_0]

        for hopn in range(self.nhops):
            keys = ng.cast_axes(inputs['keys'], [self.batch_axis, self.memory_axis,
                                self.sentence_rec_axis])
            value = ng.cast_axes(inputs['values'], [self.batch_axis, self.memory_axis,
                                 self.val_len_axis])

            # Embed keys
            m_emb_A = self.LUT_A(keys)
            m_A = ng.sum(m_emb_A * self.pos_enc,
                         reduction_axes=[self.sentence_rec_axis])  # [batch, memory_axis, F]

            # Compute scalar similarity between internal state and each memory
            # Equivalent to dot product between u[-1] and each memory in m_A
            dotted = ng.sum(u[-1] * m_A, reduction_axes=[self.embedding_axis])

            probs = ng.softmax(dotted, self.memory_axis)  # [batch, memory_axis]

            # Embed values with same embedding as keys, or new LUTs
            if self.use_v_luts:
                m_emb_C = self.LUTs_C[hopn](value)
            else:
                m_emb_C = self.LUT_A(value)

            m_C = ng.sum(m_emb_C * self.pos_enc, reduction_axes=[self.sentence_rec_axis])

            # Compute weighted sum of output embeddings
            o_k = ng.sum(probs * m_C, reduction_axes=[self.memory_axis])  # [batch, F]

            u_k = u[-1] + o_k  # [batch, F]

            # Add new internal state
            u.append(u_k)

        # Compute predicted answer from product of final internal state and final LUT weight matrix
        if self.use_v_luts:
            a_logits = ng.dot(self.LUTs_C[-1].W, u[-1])  # [batch, V]
        else:
            a_logits = ng.dot(self.LUT_A.W, u[-1])  # [batch, V]
        # rename V to vocab_axis to match answer
        a_logits = ng.cast_axes(a_logits, [self.vocab_axis, self.batch_axis])
        a_pred = ng.softmax(a_logits, self.vocab_axis)

        return a_pred, a_logits
Ejemplo n.º 4
0
def Softmax(onnx_node, ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Compute softmax normalized values for each layer in the batch of the given input."""
    input_ = ng_inputs[0]
    axis = onnx_node.get_attribute_value('axis', 1)
    if axis == -1:  # Use last dimension
        axis = len(input_.shape) - 1
    return ng.softmax(input_, range(axis, len(input_.shape)))
Ejemplo n.º 5
0
def create_loss_and_learner(model,
                            labels,
                            learning_rate,
                            momentum_coef=0.0,
                            wdecay=0.0,
                            nesterov=False,
                            gradient_clip_norm=None,
                            gradient_clip_value=None):
    """
    Auxiliary function to create loss function (cross entropy and softmax)
    and trainer using stochastic gradient descent with momentum.

    Arguments:
        model - imported model
        labels - placeholder for one-hot labels array
        learning_rate - learning rate for trainer
        momentum_coef - coefficient of momentum (deafult 0.0)
        wdecay - amount of weight decay (default 0.0)
        nesterov - use nesterov accelerated gradient (dafault False)
        gradient_clip_norm - target gradient norm (default None)
        gradient_clip_value - value to element-wise clip gradients (default None)

    Returns:
        Loss function (mean for batch)
    """
    if model.axes.lengths != labels.axes.lengths:
        labels = ng.Transpose(labels)
    assert model.axes.lengths == labels.axes.lengths
    model = ng.cast_axes(model, axes=labels.axes)

    loss = ng.cross_entropy_multi(ng.softmax(model), labels)
    optimizer = GradientDescentMomentum(learning_rate, momentum_coef, wdecay,
                                        gradient_clip_norm,
                                        gradient_clip_value, nesterov)
    return ng.sequential([optimizer(loss), ng.mean(loss, out_axes=())])
Ejemplo n.º 6
0
def test_cross_entropy_binary(transformer_factory):
    """TODO."""
    N = ng.make_axis(name='N')
    W = ng.make_axis(name='W')

    delta = .001
    W.length = 20
    N.length = 128
    axes = ng.make_axes([W, N])
    p_u = ng.placeholder(axes)
    u = rng.uniform(-3.0, 3.0, p_u.axes)
    p_v = ng.placeholder(axes)
    v = rng.uniform(-3.0, 3.0, p_u.axes)

    y = ng.sigmoid(p_u)
    t = ng.softmax(p_v)
    val_u = ng.cross_entropy_binary_inner(y, t)

    ex = ExecutorFactory()
    dval_u_num_fun = ex.numeric_derivative(val_u, p_u, delta, p_v)
    dval_u_graph_fun = ex.derivative(val_u, p_u, p_v)

    dval_u_num = dval_u_num_fun(u, v)
    dval_u_graph = dval_u_graph_fun(u, v)
    np.testing.assert_allclose(dval_u_graph, dval_u_num, atol=1e-2, rtol=1e-2)
def test_exit_condition(transformer_factory):
    bsz = 16
    class_num = 10

    # Limiting maximum absolute value for tensors elements to 7.9.
    #
    # There is used np.random.randn function to fill tensors with random values. It can give any
    # value as a result however values above 5 are highly improbable and would appear very rarely.
    # Limit 7.9 would almost never modify the tested tensor but would prevent from random
    # failures from time to time when the test is run in continuous environment.
    # This limit is approximate upper bound of range [4, 8). Numbers from this region can be
    # expressed by flexpoint number of the same dec.
    # Why not 15.9 that is approximate limit of [8, 16) range ?
    # Numbers above 8 are highly improbable and if appear from time to time can cause random
    # failures due to reduced accuracy of all numbers in tensor. Most numbers in normal
    # distribution are close to 0.

    is_flex = is_flex_factory(transformer_factory)
    clip_val = 7.9 if is_flex else 0

    N, Y = ng.make_axis(bsz), ng.make_axis(class_num)
    y_val = rng.randn_abs_clip(ng.make_axes([N, Y]), clip_max=clip_val)
    y = ng.constant(y_val, ng.make_axes([N, Y]))

    likelihood = ng.log(ng.softmax(y, normalization_axes=y.axes[1]))

    with ExecutorFactory() as ex:
        comp = ex.executor(likelihood)

        val1 = comp()
        val2 = comp()
        ng.testing.assert_allclose(val1, val2, atol=0, rtol=0)
Ejemplo n.º 8
0
def test_dynamic_attributes_softmax():
    axis = 2
    data = ng.parameter([1, 2, 3, 4], np.float32, "data_in")
    node = ng.softmax(data, axis)

    assert node.get_axis() == axis
    node.set_axis(3)
    assert node.get_axis() == 3
Ejemplo n.º 9
0
def test_softmax2(input_tensor):
    p_x = input_tensor
    x = rng.uniform(0, 1, p_x.axes)

    compare_f_at_x(ng.softmax(p_x),
                   p_x,
                   lambda x: np_softmax(x, 0),
                   x,
                   rtol=1e-5)
Ejemplo n.º 10
0
    def CrossEntropyWithSoftmax(self, cntk_op, inputs):
        """
        Computes the softmax cross entropy between the inputs[0] and inputs[1].

        Arguments:
            inputs: List of inputs to this node.

        Returns:
            A ngraph Op.
        """
        cast_0, cast_1 = self.cast_axes_for_compound_op(inputs)

        if isinstance(cast_0, ng.AssignableTensorOp):
            cast_1 = ng.softmax(cast_1)
        else:
            cast_0 = ng.softmax(cast_0)

        return ng.cross_entropy_multi(cast_0, cast_1).named(cntk_op.uid)
Ejemplo n.º 11
0
def test_softmax_rec(transformer_factory, recurrent_input_tensor):
    p_x = recurrent_input_tensor
    x = rng.uniform(0, 1, p_x.axes)

    compare_f_at_x(ng.softmax(p_x),
                   p_x,
                   lambda x: np_softmax(x, 0),
                   x,
                   rtol=1e-5)
Ejemplo n.º 12
0
    def __call__(self, x):
        """
        Returns the Softmax value.

        Arguments:
            x (Tensor or optree): Input value

        Returns:
            Tensor or optree: Output activation
        """
        return ng.softmax(x)
Ejemplo n.º 13
0
def test_softmax2(transformer_factory):
    N = ng.make_axis(name='N', batch=True)
    W = ng.make_axis(name='W')

    W.length = 3
    N.length = 10
    axes = ng.make_axes([W, N])

    x = rng.uniform(0, 1, axes)
    p_x = ng.placeholder(axes)

    compare_f_at_x(ng.softmax(p_x), p_x, lambda x: np_softmax(x, 0), x, rtol=1e-5)
Ejemplo n.º 14
0
def test_softmax_deriv(transformer_factory):
    N = ng.make_axis(name='N', batch=True)
    W = ng.make_axis(name='W')

    W.length = 3
    N.length = 10
    axes = ng.make_axes([W, N])

    x = rng.uniform(0, 1, axes)
    p_x = ng.placeholder(axes)

    check_derivative(ng.softmax(p_x), p_x, 0.001, x, atol=1e-2, rtol=1e-2)
Ejemplo n.º 15
0
    def _softmax_op(self, cntk_op, inputs):
        """
        Returns softmax of inputs[0].

        Arguments:
            cntk_op: CNTK operation to be imported.
            inputs: List of inputs to this node.

        Returns:
            A ngraph Op.
        """
        return ng.softmax(inputs[0]).named(cntk_op.uid)
Ejemplo n.º 16
0
def test_cross_entropy_rec(transformer_factory, recurrent_input_tensor):
    p_x = recurrent_input_tensor
    p_t = ng.placeholder(p_x.axes)

    cross_entropy_sm_x_t = ng.cross_entropy_multi(ng.softmax(p_x), p_t)

    x = rng.uniform(0, 1, p_x.axes)
    t = np_softmax(rng.uniform(0, 1, p_t.axes), 0)

    def f_np(x, t):
        return np_cross_entropy_multi(np_softmax(x, 0), t, axis=0)

    compare_f_at_x(cross_entropy_sm_x_t, [p_x, p_t], f_np, [x, t], rtol=1e-5)
Ejemplo n.º 17
0
    def SparseSoftmaxCrossEntropyWithLogits(self, tf_node, inputs):
        """
        Computes softmax cross entropy. The inputs `logits` are unscaled log
        probabilities, and each row of `labels[i]` must be a valid distribution.
        Reference: https://goo.gl/z5T2my

        Arguments:
            tf_node: NodeDef object, the tensorflow node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the tensorflow node.

        Inputs to tf_node:
            logits, labels, name
        """

        # logits: (N1, Y1), labels: (N2,)
        logits, labels = inputs

        # check input dimension
        try:
            assert len(logits.axes) == 2
            assert len(labels.axes) == 1
            assert logits.axes[0].length == labels.axes[0].length
        except:
            raise NotImplementedError("logits' shape must be (Y, N), "
                                      "labels' shape must be (N,), "
                                      "other shapes not supported yet.")
        # get axis
        axis_y = logits.axes[1]

        # labels_one_hot: (Y2, N2)
        labels_one_hot = ng.one_hot(labels, axis=axis_y)

        # predicts: (N1, Y1)
        predicts = ng.softmax(logits, normalization_axes=axis_y)

        # dim-shuffle / cast to (Y1, N1)
        predicts_axes = ng.make_axes(
            [axis for axis in reversed(predicts.axes)])
        predicts = ng.axes_with_order(predicts, axes=predicts_axes)
        labels_one_hot = ng.cast_axes(labels_one_hot, predicts_axes)

        # cross_entropy: (N1,)
        cross_entropy = ng.cross_entropy_multi(predicts,
                                               labels_one_hot,
                                               out_axes=(logits.axes[0], ))

        return cross_entropy
Ejemplo n.º 18
0
def test_cross_entropy_softmax_rec_deriv(transformer_factory, recurrent_input_tensor):
    p_x = recurrent_input_tensor
    p_t = ng.placeholder(p_x.axes)

    x = rng.uniform(0, 1, p_x.axes)
    t = np_softmax(rng.uniform(0, 1, p_t.axes), 0)

    check_derivative(
        ng.cross_entropy_multi(ng.softmax(p_x), p_t),
        p_x, 0.001, x,
        parameters=[p_t],
        parameter_values=[t],
        atol=1e-2, rtol=1e-2
    )
Ejemplo n.º 19
0
    def Softmax(self, cntk_op, inputs):
        """
        Returns softmax of inputs[0].

        Arguments:
            cntk_op: CNTK operation to be imported.
            inputs: List of inputs to this node.

        Returns:
            A ngraph Op.
        """
        assert len(inputs) == 1

        return ng.softmax(inputs[0])
Ejemplo n.º 20
0
def test_cross_entropy_multi_axis_order(transformer_factory, input_tensor):
    """If y and t have different axis orders, it should give the same result"""
    y = input_tensor
    t1 = ng.placeholder(y.axes)

    # Reorder axes
    feature_axis, batch_axis = y.axes
    t2 = ng.placeholder(ng.make_axes([batch_axis, feature_axis]))

    # Set up numpy variables
    np_y = np.random.uniform(0, 1, y.axes.lengths)
    if feature_axis.length > batch_axis.length:
        np_t1 = np.eye(feature_axis.length)[:, :batch_axis.length]
    else:
        np_t1 = np.eye(batch_axis.length)[:feature_axis.length, :]
    np_t2 = np_t1.T

    with ExecutorFactory() as ex:
        f1 = ex.executor(ng.cross_entropy_multi(ng.softmax(y), t1), y, t1)
        f2 = ex.executor(ng.cross_entropy_multi(ng.softmax(y), t2), y, t2)

        out1 = f1(np_y, np_t1)
        out2 = f2(np_y, np_t2)
        ng.testing.assert_allclose(out1.ravel(), out2.ravel(), rtol=1e-5)
Ejemplo n.º 21
0
def test_exit_condition(transformer_factory):
    bsz = 16
    class_num = 10

    N, Y = ng.make_axis(bsz), ng.make_axis(class_num)
    y_val = np.absolute(np.random.randn(bsz, class_num))
    y = ng.constant(y_val, ng.make_axes([N, Y]))

    likelihood = ng.log(ng.softmax(y, normalization_axes=y.axes[1]))

    with ExecutorFactory() as ex:
        comp = ex.executor(likelihood)

        val1 = comp()
        val2 = comp()
        np.testing.assert_allclose(val1, val2, atol=0, rtol=0)
Ejemplo n.º 22
0
def LogSoftmax(
        onnx_node,
        ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Compute logarithm of softmax values for each layer in the batch of the given input.

    :param onnx_node: The ONNX node representing this operation.
    :param ng_inputs: The input tensors.
    :return: The tensor with applied LogSoftmax operation.
    """
    data = ng_inputs[0]
    axis = onnx_node.get_attribute_value('axis', 1)
    if axis < 0 or axis >= len(data.shape):
        raise ValueError(
            'LogSoftmax node (%s): provided axis attribute is out of input tensor'
            ' dimensions range.', onnx_node.name)
    return ng.log(ng.softmax(data, range(axis, len(data.shape))))
Ejemplo n.º 23
0
    def Softmax(self, tf_node, inputs):
        """
        Computes softmax activations.

        Arguments:
            tf_node: NodeDef object, the tensorflow node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the tensorflow node.

        Inputs to tf_node:
            logits, name
        """
        # TODO: only support tf.nn.softmax(logits, dim=-1) now, should add more
        logits = inputs[0]
        return ng.softmax(logits, normalization_axes=logits.axes[1])
Ejemplo n.º 24
0
def test_cross_entropy_softmax_large_input(input_tensor):
    p_x = input_tensor
    p_t = ng.placeholder(p_x.axes)

    cross_entropy_sm_x_t = ng.cross_entropy_multi(ng.softmax(p_x), p_t)

    x = np.eye(3)[np.random.choice(3, 8)].T * rng.uniform(-10, 10,
                                                          p_x.axes) * 25
    t = np.eye(3)[np.random.choice(3, 8)].T

    def f_np(x, t):
        return np_cross_entropy_multi(np_softmax(x, 0), t, axis=0)

    compare_f_at_x(cross_entropy_sm_x_t, [p_x, p_t],
                   f_np, [x, t],
                   atol=1e-7,
                   rtol=1e-4)
Ejemplo n.º 25
0
def Softmax(onnx_node,
            ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Compute softmax normalized values for each layer in the batch of the given input.

    :param onnx_node: The ONNX node representing this operation.
    :param ng_inputs: The input tensors.
    :return: The tensor with applied Softmax operation.
    """
    data = ng_inputs[0]
    axis = onnx_node.get_attribute_value('axis', 1)
    # negative values are interpreted as i-th index from the end.
    if axis < 0:
        axis = len(data.shape) + axis
    if axis < 0 or axis >= len(data.shape):
        raise ValueError(
            'Softmax node (%s): provided axis attribute is out of input tensor'
            ' dimensions range.', onnx_node.name)
    return ng.softmax(data, range(axis, len(data.shape)))
Ejemplo n.º 26
0
def cross_entropy_with_softmax(model, labels):
    """
    Auxiliary function to add cross entropy and softmax (loss function)
    to imported model for training.

    Arguments:
        model - imported model
        labels - placeholder for one-hot labels array

    Returns:
        Loss function (mean for batch)
    """
    if model.axes.lengths != labels.axes.lengths:
        model = ng.Transpose(model)
    assert model.axes.lengths == labels.axes.lengths
    model = ng.cast_axes(model, axes=labels.axes)

    loss = ng.cross_entropy_multi(ng.softmax(model), labels)
    return ng.mean(loss, out_axes=())
Ejemplo n.º 27
0
    def Softmax(self, c2_op, inputs):
        """
        Computes softmax: `exp(x)/sum(exp(x)`.

        Arguments:
            c2_op: NodeDef object, the caffe2 node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the caffe2 node.
        """
        assert 1 == len(inputs)
        # get input
        x = inputs[0]

        # normalization axes
        norm_axes = x.axes[1]

        return ng.softmax(x, normalization_axes=norm_axes).named(c2_op.name)
Ejemplo n.º 28
0
def test_cross_entropy_softmax(transformer_factory):
    N = ng.make_axis(name='N', batch=True)
    W = ng.make_axis(name='W')

    W.length = 3
    N.length = 10
    axes = ng.make_axes([W, N])

    p_x = ng.placeholder(axes)
    p_t = ng.placeholder(axes)

    cross_entropy_sm_x_t = ng.cross_entropy_multi(ng.softmax(p_x), p_t)

    x = rng.uniform(0, 1, axes)
    t = np_softmax(rng.uniform(0, 1, axes), 0)

    def f_np(x, t):
        return np_cross_entropy_multi(np_softmax(x, 0), t, axis=0)

    compare_f_at_x(cross_entropy_sm_x_t, [p_x, p_t], f_np, [x, t], rtol=1e-5)
Ejemplo n.º 29
0
    def CrossEntropyWithSoftmax(self, cntk_op, inputs):
        """
        Computes the softmax cross entropy between the inputs[0] and inputs[1].

        Arguments:
            cntk_op: CNTK operation to be imported.
            inputs: List of inputs to this node.

        Returns:
            A ngraph Op.
        """
        cast_0, cast_1 = squeeze_axes(inputs)

        if cast_0.axes.lengths != cast_1.axes.lengths:
            cast_0 = ng.Transpose(cast_0)
        assert cast_0.axes.lengths == cast_1.axes.lengths

        cast_0 = ng.cast_axes(cast_0, axes=cast_1.axes)
        loss = ng.cross_entropy_multi(ng.softmax(cast_0), cast_1)

        return ng.mean(loss, out_axes=()).named(cntk_op.uid)
Ejemplo n.º 30
0
def test_cross_entropy_softmax_deriv(transformer_factory):
    N = ng.make_axis(name='N', batch=True)
    W = ng.make_axis(name='W')

    W.length = 3
    N.length = 10
    axes = ng.make_axes([W, N])

    p_x = ng.placeholder(axes)
    p_t = ng.placeholder(axes)

    x = rng.uniform(0, 1, axes)
    t = np_softmax(rng.uniform(0, 1, axes), 0)

    check_derivative(
        ng.cross_entropy_multi(ng.softmax(p_x), p_t),
        p_x, 0.001, x,
        parameters=[p_t],
        parameter_values=[t],
        atol=1e-2, rtol=1e-2
    )
Ejemplo n.º 31
0
def test_cross_entropy_binary(input_tensor):
    """TODO."""
    p_u = input_tensor
    p_v = ng.placeholder(p_u.axes)

    u = rng.uniform(-3.0, 3.0, p_u.axes)
    v = rng.uniform(-3.0, 3.0, p_u.axes)

    delta = .001

    y = ng.sigmoid(p_u)
    t = ng.softmax(p_v)
    val_u = ng.cross_entropy_binary_inner(y, t)

    with ExecutorFactory() as ex:
        dval_u_num_fun = ex.numeric_derivative(val_u, p_u, delta, p_v)
        dval_u_graph_fun = ex.derivative(val_u, p_u, p_v)

        dval_u_num = dval_u_num_fun(u, v)
        dval_u_graph = dval_u_graph_fun(u, v)
        ng.testing.assert_allclose(dval_u_graph, dval_u_num, atol=1e-2, rtol=1e-2)
Ejemplo n.º 32
0
    def __call__(self, inputs):
        query = ng.cast_axes(
            inputs['user_utt'], [
                self.batch_axis, self.sentence_rec_axis])

        # Query embedding [batch, sentence_axis, F]
        q_emb = self.LUT_A(query)

        # Multiply by position encoding and sum
        u_0 = ng.sum(q_emb, reduction_axes=[self.sentence_rec_axis])

        # Start a list of the internal states of the model. Will be appended to
        # after each memory hop
        u = [u_0]

        for hopn in range(self.nhops):
            story = ng.cast_axes(
                inputs['memory'], [
                    self.batch_axis, self.memory_axis, self.sentence_rec_axis])

            # Re-use the query embedding matrix to embed the memory sentences
            # [batch, memory_axis, sentence_axis, F]
            m_emb_A = self.LUT_A(story)
            m_A = ng.sum(
                m_emb_A, reduction_axes=[
                    self.sentence_rec_axis])  # [batch, memory_axis, F]

            # Compute scalar similarity between internal state and each memory
            # Equivalent to dot product between u[-1] and each memory in m_A
            # [batch, memory_axis]
            dotted = ng.sum(u[-1] * m_A, reduction_axes=[self.embedding_axis])

            # [batch, memory_axis]
            probs = ng.softmax(dotted, self.memory_axis)

            # Renormalize probabilites according to non-empty memories
            probs_masked = probs * inputs['memory_mask']
            renorm_sum = ng.sum(
                probs_masked, reduction_axes=[
                    self.memory_axis]) + self.eps
            probs_renorm = (probs_masked + self.eps) / renorm_sum

            # Compute weighted sum of memory embeddings
            o_k = ng.sum(
                probs_renorm * m_A,
                reduction_axes=[
                    self.memory_axis])  # [batch, F]

            # Add the output back into the internal state and project
            u_k = ng.cast_axes(ng.dot(self.R_proj, o_k), [
                               self.embedding_axis, self.batch_axis]) + u[-1]  # [batch, F_proj]

            # Add new internal state
            u.append(u_k)

        if self.use_match_type:
            # [batch_axis, cand_axis, cand_rec_axis, F]
            self.cands_mat = inputs['cands_mat']

        # Embed all candidate responses using LUT_W
        # [<batch_axis>, cand_axis, cand_rec_axis, F]
        cand_emb_W = self.LUT_W(self.cands_mat)
        # No position encoding added yet
        cands_mat_emb = ng.sum(
            cand_emb_W, reduction_axes=[
                self.cand_rec_axis])  # [<batch_axis>, cand_axis, F]

        # Compute predicted answer from product of final internal state
        # and embedded candidate answers
        # a_logits = ng.dot(cands_mat_emb, u[-1]) # [batch, cand_axis]
        # [batch, cand_axis]
        a_logits = ng.sum(u[-1] * cands_mat_emb,
                          reduction_axes=[self.embedding_axis])

        # rename V to vocab_axis to match answer
        a_logits = ng.cast_axes(a_logits, [self.batch_axis, self.cand_axis])
        a_pred = ng.softmax(a_logits, self.cand_axis)

        return a_pred, probs_renorm