Exemplo n.º 1
0
    def _step(self, h_ff, states):
        h_state = states[0]
        c_state = states[1]
        ifog = {
            k: sum([ng.cast_role(h_ff[k], self.out_axes),
                    ng.cast_role(ng.dot(self.W_recur[k], h_state), self.out_axes),
                    self.b[k],
                    ]) for k in self.metadata['gates']
        }
        ifog_act = {k: self.activation(ifog[k]) if k is 'g'
                    else self.gate_activation(ifog[k]) for k in self.metadata['gates']}

        c = ifog_act['f'] * c_state + ifog_act['i'] * ifog_act['g']
        # c_prev is the state before applying activation
        h = ifog_act['o'] * self.activation(c)
        h = ng.cast_role(h, self.out_axes)
        return [h, c]
Exemplo n.º 2
0
    def __call__(self, in_obj, init_state=None):
        """
        Sets shape based parameters of this layer given an input tuple or int
        or input layer.

        Arguments:
            in_obj (int, tuple, Layer or Tensor): object that provides shape
                                                 information for layer
            init_state (Tensor or list): object that provides initial state

        Returns:
            if sum_out or concat_out - rnn_out (Tensor): output
            otherwise - rnn_out (list of Tensors): list of length 2

        """
        if isinstance(in_obj, collections.Sequence):
            if len(in_obj) != 2:
                raise ValueError("If in_obj is a sequence, it must have length 2")
            if in_obj[0].axes != in_obj[1].axes:
                raise ValueError("If in_obj is a sequence, each element must have the same axes")
            fwd_in = in_obj[0]
            bwd_in = in_obj[1]
        else:
            fwd_in = in_obj
            bwd_in = in_obj

        if isinstance(init_state, collections.Sequence):
            if len(init_state) != 2:
                raise ValueError("If init_state is a sequence, it must have length 2")
            if init_state[0].axes != init_state[1].axes:
                raise ValueError("If init_state is a sequence, " +
                                 "each element must have the same axes")
            fwd_init = init_state[0]
            bwd_init = init_state[1]
        else:
            fwd_init = init_state
            bwd_init = init_state

        with ng.metadata(direction="fwd"):
            fwd_out = self.fwd_rnn(fwd_in, fwd_init)
        with ng.metadata(direction="bwd"):
            bwd_out = ng.cast_role(self.bwd_rnn(bwd_in, bwd_init), fwd_out.axes)

        if self.sum_out:
            return fwd_out + bwd_out
        elif self.concat_out:
            ax = fwd_out.axes.feature_axes()
            if len(ax) == 1:
                ax = ax[0]
            else:
                raise ValueError(("Multiple hidden axes: {}. "
                                  "Unable to concatenate automatically").format(ax))
            return ng.concat_along_axis([fwd_out, bwd_out], ax)
        else:
            return fwd_out, bwd_out
Exemplo n.º 3
0
 def _step(self, h_ff, states):
     h_ff = ng.cast_role(h_ff, self.out_axes)
     h_rec = ng.cast_role(ng.dot(self.W_recur, states), self.out_axes)
     return self.activation(h_rec + h_ff + self.b)
Exemplo n.º 4
0
def unroll_with_attention(cell,
                          num_steps,
                          H_pr,
                          H_hy,
                          init_states=None,
                          reset_cells=True,
                          return_sequence=True,
                          reverse_mode=False,
                          input_data=None):
    """
    Unroll the cell with attention for num_steps steps.

    Arguments:
    ----------
    cell : provide the cell that has to be unrolled (Eg: MatchLSTMCell_withAttention)
    num_steps: the number of steps needed to unroll
    H_pr : the encoding for the question
    H_hy : the encoding for the passage
    init_states: Either None or a dictionary containing states
    reset_cell: argument which determine if cell has to be reset or not
    reverse_mode: Set to True if unrolling in the opposite direction is desired
    input_data: the ArrayIterator object for training data
                (contains information of length of each sentence)

    """
    recurrent_axis = H_hy.axes.recurrent_axis()

    if init_states is not None:
        states = {
            k: ng.cast_role(v, out_axes)
            for (k, v) in init_states.items()
        }
    else:
        states = init_states

    stepped_inputs = get_steps(H_hy, recurrent_axis, backward=reverse_mode)
    stepped_outputs = []

    for t in range(num_steps):
        with ng.metadata(step=str(t)):
            if t == 0:
                output, states = cell(H_pr,
                                      stepped_inputs[t],
                                      states,
                                      output=None,
                                      input_data=input_data)
            else:
                output, states = cell(H_pr,
                                      stepped_inputs[t],
                                      states,
                                      output=output,
                                      input_data=input_data)

            stepped_outputs.append(output)

    if reverse_mode:
        if return_sequence:
            stepped_outputs.reverse()

    if return_sequence:
        outputs = ng.stack(stepped_outputs, recurrent_axis, pos=1)
    else:
        outputs = stepped_outputs[-1]

    if not reset_cells:
        update_inits = ng.doall([
            ng.assign(initial, states[name])
            for (name, initial) in states.items()
        ])
        outputs = ng.sequential([update_inits, outputs])

    return outputs
Exemplo n.º 5
0
    def __call__(self,
                 H_concat,
                 states=None,
                 output=None,
                 reset_cells=True,
                 input_data=None):
        """
        Arguments:
        ----------
        H_concat: Concatenated forward and reverse unrolled outputs of the
                 `MatchLSTMCell_withAttention` cell
        states: previous LSTM state
        output: hidden state from previous timestep
        reset_cells: argument to reset a cell
        input_data: the ArrayIterator object for training data
                    (contains information of length of each sentence)

        """

        rec_axis_pr = H_concat.axes.recurrent_axis()
        const_one = ng.constant(const=1, axes=[self.dummy_axis])

        b_k_lists = []
        # rec_axis_hy=H_hy.axes.recurrent_axis()
        for i in range(0, 2):
            if output is None:
                h_k_old = ng.constant(axes=[self.F, self.N], const=0)
            else:
                h_k_old = ng.cast_axes(output, [self.F, self.N])

            sum_1 = ng.dot(
                self.V_answer,
                ng.cast_axes(H_concat,
                             [self.lstm_feature_new, rec_axis_pr, self.N]))
            sum_1 = ng.cast_axes(
                sum_1, [self.hidden_rows, self.hidden_cols_para, self.N])

            int_sum2 = ng.dot(self.W_a, h_k_old)
            int_sum = int_sum2  # +self.b_a
            int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1)

            # Following notations from the paper
            # Compute Attention Vector
            F_i_int = sum_1 + ng.axes_with_order(
                ng.dot(int_sum, self.e_q),
                [self.hidden_rows, self.hidden_cols_para, self.N])

            F_i = ng.tanh(F_i_int)  # Attention Vector

            b_k_sum1 = ng.dot(self.v_lr, F_i)
            # This masking with -inf for length of para>max_para ensures that
            # when we do softmax over these values we get a 0
            mask_loss_new = ng.log(ng.dot(const_one, input_data['para_len']))
            mask_loss_new = ng.axes_with_order(
                ng.cast_axes(mask_loss_new, [self.N, self.hidden_cols_para]),
                [self.hidden_cols_para, self.N])

            # Add mask to the required logits
            b_k = ng.softmax(b_k_sum1 + mask_loss_new)
            b_k_req = ng.softmax(b_k_sum1 + mask_loss_new)
            b_k_repeated = ng.cast_axes(
                ng.dot(self.e_q2, ng.ExpandDims(b_k, self.dummy_axis, 0)),
                [H_concat.axes[0], rec_axis_pr, self.N])

            inputs_lstm = ng.sum(ng.multiply(H_concat, b_k_repeated),
                                 rec_axis_pr)

            # LSTM Cell calculations
            if self.out_axes is None:
                self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis(
                )
            if states is None:
                states = self.initialize_states(inputs_lstm.axes.batch_axis(),
                                                reset_cells=reset_cells)
            assert self.out_axes == states['h'].axes

            for gate in self._gate_names:
                transform = self.gate_transform[gate]
                gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate](
                    states['h'])
                self.gate_output[gate] = ng.cast_role(transform(gate_input),
                                                      self.out_axes)

            states['c'] = (states['c'] * self.gate_output['f'] +
                           self.gate_output['i'] * self.gate_output['g'])
            states['h'] = self.gate_output['o'] * self.activation(states['c'])
            states['h'] = ng.cast_role(states['h'], self.out_axes)

            output = states['h']

            # append required outputs
            b_k_lists.append(b_k_req)

        return b_k_lists
Exemplo n.º 6
0
    def __call__(self,
                 H_pr,
                 h_ip,
                 states,
                 output=None,
                 reset_cells=True,
                 input_data=None):
        """
        Arguments:
        ----------
        H_pr : Encoding for question
        h_ip: Sliced input of paragraph encoding for a particular time step
        states: State of the LSTM cell
        output: previous hidden state
        input_data: the ArrayIterator object for training data (contains information of
                                                        length of each sentence)
        """
        # get recurrent axis for question
        rec_axis_pr = H_pr.axes.recurrent_axis()
        const_one = ng.constant(const=1, axes=[self.dummy_axis])
        # if first word in a paragraph is encountered, assign the previous LSTM
        # hidden state as zeros
        if output is None:
            h_r_old = ng.constant(axes=[self.F, self.N], const=0)
        else:
            h_r_old = ng.cast_axes(output, [self.F, self.N])

        # Compute attention vector
        sum_1 = ng.dot(self.W_q, H_pr)
        sum_1 = ng.cast_axes(sum_1,
                             [self.hidden_rows, self.hidden_cols_ques, self.N])
        int_sum1 = ng.dot(self.W_p, h_ip)
        int_sum2 = ng.dot(self.W_r, h_r_old)
        int_sum = int_sum1 + int_sum2 + self.b_p
        int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1)

        # making for the attention vector
        req_mask = ng.axes_with_order(
            ng.cast_axes(ng.dot(self.e_q2, input_data['question_len']),
                         [self.hidden_rows, self.N, self.hidden_cols_ques]),
            [self.hidden_rows, self.hidden_cols_ques, self.N])

        req_mask_2 = ng.axes_with_order(
            ng.cast_axes(ng.dot(const_one, input_data['question_len']),
                         [self.N, self.hidden_cols_ques]),
            [self.hidden_cols_ques, self.N])

        G_i_int = sum_1 + ng.multiply(
            req_mask,
            ng.axes_with_order(
                ng.dot(int_sum, self.e_q),
                [self.hidden_rows, self.hidden_cols_ques, self.N]))

        G_i = ng.tanh(G_i_int)
        # Attention Vector
        at_sum1 = ng.dot(self.w_lr, G_i)
        at = ng.softmax(at_sum1 + ng.log(req_mask_2))
        at_repeated = ng.cast_axes(
            ng.dot(self.e_q2, ng.ExpandDims(at, self.dummy_axis, 0)),
            [self.F, rec_axis_pr, self.N])

        # Stack the 2 vectors as per the equation in the paper
        z1 = h_ip
        z2 = ng.sum(ng.multiply(H_pr, at_repeated), rec_axis_pr)
        # represents the inp to lstm_cell
        # ng.concat_along_axis([z1,z2],self.F)
        inputs_lstm = ng.dot(self.ZX, z1) + ng.dot(self.ZY, z2)

        # LSTM cell computations (from LSTM brach in ngraph)
        if self.out_axes is None:
            self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis()
        if states is None:
            states = self.initialize_states(inputs_lstm.axes.batch_axis(),
                                            reset_cells=reset_cells)
        assert self.out_axes == states['h'].axes

        for gate in self._gate_names:
            transform = self.gate_transform[gate]
            gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate](
                states['h'])
            self.gate_output[gate] = ng.cast_role(transform(gate_input),
                                                  self.out_axes)

        states['c'] = (states['c'] * self.gate_output['f'] +
                       self.gate_output['i'] * self.gate_output['g'])
        states['h'] = self.gate_output['o'] * self.activation(states['c'])
        states['h'] = ng.cast_role(states['h'], self.out_axes)
        # return unrolled output and state of LSTM cell
        return ng.cast_axes(states['h'], axes=[self.F, self.N]), states