Exemple #1
0
    def forward(self, words, labels=None):
        multitask = labels is not None

        if self.training:
            words = self.word_vocab.unkify(words)

        rnn = self.rnn_builder.initial_state()

        word_ids = [self.word_vocab.index_or_unk(word)
            for word in [START] + words + [STOP]]

        prev_embeddings = [self.embeddings[word_id] for word_id in word_ids[:-1]]
        lstm_outputs = rnn.transduce(prev_embeddings)
        logits = self.out(dy.concatenate_to_batch(lstm_outputs))
        nlls = dy.pickneglogsoftmax_batch(logits, word_ids[1:])
        word_nll = dy.sum_batches(nlls)

        if multitask:

            label_ids = [self.label_vocab.index(label) for label in labels]

            logits = self.f_label(dy.concatenate_to_batch(lstm_outputs[1:]))
            nlls = dy.pickneglogsoftmax_batch(logits, label_ids)
            label_nll = dy.sum_batches(nlls)

            # easy proxy to track progress on this task
            self.correct += np.sum(np.argmax(logits.npvalue(), axis=0) == label_ids)
            self.predicted += len(label_ids)

            nll = word_nll + label_nll

        else:
            nll = word_nll

        return nll
Exemple #2
0
def calc_sent_loss(sent):
    # Create a computation graph
    dy.renew_cg()

    # Get embeddings for the sentence
    emb = [W_w_p[x] for x in sent]

    # Sample K negative words for each predicted word at each position
    all_neg_words = np.random.choice(nwords,
                                     size=2 * N * K * len(emb),
                                     replace=True,
                                     p=word_probabilities)

    # W_w = dy.parameter(W_w_p)
    # Step through the sentence and calculate the negative and positive losses
    all_losses = []
    for i, my_emb in enumerate(emb):
        neg_words = all_neg_words[i * K * 2 * N:(i + 1) * K * 2 * N]
        pos_words = (
            [sent[x] if x >= 0 else S for x in range(i - N, i)] +
            [sent[x] if x < len(sent) else S for x in range(i + 1, i + N + 1)])
        neg_loss = -dy.log(
            dy.logistic(
                -dy.dot_product(my_emb, dy.lookup_batch(W_c_p, neg_words))))
        pos_loss = -dy.log(
            dy.logistic(
                dy.dot_product(my_emb, dy.lookup_batch(W_c_p, pos_words))))
        all_losses.append(dy.sum_batches(neg_loss) + dy.sum_batches(pos_loss))
    return dy.esum(all_losses)
Exemple #3
0
    def compute(self, comb_method: str = "sum") -> Tuple[dy.Expression, Dict]:
        """
    Compute loss as DyNet expression by summing over factors and batch elements.

    Args:
      comb_method: method for combining loss across batch elements ('sum' or 'avg').

    Returns:
      Scalar DyNet expression.
    """
        loss_exprs = 0
        loss_data = {}

        for name, loss_expr in self.expr_factors.items():
            expr, units = loss_expr.loss_value()
            loss_exprs += expr
            loss_data[name] = dy.sum_batches(expr).value(), units

        # Combining
        if comb_method == "sum":
            loss_exprs = dy.sum_batches(loss_exprs)
        elif comb_method == "avg":
            loss_exprs = dy.sum_batches(loss_exprs) * (1.0 /
                                                       loss_exprs.dim()[1])
        else:
            raise ValueError(
                f"Unknown batch combination method '{comb_method}', expected 'sum' or 'avg'.'"
            )

        return loss_exprs, loss_data
Exemple #4
0
 def _combine_batches(self, batched_expr, comb_method: str = "sum"):
   if comb_method == "sum":
     return dy.sum_batches(batched_expr)
   elif comb_method == "avg":
     return dy.sum_batches(batched_expr) * (1.0 / batched_expr.dim()[1])
   else:
     raise ValueError(f"Unknown batch combination method '{comb_method}', expected 'sum' or 'avg'.'")
Exemple #5
0
    def forward(self, words, spans=None):
        multitask = spans is not None

        if self.training:
            words = self.word_vocab.unkify(words)

        rnn = self.rnn_builder.initial_state()

        word_ids = [self.word_vocab.index_or_unk(word)
            for word in [START] + words + [STOP]]

        prev_embeddings = [self.embeddings[word_id] for word_id in word_ids[:-1]]
        lstm_outputs = rnn.transduce(prev_embeddings)
        logits = self.out(dy.concatenate_to_batch(lstm_outputs))
        nlls = dy.pickneglogsoftmax_batch(logits, word_ids[1:])
        word_nll = dy.sum_batches(nlls)

        if multitask:

            # predict label for each possible span (null for nonexistent spans)
            if self.predict_all_spans:
                gold_spans = {(left, right): self.label_vocab.index(label)
                    for left, right, label in spans}

                all_spans = [(left, left + length)
                    for length in range(1, len(words) + 1)
                    for left in range(0, len(words) + 1 - length)]

                label_ids = [gold_spans.get((left, right), self.label_vocab.size)  # last index is for null label
                    for left, right in all_spans]

                # 'lstm minus' features, same as those of the crf parser
                span_encodings = [lstm_outputs[right] - lstm_outputs[left]
                    for left, right in all_spans]

            # only predict labels for existing spans
            else:
                label_ids = [self.label_vocab.index(label) for _, _, label in spans]

                # 'lstm minus' features, same as those of the crf parser
                span_encodings = [lstm_outputs[right] - lstm_outputs[left]
                    for left, right, label in spans]

            logits = self.f_label(dy.concatenate_to_batch(span_encodings))
            nlls = dy.pickneglogsoftmax_batch(logits, label_ids)
            label_nll = dy.sum_batches(nlls)

            # easy proxy to track progress on this task
            self.correct += np.sum(np.argmax(logits.npvalue(), axis=0) == label_ids)
            self.predicted += len(label_ids)

            nll = word_nll + label_nll

        else:
            nll = word_nll

        return nll
 def calc_loss(self, policy):
     if self.weight < 1e-8:
         return None
     neg_entropy = []
     for i, ll in enumerate(policy):
         if self.valid_pos is not None:
             ll = dy.pick_batch_elems(ll, self.valid_pos[i])
         loss = dy.sum_batches(dy.sum_elems(dy.cmult(dy.exp(ll), ll)))
         neg_entropy.append(dy.sum_batches(loss))
     return self.weight * dy.esum(neg_entropy)
Exemple #7
0
    def decode_loss(self, src_encodings, tgt_seqs):
        """
        :param tgt_seqs: (tgt_heads, tgt_labels): list (length=batch_size) of (src_len)
        """

        # todo(NOTE): Sentences should start with empty token (as root of dependency tree)!

        tgt_heads, tgt_labels = tgt_seqs

        src_len = len(tgt_heads[0])
        batch_size = len(tgt_heads)
        np_tgt_heads = np.array(tgt_heads).flatten()  # (src_len * batch_size)
        np_tgt_labels = np.array(tgt_labels).flatten()
        s_arc, s_label = self.cal_scores(src_encodings)  # (src_len, src_len, bs), ([(src_len, src_len, bs)])

        s_arc_value = s_arc.npvalue()
        s_arc_choice = np.argmax(s_arc_value, axis=0).transpose().flatten()  # (src_len * batch_size)

        s_pick_labels = [dy.pick_batch(dy.reshape(score, (src_len,), batch_size=src_len * batch_size), s_arc_choice)
                     for score in s_label]
        s_argmax_labels = dy.concatenate(s_pick_labels, d=0)  # n_labels, src_len * batch_size

        reshape_s_arc = dy.reshape(s_arc, (src_len,), batch_size=src_len * batch_size)
        arc_loss = dy.pickneglogsoftmax_batch(reshape_s_arc, np_tgt_heads)
        label_loss = dy.pickneglogsoftmax_batch(s_argmax_labels, np_tgt_labels)

        loss = dy.sum_batches(arc_loss + label_loss) / batch_size
        return loss
Exemple #8
0
    def decode_loss(self, src_encodings, tgt_seqs):
        """
        :param tgt_seqs: (tgt_heads, tgt_labels): list (length=batch_size) of (src_len)
        """

        # todo(NOTE): Sentences should start with empty token (as root of dependency tree)!

        tgt_heads, tgt_labels = tgt_seqs

        src_len = len(tgt_heads[0])
        batch_size = len(tgt_heads)
        np_tgt_heads = np.array(tgt_heads).flatten()  # (src_len * batch_size)
        np_tgt_labels = np.array(tgt_labels).flatten()
        s_arc, s_label = self.cal_scores(src_encodings)  # (src_len, src_len, bs), ([(src_len, src_len, bs)])

        s_arc_value = s_arc.npvalue()
        s_arc_choice = np.argmax(s_arc_value, axis=0).transpose().flatten()  # (src_len * batch_size)

        s_pick_labels = [dy.pick_batch(dy.reshape(score, (src_len,), batch_size=src_len * batch_size), s_arc_choice)
                     for score in s_label]
        s_argmax_labels = dy.concatenate(s_pick_labels, d=0)  # n_labels, src_len * batch_size

        reshape_s_arc = dy.reshape(s_arc, (src_len,), batch_size=src_len * batch_size)
        arc_loss = dy.pickneglogsoftmax_batch(reshape_s_arc, np_tgt_heads)
        label_loss = dy.pickneglogsoftmax_batch(s_argmax_labels, np_tgt_labels)

        loss = dy.sum_batches(arc_loss + label_loss) / batch_size
        return loss
    def fit(self,
            X_train=None,
            y_train=None,
            X_test=None,
            y_test=None,
            epochs=None):
        X_train, y_train, X_test, y_test = self._prepare_data(
            X_train, y_train, X_test, y_test)
        self._initialize_model()

        print('Starting training...')

        self.trainer = dy.AdamTrainer(self.model)
        losses = []

        for _ in tqdm(range(epochs)):
            curr_loss = 0.0

            for X, y in zip(X_train, y_train):
                y_prob = self._predict_proba(X)

                loss = dy.sum_batches(dy.pickneglogsoftmax_batch(y_prob, y))
                curr_loss += loss.value()

                loss.backward()
                self.trainer.update()

            losses.append(curr_loss / len(X_train))

            print('Train Loss:', losses[-1])
            self.evaluate(X_test, y_test)
            print()

        print('Done training')
Exemple #10
0
 def compose(
         self, embeds: Union[dy.Expression,
                             List[dy.Expression]]) -> dy.Expression:
     if type(embeds) != list:
         return dy.sum_batches(embeds)
     else:
         return dy.esum(embeds)
Exemple #11
0
  def calc_loss(self, src, db_idx, src_mask=None, trg_mask=None):
    src_embeddings = self.src_embedder.embed_sent(src, mask=src_mask)
    self.src_encoder.set_input(src)
    src_encodings = self.exprseq_pooling(self.src_encoder.transduce(src_embeddings))
    trg_batch, trg_mask = self.database[db_idx]
    # print("trg_mask=\n",trg_mask)
    trg_encodings = self.encode_trg_example(trg_batch, mask=trg_mask)
    dim = trg_encodings.dim()
    trg_reshaped = dy.reshape(trg_encodings, (dim[0][0], dim[1]))
    # ### DEBUG
    # trg_npv = trg_reshaped.npvalue()
    # for i in range(dim[1]):
    #   print("--- trg_reshaped {}: {}".format(i,list(trg_npv[:,i])))
    # ### DEBUG
    prod = dy.transpose(src_encodings) * trg_reshaped
    # ### DEBUG
    # prod_npv = prod.npvalue()
    # for i in range(dim[1]):
    #   print("--- prod {}: {}".format(i,list(prod_npv[0].transpose()[i])))
    # ### DEBUG
    id_range = list(range(len(db_idx)))
    # This is ugly:
    if self.loss_direction == "forward":
      prod = dy.transpose(prod)
      loss = dy.sum_batches(dy.hinge_batch(prod, id_range))
    elif self.loss_direction == "bidirectional":
      prod = dy.reshape(prod, (len(db_idx), len(db_idx)))
      loss = dy.sum_elems(
        dy.hinge_dim(prod, id_range, d=0) + dy.hinge_dim(prod, id_range, d=1))
    else:
      raise RuntimeError("Illegal loss direction {}".format(self.loss_direction))

    return loss
Exemple #12
0
 def test_inputTensor_batched_list(self):
     for i in range(4):
         dy.renew_cg()
         input_tensor = self.input_vals.reshape(self.shapes[i])
         xb = dy.inputTensor([np.asarray(x).transpose()
                              for x in input_tensor.transpose()])
         self.assertEqual(
             xb.dim()[0],
             (self.shapes[i][:-1] if i > 0 else (1,)),
             msg="Dimension mismatch"
         )
         self.assertEqual(
             xb.dim()[1],
             self.shapes[i][-1],
             msg="Dimension mismatch"
         )
         self.assertTrue(
             np.allclose(xb.npvalue(), input_tensor),
             msg="Expression value different from initial value"
         )
         self.assertEqual(
             dy.sum_batches(dy.squared_norm(xb)).scalar_value(),
             self.squared_norm,
             msg="Value mismatch"
         )
Exemple #13
0
    def calc_loss(self, bisents):
        """
        :param bisents: List of (batch size) parallel sentences.
        :return: Average batched loss and number of words processed.
        """
        dy.renew_cg()
        src_sents = [x[0] for x in bisents]
        tgt_sents = [x[1] for x in bisents]
        self.encode(src_sents)
        self.decoder.init(
            dy.affine_transform([
                dy.parameter(self.b_bridge),
                dy.parameter(self.W_bridge),
                self.encoder.final_state()
            ]))

        # mask batch
        tgt_sents_by_words, masks, num_words = prepare_masks(
            tgt_sents, self.tgt_vocab.eos)
        prev_words = tgt_sents_by_words[0]
        all_losses = []

        for next_words, mask in zip(tgt_sents_by_words[1:], masks):
            scores = self.decode(
                prev_words)  # get the decoder output on the previous time step
            loss = self.cross_entropy_loss(scores, next_words)
            mask_expr = dy.reshape(dy.inputVector(mask), (1, ),
                                   len(bisents))  # change dimension
            mask_loss = loss * mask_expr
            all_losses.append(mask_loss)
            prev_words = next_words
        return dy.sum_batches(dy.esum(all_losses)), num_words
Exemple #14
0
    def get_loss_batch(self, sent_array):
        renew_cg()
        init_state = self.builder.initial_state()

        R = parameter(self.R)
        bias = parameter(self.bias)
        wids = []
        masks = []

        # get the wids and masks for each step
        # "I am good", "This is good", "Good Morning" -> [['I', 'Today', 'Good'], ['am', 'is', 'Morning'], ['good', 'good', '<S>'], ['I', 'Today', 'Good'], ['am', 'is', 'Morning'], ['good', 'good', '<S>']]

        tot_words = 0
        wids = []
        masks = []
        for i in range(len(sent_array[0])):
            wids.append([(sent[i] if len(sent) > i else 3)
                         for sent in sent_array])
            mask = [(1 if len(sent) > i else 0) for sent in sent_array]
            masks.append(mask)
            tot_words += sum(mask)

        # start the rnn by inputting "<s>"
        init_ids = [2] * len(sent_array)
        #print dy.lookup_batch(self.lookup,init_ids)
        #print "Looked up"
        s = init_state.add_input(dy.lookup_batch(self.lookup, init_ids))

        # feed word vectors into the RNN and predict the next word
        losses = []
        for wid, mask in zip(wids, masks):
            # calculate the softmax and loss
            #print "WID ", wid
            score = dy.affine_transform([bias, R, s.output()])
            loss = dy.pickneglogsoftmax_batch(score, wid)
            # mask the loss if at least one sentence is shorter
            if mask[-1] != 1:
                mask_expr = dy.inputVector(mask)
                mask_expr = dy.reshape(mask_expr, (1, ), len(sent_array))
                loss = loss * mask_expr
            losses.append(loss)
            # update the state of the RNN
            wemb = dy.lookup_batch(self.lookup, wid)
            s = s.add_input(wemb)

        return dy.sum_batches(dy.esum(losses)), tot_words

        errs = []  # will hold expressions
        es = []

        for (wid, mask) in zip(wids, masks):
            # assume word is already a word-id
            x_t = lookup(self.lookup, int(cw))
            state = state.add_input(x_t)
            y_t = state.output()
            r_t = bias + (R * y_t)
            err = pickneglogsoftmax(r_t, int(nw))
            errs.append(err)
        nerr = esum(errs)
        return nerr
Exemple #15
0
 def calc_loss(self, policy_reward, only_final_reward=True):
   loss = losses.FactoredLossExpr()
   ## Calculate baseline
   pred_reward, baseline_loss = self.calc_baseline_loss(policy_reward, only_final_reward)
   if only_final_reward:
     rewards = [policy_reward - pw_i for pw_i in pred_reward]
   else:
     rewards = [pr_i - pw_i for pr_i, pw_i in zip(policy_reward, pred_reward)]
   loss.add_loss("rl_baseline", baseline_loss)
   ## Z-Normalization
   rewards = dy.concatenate(rewards, d=0)
   if self.z_normalization:
     rewards_value = rewards.value()
     rewards_mean = np.mean(rewards_value)
     rewards_std = np.std(rewards_value) + 1e-10
     rewards = (rewards - rewards_mean) / rewards_std
   ## Calculate Confidence Penalty
   if self.confidence_penalty:
     cp_loss = self.confidence_penalty.calc_loss(self.policy_lls)
     loss.add_loss("rl_confpen", cp_loss)
   ## Calculate Reinforce Loss
   reinf_loss = []
   # Loop through all action in one sequence
   for i, (policy, action) in enumerate(zip(self.policy_lls, self.actions)):
     # Main Reinforce calculation
     reward = dy.pick(rewards, i)
     ll = dy.pick_batch(policy, action)
     if self.valid_pos is not None:
       ll = dy.pick_batch_elems(ll, self.valid_pos[i])
       reward = dy.pick_batch_elems(reward, self.valid_pos[i])
     reinf_loss.append(dy.sum_batches(ll * reward))
   loss.add_loss("rl_reinf", -self.weight * dy.esum(reinf_loss))
   ## the composed losses
   return loss
Exemple #16
0
def get_loss(x, y):
    """
    Get loss -log(softmax(score[y]))
    """
    score = run_MLP(x)
    bsize = x.shape[0]
    return dy.sum_batches(dy.pickneglogsoftmax_batch(score, y)) / bsize
Exemple #17
0
 def add_loss(self, loss_name, loss_expr):
     if type(loss_expr) == LossBuilder:
         self.loss_nodes.extend(loss_expr.loss_nodes)
     else:
         if loss_expr.dim()[1] > 1:
             loss_expr = dy.sum_batches(loss_expr)
         self.loss_nodes.append((loss_name, loss_expr))
Exemple #18
0
def get_loss(x, y):
    """
    Get loss -log(softmax(score[y]))
    """
    score = run_IRNN(x)
    bsize, seq_len = x.shape
    return dy.sum_batches(dy.pickneglogsoftmax_batch(score, y)) / bsize
Exemple #19
0
    def calc_loss(self, src_seqs, trg_seqs, training=True):
        batch_size = len(src_seqs)
        src_encodings = self.encoder.encode(src_seqs, training=training)
        src_enc_all = dy.concatenate_cols(src_encodings)
        src_trans_att = self.attender.get_src_transformation(src_enc_all)
        state = self.decoder.initialize(src_encodings, training=training)
        ctx_tm1 = dy.vecInput(self.encoder.state_dim)
        losses = []

        max_len = max(map(len, trg_seqs))
        for i in xrange(1, max_len):
            y_tm1 = [trg_seq[i - 1] if i < len(trg_seq) else trg_seq[-1] for trg_seq in trg_seqs]
            ref_y_t = [trg_seq[i] if i < len(trg_seq) else trg_seq[-1] for trg_seq in trg_seqs]
            y_tm1_embed = self.decoder.embedder.embed_item(y_tm1, training=training)

            x = dy.concatenate([y_tm1_embed, ctx_tm1])
            state = state.add_input(x)
            h_t = state.output()
            ctx_t, alpha_t = self.attender.calc_context(src_enc_all, src_trans_att, h_t)

            loss_t = self.decoder.calc_loss(h_t, ctx_t, ref_y_t, training=training)

            mask = dy.inputVector([1 if i < len(trg_seq) else 0 for trg_seq in trg_seqs])
            mask = dy.reshape(mask, (1,), batch_size)
            loss_t = dy.sum_batches(loss_t * mask)

            ctx_tm1 = ctx_t
            losses.append(loss_t)

        loss = dy.esum(losses)
        return loss
    def step_batch(self, instances):
        dy.renew_cg()
        self.l2r_builder.set_dropout(0.2)
        self.r2l_builder.set_dropout(0.2)
        self.dec_builder.set_dropout(0.2)
    
        W_y = dy.parameter(self.W_y)
        b_y = dy.parameter(self.b_y)
        src_sents = [x[0] for x in instances]
        padded_src = self.__pad_batch(src_sents, True)
        src_cws = np.transpose(padded_src)
        tgt_sents = [x[1] for x in instances]
        padded_tgt = self.__pad_batch(tgt_sents, False)
        masks_tgt, num_words = self.__mask(tgt_sents)
        masks_tgt = np.transpose(masks_tgt)
        padded_tgt = np.transpose(padded_tgt)
        instance_size = len(instances)
        src_cws_rev = list(reversed(src_cws))
        # Bidirectional representations
        l2r_state = self.l2r_builder.initial_state()
        r2l_state = self.r2l_builder.initial_state()
        l2r_contexts = []
        r2l_contexts = []
        for (cws_l2r, cws_r2l) in zip(src_cws, src_cws_rev):
            l2r_state = l2r_state.add_input(dy.lookup_batch(self.src_lookup, cws_l2r))
            r2l_state = r2l_state.add_input(dy.lookup_batch(self.src_lookup, cws_r2l))
            l2r_contexts.append(l2r_state.output()) #[<S>, x_1, x_2, ..., </S>]
            r2l_contexts.append(r2l_state.output()) #[</S> x_n, x_{n-1}, ... <S>]
        r2l_contexts.reverse() #[<S>, x_1, x_2, ..., </S>]
        # Combine the left and right representations for every word
        h_fs = []
        for (l2r_i, r2l_i) in zip(l2r_contexts, r2l_contexts):
            h_fs.append(dy.concatenate([l2r_i, r2l_i]))
        h_fs_matrix = dy.concatenate_cols(h_fs)
        losses = []
        
        # Decoder
        c_t = dy.vecInput(self.hidden_size * 2)
        start = dy.concatenate([dy.lookup_batch(self.tgt_lookup, len(tgt_sents) * [self.tgt_token_to_id[self.src_pad]]), c_t])
        dec_state = self.dec_builder.initial_state().add_input(start)

        for (cws, nws, mask) in zip(padded_tgt, padded_tgt[1:], masks_tgt):
            h_e = dec_state.output()
            c_t = self.__attention_mlp(h_fs_matrix, h_e)
            # Get the embedding for the current target word
            embed_t = dy.lookup_batch(self.tgt_lookup, cws)
            # Create input vector to the decoder
            x_t = dy.concatenate([embed_t, c_t])
            dec_state = dec_state.add_input(x_t)
            y_star = b_y + W_y * dec_state.output()
            loss = dy.pickneglogsoftmax_batch(y_star, nws)
            if mask[-1] == 0:
                mask_loss = dy.reshape(dy.inputVector(mask), (1,), instance_size)
                masked = loss * mask_loss
                losses.append(masked)
            else:
                losses.append(loss)
            #losses = [(x / num_words) for x in losses]
        return dy.sum_batches(dy.esum(losses)), num_words
    def compute_decoder_batch_loss(self, encoded_inputs, input_masks, output_word_ids, output_masks, batch_size):
        self.readout = dn.parameter(self.params['readout'])
        self.bias = dn.parameter(self.params['bias'])
        self.w_c = dn.parameter(self.params['w_c'])
        self.u_a = dn.parameter(self.params['u_a'])
        self.v_a = dn.parameter(self.params['v_a'])
        self.w_a = dn.parameter(self.params['w_a'])

        # initialize the decoder rnn
        s_0 = self.decoder_rnn.initial_state()

        # initial "input feeding" vectors to feed decoder - 3*h
        init_input_feeding = dn.lookup_batch(self.init_lookup, [0] * batch_size)

        # initial feedback embeddings for the decoder, use begin seq symbol embedding
        init_feedback = dn.lookup_batch(self.output_lookup, [self.y2int[common.BEGIN_SEQ]] * batch_size)

        # init decoder rnn
        decoder_init = dn.concatenate([init_feedback, init_input_feeding])
        s = s_0.add_input(decoder_init)

        # loss per timestep
        losses = []

        # run the decoder through the output sequences and aggregate loss
        for i, step_word_ids in enumerate(output_word_ids):

            # returns h x batch size matrix
            decoder_rnn_output = s.output()

            # compute attention context vector for each sequence in the batch (returns 2h x batch size matrix)
            attention_output_vector, alphas = self.attend(encoded_inputs, decoder_rnn_output, input_masks)

            # compute output scores (returns vocab_size x batch size matrix)
            # h = readout * attention_output_vector + bias
            h = dn.affine_transform([self.bias, self.readout, attention_output_vector])

            # get batch loss for this timestep
            batch_loss = dn.pickneglogsoftmax_batch(h, step_word_ids)

            # mask the loss if at least one sentence is shorter
            if output_masks and output_masks[i][-1] != 1:
                mask_expr = dn.inputVector(output_masks[i])
                # noinspection PyArgumentList
                mask_expr = dn.reshape(mask_expr, (1,), batch_size)
                batch_loss = batch_loss * mask_expr

            # input feeding approach - input h (attention_output_vector) to the decoder
            # prepare for the next iteration - "feedback"
            feedback_embeddings = dn.lookup_batch(self.output_lookup, step_word_ids)
            decoder_input = dn.concatenate([feedback_embeddings, attention_output_vector])
            s = s.add_input(decoder_input)

            losses.append(batch_loss)

        # sum the loss over the time steps and batch
        total_batch_loss = dn.sum_batches(dn.esum(losses))
        return total_batch_loss
    def step_batch(self, instances):
        dy.renew_cg()
        W_y = dy.parameter(self.W_y)
        b_y = dy.parameter(self.b_y)
        src_sents = [x[0] for x in instances]
        padded_src = self.__pad_batch(src_sents)
        masks_src = np.transpose(self.__mask(padded_src))
        src_cws = np.transpose(padded_src)
        tgt_sents = [x[1] for x in instances]
        tgt_ids = []

        for sent in tgt_sents:
            sent = [self.tgt_token_to_id[x] for x in sent]
            tgt_ids.append(sent)

        tgt_ids = map(list, zip(*tgt_ids))
        padded_src_rev = list(reversed(padded_src))
        src_cws_rev = np.transpose(padded_src_rev)
        # Bidirectional representations
        l2r_state = self.l2r_builder.initial_state()
        r2l_state = self.r2l_builder.initial_state()
        l2r_contexts = []
        r2l_contexts = []
        for (cws_l2r, cws_r2l) in zip(src_cws, src_cws_rev):
            l2r_state = l2r_state.add_input(dy.lookup_batch(self.src_lookup, cws_l2r))
            r2l_state = r2l_state.add_input(dy.lookup_batch(self.src_lookup, cws_r2l))
            l2r_contexts.append(l2r_state.output()) #[<S>, x_1, x_2, ..., </S>]
            r2l_contexts.append(r2l_state.output()) #[</S> x_n, x_{n-1}, ... <S>]
        r2l_contexts.reverse() #[<S>, x_1, x_2, ..., </S>]
        # Combine the left and right representations for every word
        h_fs = []
        for (l2r_i, r2l_i) in zip(l2r_contexts, r2l_contexts):
            h_fs.append(dy.concatenate([l2r_i, r2l_i]))
        h_fs_matrix = dy.concatenate_cols(h_fs)
        losses = []
        num_words = 0

        # Decoder
        c_t = dy.vecInput(self.hidden_size * 2)
        start = dy.concatenate([dy.lookup_batch(self.tgt_lookup, len(tgt_sents) * [self.tgt_token_to_id['<S>']]), c_t])
        dec_state = self.dec_builder.initial_state().add_input(start)

        for (cws, nws, mask) in zip(tgt_ids, tgt_ids[1:], masks_src):
            h_e = dec_state.output()
            c_t = self.__attention_mlp(h_fs_matrix, h_e)
            # Get the embedding for the current target word
            embed_t = dy.lookup_batch(self.tgt_lookup, cws)
            # Create input vector to the decoder
            x_t = dy.concatenate([embed_t, c_t])
            dec_state = dec_state.add_input(x_t)
            y_star = b_y + W_y * dec_state.output()
            loss = dy.pickneglogsoftmax_batch(y_star, nws)
            if mask[0] == 0:
                mask_loss = dy.reshape(dy.inputVector(mask), (1,), self.BATCH_SIZE)
                loss = loss * mask_loss
            losses.append(loss)
            num_words += 1
        return dy.sum_batches(dy.esum(losses)/num_words), num_words
def calc_loss(sents):
    dy.renew_cg()

    # Transduce all batch elements with an LSTM
    src_sents = [x[0] for x in sents]
    tgt_sents = [x[1] for x in sents]
    src_cws = []

    src_len = [len(sent) for sent in src_sents]
    max_src_len = np.max(src_len)
    num_words = 0

    for i in range(max_src_len):
        src_cws.append([sent[i] for sent in src_sents])

    #initialize the LSTM
    init_state_src = LSTM_SRC_BUILDER.initial_state()

    #get the output of the first LSTM
    src_output = init_state_src.add_inputs(
        [dy.lookup_batch(LOOKUP_SRC, cws) for cws in src_cws])[-1].output()
    #now decode
    all_losses = []

    # Decoder
    #need to mask padding at end of sentence
    tgt_cws = []
    tgt_len = [len(sent) for sent in sents]
    max_tgt_len = np.max(tgt_len)
    masks = []

    for i in range(max_tgt_len):
        tgt_cws.append(
            [sent[i] if len(sent) > i else eos_trg for sent in tgt_sents])
        mask = [(1 if len(sent) > i else 0) for sent in tgt_sents]
        masks.append(mask)
        num_words += sum(mask)

    current_state = LSTM_TRG_BUILDER.initial_state().set_s(
        [src_output, dy.tanh(src_output)])
    prev_words = tgt_cws[0]
    W_sm = dy.parameter(W_sm_p)
    b_sm = dy.parameter(b_sm_p)

    for next_words, mask in zip(tgt_cws[1:], masks):
        #feed the current state into the
        current_state = current_state.add_input(
            dy.lookup_batch(LOOKUP_TRG, prev_words))
        output_embedding = current_state.output()

        s = dy.affine_transform([b_sm, W_sm, output_embedding])
        loss = (dy.pickneglogsoftmax_batch(s, next_words))
        mask_expr = dy.inputVector(mask)
        mask_expr = dy.reshape(mask_expr, (1, ), len(sents))
        mask_loss = loss * mask_expr
        all_losses.append(mask_loss)
        prev_words = next_words
    return dy.sum_batches(dy.esum(all_losses)), num_words
def calc_loss(sents):
    dy.renew_cg()

    # Transduce all batch elements with an LSTM
    src_sents = [x[0] for x in sents]
    tgt_sents = [x[1] for x in sents]
    src_cws = []

    src_len = [len(sent) for sent in src_sents]        
    max_src_len = np.max(src_len)
    num_words = 0

    for i in range(max_src_len):
        src_cws.append([sent[i] for sent in src_sents])


    #initialize the LSTM
    init_state_src = LSTM_SRC_BUILDER.initial_state()

    #get the output of the first LSTM
    src_output = init_state_src.add_inputs([dy.lookup_batch(LOOKUP_SRC, cws) for cws in src_cws])[-1].output()
    #now decode
    all_losses = []

    # Decoder
    #need to mask padding at end of sentence
    tgt_cws = []
    tgt_len = [len(sent) for sent in sents]
    max_tgt_len = np.max(tgt_len)
    masks = []

    for i in range(max_tgt_len):
        tgt_cws.append([sent[i] if len(sent) > i else eos_trg for sent in tgt_sents])
        mask = [(1 if len(sent) > i else 0) for sent in tgt_sents]
        masks.append(mask)
        num_words += sum(mask)



    current_state = LSTM_TRG_BUILDER.initial_state().set_s([src_output, dy.tanh(src_output)])
    prev_words = tgt_cws[0]
    W_sm = dy.parameter(W_sm_p)
    b_sm = dy.parameter(b_sm_p)

    for next_words, mask in zip(tgt_cws[1:], masks):
        #feed the current state into the 
        current_state = current_state.add_input(dy.lookup_batch(LOOKUP_TRG, prev_words))
        output_embedding = current_state.output()

        s = dy.affine_transform([b_sm, W_sm, output_embedding])
        loss = (dy.pickneglogsoftmax_batch(s, next_words))
        mask_expr = dy.inputVector(mask)
        mask_expr = dy.reshape(mask_expr, (1,),len(sents))
        mask_loss = loss * mask_expr
        all_losses.append(mask_loss)
        prev_words = next_words
    return dy.sum_batches(dy.esum(all_losses)), num_words
    def compute_loss(self, words, extwords, tags, true_arcs, true_labels):
        arc_logits, rel_logits = self.forward(words, extwords, tags, True)
        seq_len = len(true_arcs)
        targets_1D = dynet_flatten_numpy(true_arcs)
        flat_arc_logits = dy.reshape(arc_logits, (seq_len, ), seq_len)
        losses = dy.pickneglogsoftmax_batch(flat_arc_logits, targets_1D)
        arc_loss = dy.sum_batches(losses)

        flat_rel_logits = dy.reshape(rel_logits, (seq_len, self.rel_size),
                                     seq_len)
        partial_rel_logits = dy.pick_batch(flat_rel_logits, targets_1D)
        targets_rel1D = dynet_flatten_numpy(true_labels)
        losses = dy.pickneglogsoftmax_batch(partial_rel_logits, targets_rel1D)
        rel_loss = dy.sum_batches(losses)

        loss = arc_loss + rel_loss

        return loss
Exemple #26
0
    def calc_loss(self, src, db_idx):
        src_embeddings = self.src_embedder.embed_sent(src)
        src_encodings = self.exprseq_pooling(
            self.src_encoder.transduce(src_embeddings))
        trg_encodings = self.encode_trg_example(self.database[db_idx])

        prod = dy.transpose(dy.transpose(src_encodings) * trg_encodings)
        loss = dy.sum_batches(
            dy.hinge_batch(prod, list(six.moves.range(len(db_idx)))))
        print(loss.npvalue())
        return loss
def calc_sent_loss(sent):
  # Create a computation graph
  dy.renew_cg()
  
  # Get embeddings for the sentence
  emb = [W_w_p[x] for x in sent]

  # Sample K negative words for each predicted word at each position
  all_neg_words = np.random.choice(nwords, size=2*N*K*len(emb), replace=True, p=word_probabilities)

  # W_w = dy.parameter(W_w_p)
  # Step through the sentence and calculate the negative and positive losses
  all_losses = [] 
  for i, my_emb in enumerate(emb):
    neg_words = all_neg_words[i*K*2*N:(i+1)*K*2*N]
    pos_words = ([sent[x] if x >= 0 else S for x in range(i-N,i)] +
                 [sent[x] if x < len(sent) else S for x in range(i+1,i+N+1)])
    neg_loss = -dy.log(dy.logistic(-dy.dot_product(my_emb, dy.lookup_batch(W_c_p, neg_words))))
    pos_loss = -dy.log(dy.logistic(dy.dot_product(my_emb, dy.lookup_batch(W_c_p, pos_words))))
    all_losses.append(dy.sum_batches(neg_loss) + dy.sum_batches(pos_loss))
  return dy.esum(all_losses)
Exemple #28
0
    def forward(self, inputs, expected_output, droput1=0.0, droput2=0.0):
        out = self(inputs, droput1, droput2)
        expected_output = np.array([self._T2I[exp] for exp in expected_output])
        loss = dy.sum_batches(dy.pickneglogsoftmax_batch(out, expected_output))
        predictions_probs = out.npvalue()

        if len(predictions_probs.shape) == 2:
            predictions = np.argmax(predictions_probs.T, axis=1)
        else:
            predictions = np.array([np.argmax(predictions_probs)])
        predictions = np.array(
            [self._I2T[prediction] for prediction in predictions])
        return loss, predictions
Exemple #29
0
 def test_inputTensor_batched_list(self):
     for i in range(4):
         dy.renew_cg()
         input_tensor = self.input_vals.reshape(self.shapes[i])
         xb = dy.inputTensor([np.asarray(x).transpose() for x in input_tensor.transpose()])
         self.assertEqual(xb.dim()[0], (self.shapes[i][:-1] if i > 0 else (1,)),
                          msg="Dimension mismatch")
         self.assertEqual(xb.dim()[1], self.shapes[i][-1],
                          msg="Dimension mismatch")
         self.assertTrue(np.allclose(xb.npvalue(), input_tensor),
                         msg="Expression value different from initial value")
         self.assertEqual(dy.sum_batches(dy.squared_norm(xb)).scalar_value(),
                          self.squared_norm, msg="Value mismatch")
Exemple #30
0
    def fit_partial(self, instances):
        random.shuffle(instances)
        self.iter += 1

        losses = []
        dy.renew_cg()

        total_loss, total_size = 0., 0
        prog = tqdm(desc="Epoch {}".format(self.iter), ncols=80, total=len(instances) + 1)
        for i, ins in enumerate(instances, 1):
            losses.extend(list(self.model.loss(*ins)))
            if i % self.batch_size == 0:
                loss = dy.sum_batches(dy.concatenate_to_batch(losses))
                total_loss += loss.value()
                total_size += len(losses)
                prog.set_postfix(loss=loss.value()/len(losses))

                loss.backward()
                self.opt.update()
                dy.renew_cg()
                losses = []

            prog.update()

        if losses:
            loss = dy.sum_batches(dy.concatenate_to_batch(losses))
            total_loss += loss.value()
            total_size += len(losses)
            self.loss = total_loss / total_size
            prog.set_postfix(loss=self.loss)

            loss.backward()
            self.opt.update()
            dy.renew_cg()

            prog.update()

        self.opt.learning_rate *= self.lr_decay
        prog.close()
Exemple #31
0
def sum(x, dim=None, include_batch_dim=False):
    if isinstance(x, list):
        return dy.esum(x)
    head_shape, batch_size = x.dim()
    if dim is None:
        x =  dy.sum_elems(x)
        if include_batch_dim and batch_size > 1:
            return dy.sum_batches(x)
        else:
            return x
    else:
        if dim == -1:
            dim = len(head_shape) - 1
        return dy.sum_dim(x, d=[dim], b=include_batch_dim)
Exemple #32
0
def calc_sent_loss(sent, dropout=0.0):
  # Create a computation graph
  dy.renew_cg()
  # The initial history is equal to end of sentence symbols
  hist = [S] * N
  # Step through the sentence, including the end of sentence token
  all_histories = []
  all_targets = []
  for next_word in sent + [S]:
    all_histories.append(list(hist))
    all_targets.append(next_word)
    hist = hist[1:] + [next_word]
  s = calc_score_of_histories(all_histories, dropout=dropout)
  return dy.sum_batches(dy.pickneglogsoftmax_batch(s, all_targets))
Exemple #33
0
def calc_sent_loss(sent, dropout=0.0):
  # Create a computation graph
  dy.renew_cg()
  # The initial history is equal to end of sentence symbols
  hist = [S] * N
  # Step through the sentence, including the end of sentence token
  all_histories = []
  all_targets = []
  for next_word in sent + [S]:
    all_histories.append(list(hist))
    all_targets.append(next_word)
    hist = hist[1:] + [next_word]
  s = calc_score_of_histories(all_histories, dropout=dropout)
  return dy.sum_batches(dy.pickneglogsoftmax_batch(s, all_targets))
Exemple #34
0
    def calc_nll(self, src: Union[batchers.Batch, sent.Sentence],
                 trg: Union[batchers.Batch, sent.Sentence]) -> dy.Expression:
        event_trigger.start_sent(src)
        if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
        # Encode the sentence
        initial_state = self._encode_src(src)

        dec_state = initial_state
        trg_mask = trg.mask if batchers.is_batched(trg) else None
        cur_losses = []
        seq_len = trg.sent_len()

        if settings.CHECK_VALIDITY and batchers.is_batched(src):
            for j, single_trg in enumerate(trg):
                assert single_trg.sent_len(
                ) == seq_len  # assert consistent length
                assert 1 == len([
                    i for i in range(seq_len)
                    if (trg_mask is None or trg_mask.np_arr[j, i] == 0)
                    and single_trg[i] == vocabs.Vocab.ES
                ])  # assert exactly one unmasked ES token

        input_word = None
        for i in range(seq_len):
            ref_word = DefaultTranslator._select_ref_words(
                trg, i, truncate_masked=self.truncate_dec_batches)
            if self.truncate_dec_batches and batchers.is_batched(ref_word):
                dec_state.rnn_state, ref_word = batchers.truncate_batches(
                    dec_state.rnn_state, ref_word)

            if input_word is not None:
                dec_state = self.decoder.add_input(
                    dec_state, self.trg_embedder.embed(input_word))
            rnn_output = dec_state.rnn_state.output()
            dec_state.context = self.attender.calc_context(rnn_output)
            word_loss = self.decoder.calc_loss(dec_state, ref_word)

            if not self.truncate_dec_batches and batchers.is_batched(
                    src) and trg_mask is not None:
                word_loss = trg_mask.cmult_by_timestep_expr(word_loss,
                                                            i,
                                                            inverse=True)
            cur_losses.append(word_loss)
            input_word = ref_word

        if self.truncate_dec_batches:
            loss_expr = dy.esum([dy.sum_batches(wl) for wl in cur_losses])
        else:
            loss_expr = dy.esum(cur_losses)
        return loss_expr
Exemple #35
0
 def calc_baseline_loss(self, reward, only_final_reward):
   pred_rewards = []
   cur_losses = []
   for i, state in enumerate(self.states):
     pred_reward = self.baseline.transform(dy.nobackprop(state))
     pred_rewards.append(dy.nobackprop(pred_reward))
     seq_reward = reward if only_final_reward else reward[i]
     if self.valid_pos is not None:
       pred_reward = dy.pick_batch_elems(pred_reward, self.valid_pos[i])
       act_reward = dy.pick_batch_elems(seq_reward, self.valid_pos[i])
     else:
       act_reward = seq_reward
     cur_losses.append(dy.sum_batches(dy.squared_distance(pred_reward, dy.nobackprop(act_reward))))
   return pred_rewards, dy.esum(cur_losses)
Exemple #36
0
    def learn(self, batch_size):
        if self.prioritized:
            if not self.memory.is_full(): return -np.inf
            indices, exps, weights = self.memory.sample(batch_size, self.beta)
        else:
            exps = self.memory.sample(batch_size)
        obss, actions, rewards, obs_nexts, dones = self._process(exps)

        dy.renew_cg()
        target_network = self.target_network if self.use_double_dqn else self.network
        if self.dueling:
            target_values, v = target_network(obs_nexts, batched=True)
            target_values = target_values.npvalue() + v.npvalue()
        else:
            target_values = target_network(obs_nexts, batched=True)
            target_values = target_values.npvalue()
        target_values = np.max(target_values, axis=0)
        target_values = rewards + self.reward_decay * (target_values * (1 - dones))

        dy.renew_cg()
        if self.dueling:
            all_values_expr, v = self.network(obss, batched=True)
        else:
            all_values_expr = self.network(obss, batched=True)
        picked_values = dy.pick_batch(all_values_expr, actions)
        diff = (picked_values + v if self.dueling else picked_values) - dy.inputTensor(target_values, batched=True)
        if self.prioritized:
            self.memory.update(indices, np.transpose(np.abs(diff.npvalue())))
        losses = dy.pow(diff, dy.constant(1, 2))
        if self.prioritized:
            losses = dy.cmult(losses, dy.inputTensor(weights, batched=True))
        loss = dy.sum_batches(losses)
        loss_value = loss.npvalue()
        loss.backward()
        self.trainer.update()

        self.epsilon = max(self.epsilon - self.epsilon_decrease, self.epsilon_lower)
        if self.prioritized:
            self.beta = min(self.beta + self.beta_increase, 1.)

        self.learn_step += 1
        if self.use_double_dqn and self.learn_step % self.n_replace_target == 0:
            self.target_network.update(self.network)
        return loss_value
Exemple #37
0
    def BuildLMGraph(self, sents):
        dy.renew_cg()
        # initialize the RNN
        init_state = self.builder.initial_state()
        # parameters -> expressions
        R = dy.parameter(self.R)
        bias = dy.parameter(self.bias)

        S = vocab.w2i["<s>"]
        # get the cids and masks for each step
        tot_chars = 0
        cids = []
        masks = []

        for i in range(len(sents[0])):
            cids.append([(vocab.w2i[sent[i]] if len(sent) > i else S) for sent in sents])
            mask = [(1 if len(sent)>i else 0) for sent in sents]
            masks.append(mask)
            tot_chars += sum(mask)

        # start the rnn with "<s>"
        init_ids = cids[0]
        s = init_state.add_input(lookup_batch(self.lookup, init_ids))

        losses = []

        # feed char vectors into the RNN and predict the next char
        for cid, mask in zip(cids[1:], masks[1:]):
            score = dy.affine_transform([bias, R, s.output()])
            loss = dy.pickneglogsoftmax_batch(score, cid)
            # mask the loss if at least one sentence is shorter
            if mask[-1] != 1:
                mask_expr = dy.inputVector(mask)
                mask_expr = dy.reshape(mask_expr, (1,), len(sents))
                loss = loss * mask_expr

            losses.append(loss)
            # update the state of the RNN
            cemb = dy.lookup_batch(self.lookup, cid)
            s = s.add_input(cemb)

        return dy.sum_batches(dy.esum(losses)), tot_chars
Exemple #38
0
def calc_lm_loss(sents):
    dy.renew_cg()

    # initialize the RNN
    f_init = RNN.initial_state()

    # get the wids and masks for each step
    tot_words = 0
    wids = []
    masks = []
    for i in range(len(sents[0])):
        wids.append([(sent[i] if len(sent) > i else S) for sent in sents])
        mask = [(1 if len(sent) > i else 0) for sent in sents]
        masks.append(mask)
        tot_words += sum(mask)

    # start the rnn by inputting "<s>"
    init_ids = [S] * len(sents)
    s = f_init.add_input(dy.lookup_batch(WORDS_LOOKUP, init_ids))

    # feed word vectors into the RNN and predict the next word
    losses = []
    for wid, mask in zip(wids, masks):
        # calculate the softmax and loss
        score = dy.affine_transform([b_exp, W_exp, s.output()])
        loss = dy.pickneglogsoftmax_batch(score, wid)
        # mask the loss if at least one sentence is shorter
        if mask[-1] != 1:
            mask_expr = dy.inputVector(mask)
            mask_expr = dy.reshape(mask_expr, (1,), len(sents))
            loss = loss * mask_expr
        losses.append(loss)
        # update the state of the RNN
        wemb = dy.lookup_batch(WORDS_LOOKUP, wid)
        s = s.add_input(wemb)

    return dy.sum_batches(dy.esum(losses)), tot_words
def calc_loss(sents):
    dy.renew_cg()

    # Transduce all batch elements with an LSTM
    src_sents = [x[0] for x in sents]
    tgt_sents = [x[1] for x in sents]
    src_cws = []

    src_len = [len(sent) for sent in src_sents]        
    max_src_len = np.max(src_len)
    num_words = 0

    for i in range(max_src_len):
        src_cws.append([sent[i] for sent in src_sents])


    #get the outputs of the first LSTM
    src_outputs = [dy.concatenate([x.output(), y.output()]) for x,y in LSTM_SRC.add_inputs([dy.lookup_batch(LOOKUP_SRC, cws) for cws in src_cws])]
    src_output = src_outputs[-1]

    #gets the parameters for the attention
    src_output_matrix = dy.concatenate_cols(src_outputs)
    w1_att_src = dy.parameter(w1_att_src_p)
    fixed_attentional_component = w1_att_src * src_output_matrix

    #now decode
    all_losses = []

    # Decoder
    #need to mask padding at end of sentence
    tgt_cws = []
    tgt_len = [len(sent) for sent in sents]
    max_tgt_len = np.max(tgt_len)
    masks = []

    for i in range(max_tgt_len):
        tgt_cws.append([sent[i] if len(sent) > i else eos_trg for sent in tgt_sents])
        mask = [(1 if len(sent) > i else 0) for sent in tgt_sents]
        masks.append(mask)
        num_words += sum(mask)



    current_state = LSTM_TRG_BUILDER.initial_state().set_s([src_output, dy.tanh(src_output)])
    prev_words = tgt_cws[0]
    W_sm = dy.parameter(W_sm_p)
    b_sm = dy.parameter(b_sm_p)

    W_m = dy.parameter(W_m_p)
    b_m = dy.parameter(b_m_p)

    for next_words, mask in zip(tgt_cws[1:], masks):
        #feed the current state into the 
        current_state = current_state.add_input(dy.lookup_batch(LOOKUP_TRG, prev_words))
        output_embedding = current_state.output()
        att_output, _ = calc_attention(src_output_matrix, output_embedding, fixed_attentional_component)
        middle_expr = dy.tanh(dy.affine_transform([b_m, W_m, dy.concatenate([output_embedding, att_output])]))
        s = dy.affine_transform([b_sm, W_sm, middle_expr])
        loss = (dy.pickneglogsoftmax_batch(s, next_words))
        mask_expr = dy.inputVector(mask)
        mask_expr = dy.reshape(mask_expr, (1,),len(sents))
        mask_loss = loss * mask_expr
        all_losses.append(mask_loss)
        prev_words = next_words
    return dy.sum_batches(dy.esum(all_losses)), num_words