def _step(self, h_ff, states): h_state = states[0] c_state = states[1] ifog = { k: sum([ng.cast_role(h_ff[k], self.out_axes), ng.cast_role(ng.dot(self.W_recur[k], h_state), self.out_axes), self.b[k], ]) for k in self.metadata['gates'] } ifog_act = {k: self.activation(ifog[k]) if k is 'g' else self.gate_activation(ifog[k]) for k in self.metadata['gates']} c = ifog_act['f'] * c_state + ifog_act['i'] * ifog_act['g'] # c_prev is the state before applying activation h = ifog_act['o'] * self.activation(c) h = ng.cast_role(h, self.out_axes) return [h, c]
def __call__(self, in_obj, init_state=None): """ Sets shape based parameters of this layer given an input tuple or int or input layer. Arguments: in_obj (int, tuple, Layer or Tensor): object that provides shape information for layer init_state (Tensor or list): object that provides initial state Returns: if sum_out or concat_out - rnn_out (Tensor): output otherwise - rnn_out (list of Tensors): list of length 2 """ if isinstance(in_obj, collections.Sequence): if len(in_obj) != 2: raise ValueError("If in_obj is a sequence, it must have length 2") if in_obj[0].axes != in_obj[1].axes: raise ValueError("If in_obj is a sequence, each element must have the same axes") fwd_in = in_obj[0] bwd_in = in_obj[1] else: fwd_in = in_obj bwd_in = in_obj if isinstance(init_state, collections.Sequence): if len(init_state) != 2: raise ValueError("If init_state is a sequence, it must have length 2") if init_state[0].axes != init_state[1].axes: raise ValueError("If init_state is a sequence, " + "each element must have the same axes") fwd_init = init_state[0] bwd_init = init_state[1] else: fwd_init = init_state bwd_init = init_state with ng.metadata(direction="fwd"): fwd_out = self.fwd_rnn(fwd_in, fwd_init) with ng.metadata(direction="bwd"): bwd_out = ng.cast_role(self.bwd_rnn(bwd_in, bwd_init), fwd_out.axes) if self.sum_out: return fwd_out + bwd_out elif self.concat_out: ax = fwd_out.axes.feature_axes() if len(ax) == 1: ax = ax[0] else: raise ValueError(("Multiple hidden axes: {}. " "Unable to concatenate automatically").format(ax)) return ng.concat_along_axis([fwd_out, bwd_out], ax) else: return fwd_out, bwd_out
def _step(self, h_ff, states): h_ff = ng.cast_role(h_ff, self.out_axes) h_rec = ng.cast_role(ng.dot(self.W_recur, states), self.out_axes) return self.activation(h_rec + h_ff + self.b)
def unroll_with_attention(cell, num_steps, H_pr, H_hy, init_states=None, reset_cells=True, return_sequence=True, reverse_mode=False, input_data=None): """ Unroll the cell with attention for num_steps steps. Arguments: ---------- cell : provide the cell that has to be unrolled (Eg: MatchLSTMCell_withAttention) num_steps: the number of steps needed to unroll H_pr : the encoding for the question H_hy : the encoding for the passage init_states: Either None or a dictionary containing states reset_cell: argument which determine if cell has to be reset or not reverse_mode: Set to True if unrolling in the opposite direction is desired input_data: the ArrayIterator object for training data (contains information of length of each sentence) """ recurrent_axis = H_hy.axes.recurrent_axis() if init_states is not None: states = { k: ng.cast_role(v, out_axes) for (k, v) in init_states.items() } else: states = init_states stepped_inputs = get_steps(H_hy, recurrent_axis, backward=reverse_mode) stepped_outputs = [] for t in range(num_steps): with ng.metadata(step=str(t)): if t == 0: output, states = cell(H_pr, stepped_inputs[t], states, output=None, input_data=input_data) else: output, states = cell(H_pr, stepped_inputs[t], states, output=output, input_data=input_data) stepped_outputs.append(output) if reverse_mode: if return_sequence: stepped_outputs.reverse() if return_sequence: outputs = ng.stack(stepped_outputs, recurrent_axis, pos=1) else: outputs = stepped_outputs[-1] if not reset_cells: update_inits = ng.doall([ ng.assign(initial, states[name]) for (name, initial) in states.items() ]) outputs = ng.sequential([update_inits, outputs]) return outputs
def __call__(self, H_concat, states=None, output=None, reset_cells=True, input_data=None): """ Arguments: ---------- H_concat: Concatenated forward and reverse unrolled outputs of the `MatchLSTMCell_withAttention` cell states: previous LSTM state output: hidden state from previous timestep reset_cells: argument to reset a cell input_data: the ArrayIterator object for training data (contains information of length of each sentence) """ rec_axis_pr = H_concat.axes.recurrent_axis() const_one = ng.constant(const=1, axes=[self.dummy_axis]) b_k_lists = [] # rec_axis_hy=H_hy.axes.recurrent_axis() for i in range(0, 2): if output is None: h_k_old = ng.constant(axes=[self.F, self.N], const=0) else: h_k_old = ng.cast_axes(output, [self.F, self.N]) sum_1 = ng.dot( self.V_answer, ng.cast_axes(H_concat, [self.lstm_feature_new, rec_axis_pr, self.N])) sum_1 = ng.cast_axes( sum_1, [self.hidden_rows, self.hidden_cols_para, self.N]) int_sum2 = ng.dot(self.W_a, h_k_old) int_sum = int_sum2 # +self.b_a int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1) # Following notations from the paper # Compute Attention Vector F_i_int = sum_1 + ng.axes_with_order( ng.dot(int_sum, self.e_q), [self.hidden_rows, self.hidden_cols_para, self.N]) F_i = ng.tanh(F_i_int) # Attention Vector b_k_sum1 = ng.dot(self.v_lr, F_i) # This masking with -inf for length of para>max_para ensures that # when we do softmax over these values we get a 0 mask_loss_new = ng.log(ng.dot(const_one, input_data['para_len'])) mask_loss_new = ng.axes_with_order( ng.cast_axes(mask_loss_new, [self.N, self.hidden_cols_para]), [self.hidden_cols_para, self.N]) # Add mask to the required logits b_k = ng.softmax(b_k_sum1 + mask_loss_new) b_k_req = ng.softmax(b_k_sum1 + mask_loss_new) b_k_repeated = ng.cast_axes( ng.dot(self.e_q2, ng.ExpandDims(b_k, self.dummy_axis, 0)), [H_concat.axes[0], rec_axis_pr, self.N]) inputs_lstm = ng.sum(ng.multiply(H_concat, b_k_repeated), rec_axis_pr) # LSTM Cell calculations if self.out_axes is None: self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis( ) if states is None: states = self.initialize_states(inputs_lstm.axes.batch_axis(), reset_cells=reset_cells) assert self.out_axes == states['h'].axes for gate in self._gate_names: transform = self.gate_transform[gate] gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate]( states['h']) self.gate_output[gate] = ng.cast_role(transform(gate_input), self.out_axes) states['c'] = (states['c'] * self.gate_output['f'] + self.gate_output['i'] * self.gate_output['g']) states['h'] = self.gate_output['o'] * self.activation(states['c']) states['h'] = ng.cast_role(states['h'], self.out_axes) output = states['h'] # append required outputs b_k_lists.append(b_k_req) return b_k_lists
def __call__(self, H_pr, h_ip, states, output=None, reset_cells=True, input_data=None): """ Arguments: ---------- H_pr : Encoding for question h_ip: Sliced input of paragraph encoding for a particular time step states: State of the LSTM cell output: previous hidden state input_data: the ArrayIterator object for training data (contains information of length of each sentence) """ # get recurrent axis for question rec_axis_pr = H_pr.axes.recurrent_axis() const_one = ng.constant(const=1, axes=[self.dummy_axis]) # if first word in a paragraph is encountered, assign the previous LSTM # hidden state as zeros if output is None: h_r_old = ng.constant(axes=[self.F, self.N], const=0) else: h_r_old = ng.cast_axes(output, [self.F, self.N]) # Compute attention vector sum_1 = ng.dot(self.W_q, H_pr) sum_1 = ng.cast_axes(sum_1, [self.hidden_rows, self.hidden_cols_ques, self.N]) int_sum1 = ng.dot(self.W_p, h_ip) int_sum2 = ng.dot(self.W_r, h_r_old) int_sum = int_sum1 + int_sum2 + self.b_p int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1) # making for the attention vector req_mask = ng.axes_with_order( ng.cast_axes(ng.dot(self.e_q2, input_data['question_len']), [self.hidden_rows, self.N, self.hidden_cols_ques]), [self.hidden_rows, self.hidden_cols_ques, self.N]) req_mask_2 = ng.axes_with_order( ng.cast_axes(ng.dot(const_one, input_data['question_len']), [self.N, self.hidden_cols_ques]), [self.hidden_cols_ques, self.N]) G_i_int = sum_1 + ng.multiply( req_mask, ng.axes_with_order( ng.dot(int_sum, self.e_q), [self.hidden_rows, self.hidden_cols_ques, self.N])) G_i = ng.tanh(G_i_int) # Attention Vector at_sum1 = ng.dot(self.w_lr, G_i) at = ng.softmax(at_sum1 + ng.log(req_mask_2)) at_repeated = ng.cast_axes( ng.dot(self.e_q2, ng.ExpandDims(at, self.dummy_axis, 0)), [self.F, rec_axis_pr, self.N]) # Stack the 2 vectors as per the equation in the paper z1 = h_ip z2 = ng.sum(ng.multiply(H_pr, at_repeated), rec_axis_pr) # represents the inp to lstm_cell # ng.concat_along_axis([z1,z2],self.F) inputs_lstm = ng.dot(self.ZX, z1) + ng.dot(self.ZY, z2) # LSTM cell computations (from LSTM brach in ngraph) if self.out_axes is None: self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis() if states is None: states = self.initialize_states(inputs_lstm.axes.batch_axis(), reset_cells=reset_cells) assert self.out_axes == states['h'].axes for gate in self._gate_names: transform = self.gate_transform[gate] gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate]( states['h']) self.gate_output[gate] = ng.cast_role(transform(gate_input), self.out_axes) states['c'] = (states['c'] * self.gate_output['f'] + self.gate_output['i'] * self.gate_output['g']) states['h'] = self.gate_output['o'] * self.activation(states['c']) states['h'] = ng.cast_role(states['h'], self.out_axes) # return unrolled output and state of LSTM cell return ng.cast_axes(states['h'], axes=[self.F, self.N]), states