def decode_probs(self, context, state, emb): ''' Get the probability of the next word. Used in beam search and sampling. :type context: theano variable :param context: the context vectors :type state: theano variable :param state: the last hidden state :type emb: theano variable :param emb: the embedding of the last generated word ''' att_c = tools.dot3d(context, self.att_context) att_before = tensor.dot(state, self.att_hidden) # size: (batch_size,dim) energy = tensor.dot(tensor.tanh(att_c + att_before.dimshuffle('x', 0, 1)), self.att).reshape((context.shape[0], context.shape[1])) # size: (length, batch_size) energy = tensor.exp(energy) normalizer = energy.sum(axis = 0) attention = energy / normalizer # size: (length, batch_size) c = (context * attention.dimshuffle(0, 1, 'x')).sum(axis = 0) # size: (batch_size, dim_c) readout = tensor.dot(emb, self.readout_emb) + \ tensor.dot(c, self.readout_context) + \ tensor.dot(state, self.readout_hidden) readout += self.readout_offset maxout = tools.maxout(readout, self.maxout) outenergy = tensor.dot(maxout, self.probs_emb) outenergy = tensor.dot(outenergy, self.probs) outenergy += self.probs_offset return outenergy, c
def forward(self, emb_in, length, context, state_init, parent_t_seq, batch_size = 1, mask = None, cmask = None): ''' Build the computational graph which computes the hidden states. :type emb_in: theano variable :param emb_in: the input word embeddings :type length: theano variable :param length: the length of the input :type context: theano variable :param context: the context vectors :type state_init: theano variable :param state_init: the inital states :type batch_size: int :param batch_size: the batch size :type mask: theano variable :param mask: indicate the length of each sequence in one batch :type cmask: theano variable :param cmask: indicate the length of each context sequence in one batch ''' # calculate the input vector for inputter, updater and reseter att_c = tools.dot3d(context, self.att_context) # size: (length annotation, batch_size,dim) state_in = (tensor.dot(emb_in, self.input_emb)+self.input_emb_offset).reshape((length, batch_size, self.dim)) gate_in = tensor.dot(emb_in, self.gate_emb).reshape((length, batch_size, self.dim)) reset_in = tensor.dot(emb_in, self.reset_emb).reshape((length, batch_size, self.dim)) length_annot, batch_size, dim = att_c.shape[0], att_c.shape[1], att_c.shape[2] state_hist = tools.alloc_zeros_matrix(length+2, batch_size, dim) state_hist = tensor.set_subtensor(state_hist[1, :, :], state_init) # first is empty state, second is time 1 time_steps = tensor.arange(length, dtype='int64') + 2 #theano.printing.debugprint(gate_in.shape) #theano.printing.debugprint(length) #theano.printing.debugprint(batch_size) #theano.printing.debugprint(dim) #parent_t_seq_print = theano.printing.Print('this is a very important value')(parent_t_seq) if mask: scan_inp = [time_steps, state_in, gate_in, reset_in, mask, parent_t_seq] scan_func = lambda t, x, g, r, m, par_t, h, s_hist, c, attc, cm: self.forward_step(t, h, s_hist, x, g, r, par_t, c, attc,m, cm) else: scan_inp = [time_steps, state_in, gate_in, reset_in, parent_t_seq] scan_func = lambda t, x, g, r, par_t, h, s_hist, c, attc: self.forward_step(t, h, s_hist, x, g, r, par_t, c, attc) if self.verbose: outputs_info=[state_init, state_hist, None, None, None, None, None, None, None, None, None, None, None, None] else: outputs_info=[state_init, state_hist, None, None] # calculate hidden states hiddens, updates = theano.scan(scan_func, sequences = scan_inp, outputs_info = outputs_info, non_sequences = [context, att_c, cmask], n_steps = length) c = hiddens[2] attentions = hiddens[3] # Add the initial state and discard the last hidden state state_before = tensor.concatenate((state_init.reshape((1, state_init.shape[0], state_init.shape[1])) , hiddens[0][:-1])) state_in_prev = tensor.dot(emb_in, self.readout_emb).reshape((length, batch_size, self.dim)) # calculate the energy for each word readout_c = tensor.dot(c, self.readout_context) readout_h = tensor.dot(state_before, self.readout_hidden) readout_h += self.readout_offset state_in_prev = tools.shift_one(state_in_prev) readout = readout_c + readout_h + state_in_prev readout = readout.reshape((readout.shape[0] * readout.shape[1], readout.shape[2])) maxout = tools.maxout(readout, self.maxout) outenergy = tensor.dot(maxout, self.probs_emb) outenergy_1 = outenergy outenergy = tensor.dot(outenergy, self.probs) outenergy_2 = outenergy outenergy += self.probs_offset if self.verbose: return hiddens, outenergy, state_in, gate_in, reset_in, state_in_prev, readout, maxout, outenergy_1, outenergy_2 else: return hiddens, outenergy, attentions
def forward(self, emb_in, length, context, state_init, batch_size = 1, mask = None, cmask = None): ''' Build the computational graph which computes the hidden states. :type emb_in: theano variable :param emb_in: the input word embeddings :type length: theano variable :param length: the length of the input :type context: theano variable :param context: the context vectors :type state_init: theano variable :param state_init: the inital states :type batch_size: int :param batch_size: the batch size :type mask: theano variable :param mask: indicate the length of each sequence in one batch :type cmask: theano variable :param cmask: indicate the length of each context sequence in one batch ''' # calculate the input vector for inputter, updater and reseter att_c = tools.dot3d(context, self.att_context) # size: (length, batch_size,dim) state_in = (tensor.dot(emb_in, self.input_emb)+self.input_emb_offset).reshape((length, batch_size, self.dim)) gate_in = tensor.dot(emb_in, self.gate_emb).reshape((length, batch_size, self.dim)) reset_in = tensor.dot(emb_in, self.reset_emb).reshape((length, batch_size, self.dim)) if mask: scan_inp = [state_in, gate_in, reset_in, mask] scan_func = lambda x, g, r, m, h, c, attc, cm : self.forward_step(h, x, g, r, c, attc,m, cm) else: scan_inp = [state_in, gate_in, reset_in] scan_func = lambda x, g, r, h, c, attc : self.forward_step(h, x, g, r, c, attc) if self.verbose: outputs_info=[state_init, None, None, None, None, None, None, None, None, None, None, None, None] else: outputs_info=[state_init, None, None] # calculate hidden states hiddens, updates = theano.scan(scan_func, sequences = scan_inp, outputs_info = outputs_info, non_sequences = [context, att_c, cmask], n_steps = length) c = hiddens[1] attentions = hiddens[2] # Add the initial state and discard the last hidden state state_before = tensor.concatenate((state_init.reshape((1, state_init.shape[0], state_init.shape[1])) , hiddens[0][:-1])) state_in_prev = tensor.dot(emb_in, self.readout_emb).reshape((length, batch_size, self.dim)) # calculate the energy for each word readout_c = tensor.dot(c, self.readout_context) readout_h = tensor.dot(state_before, self.readout_hidden) readout_h += self.readout_offset state_in_prev = tools.shift_one(state_in_prev) readout = readout_c + readout_h + state_in_prev readout = readout.reshape((readout.shape[0] * readout.shape[1], readout.shape[2])) maxout = tools.maxout(readout, self.maxout) outenergy = tensor.dot(maxout, self.probs_emb) outenergy_1 = outenergy outenergy = tensor.dot(outenergy, self.probs) outenergy_2 = outenergy outenergy += self.probs_offset if self.verbose: return hiddens, outenergy, state_in, gate_in, reset_in, state_in_prev, readout, maxout, outenergy_1, outenergy_2 else: return hiddens, outenergy, attentions