def get_diag_loss(self, a_t, t):
     if self.diag_loss < 0:
         return dy.scalarInput(0)
     off_diag_elems = [dy.scalarInput(0)]
     for i, prob in enumerate(a_t):
         if i < (t - self.diag_loss) or i > (t + self.diag_loss):
             off_diag_elems.append(prob)
     return dy.esum(off_diag_elems)
示例#2
0
 def score_sentence(self, score_vecs, tags):
     assert(len(score_vecs)==len(tags))
     tags.insert(0, START_TAG) # add start
     total = dynet.scalarInput(.0)
     for i, obs in enumerate(score_vecs):
         # transition to next from i and emission
         next_tag = tags[i + 1]
         total += dynet.pick(self.trans_mat[next_tag],tags[i]) + dynet.pick(obs,next_tag)
     total += dynet.pick(self.trans_mat[END_TAG],tags[-1])
     return total
示例#3
0
文件: mnnl.py 项目: bplank/bilstm-aux
 def score_sentence(self, score_vecs, tags):
     assert(len(score_vecs)==len(tags))
     tags.insert(0, START_TAG) # add start
     total = dynet.scalarInput(.0)
     for i, obs in enumerate(score_vecs):
         # transition to next from i and emission
         next_tag = tags[i + 1]
         total += dynet.pick(self.trans_mat[next_tag],tags[i]) + dynet.pick(obs,next_tag)
     total += dynet.pick(self.trans_mat[END_TAG],tags[-1])
     return total
 def get_coverage(self, a_t, prev_coverage, training=True):
     if not self.coverage:
         if not training:
             return None
         return dy.scalarInput(0), None
     coverage = a_t + prev_coverage
     if training:
         return (
             dy.sum_elems(dy.min_dim(dy.concatenate([a_t, coverage], d=1), d=1)),
             coverage,
         )
     return coverage
示例#5
0
    def __call__(self, x, soft_labels=False, temperature=None):
        if self.mlp:
            W_mlp = dynet.parameter(self.W_mlp)
            b_mlp = dynet.parameter(self.b_mlp)
            act = self.mlp_activation
            x_in = act(W_mlp * x + b_mlp)
        else:
            x_in = x
        # from params to expressions
        W = dynet.parameter(self.W)
        b = dynet.parameter(self.b)

        logits = (W * x_in + b) + dynet.scalarInput(1e-15)
        if soft_labels and temperature:
            # calculate the soft labels smoothed with the temperature
            # see Distilling the Knowledge in a Neural Network
            elems = dynet.exp(logits / temperature)
            return dynet.cdiv(elems, dynet.sum_elems(elems))
        return self.act(logits)
    def calculate_loss(self, sents):
        dy.renew_cg()
        losses = []
        for sent in sents:
            features, t_features, feat_reconstruct = self.get_features_for_tagging(
                sent, True
            )
            gold_tags = [tag for chars, word, feats, tag in sent]
            cur_loss = self.crf_module.negative_log_loss(
                features, t_features, gold_tags
            )
            if self.autoencoder:
                autoencoder_loss = [
                    dy.binary_log_loss(reconstruct, dy.inputTensor(feats))
                    for reconstruct, (chars, word, feats, tag) in zip(
                        feat_reconstruct, sent
                    )
                ]
            else:  # remove autoencoder loss
                autoencoder_loss = [dy.scalarInput(0)]
            losses.append(cur_loss + (dy.esum(autoencoder_loss) / self.featsize))

        return dy.esum(losses)
示例#7
0
    def fit(self,
            train_X,
            train_Y,
            num_epochs,
            val_X=None,
            val_Y=None,
            patience=2,
            model_path=None,
            seed=None,
            word_dropout_rate=0.25,
            trg_vectors=None,
            unsup_weight=1.0,
            labeled_weight_proportion=1.0):
        """
        train the model
        :param trg_vectors: the prediction targets used for the unsupervised loss
                            in temporal ensembling
        :param unsup_weight: weight for the unsupervised consistency loss
                                    used in temporal ensembling
        """
        if seed:
            print(">>> using seed: ", seed, file=sys.stderr)
            random.seed(seed)  #setting random seed

        assert(train_X.shape[0] == len(train_Y)), \
            '# examples %d != # labels %d.' % (train_X.shape[0], len(train_Y))
        train_data = list(zip(train_X, train_Y))

        print('Starting training for %d epochs...' % num_epochs)
        best_val_f1, epochs_no_improvement = 0., 0
        if val_X is not None and val_Y is not None and model_path is not None:
            print('Using early stopping with patience of %d...' % patience)
        for cur_iter in range(num_epochs):
            bar = Bar('Training epoch %d/%d...' % (cur_iter + 1, num_epochs),
                      max=len(train_data),
                      flush=True)
            total_loss = 0.0

            random_indices = np.arange(len(train_data))
            random.shuffle(random_indices)

            for i, idx in enumerate(random_indices):

                x, y = train_data[idx]
                output = self.predict(x,
                                      train=True,
                                      dropout_rate=word_dropout_rate)
                # in temporal ensembling, we assign a dummy label of -1 for
                # unlabeled sequences; we skip the supervised loss for these
                loss = dynet.scalarInput(0) if y == -1 else self.pick_neg_log(
                    output, y)

                if trg_vectors is not None:
                    # the consistency loss in temporal ensembling is used for
                    # both supervised and unsupervised input
                    target = trg_vectors[idx]

                    other_loss = dynet.squared_distance(
                        output, dynet.inputVector(target))

                    if y != -1:
                        other_loss *= labeled_weight_proportion
                    loss += other_loss * unsup_weight
                total_loss += loss.value()

                loss.backward()
                self.trainer.update()
                bar.next()

            print(" iter {2} {0:>12}: {1:.2f}".format(
                "total loss", total_loss / len(train_data), cur_iter),
                  file=sys.stderr)

            if val_X is not None and val_Y is not None and model_path is not None:
                # get the best F1 score on the validation set
                val_f1 = self.evaluate(val_X, val_Y)

                if val_f1 > best_val_f1:
                    print('F1 %.4f is better than best val F1 %.4f.' %
                          (val_f1, best_val_f1))
                    best_val_f1 = val_f1
                    epochs_no_improvement = 0
                    save_model(self, model_path)
                else:
                    print('F1 %.4f is worse than best val F1 %.4f.' %
                          (val_f1, best_val_f1))
                    epochs_no_improvement += 1
                if epochs_no_improvement == patience:
                    print('No improvement for %d epochs. Early stopping...' %
                          epochs_no_improvement)
                    break
示例#8
0
    def fit(self,
            train_dict,
            num_epochs,
            val_X=None,
            val_Y=None,
            patience=2,
            model_path=None,
            seed=None,
            word_dropout_rate=0.25,
            trg_vectors=None,
            unsup_weight=1.0,
            orthogonality_weight=0.0,
            adversarial=False):
        """
        train the model
        :param trg_vectors: the prediction targets used for the unsupervised loss
                            in temporal ensembling
        :param unsup_weight: weight for the unsupervised consistency loss
                                    used in temporal ensembling
        """
        if seed:
            print(">>> using seed: ", seed, file=sys.stderr)
            random.seed(seed)  #setting random seed

        train_data = []
        for task, task_dict in train_dict.items():
            for key in ["X", "Y", "domain"]:
                assert key in task_dict, "Error: %s is not available." % key
            examples, labels, domain_tags = task_dict["X"], task_dict["Y"], \
                                            task_dict["domain"]
            assert examples.shape[0] == len(labels)

            # train data is a list of 4-tuples: (example, label, task_id, domain_id)
            train_data += list(
                zip(examples, labels, [[task] * len(labels)][0], domain_tags))

        print('Starting training for %d epochs...' % num_epochs)
        best_val_f1, epochs_no_improvement = 0., 0

        if val_X is not None and val_Y is not None and model_path is not None:
            print('Using early stopping with patience of %d...' % patience)

        for cur_iter in range(num_epochs):
            bar = Bar('Training epoch %d/%d...' % (cur_iter + 1, num_epochs),
                      max=len(train_data),
                      flush=True)
            total_loss, total_constraint, total_adversarial = 0.0, 0.0, 0.0

            random_indices = np.arange(len(train_data))
            random.shuffle(random_indices)

            for i, idx in enumerate(random_indices):

                x, y, task_id, domain_id = train_data[idx]
                task_ids = [task_id]

                if task_id == 'src':
                    # we train both F0 and F1 on source data
                    task_ids = ['F0', 'F1']
                elif task_id == 'src_all':
                    # we train F0, F1, and Ft on source data for base training
                    task_ids = ['F0', 'F1', 'Ft']

                loss = 0
                outputs, constraint, adv = self.predict(
                    x,
                    task_ids,
                    train=True,
                    dropout_rate=word_dropout_rate,
                    orthogonality_weight=orthogonality_weight,
                    domain_id=domain_id if adversarial else None)

                # in temporal ensembling, we assign a dummy label of -1 for
                # unlabeled sequences; we skip the supervised loss for these
                for output in outputs:
                    loss += dynet.scalarInput(
                        0) if y == -1 else self.pick_neg_log(output, y)

                    if trg_vectors is not None:
                        # the consistency loss in temporal ensembling is used for
                        # both supervised and unsupervised input
                        target = trg_vectors[idx]

                        other_loss = dynet.squared_distance(
                            output, dynet.inputVector(target))
                        loss += other_loss * unsup_weight

                # the orthogonality weight is the same for every prediction,
                # so we can add it in the end
                if orthogonality_weight != 0.0:
                    # add the orthogonality constraint to the loss
                    loss += constraint * orthogonality_weight
                    total_constraint += constraint.value()
                if adversarial:
                    total_adversarial += adv.value()
                    loss += adv

                total_loss += loss.value()
                loss.backward()
                self.trainer.update()
                bar.next()

            print(
                "\niter {}. Total loss: {:.3f}, total penalty: {:.3f}, adv: {:.3f}"
                .format(cur_iter, total_loss / len(train_data),
                        total_constraint / len(train_data),
                        total_adversarial / len(train_data)),
                file=sys.stderr)

            if val_X is not None and val_Y is not None and model_path is not None:
                # get the best F1 score on the validation set
                val_f1 = self.evaluate(val_X, val_Y, 'F0')

                if val_f1 > best_val_f1:
                    print('F1 %.4f is better than best val F1 %.4f.' %
                          (val_f1, best_val_f1))
                    best_val_f1 = val_f1
                    epochs_no_improvement = 0
                    save_mttri_model(self, model_path)
                else:
                    print('F1 %.4f is worse than best val F1 %.4f.' %
                          (val_f1, best_val_f1))
                    epochs_no_improvement += 1
                if epochs_no_improvement == patience:
                    print('No improvement for %d epochs. Early stopping...' %
                          epochs_no_improvement)
                    break
示例#9
0
    def fit(self,
            train_X,
            train_Y,
            num_epochs,
            val_X=None,
            val_Y=None,
            patience=2,
            model_path=None,
            seed=None,
            word_dropout_rate=0.25,
            trg_vectors=None,
            unsup_weight=1.0,
            variance_weights=None,
            labeled_weight_proportion=1.0):
        """
        train the tagger
        :param trg_vectors: the prediction targets used for the unsupervised loss
                            in temporal ensembling
        :param unsup_weight: weight for the unsupervised consistency loss
                                    used in temporal ensembling
        :param clip_threshold: use gradient clipping with threshold (on if >0; default: 5.0)
        :param labeled_weight_proportion: proportion of the unsupervised weight
                                          that should be assigned to labeled
                                          examples
        """
        print("read training data", file=sys.stderr)

        if variance_weights is not None:
            print('First 20 variance weights:', variance_weights[:20])

        if seed:
            print(">>> using seed: ", seed, file=sys.stderr)
            random.seed(seed)  #setting random seed

        # if we use word dropout keep track of counts
        if word_dropout_rate > 0.0:
            widCount = Counter()
            for sentence, _ in train_X:
                widCount.update([w for w in sentence])

        assert (len(train_X) == len(train_Y))
        train_data = list(zip(train_X, train_Y))

        # if we use target vectors, keep track of the targets per sentence
        if trg_vectors is not None:
            trg_start_id = 0
            sentence_trg_vectors = []
            sentence_var_weights = []
            for i, (example, y) in enumerate(train_data):
                sentence_trg_vectors.append(
                    trg_vectors[trg_start_id:trg_start_id +
                                len(example[0]), :])
                if variance_weights is not None:
                    sentence_var_weights.append(
                        variance_weights[trg_start_id:trg_start_id +
                                         len(example[0])])
                trg_start_id += len(example[0])
            assert trg_start_id == len(trg_vectors),\
                'Error: Idx {} is not at {}.'.format(trg_start_id, len(trg_vectors))
            assert len(sentence_trg_vectors) == len(train_X)
            if variance_weights is not None:
                assert trg_start_id == len(variance_weights)
                assert len(sentence_var_weights) == len(train_X)

        print('Starting training for {} epochs...'.format(num_epochs))
        best_val_acc, epochs_no_improvement = 0., 0
        if val_X is not None and val_Y is not None and model_path is not None:
            print(
                'Using early stopping with patience of {}...'.format(patience))

        for cur_iter in range(num_epochs):
            bar = Bar('Training epoch {}/{}...'.format(cur_iter + 1,
                                                       num_epochs),
                      max=len(train_data),
                      flush=True)
            total_loss = 0.0
            total_tagged = 0.0

            total_other_loss, total_other_loss_weighted = 0.0, 0.0

            random_indices = np.arange(len(train_data))
            random.shuffle(random_indices)

            for i, idx in enumerate(random_indices):
                (word_indices, char_indices), y = train_data[idx]

                if word_dropout_rate > 0.0:
                    word_indices = [
                        self.w2i["_UNK"] if
                        (random.random() >
                         (widCount.get(w) /
                          (word_dropout_rate + widCount.get(w)))) else w
                        for w in word_indices
                    ]
                output = self.predict(word_indices, char_indices, train=True)

                if len(y) == 1 and y[0] == 0:
                    # in temporal ensembling, we assign a dummy label of [0] for
                    # unlabeled sequences; we skip the supervised loss for these
                    loss = dynet.scalarInput(0)
                else:
                    loss = dynet.esum([
                        self.pick_neg_log(pred, gold)
                        for pred, gold in zip(output, y)
                    ])

                if trg_vectors is not None:
                    # the consistency loss in temporal ensembling is used for
                    # both supervised and unsupervised input
                    targets = sentence_trg_vectors[idx]
                    assert len(output) == len(targets)
                    if variance_weights is not None:
                        var_weights = sentence_var_weights[idx]
                        assert len(output) == len(var_weights)
                        # multiply the normalized mean variance with each loss
                        other_loss = dynet.esum([
                            v * dynet.squared_distance(o, dynet.inputVector(t))
                            for o, t, v in zip(output, targets, var_weights)
                        ])
                    else:
                        other_loss = dynet.esum([
                            dynet.squared_distance(o, dynet.inputVector(t))
                            for o, t in zip(output, targets)
                        ])

                    total_other_loss += other_loss.value()
                    if len(y) == 1 and y[0] == 0:  #unlab_ex
                        other_loss += other_loss * unsup_weight
                    else:  #lab_ex
                        # assign the unsupervised weight for labeled examples
                        other_loss += other_loss * unsup_weight * labeled_weight_proportion
                    # keep track for logging
                    total_loss += loss.value()  # main loss
                    total_tagged += len(word_indices)
                    total_other_loss_weighted += other_loss.value()

                    # combine losses
                    loss += other_loss

                else:
                    # keep track for logging
                    total_loss += loss.value()
                    total_tagged += len(word_indices)

                loss.backward()
                self.trainer.update()
                bar.next()

            if trg_vectors is None:
                print("iter {2} {0:>12}: {1:.2f}".format(
                    "total loss", total_loss / total_tagged, cur_iter),
                      file=sys.stderr)
            else:
                print(
                    "iter {2} {0:>12}: {1:.2f} unsupervised loss: {3:.2f} (weighted: {4:.2f})"
                    .format("supervised loss", total_loss / total_tagged,
                            cur_iter, total_other_loss / total_tagged,
                            total_other_loss_weighted / total_tagged),
                    file=sys.stderr)

            if val_X is not None and val_Y is not None and model_path is not None:
                # get the best accuracy on the validation set
                val_correct, val_total = self.evaluate(val_X, val_Y)
                val_accuracy = val_correct / val_total

                if val_accuracy > best_val_acc:
                    print(
                        'Accuracy {:.4f} is better than best val accuracy {:.4f}'
                        .format(val_accuracy, best_val_acc))
                    best_val_acc = val_accuracy
                    epochs_no_improvement = 0
                    save_tagger(self, model_path)
                else:
                    print(
                        'Accuracy {:.4f} is worse than best val loss {:.4f}.'.
                        format(val_accuracy, best_val_acc))
                    epochs_no_improvement += 1
                if epochs_no_improvement == patience:
                    print('No improvement for {} epochs. Early stopping...'.
                          format(epochs_no_improvement))
                    break
示例#10
0
    def compute_decoder_batch_loss(self, encoded_inputs, input_masks,
                                   output_word_ids, output_masks, batch_size):
        self.readout = dn.parameter(self.params['readout'])
        self.bias = dn.parameter(self.params['bias'])
        self.w_c = dn.parameter(self.params['w_c'])
        self.u_a = dn.parameter(self.params['u_a'])
        self.v_a = dn.parameter(self.params['v_a'])
        self.w_a = dn.parameter(self.params['w_a'])

        # initialize the decoder rnn
        s_0 = self.decoder_rnn.initial_state()

        # initial "input feeding" vectors to feed decoder - 3*h
        init_input_feeding = dn.lookup_batch(self.init_lookup,
                                             [0] * batch_size)

        # initial feedback embeddings for the decoder, use begin seq symbol embedding
        init_feedback = dn.lookup_batch(
            self.output_lookup, [self.y2int[common.BEGIN_SEQ]] * batch_size)

        # init decoder rnn
        decoder_init = dn.concatenate([init_feedback, init_input_feeding])
        s = s_0.add_input(decoder_init)

        # loss per timestep
        losses = []

        # run the decoder through the output sequences and aggregate loss
        for i, step_word_ids in enumerate(output_word_ids):

            # returns h x batch size matrix
            decoder_rnn_output = s.output()

            # compute attention context vector for each sequence in the batch (returns 2h x batch size matrix)
            attention_output_vector, alphas = self.attend(
                encoded_inputs, decoder_rnn_output, input_masks)

            # compute output scores (returns vocab_size x batch size matrix)
            # h = readout * attention_output_vector + bias
            h = dn.affine_transform(
                [self.bias, self.readout, attention_output_vector])

            # encourage diversity by punishing highly confident predictions
            # TODO: support batching - esp. w.r.t. scalar inputs
            if self.diverse:
                soft = dn.softmax(dn.tanh(h))
                batch_loss = dn.pick_batch(-dn.log(soft), step_word_ids) \
                    - dn.log(dn.scalarInput(1) - dn.pick_batch(soft, step_word_ids)) - dn.log(dn.scalarInput(4))
            else:
                # get batch loss for this timestep
                batch_loss = dn.pickneglogsoftmax_batch(h, step_word_ids)

            # mask the loss if at least one sentence is shorter
            if output_masks and output_masks[i][-1] != 1:
                mask_expr = dn.inputVector(output_masks[i])
                # noinspection PyArgumentList
                mask_expr = dn.reshape(mask_expr, (1, ), batch_size)
                batch_loss = batch_loss * mask_expr

            # input feeding approach - input h (attention_output_vector) to the decoder
            # prepare for the next iteration - "feedback"
            feedback_embeddings = dn.lookup_batch(self.output_lookup,
                                                  step_word_ids)
            decoder_input = dn.concatenate(
                [feedback_embeddings, attention_output_vector])
            s = s.add_input(decoder_input)

            losses.append(batch_loss)

        # sum the loss over the time steps and batch
        total_batch_loss = dn.sum_batches(dn.esum(losses))

        return total_batch_loss
示例#11
0
    def fit(self,
            train_dict,
            num_epochs,
            val_X=None,
            val_Y=None,
            patience=2,
            model_path=None,
            seed=None,
            word_dropout_rate=0.25,
            trg_vectors=None,
            unsup_weight=1.0,
            clip_threshold=5.0,
            orthogonality_weight=0.0,
            adversarial=False,
            adversarial_weight=1.0,
            ignore_src_Ft=False):
        """
        train the tagger
        :param trg_vectors: the prediction targets used for the unsupervised loss
                            in temporal ensembling
        :param unsup_weight: weight for the unsupervised consistency loss
                                    used in temporal ensembling
        :param adversarial: note: if we want to use adversarial, we have to
                            call add_adversarial_loss before;
        :param adversarial_weight: 1 by default (do not weigh adv loss)
        :param ignore_src_Ft: if asymm.tri. 2nd stage, do not further train Ft on 'src'
        :param train_dict: a dictionary mapping tasks ("F0", "F1", and "Ft")
                           to a dictionary
                           {"X": list of examples,
                            "Y": list of labels,
                            "domain": list of domain tag (0,1) of example}
        Three tasks are indexed as "F0", "F1" and "Ft"

        Note: if a task 'src' is given than a single model with three heads is trained where
        all data is given to all outputs
        """
        print("read training data")

        widCount = Counter()
        train_data = []
        for task, task_dict in train_dict.items():  #task: eg. "F0"
            for key in ["X", "Y", "domain"]:
                assert key in task_dict, "Error: %s is not available." % key
            examples, labels, domain_tags = task_dict["X"], task_dict[
                "Y"], task_dict["domain"]
            assert len(examples) == len(labels)
            if word_dropout_rate > 0.0:
                # keep track of the counts for word dropout
                for sentence, _ in examples:
                    widCount.update([w for w in sentence])

            # train data is a list of 4-tuples: (example, label, task_id, domain_id)
            train_data += list(
                zip(examples, labels, [[task] * len(labels)][0], domain_tags))

        # if we use target vectors, keep track of the targets per sentence
        if trg_vectors is not None:
            trg_start_id = 0
            sentence_trg_vectors = []
            for i, (example, y) in enumerate(train_data):
                sentence_trg_vectors.append(
                    trg_vectors[trg_start_id:trg_start_id +
                                len(example[0]), :])
                trg_start_id += len(example[0])
            assert trg_start_id == len(trg_vectors),\
                'Error: Idx {} is not at {}.'.format(trg_start_id, len(trg_vectors))

        print('Starting training for {} epochs...'.format(num_epochs))
        best_val_acc, epochs_no_improvement = 0., 0
        if val_X is not None and val_Y is not None and model_path is not None:
            print(
                'Using early stopping with patience of {}...'.format(patience))

        if seed:
            random.seed(seed)

        for cur_iter in range(num_epochs):
            bar = Bar('Training epoch {}/{}...'.format(cur_iter + 1,
                                                       num_epochs),
                      max=len(train_data),
                      flush=True)

            random_indices = np.arange(len(train_data))
            random.shuffle(random_indices)

            total_loss, total_tagged, total_constraint, total_adversarial = 0.0, 0.0, 0.0, 0.0
            total_orth_constr = 0  # count how many updates

            # log separate losses
            log_losses = {}
            log_total = {}
            for task_id in self.task_ids:
                log_losses[task_id] = 0.0
                log_total[task_id] = 0

            for i, idx in enumerate(random_indices):
                (word_indices,
                 char_indices), y, task_id, domain_id = train_data[idx]

                if word_dropout_rate > 0.0:
                    word_indices = [
                        self.w2i["_UNK"] if
                        (random.random() >
                         (widCount.get(w) /
                          (word_dropout_rate + widCount.get(w)))) else w
                        for w in word_indices
                    ]

                output, constraint, adv = self.predict(
                    word_indices,
                    char_indices,
                    task_id,
                    train=True,
                    orthogonality_weight=orthogonality_weight,
                    domain_id=domain_id if adversarial else None)

                if task_id not in ['src', 'trg']:

                    if len(y) == 1 and y[0] == 0:
                        # in temporal ensembling, we assign a dummy label of [0] for
                        # unlabeled sequences; we skip the supervised loss for these
                        loss = dynet.scalarInput(0)
                    else:
                        loss = dynet.esum([
                            self.pick_neg_log(pred, gold)
                            for pred, gold in zip(output, y)
                        ])

                    if trg_vectors is not None:
                        # the consistency loss in temporal ensembling is used for
                        # both supervised and unsupervised input
                        targets = sentence_trg_vectors[idx]
                        assert len(output) == len(targets)
                        other_loss = unsup_weight * dynet.average([
                            dynet.squared_distance(o, dynet.inputVector(t))
                            for o, t in zip(output, targets)
                        ])
                        loss += other_loss

                    if orthogonality_weight != 0.0 and task_id != 'Ft':
                        # add the orthogonality constraint to the loss
                        total_constraint += constraint.value(
                        ) * orthogonality_weight
                        total_orth_constr += 1
                        loss += constraint * orthogonality_weight

                    if adversarial:
                        total_adversarial += adv.value() * adversarial_weight
                        loss += adv * adversarial_weight

                    total_loss += loss.value()  # for output

                    log_losses[task_id] += total_loss
                    total_tagged += len(word_indices)
                    log_total[task_id] += total_tagged

                    loss.backward()
                    self.trainer.update()
                    bar.next()
                else:
                    # bootstrap=False, the output contains list of outputs one for each task
                    assert trg_vectors is None, 'temporal ensembling not implemented for bootstrap=False'
                    loss = dynet.scalarInput(1)  #initialize
                    if ignore_src_Ft:
                        output = output[:
                                        -1]  # ignore last = Ft when further training with 'src'

                    for t_i, output_t in enumerate(
                            output):  # get loss for each task
                        loss += dynet.esum([
                            self.pick_neg_log(pred, gold)
                            for pred, gold in zip(output_t, y)
                        ])
                        task_id = self.task_ids[t_i]
                        log_losses[task_id] += total_loss
                        log_total[task_id] += total_tagged

                    if orthogonality_weight != 0.0:
                        # add the orthogonality constraint to the loss
                        total_constraint += constraint.value(
                        ) * orthogonality_weight
                        total_orth_constr += 1
                        loss += constraint * orthogonality_weight

                    if adversarial:
                        total_adversarial += adv.value() * adversarial_weight
                        loss += adv * adversarial_weight

                    total_loss += loss.value()  # for output
                    total_tagged += len(word_indices)

                    loss.backward()
                    self.trainer.update()
                    bar.next()

            if adversarial and orthogonality_weight:
                print(
                    "iter {}. Total loss: {:.3f}, total penalty: {:.3f}, total weighted adv loss: {:.3f}"
                    .format(cur_iter, total_loss / total_tagged,
                            total_constraint / total_orth_constr,
                            total_adversarial / total_tagged),
                    file=sys.stderr)
            elif orthogonality_weight:
                print("iter {}. Total loss: {:.3f}, total penalty: {:.3f}".
                      format(cur_iter, total_loss / total_tagged,
                             total_constraint / total_orth_constr),
                      file=sys.stderr)
            else:
                print("iter {}. Total loss: {:.3f} ".format(
                    cur_iter, total_loss / total_tagged),
                      file=sys.stderr)

            for task_id in self.task_ids:
                if log_total[task_id] > 0:
                    print("{0}: {1:.3f}".format(
                        task_id, log_losses[task_id] / log_total[task_id]))

            if val_X is not None and val_Y is not None and model_path is not None:
                # get the best accuracy on the validation set
                val_correct, val_total = self.evaluate(val_X, val_Y)
                val_accuracy = val_correct / val_total

                if val_accuracy > best_val_acc:
                    print(
                        'Accuracy {:.4f} is better than best val accuracy {:.4f}.'
                        .format(val_accuracy, best_val_acc))
                    best_val_acc = val_accuracy
                    epochs_no_improvement = 0
                    save_tagger(self, model_path)
                else:
                    print(
                        'Accuracy {:.4f} is worse than best val loss {:.4f}.'.
                        format(val_accuracy, best_val_acc))
                    epochs_no_improvement += 1
                if epochs_no_improvement == patience:
                    print('No improvement for {} epochs. Early stopping...'.
                          format(epochs_no_improvement))
                    break