def __call__(self, H_concat, states=None, output=None, reset_cells=True, input_data=None): """ Arguments: ---------- H_concat: Concatenated forward and reverse unrolled outputs of the `MatchLSTMCell_withAttention` cell states: previous LSTM state output: hidden state from previous timestep reset_cells: argument to reset a cell input_data: the ArrayIterator object for training data (contains information of length of each sentence) """ rec_axis_pr = H_concat.axes.recurrent_axis() const_one = ng.constant(const=1, axes=[self.dummy_axis]) b_k_lists = [] # rec_axis_hy=H_hy.axes.recurrent_axis() for i in range(0, 2): if output is None: h_k_old = ng.constant(axes=[self.F, self.N], const=0) else: h_k_old = ng.cast_axes(output, [self.F, self.N]) sum_1 = ng.dot( self.V_answer, ng.cast_axes(H_concat, [self.lstm_feature_new, rec_axis_pr, self.N])) sum_1 = ng.cast_axes( sum_1, [self.hidden_rows, self.hidden_cols_para, self.N]) int_sum2 = ng.dot(self.W_a, h_k_old) int_sum = int_sum2 # +self.b_a int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1) # Following notations from the paper # Compute Attention Vector F_i_int = sum_1 + ng.axes_with_order( ng.dot(int_sum, self.e_q), [self.hidden_rows, self.hidden_cols_para, self.N]) F_i = ng.tanh(F_i_int) # Attention Vector b_k_sum1 = ng.dot(self.v_lr, F_i) # This masking with -inf for length of para>max_para ensures that # when we do softmax over these values we get a 0 mask_loss_new = ng.log(ng.dot(const_one, input_data['para_len'])) mask_loss_new = ng.axes_with_order( ng.cast_axes(mask_loss_new, [self.N, self.hidden_cols_para]), [self.hidden_cols_para, self.N]) # Add mask to the required logits b_k = ng.softmax(b_k_sum1 + mask_loss_new) b_k_req = ng.softmax(b_k_sum1 + mask_loss_new) b_k_repeated = ng.cast_axes( ng.dot(self.e_q2, ng.ExpandDims(b_k, self.dummy_axis, 0)), [H_concat.axes[0], rec_axis_pr, self.N]) inputs_lstm = ng.sum(ng.multiply(H_concat, b_k_repeated), rec_axis_pr) # LSTM Cell calculations if self.out_axes is None: self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis( ) if states is None: states = self.initialize_states(inputs_lstm.axes.batch_axis(), reset_cells=reset_cells) assert self.out_axes == states['h'].axes for gate in self._gate_names: transform = self.gate_transform[gate] gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate]( states['h']) self.gate_output[gate] = ng.cast_role(transform(gate_input), self.out_axes) states['c'] = (states['c'] * self.gate_output['f'] + self.gate_output['i'] * self.gate_output['g']) states['h'] = self.gate_output['o'] * self.activation(states['c']) states['h'] = ng.cast_role(states['h'], self.out_axes) output = states['h'] # append required outputs b_k_lists.append(b_k_req) return b_k_lists
def test_expand_dims(transformer_factory): """TODO.""" C = ng.make_axis() D = ng.make_axis() N = ng.make_axis() max_new_axis_length = 4 tests = [{ 'tensor': [[2, 5], [13, 5]], 'tensor_axes': (N, D), 'tensor_axes_lengths': (2, 2), 'new_axis': C, }, { 'tensor': 2, 'tensor_axes': (), 'tensor_axes_lengths': (), 'new_axis': D }] for test in tests: for new_axis_length in range(1, max_new_axis_length + 1): tensor_axes = test['tensor_axes'] tensor_axes_lengths = test['tensor_axes_lengths'] for dim in range(len(tensor_axes) + 1): for axis, length in zip(tensor_axes, tensor_axes_lengths): axis.length = length new_axis = test['new_axis'] new_axis.length = new_axis_length tensor_np = np.array(test['tensor'], dtype=np.float32) tensor = ng.placeholder(tensor_axes) expanded = ng.ExpandDims(tensor, new_axis, dim) with ExecutorFactory() as ex: expander_fun = ex.executor(expanded, tensor) num_deriv_fun = ex.numeric_derivative( expanded, tensor, delta) sym_deriv_fun = ex.derivative(expanded, tensor) expanded_shape = tensor_np.shape[:dim] \ + (new_axis.length,) + tensor_np.shape[dim:] expanded_strides = tensor_np.strides[:dim] \ + (0,) + tensor_np.strides[dim:] expanded_np = np.ndarray(buffer=tensor_np, shape=expanded_shape, strides=expanded_strides, dtype=tensor_np.dtype) expanded_result = expander_fun(tensor_np) assert np.array_equal(expanded_np, expanded_result) # Test backpropagation numeric_deriv = num_deriv_fun(tensor_np) sym_deriv = sym_deriv_fun(tensor_np) assert ng.testing.allclose(numeric_deriv, sym_deriv, rtol=rtol, atol=atol)
def __call__(self, H_pr, h_ip, states, output=None, reset_cells=True, input_data=None): """ Arguments: ---------- H_pr : Encoding for question h_ip: Sliced input of paragraph encoding for a particular time step states: State of the LSTM cell output: previous hidden state input_data: the ArrayIterator object for training data (contains information of length of each sentence) """ # get recurrent axis for question rec_axis_pr = H_pr.axes.recurrent_axis() const_one = ng.constant(const=1, axes=[self.dummy_axis]) # if first word in a paragraph is encountered, assign the previous LSTM # hidden state as zeros if output is None: h_r_old = ng.constant(axes=[self.F, self.N], const=0) else: h_r_old = ng.cast_axes(output, [self.F, self.N]) # Compute attention vector sum_1 = ng.dot(self.W_q, H_pr) sum_1 = ng.cast_axes(sum_1, [self.hidden_rows, self.hidden_cols_ques, self.N]) int_sum1 = ng.dot(self.W_p, h_ip) int_sum2 = ng.dot(self.W_r, h_r_old) int_sum = int_sum1 + int_sum2 + self.b_p int_sum = ng.ExpandDims(int_sum, self.dummy_axis, 1) # making for the attention vector req_mask = ng.axes_with_order( ng.cast_axes(ng.dot(self.e_q2, input_data['question_len']), [self.hidden_rows, self.N, self.hidden_cols_ques]), [self.hidden_rows, self.hidden_cols_ques, self.N]) req_mask_2 = ng.axes_with_order( ng.cast_axes(ng.dot(const_one, input_data['question_len']), [self.N, self.hidden_cols_ques]), [self.hidden_cols_ques, self.N]) G_i_int = sum_1 + ng.multiply( req_mask, ng.axes_with_order( ng.dot(int_sum, self.e_q), [self.hidden_rows, self.hidden_cols_ques, self.N])) G_i = ng.tanh(G_i_int) # Attention Vector at_sum1 = ng.dot(self.w_lr, G_i) at = ng.softmax(at_sum1 + ng.log(req_mask_2)) at_repeated = ng.cast_axes( ng.dot(self.e_q2, ng.ExpandDims(at, self.dummy_axis, 0)), [self.F, rec_axis_pr, self.N]) # Stack the 2 vectors as per the equation in the paper z1 = h_ip z2 = ng.sum(ng.multiply(H_pr, at_repeated), rec_axis_pr) # represents the inp to lstm_cell # ng.concat_along_axis([z1,z2],self.F) inputs_lstm = ng.dot(self.ZX, z1) + ng.dot(self.ZY, z2) # LSTM cell computations (from LSTM brach in ngraph) if self.out_axes is None: self.out_axes = self.feature_axes + inputs_lstm.axes.batch_axis() if states is None: states = self.initialize_states(inputs_lstm.axes.batch_axis(), reset_cells=reset_cells) assert self.out_axes == states['h'].axes for gate in self._gate_names: transform = self.gate_transform[gate] gate_input = self.i2h[gate](inputs_lstm) + self.h2h[gate]( states['h']) self.gate_output[gate] = ng.cast_role(transform(gate_input), self.out_axes) states['c'] = (states['c'] * self.gate_output['f'] + self.gate_output['i'] * self.gate_output['g']) states['h'] = self.gate_output['o'] * self.activation(states['c']) states['h'] = ng.cast_role(states['h'], self.out_axes) # return unrolled output and state of LSTM cell return ng.cast_axes(states['h'], axes=[self.F, self.N]), states