def _forward(self, emissions): """Viterbi forward to calculate all path scores. :param emissions: List[dy.Expression] Returns: dy.Expression ((1,), B) """ init_alphas = [-1e4] * self.n_tags init_alphas[self.start_idx] = 0 alphas = dy.inputVector(init_alphas) transitions = self.transitions # len(emissions) == T for emission in emissions: add_emission = dy.colwise_add(transitions, emission) scores = dy.colwise_add(dy.transpose(add_emission), alphas) # dy.logsumexp takes a list of dy.Expression and computes logsumexp # elementwise across the lists so for example the logsumexp is calculated # for [0] in each list. This means we want the scores for a given # transition scores for a tag to be in the columns alphas = dy.logsumexp([x for x in scores]) last_alpha = alphas + dy.pick(transitions, self.end_idx) alpha = dy.logsumexp([x for x in last_alpha]) return alpha
def logmulexp_vM(vec, Mat): '''calculate: log(exp(vec)*exp(A)), where * means a matrix is left-multiplied by a vector @param vec: a vector expression, dim (d,) @param Mat: a matrix expression, dim (d,d) @return: a vector expression, dim (d,) ''' vdims, Mdims = vec.dim()[0], Mat.dim()[0] assert len(vdims) == 1 and len(Mdims) == 2, "check dimension" d = vdims[0] Temp = dy.colwise_add(Mat, vec) rows = [Temp[i] for i in range(d)] ret = dy.logsumexp(rows) assert len(ret.dim()[0]) == 1, "check dimension" return ret
def get_logloss_partition(factorexprs, valid_fes, sentlen): logalpha = [None for _ in range(sentlen)] # ssum = lossformula(sentlen, len(valid_fes)) for j in range(sentlen): # full length spans spanscores = [] if not USE_SPAN_CLIP or j <= ALLOWED_SPANLEN: spanscores = [factorexprs[Factor(0, j, y)] for y in valid_fes] # recursive case istart = 0 if USE_SPAN_CLIP and j > ALLOWED_SPANLEN: istart = max(0, j - ALLOWED_SPANLEN - 1) for i in range(istart, j): facscores = [logalpha[i] + factorexprs[Factor(i + 1, j, y)] for y in valid_fes] spanscores.extend(facscores) if not USE_SPAN_CLIP and len(spanscores) != len(valid_fes) * (j + 1): raise Exception("counting errors") logalpha[j] = dy.logsumexp(spanscores) return logalpha[sentlen - 1]
def process_batch(self, batch, training=False, debug=False): if self.args.test_samples > 1 and not training: results = [] for _ in range(self.args.test_samples): dynet.renew_cg() results.append( self.process_batch_internal(batch, training=training, debug=debug)) results[-1]["loss"] = results[-1]["loss"].npvalue() dynet.renew_cg() for r in results: r["loss"] = dynet.inputTensor(r["loss"], batched=True) result = results[0] result["loss"] = -(dynet.logsumexp([-r["loss"] for r in results]) - math.log(self.args.test_samples)) return result else: return self.process_batch_internal(batch, training=training, debug=debug)
indices: Array-like object for selection of cost-augmented logits to be maximized.""" # assumed: either valid_actions or costs, the latter should typically incorporate # validity of actions. if valid_actions is not None: try: costs = -np.ones(logits_len) * np.inf costs[valid_actions] = 0. except Exception, e: print "np.ones_like(logits), valid_actions, costs: ", np.ones_like( logits), valid_actions, costs raise e if costs is not None: if verbose == 2: print 'Indices, costs: ', indices, costs logits += dy.inputVector(costs) log_sum_selected_terms = dy.logsumexp( [dy.pick(logits, index=e) for e in indices]) normalization_term = dy.logsumexp([l for l in logits]) return log_sum_selected_terms - normalization_term def log_sum_softmax_loss(indices, logits, logits_len, valid_actions=None, verbose=False): """Compute dynamic-oracle negative log loss. Args: indices: Array-like object for selection of logits to be maximized.""" # @TODO could just build cost vector from `valid_actions` and call `log_sum_softmax_margin_loss` if valid_actions is not None: # build cost vector expressing action invalidity
def process_batch_internal(self, batch, training=False, debug=False): self.TRAINING_ITER = training self.DROPOUT = self.args.dropout if ( self.TRAINING_ITER and self.args.dropout > 0) else None self.BATCH_SIZE = len(batch) self.instantiate_parameters() if self.args.use_cache: self.initialize_cache(batch) sents, masks = self.vocab.batchify(batch) # paths represent the different connections within the lattice. paths[i] contains all the state/chunk pairs that # end at index i paths = [[] for _ in range(len(sents))] paths[0] = [(self.rnn.fresh_state(init_to_zero=True), sents[0], dynet.scalarInput(0.0, device=self.args.param_device))] for tok_i in range(len(sents) - 1): # calculate the total probability of reaching this state _, _, lps = zip(*paths[tok_i]) if len(lps) == 1: cum_lp = lps[0] else: cum_lp = dynet.logsumexp(list(lps)) # add all previous state/chunk pairs to the tree_lstm new_state = self.rnn.fresh_state() if self.TRAINING_ITER and self.args.train_with_random and not self.first_time_memory_test: raise Exception("bruh") else: self.first_time_memory_test = False for state, c_t, lp in paths[tok_i]: x_t = dynet.pick_batch(self.vocab_R, c_t) h_t_stack, c_t_stack = state.add_input(x_t) new_state.add_history(h_t_stack, c_t_stack, lp) # treeLSTM state merging new_state.concat_weights() if self.args.gumbel_sample: new_state.apply_gumbel_noise_to_weights( temperature=max(.25, self.args.temperature)) if not self.TRAINING_ITER or self.args.sample_train: new_state.weights_to_argmax() # new_state.weights_to_argmax() # output of tree_lstm y_t = dynet.to_device(new_state.output(), self.args.param_device) if self.DROPOUT: y_t = dynet.cmult(y_t, self.dropout_mask_y_t) # get the list of next tokens to consider base_is = sents[tok_i + 1] n_ts = [[nt + (i * self.vocab.size) for nt in base_is] for i in range(self.args.multi_size)] r_t = dynet.affine_transform([ self.vocab_bias, self.vocab_R, dynet.tanh(dynet.affine_transform([self.bias, self.R, y_t])) ]) for n_t in n_ts: lp = -dynet.pickneglogsoftmax_batch(r_t, n_t) paths[tok_i + 1].append((new_state, n_t, cum_lp + lp)) ending_masks = [[0.0] * self.BATCH_SIZE for _ in range(len(masks))] for sent_i in range(len(batch)): ending_masks[batch[sent_i].index( self.vocab.end_token.s)][sent_i] = 1.0 # put together all of the final path states to get the final error cum_lp = dynet.scalarInput(0.0, device=self.args.param_device) for path, mask in zip(paths, ending_masks): if max(mask) == 1: assert len(path) != 0 _, _, lps = zip(*path) if len(lps) == 1: local_cum_lp = lps[0] else: local_cum_lp = dynet.logsumexp(list(lps)) cum_lp += local_cum_lp * dynet.inputTensor( mask, batched=True, device=self.args.param_device) if debug: return paths err = -cum_lp char_count = [1 + len(self.vocab.pp(sent[1:-1])) for sent in batch] word_count = [len(sent[1:]) for sent in batch] # word_count = [2+self.lattice_vocab.pp(sent[1:-1]).count(' ') for sent in batch] return {"loss": err, "charcount": char_count, "wordcount": word_count}
def process_batch_internal(self, batch, training=False, debug=False): self.TRAINING_ITER = training self.DROPOUT = self.args.dropout if ( self.TRAINING_ITER and self.args.dropout > 0) else None self.BATCH_SIZE = len(batch) self.instantiate_parameters() if self.args.use_cache: self.initialize_cache(batch) sents, masks = self.lattice_vocab.batchify(batch) # paths represent the different connections within the lattice. paths[i] contains all the state/chunk pairs that # end at index i paths = [[] for _ in range(len(sents))] paths[0] = [(self.rnn.fresh_state(init_to_zero=True), [sents[0]], dynet.scalarInput(0.0, device=self.args.param_device))] for tok_i in range(len(sents) - 1): # calculate the total probability of reaching this state _, _, lps = zip(*paths[tok_i]) if len(lps) == 1: cum_lp = lps[0] else: cum_lp = dynet.logsumexp(list(lps)) # add all previous state/chunk pairs to the tree_lstm new_state = self.rnn.fresh_state() if self.TRAINING_ITER and self.args.train_with_random and not self.first_time_memory_test: state, c_t, lp = random.choice(paths[tok_i]) if self.args.use_cache: x_t = self.cached_embedding_lookup(c_t) else: x_t = self.get_chunk_embedding(c_t) h_t_stack, c_t_stack = state.add_input(x_t) new_state.add_history(h_t_stack, c_t_stack, lp) else: self.first_time_memory_test = False for state, c_t, lp in paths[tok_i]: if self.args.use_cache: x_t = self.cached_embedding_lookup(c_t) else: x_t = self.get_chunk_embedding(c_t) h_t_stack, c_t_stack = state.add_input(x_t) new_state.add_history(h_t_stack, c_t_stack, lp) # treeLSTM state merging new_state.concat_weights() if self.args.gumbel_sample: new_state.apply_gumbel_noise_to_weights( temperature=max(.25, self.args.temperature)) if not self.TRAINING_ITER: new_state.weights_to_argmax() # new_state.weights_to_argmax() # output of tree_lstm y_t = new_state.output() y_t = dynet.to_device(y_t, self.args.param_device) if self.DROPOUT: y_t = dynet.cmult(y_t, self.dropout_mask_y_t) # based on lattice_size, decide what set of chunks to consider from here if self.args.lattice_size < 1: end_tok_i = len(sents) else: end_tok_i = min(tok_i + 1 + self.args.lattice_size, len(sents)) next_chunks = sents[tok_i + 1:end_tok_i] # for each chunk, calculate the probability of that chunk, and then add a pointer to the state/chunk into # the place in the sentence where the chunk will end assert not (self.args.no_fixed_preds and self.args.no_dynamic_preds) if not self.args.no_fixed_preds: fixed_chunk_lps, use_dynamic_lp = self.predict_chunks( y_t, next_chunks) if not self.args.no_dynamic_preds: dynamic_chunk_lps = self.predict_chunks_by_tokens( y_t, next_chunks) for chunk_i, tok_loc in enumerate(range(tok_i + 1, end_tok_i)): if self.args.no_fixed_preds: lp = dynamic_chunk_lps[chunk_i] elif self.args.no_dynamic_preds: lp = fixed_chunk_lps[chunk_i] else: # we are using both fixed & dynamic predictions lp = dynet.logsumexp([ fixed_chunk_lps[chunk_i], use_dynamic_lp + dynamic_chunk_lps[chunk_i] ]) paths[tok_loc].append( (new_state, sents[tok_i + 1:tok_loc + 1], cum_lp + lp)) ending_masks = [[0.0] * self.BATCH_SIZE for _ in range(len(masks))] for sent_i in range(len(batch)): ending_masks[batch[sent_i].index( self.lattice_vocab.end_token.s)][sent_i] = 1.0 # put together all of the final path states to get the final error cum_lp = dynet.scalarInput(0.0, device=self.args.param_device) for path, mask in zip(paths, ending_masks): if max(mask) == 1: assert len(path) != 0 _, _, lps = zip(*path) if len(lps) == 1: local_cum_lp = lps[0] else: local_cum_lp = dynet.logsumexp(list(lps)) cum_lp += local_cum_lp * dynet.inputTensor( mask, batched=True, device=self.args.param_device) if debug: return paths err = -cum_lp char_count = [ 1 + len(self.lattice_vocab.pp(sent[1:-1])) for sent in batch ] word_count = [len(sent[1:]) for sent in batch] # word_count = [2+self.lattice_vocab.pp(sent[1:-1]).count(' ') for sent in batch] return {"loss": err, "charcount": char_count, "wordcount": word_count}
def logadd(x, y): """Binary logsumexp""" return dy.logsumexp([x, y])