Esempio n. 1
0
    def _forward(self, emissions):

        """Viterbi forward to calculate all path scores.

        :param emissions: List[dy.Expression]

        Returns:
            dy.Expression ((1,), B)
        """
        init_alphas = [-1e4] * self.n_tags
        init_alphas[self.start_idx] = 0

        alphas = dy.inputVector(init_alphas)
        transitions = self.transitions
        # len(emissions) == T
        for emission in emissions:
            add_emission = dy.colwise_add(transitions, emission)
            scores = dy.colwise_add(dy.transpose(add_emission), alphas)
            # dy.logsumexp takes a list of dy.Expression and computes logsumexp
            # elementwise across the lists so for example the logsumexp is calculated
            # for [0] in each list. This means we want the scores for a given
            # transition scores for a tag to be in the columns
            alphas = dy.logsumexp([x for x in scores])
        last_alpha = alphas + dy.pick(transitions, self.end_idx)
        alpha = dy.logsumexp([x for x in last_alpha])
        return alpha
Esempio n. 2
0
def logmulexp_vM(vec, Mat):
    '''calculate: log(exp(vec)*exp(A)), where * means a matrix is 
    left-multiplied by a vector
    @param vec: a vector expression, dim (d,)
    @param Mat: a matrix expression, dim (d,d)
    @return: a vector expression, dim (d,)
    '''
    vdims, Mdims = vec.dim()[0], Mat.dim()[0]
    assert len(vdims) == 1 and len(Mdims) == 2, "check dimension"
    d = vdims[0]
    Temp = dy.colwise_add(Mat, vec)
    rows = [Temp[i] for i in range(d)]
    ret = dy.logsumexp(rows)
    assert len(ret.dim()[0]) == 1, "check dimension"
    return ret
Esempio n. 3
0
def get_logloss_partition(factorexprs, valid_fes, sentlen):
    logalpha = [None for _ in range(sentlen)]
    # ssum = lossformula(sentlen, len(valid_fes))
    for j in range(sentlen):
        # full length spans
        spanscores = []
        if not USE_SPAN_CLIP or j <= ALLOWED_SPANLEN:
            spanscores = [factorexprs[Factor(0, j, y)] for y in valid_fes]

        # recursive case
        istart = 0
        if USE_SPAN_CLIP and j > ALLOWED_SPANLEN: istart = max(0, j - ALLOWED_SPANLEN - 1)
        for i in range(istart, j):
            facscores = [logalpha[i] + factorexprs[Factor(i + 1, j, y)] for y in valid_fes]
            spanscores.extend(facscores)

        if not USE_SPAN_CLIP and len(spanscores) != len(valid_fes) * (j + 1):
            raise Exception("counting errors")
        logalpha[j] = dy.logsumexp(spanscores)

    return logalpha[sentlen - 1]
 def process_batch(self, batch, training=False, debug=False):
     if self.args.test_samples > 1 and not training:
         results = []
         for _ in range(self.args.test_samples):
             dynet.renew_cg()
             results.append(
                 self.process_batch_internal(batch,
                                             training=training,
                                             debug=debug))
             results[-1]["loss"] = results[-1]["loss"].npvalue()
         dynet.renew_cg()
         for r in results:
             r["loss"] = dynet.inputTensor(r["loss"], batched=True)
         result = results[0]
         result["loss"] = -(dynet.logsumexp([-r["loss"] for r in results]) -
                            math.log(self.args.test_samples))
         return result
     else:
         return self.process_batch_internal(batch,
                                            training=training,
                                            debug=debug)
           indices: Array-like object for selection of cost-augmented logits
               to be maximized."""
    # assumed: either valid_actions or costs, the latter should typically incorporate
    # validity of actions.
    if valid_actions is not None:
        try:
            costs = -np.ones(logits_len) * np.inf
            costs[valid_actions] = 0.
        except Exception, e:
            print "np.ones_like(logits), valid_actions, costs: ", np.ones_like(
                logits), valid_actions, costs
            raise e
    if costs is not None:
        if verbose == 2: print 'Indices, costs: ', indices, costs
        logits += dy.inputVector(costs)
    log_sum_selected_terms = dy.logsumexp(
        [dy.pick(logits, index=e) for e in indices])
    normalization_term = dy.logsumexp([l for l in logits])
    return log_sum_selected_terms - normalization_term


def log_sum_softmax_loss(indices,
                         logits,
                         logits_len,
                         valid_actions=None,
                         verbose=False):
    """Compute dynamic-oracle negative log loss.
       Args:
           indices: Array-like object for selection of logits to be maximized."""
    # @TODO could just build cost vector from `valid_actions` and call `log_sum_softmax_margin_loss`
    if valid_actions is not None:
        # build cost vector expressing action invalidity
    def process_batch_internal(self, batch, training=False, debug=False):
        self.TRAINING_ITER = training
        self.DROPOUT = self.args.dropout if (
            self.TRAINING_ITER and self.args.dropout > 0) else None
        self.BATCH_SIZE = len(batch)
        self.instantiate_parameters()

        if self.args.use_cache: self.initialize_cache(batch)

        sents, masks = self.vocab.batchify(batch)

        # paths represent the different connections within the lattice. paths[i] contains all the state/chunk pairs that
        #  end at index i
        paths = [[] for _ in range(len(sents))]
        paths[0] = [(self.rnn.fresh_state(init_to_zero=True), sents[0],
                     dynet.scalarInput(0.0, device=self.args.param_device))]
        for tok_i in range(len(sents) - 1):
            # calculate the total probability of reaching this state
            _, _, lps = zip(*paths[tok_i])
            if len(lps) == 1:
                cum_lp = lps[0]
            else:
                cum_lp = dynet.logsumexp(list(lps))

            # add all previous state/chunk pairs to the tree_lstm
            new_state = self.rnn.fresh_state()
            if self.TRAINING_ITER and self.args.train_with_random and not self.first_time_memory_test:
                raise Exception("bruh")
            else:
                self.first_time_memory_test = False
                for state, c_t, lp in paths[tok_i]:
                    x_t = dynet.pick_batch(self.vocab_R, c_t)
                    h_t_stack, c_t_stack = state.add_input(x_t)
                    new_state.add_history(h_t_stack, c_t_stack, lp)

            # treeLSTM state merging
            new_state.concat_weights()
            if self.args.gumbel_sample:
                new_state.apply_gumbel_noise_to_weights(
                    temperature=max(.25, self.args.temperature))
                if not self.TRAINING_ITER or self.args.sample_train:
                    new_state.weights_to_argmax()
                # new_state.weights_to_argmax()

            # output of tree_lstm
            y_t = dynet.to_device(new_state.output(), self.args.param_device)
            if self.DROPOUT: y_t = dynet.cmult(y_t, self.dropout_mask_y_t)

            # get the list of next tokens to consider
            base_is = sents[tok_i + 1]
            n_ts = [[nt + (i * self.vocab.size) for nt in base_is]
                    for i in range(self.args.multi_size)]

            r_t = dynet.affine_transform([
                self.vocab_bias, self.vocab_R,
                dynet.tanh(dynet.affine_transform([self.bias, self.R, y_t]))
            ])
            for n_t in n_ts:
                lp = -dynet.pickneglogsoftmax_batch(r_t, n_t)
                paths[tok_i + 1].append((new_state, n_t, cum_lp + lp))

        ending_masks = [[0.0] * self.BATCH_SIZE for _ in range(len(masks))]
        for sent_i in range(len(batch)):
            ending_masks[batch[sent_i].index(
                self.vocab.end_token.s)][sent_i] = 1.0

        # put together all of the final path states to get the final error
        cum_lp = dynet.scalarInput(0.0, device=self.args.param_device)
        for path, mask in zip(paths, ending_masks):
            if max(mask) == 1:
                assert len(path) != 0
                _, _, lps = zip(*path)
                if len(lps) == 1:
                    local_cum_lp = lps[0]
                else:
                    local_cum_lp = dynet.logsumexp(list(lps))
                cum_lp += local_cum_lp * dynet.inputTensor(
                    mask, batched=True, device=self.args.param_device)

        if debug: return paths

        err = -cum_lp
        char_count = [1 + len(self.vocab.pp(sent[1:-1])) for sent in batch]
        word_count = [len(sent[1:]) for sent in batch]
        # word_count = [2+self.lattice_vocab.pp(sent[1:-1]).count(' ') for sent in batch]
        return {"loss": err, "charcount": char_count, "wordcount": word_count}
    def process_batch_internal(self, batch, training=False, debug=False):
        self.TRAINING_ITER = training
        self.DROPOUT = self.args.dropout if (
            self.TRAINING_ITER and self.args.dropout > 0) else None
        self.BATCH_SIZE = len(batch)
        self.instantiate_parameters()

        if self.args.use_cache: self.initialize_cache(batch)

        sents, masks = self.lattice_vocab.batchify(batch)

        # paths represent the different connections within the lattice. paths[i] contains all the state/chunk pairs that
        #  end at index i
        paths = [[] for _ in range(len(sents))]
        paths[0] = [(self.rnn.fresh_state(init_to_zero=True), [sents[0]],
                     dynet.scalarInput(0.0, device=self.args.param_device))]
        for tok_i in range(len(sents) - 1):
            # calculate the total probability of reaching this state
            _, _, lps = zip(*paths[tok_i])
            if len(lps) == 1: cum_lp = lps[0]
            else: cum_lp = dynet.logsumexp(list(lps))

            # add all previous state/chunk pairs to the tree_lstm
            new_state = self.rnn.fresh_state()
            if self.TRAINING_ITER and self.args.train_with_random and not self.first_time_memory_test:
                state, c_t, lp = random.choice(paths[tok_i])
                if self.args.use_cache: x_t = self.cached_embedding_lookup(c_t)
                else: x_t = self.get_chunk_embedding(c_t)
                h_t_stack, c_t_stack = state.add_input(x_t)
                new_state.add_history(h_t_stack, c_t_stack, lp)
            else:
                self.first_time_memory_test = False
                for state, c_t, lp in paths[tok_i]:
                    if self.args.use_cache:
                        x_t = self.cached_embedding_lookup(c_t)
                    else:
                        x_t = self.get_chunk_embedding(c_t)
                    h_t_stack, c_t_stack = state.add_input(x_t)
                    new_state.add_history(h_t_stack, c_t_stack, lp)

            # treeLSTM state merging
            new_state.concat_weights()
            if self.args.gumbel_sample:
                new_state.apply_gumbel_noise_to_weights(
                    temperature=max(.25, self.args.temperature))
                if not self.TRAINING_ITER: new_state.weights_to_argmax()
                # new_state.weights_to_argmax()

            # output of tree_lstm
            y_t = new_state.output()
            y_t = dynet.to_device(y_t, self.args.param_device)
            if self.DROPOUT: y_t = dynet.cmult(y_t, self.dropout_mask_y_t)

            # based on lattice_size, decide what set of chunks to consider from here
            if self.args.lattice_size < 1: end_tok_i = len(sents)
            else:
                end_tok_i = min(tok_i + 1 + self.args.lattice_size, len(sents))
            next_chunks = sents[tok_i + 1:end_tok_i]

            # for each chunk, calculate the probability of that chunk, and then add a pointer to the state/chunk into
            #  the place in the sentence where the chunk will end
            assert not (self.args.no_fixed_preds
                        and self.args.no_dynamic_preds)
            if not self.args.no_fixed_preds:
                fixed_chunk_lps, use_dynamic_lp = self.predict_chunks(
                    y_t, next_chunks)
            if not self.args.no_dynamic_preds:
                dynamic_chunk_lps = self.predict_chunks_by_tokens(
                    y_t, next_chunks)
            for chunk_i, tok_loc in enumerate(range(tok_i + 1, end_tok_i)):
                if self.args.no_fixed_preds:
                    lp = dynamic_chunk_lps[chunk_i]
                elif self.args.no_dynamic_preds:
                    lp = fixed_chunk_lps[chunk_i]
                else:  # we are using both fixed & dynamic predictions
                    lp = dynet.logsumexp([
                        fixed_chunk_lps[chunk_i],
                        use_dynamic_lp + dynamic_chunk_lps[chunk_i]
                    ])
                paths[tok_loc].append(
                    (new_state, sents[tok_i + 1:tok_loc + 1], cum_lp + lp))

        ending_masks = [[0.0] * self.BATCH_SIZE for _ in range(len(masks))]
        for sent_i in range(len(batch)):
            ending_masks[batch[sent_i].index(
                self.lattice_vocab.end_token.s)][sent_i] = 1.0

        # put together all of the final path states to get the final error
        cum_lp = dynet.scalarInput(0.0, device=self.args.param_device)
        for path, mask in zip(paths, ending_masks):
            if max(mask) == 1:
                assert len(path) != 0
                _, _, lps = zip(*path)
                if len(lps) == 1: local_cum_lp = lps[0]
                else: local_cum_lp = dynet.logsumexp(list(lps))
                cum_lp += local_cum_lp * dynet.inputTensor(
                    mask, batched=True, device=self.args.param_device)

        if debug: return paths

        err = -cum_lp
        char_count = [
            1 + len(self.lattice_vocab.pp(sent[1:-1])) for sent in batch
        ]
        word_count = [len(sent[1:]) for sent in batch]
        # word_count = [2+self.lattice_vocab.pp(sent[1:-1]).count(' ') for sent in batch]
        return {"loss": err, "charcount": char_count, "wordcount": word_count}
Esempio n. 8
0
def logadd(x, y):
    """Binary logsumexp"""
    return dy.logsumexp([x, y])