def transduce(self, embeds): expr_seq = [] seq_len = embeds.dim()[0][1] for i in range(seq_len): expr_seq.append(dy.max_dim(dy.select_cols(embeds, [i]), 1)) encodings = self.seq_transducer.transduce(ExpressionSequence(expr_seq)) return self.seq_transducer.get_final_states()[-1].main_expr()
def translate(self, x, beam_size=1): """Translate a source sentence Translate a single source sentence by decoding using beam search Arguments: x (list): Source sentence (list of indices) Keyword Arguments: beam_size (int): Size of the beam for beam search. A value of 1 means greedy decoding (default: (1)) Returns: list: generated translation (list of indices) """ dy.renew_cg() input_len = len(x) encodings = self.encode([x], test=True) # Decode # Add parameters to the graph Wp, bp = self.Wp_p.expr(), self.bp_p.expr() Wo, bo = self.Wo_p.expr(), self.bo_p.expr() D, b = dy.transpose(dy.parameter(self.MT_p)), self.b_p.expr() # Initialize decoder with last encoding last_enc = dy.select_cols(encodings, [encodings.dim()[0][-1] - 1]) init_state = dy.affine_transform([bp, Wp, last_enc]) ds = self.dec.initial_state([init_state, dy.zeroes((self.dh, ))]) # Initialize context context = dy.zeroes((self.enc_dim, )) # Initialize beam beam = [(ds, context, [self.trg_sos], 0.0)] # Loop for i in range(int(min(self.max_len, input_len * 1.5))): new_beam = [] for ds, pc, pw, logprob in beam: embs = dy.lookup(self.MT_p, pw[-1]) # Run LSTM ds = ds.add_input(dy.concatenate([embs, pc])) h = ds.output() # Compute next context context, _ = self.attend(encodings, h) # Compute output with residual connections output = dy.affine_transform( [bo, Wo, dy.concatenate([h, context, embs])]) # Score s = dy.affine_transform([b, D, output]) # Probabilities p = dy.softmax(s).npvalue().flatten() # Careful of float error p = p / p.sum() kbest = np.argsort(p) for nw in kbest[-beam_size:]: new_beam.append( (ds, context, pw + [nw], logprob + np.log(p[nw]))) beam = sorted(new_beam, key=lambda x: x[-1])[-beam_size:] if beam[-1][2][-1] == self.trg_eos: break return beam[-1][2]
def decode_loss(self, src_encodings, masks, tgt_seqs, sents_len): """ :param tgt_seqs: (tgt_heads, tgt_labels): list (length=batch_size) of (src_len) """ tgt_heads, tgt_labels = tgt_seqs src_len = len(tgt_heads[0]) batch_size = len(tgt_heads) np_tgt_heads = np.array(tgt_heads).flatten() # (src_len * batch_size) np_tgt_labels = np.array(tgt_labels).flatten() np_masks = np.array(masks).transpose().flatten() masks_expr = dy.inputVector(np_masks) masks_expr = dy.reshape(masks_expr, (1, ), batch_size=src_len * batch_size) s_arc, s_label = self.cal_scores( src_encodings ) # (src_len, src_len, bs), ([(src_len, src_len, bs)]) s_arc = dy.select_cols(s_arc, range(1, src_len + 1)) s_label = [ dy.select_cols(label, range(1, src_len + 1)) for label in s_label ] s_arc_value = s_arc.npvalue() s_arc_choice = np.argmax( s_arc_value, axis=0).transpose().flatten() # (src_len * batch_size) s_pick_labels = [ dy.pick_batch( dy.reshape(score, (src_len + 1, ), batch_size=src_len * batch_size), s_arc_choice) for score in s_label ] s_argmax_labels = dy.concatenate(s_pick_labels, d=0) # n_labels, src_len * batch_size reshape_s_arc = dy.reshape(s_arc, (src_len + 1, ), batch_size=src_len * batch_size) arc_loss = dy.pickneglogsoftmax_batch(reshape_s_arc, np_tgt_heads) arc_loss = arc_loss * masks_expr label_loss = dy.pickneglogsoftmax_batch(s_argmax_labels, np_tgt_labels) label_loss = label_loss * masks_expr loss = dy.sum_batches(arc_loss + label_loss) return loss
def split(x, dim=1): head_shape, batch_size = x.dim() res = [] if dim == 0: for i in range(head_shape[0]): res.append(dy.select_rows(x, [i])) elif dim == 1: for i in range(head_shape[1]): res.append(dy.select_cols(x, [i])) return res
def decode_loss(self, encodings, trg, test=False): """Compute the negative conditional log likelihood of the target sentence given the encoding of the source sentence Arguments: encodings (dynet.Expression): Source sentence encodings obtained with self.encode trg (list): List of target sentences Keyword Arguments: test (bool): Switch used for things like dropout where the behaviour is different at test time (default: (False) Returns: dynet.Expression: Expression of the loss averaged on the minibatch """ y, masksy = self.prepare_batch(trg, self.trg_eos) slen, bsize = y.shape # Add parameters to the graph Wp, bp = self.Wp_p.expr(), self.bp_p.expr() Wo, bo = self.Wo_p.expr(), self.bo_p.expr() D, b = dy.transpose(dy.parameter(self.MT_p)), self.b_p.expr() # Initialize decoder with last encoding last_enc = dy.select_cols(encodings, [encodings.dim()[0][-1] - 1]) init_state = dy.affine_transform([bp, Wp, last_enc]) ds = self.dec.initial_state( [init_state, dy.zeroes((self.dh, ), batch_size=bsize)]) # Initialize context context = dy.zeroes((self.enc_dim, ), batch_size=bsize) # Start decoding errs = [] for cw, nw, mask in zip(y, y[1:], masksy[1:]): embs = dy.lookup_batch(self.MT_p, cw) # Run LSTM ds = ds.add_input(dy.concatenate([embs, context])) h = ds.output() # Compute next context context, _ = self.attend(encodings, h) # Compute output with residual connections output = dy.affine_transform( [bo, Wo, dy.concatenate([h, context, embs])]) if not test: output = dy.dropout(output, self.dr) # Score s = dy.affine_transform([b, D, output]) masksy_e = dy.inputTensor(mask, batched=True) # Loss err = dy.cmult(dy.pickneglogsoftmax_batch(s, nw), masksy_e) errs.append(err) # Add all losses together err = dy.sum_batches(dy.esum(errs)) / float(bsize) return err
def _get_scores(self, sentence, is_train, elmo_embeddings): lstm_outputs = self._featurize_sentence( sentence, is_train=is_train, elmo_embeddings=elmo_embeddings) other_encodings = [] single_word_encodings = [] temporary_span_to_index = {} for left in range(len(sentence)): for right in range(left + 1, len(sentence) + 1): encoding = self._get_span_encoding(left, right, lstm_outputs) span = (left, right) if right - left == 1: temporary_span_to_index[span] = len(single_word_encodings) single_word_encodings.append(encoding) else: temporary_span_to_index[span] = len(other_encodings) other_encodings.append(encoding) encodings = single_word_encodings + other_encodings span_to_index = {} for span, index in temporary_span_to_index.items(): if span[1] - span[0] == 1: new_index = index else: new_index = index + len(single_word_encodings) span_to_index[span] = new_index span_encodings = dy.rectify( dy.reshape(self.f_encoding(dy.concatenate_to_batch(encodings)), (self.hidden_dim, len(encodings)))) label_scores = self.f_label(span_encodings) label_scores_reshaped = dy.reshape( label_scores, (self.label_vocab.size, len(encodings))) label_log_probabilities = dy.log_softmax(label_scores_reshaped) single_word_span_encodings = dy.select_cols( span_encodings, list(range(len(single_word_encodings)))) tag_scores = self.f_tag(single_word_span_encodings) tag_scores_reshaped = dy.reshape( tag_scores, (self.tag_vocab.size, len(single_word_encodings))) tag_log_probabilities = dy.log_softmax(tag_scores_reshaped) return label_log_probabilities, tag_log_probabilities, span_to_index
def predict(self, text): dy.renew_cg() output_mgc = [] last_mgc = self.start_lookup[0] x = self._make_input(text) x_speaker = self._get_speaker_embedding() x_fw = self.encoder_fw.initial_state().transduce(x) x_bw = self.encoder_bw.initial_state().transduce(reversed(x)) encoder = [ dy.concatenate([fw, bw, x_speaker]) for fw, bw in zip(x_fw, reversed(x_bw)) ] encoder = dy.concatenate(encoder, 1) hidden_encoder = self.att_w1 * encoder decoder = self.decoder.initial_state().add_input( self.decoder_start_lookup[0]) last_att_pos = 0 warm_up = 5 finish = 5 for k in range(5 * len(text)): attention_weights = ( self.att_v * dy.tanh(hidden_encoder + self.att_w2 * decoder.output()))[0] current_pos = np.argmax(attention_weights.npvalue()) if current_pos < last_att_pos: current_pos = last_att_pos if current_pos > last_att_pos: current_pos = last_att_pos + 1 last_att_pos = current_pos att = dy.select_cols(encoder, [current_pos]) if warm_up > 0: last_att_pos = 0 warm_up -= 1 if last_att_pos >= len(text): if finish > 0: finish -= 1 else: break mgc_proj = dy.tanh(self.last_mgc_proj_w * last_mgc + self.last_mgc_proj_b) decoder = decoder.add_input(dy.concatenate([mgc_proj, att])) hidden = dy.tanh(self.hid_w * decoder.output() + self.hid_b) highway_hidden = self.highway_w * att output_mgc.append( dy.logistic(highway_hidden + self.proj_w_1 * hidden + self.proj_b_1)) output_mgc.append( dy.logistic(highway_hidden + self.proj_w_2 * hidden + self.proj_b_2)) output_mgc.append( dy.logistic(highway_hidden + self.proj_w_3 * hidden + self.proj_b_3)) last_mgc = output_mgc[-1] ''' Stop layer seems to finish in about 40% of cases with an average of 3 steps earlier However, it introduces another matmul, and further testing proved that it is faster without it The testing was done on a sentence level, the results may be better on a word parallelization, so I'm not removing it yet. ''' # output_stop = dy.tanh(self.stop_w * decoder.output() + self.stop_b) # if output_stop.value() < -0.5: # break return output_mgc
def split_cols(matrix): total, rows, cols, bt = matrix.dim() assert bt == 1 return [dy.reshape(dy.select_cols(matrix, [i]), (rows,), batch_size=bt) for i in xrange(cols)]
def _span_parser_predict_pos(self, sentence, is_train, elmo_embeddings, gold=None): if gold is not None: assert isinstance(gold, ParseNode) lstm_outputs = self._featurize_sentence(sentence, is_train=is_train, elmo_embeddings=elmo_embeddings) other_encodings = [] single_word_encodings = [] temporary_span_to_index = {} for left in range(len(sentence)): for right in range(left + 1, len(sentence) + 1): encoding = self._get_span_encoding(left, right, lstm_outputs) span = (left, right) if right - left == 1: temporary_span_to_index[span] = len(single_word_encodings) single_word_encodings.append(encoding) else: temporary_span_to_index[span] = len(other_encodings) other_encodings.append(encoding) encodings = single_word_encodings + other_encodings span_to_index = {} for span, index in temporary_span_to_index.items(): if span[1] - span[0] == 1: new_index = index else: new_index = index + len(single_word_encodings) span_to_index[span] = new_index span_encodings = dy.rectify(dy.reshape(self.f_encoding(dy.concatenate_to_batch(encodings)), (self.hidden_dim, len(encodings)))) label_scores = self.f_label(span_encodings) label_scores_reshaped = dy.reshape(label_scores, (self.label_vocab.size, len(encodings))) label_log_probabilities = dy.log_softmax(label_scores_reshaped) single_word_span_encodings = dy.select_cols(span_encodings, list(range(len(single_word_encodings)))) tag_scores = self.f_tag(single_word_span_encodings) tag_scores_reshaped = dy.reshape(tag_scores, (self.tag_vocab.size, len(single_word_encodings))) tag_log_probabilities = dy.log_softmax(tag_scores_reshaped) if is_train: total_loss = dy.zeros(1) span_to_gold_label = get_all_spans(gold) for span, oracle_label in span_to_gold_label.items(): oracle_label_index = self.label_vocab.index(oracle_label) index = span_to_index[span] if span[1] - span[0] == 1: oracle_tag = sentence[span[0]][0] total_loss -= tag_log_probabilities[self.tag_vocab.index(oracle_tag)][index] total_loss -= label_log_probabilities[oracle_label_index][index] return total_loss else: label_log_probabilities_np = label_log_probabilities.npvalue() tag_log_probabilities_np = tag_log_probabilities.npvalue() sentence_with_tags = [] num_correct = 0 total = 0 # print('output has gold pos tags') for word_index, (oracle_tag, word) in enumerate(sentence): # tag_index = np.argmax(tag_log_probabilities_np[:, word_index]) # tag = self.tag_vocab.value(tag_index) # oracle_tag_is_deletable = oracle_tag in deletable_tags # predicted_tag_is_deletable = tag in deletable_tags # if oracle_tag is not None and oracle_tag not in self.tag_vocab.indices: # print(oracle_tag, 'not in tag vocab') # oracle_tag = None # if oracle_tag is not None: # oracle_tag_index = self.tag_vocab.index(oracle_tag) # if oracle_tag_index == tag_index and tag != oracle_tag: # if oracle_tag[0] != '-': # print(tag, oracle_tag) # tag = oracle_tag # num_correct += tag_index == oracle_tag_index # if oracle_tag is not None and oracle_tag_is_deletable != predicted_tag_is_deletable: # # print('falling back on gold tag', oracle_tag, tag) # sentence_with_tags.append((oracle_tag, word)) # else: # sentence_with_tags.append((tag, word)) assert oracle_tag is not None sentence_with_tags.append((oracle_tag, word)) total += 1 tree, additional_info = optimal_parser(label_log_probabilities_np, span_to_index, sentence_with_tags, self.empty_label_index, self.label_vocab, gold) return tree, (additional_info, num_correct, total)