예제 #1
0
    def __call__(self,
                 H_concat,
                 states=None,
                 output=None,
                 reset_cells=True,
                 input_data=None):
        """
        Arguments:
        ----------
        H_concat: Concatenated forward and reverse unrolled outputs of the
                 `MatchLSTMCell_withAttention` cell
        states: previous LSTM state
        output: hidden state from previous timestep
        reset_cells: argument to reset a cell
        input_data: the ArrayIterator object for training data
                    (contains information of length of each sentence)

        """

        rec_axis_pr = H_concat.axes.recurrent_axis()
        const_one = ng.constant(const=1, axes=[self.dummy_axis])

        b_k_lists = []
        # rec_axis_hy=H_hy.axes.recurrent_axis()
        for i in range(0, 2):
            if output is None:
                h_k_old = ng.constant(axes=[self.F, self.N], const=0)
            else:
                h_k_old = ng.cast_axes(output, [self.F, self.N])

            sum_1 = ng.dot(
                self.V_answer,
                ng.cast_axes(H_concat,
                             [self.lstm_feature_new, rec_axis_pr, self.N]))
            sum_1 = ng.cast_axes(
                sum_1, [self.hidden_rows, self.hidden_cols_para, self.N])

            int_sum2 = ng.dot(self.W_a, h_k_old)
            int_sum = int_sum2  # +self.b_a
            int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1)

            # Following notations from the paper
            # Compute Attention Vector
            F_i_int = sum_1 + ng.axes_with_order(
                ng.dot(int_sum, self.e_q),
                [self.hidden_rows, self.hidden_cols_para, self.N])

            F_i = ng.tanh(F_i_int)  # Attention Vector

            b_k_sum1 = ng.dot(self.v_lr, F_i)
            # This masking with -inf for length of para>max_para ensures that
            # when we do softmax over these values we get a 0
            mask_loss_new = ng.log(ng.dot(const_one, input_data['para_len']))
            mask_loss_new = ng.axes_with_order(
                ng.cast_axes(mask_loss_new, [self.N, self.hidden_cols_para]),
                [self.hidden_cols_para, self.N])

            # Add mask to the required logits
            b_k = ng.softmax(b_k_sum1 + mask_loss_new)
            b_k_req = ng.softmax(b_k_sum1 + mask_loss_new)
            b_k_repeated = ng.cast_axes(
                ng.dot(self.e_q2, ng.ExpandDims(b_k, self.dummy_axis, 0)),
                [H_concat.axes[0], rec_axis_pr, self.N])

            inputs_lstm = ng.sum(ng.multiply(H_concat, b_k_repeated),
                                 rec_axis_pr)

            # LSTM Cell calculations
            if self.out_axes is None:
                self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis(
                )
            if states is None:
                states = self.initialize_states(inputs_lstm.axes.batch_axis(),
                                                reset_cells=reset_cells)
            assert self.out_axes == states['h'].axes

            for gate in self._gate_names:
                transform = self.gate_transform[gate]
                gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate](
                    states['h'])
                self.gate_output[gate] = ng.cast_role(transform(gate_input),
                                                      self.out_axes)

            states['c'] = (states['c'] * self.gate_output['f'] +
                           self.gate_output['i'] * self.gate_output['g'])
            states['h'] = self.gate_output['o'] * self.activation(states['c'])
            states['h'] = ng.cast_role(states['h'], self.out_axes)

            output = states['h']

            # append required outputs
            b_k_lists.append(b_k_req)

        return b_k_lists
예제 #2
0
def test_expand_dims(transformer_factory):
    """TODO."""
    C = ng.make_axis()
    D = ng.make_axis()
    N = ng.make_axis()

    max_new_axis_length = 4

    tests = [{
        'tensor': [[2, 5], [13, 5]],
        'tensor_axes': (N, D),
        'tensor_axes_lengths': (2, 2),
        'new_axis': C,
    }, {
        'tensor': 2,
        'tensor_axes': (),
        'tensor_axes_lengths': (),
        'new_axis': D
    }]

    for test in tests:
        for new_axis_length in range(1, max_new_axis_length + 1):
            tensor_axes = test['tensor_axes']
            tensor_axes_lengths = test['tensor_axes_lengths']

            for dim in range(len(tensor_axes) + 1):
                for axis, length in zip(tensor_axes, tensor_axes_lengths):
                    axis.length = length

                new_axis = test['new_axis']
                new_axis.length = new_axis_length

                tensor_np = np.array(test['tensor'], dtype=np.float32)
                tensor = ng.placeholder(tensor_axes)

                expanded = ng.ExpandDims(tensor, new_axis, dim)
                with ExecutorFactory() as ex:
                    expander_fun = ex.executor(expanded, tensor)
                    num_deriv_fun = ex.numeric_derivative(
                        expanded, tensor, delta)
                    sym_deriv_fun = ex.derivative(expanded, tensor)

                    expanded_shape = tensor_np.shape[:dim] \
                        + (new_axis.length,) + tensor_np.shape[dim:]
                    expanded_strides = tensor_np.strides[:dim] \
                        + (0,) + tensor_np.strides[dim:]
                    expanded_np = np.ndarray(buffer=tensor_np,
                                             shape=expanded_shape,
                                             strides=expanded_strides,
                                             dtype=tensor_np.dtype)

                    expanded_result = expander_fun(tensor_np)
                    assert np.array_equal(expanded_np, expanded_result)

                    # Test backpropagation
                    numeric_deriv = num_deriv_fun(tensor_np)
                    sym_deriv = sym_deriv_fun(tensor_np)
                    assert ng.testing.allclose(numeric_deriv,
                                               sym_deriv,
                                               rtol=rtol,
                                               atol=atol)
예제 #3
0
    def __call__(self,
                 H_pr,
                 h_ip,
                 states,
                 output=None,
                 reset_cells=True,
                 input_data=None):
        """
        Arguments:
        ----------
        H_pr : Encoding for question
        h_ip: Sliced input of paragraph encoding for a particular time step
        states: State of the LSTM cell
        output: previous hidden state
        input_data: the ArrayIterator object for training data (contains information of
                                                        length of each sentence)
        """
        # get recurrent axis for question
        rec_axis_pr = H_pr.axes.recurrent_axis()
        const_one = ng.constant(const=1, axes=[self.dummy_axis])
        # if first word in a paragraph is encountered, assign the previous LSTM
        # hidden state as zeros
        if output is None:
            h_r_old = ng.constant(axes=[self.F, self.N], const=0)
        else:
            h_r_old = ng.cast_axes(output, [self.F, self.N])

        # Compute attention vector
        sum_1 = ng.dot(self.W_q, H_pr)
        sum_1 = ng.cast_axes(sum_1,
                             [self.hidden_rows, self.hidden_cols_ques, self.N])
        int_sum1 = ng.dot(self.W_p, h_ip)
        int_sum2 = ng.dot(self.W_r, h_r_old)
        int_sum = int_sum1 + int_sum2 + self.b_p
        int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1)

        # making for the attention vector
        req_mask = ng.axes_with_order(
            ng.cast_axes(ng.dot(self.e_q2, input_data['question_len']),
                         [self.hidden_rows, self.N, self.hidden_cols_ques]),
            [self.hidden_rows, self.hidden_cols_ques, self.N])

        req_mask_2 = ng.axes_with_order(
            ng.cast_axes(ng.dot(const_one, input_data['question_len']),
                         [self.N, self.hidden_cols_ques]),
            [self.hidden_cols_ques, self.N])

        G_i_int = sum_1 + ng.multiply(
            req_mask,
            ng.axes_with_order(
                ng.dot(int_sum, self.e_q),
                [self.hidden_rows, self.hidden_cols_ques, self.N]))

        G_i = ng.tanh(G_i_int)
        # Attention Vector
        at_sum1 = ng.dot(self.w_lr, G_i)
        at = ng.softmax(at_sum1 + ng.log(req_mask_2))
        at_repeated = ng.cast_axes(
            ng.dot(self.e_q2, ng.ExpandDims(at, self.dummy_axis, 0)),
            [self.F, rec_axis_pr, self.N])

        # Stack the 2 vectors as per the equation in the paper
        z1 = h_ip
        z2 = ng.sum(ng.multiply(H_pr, at_repeated), rec_axis_pr)
        # represents the inp to lstm_cell
        # ng.concat_along_axis([z1,z2],self.F)
        inputs_lstm = ng.dot(self.ZX, z1) + ng.dot(self.ZY, z2)

        # LSTM cell computations (from LSTM brach in ngraph)
        if self.out_axes is None:
            self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis()
        if states is None:
            states = self.initialize_states(inputs_lstm.axes.batch_axis(),
                                            reset_cells=reset_cells)
        assert self.out_axes == states['h'].axes

        for gate in self._gate_names:
            transform = self.gate_transform[gate]
            gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate](
                states['h'])
            self.gate_output[gate] = ng.cast_role(transform(gate_input),
                                                  self.out_axes)

        states['c'] = (states['c'] * self.gate_output['f'] +
                       self.gate_output['i'] * self.gate_output['g'])
        states['h'] = self.gate_output['o'] * self.activation(states['c'])
        states['h'] = ng.cast_role(states['h'], self.out_axes)
        # return unrolled output and state of LSTM cell
        return ng.cast_axes(states['h'], axes=[self.F, self.N]), states