Exemplo n.º 1
0
 def add_arguments(parser: argparse.ArgumentParser):
     add_parse_args(parser)
     parser.add_argument("--teacher_model_path", type=str, required=True,
                         help="Path to teacher model")
     parser.add_argument("--teacher_model_type", type=str, required=True,
                         choices=TEACHER_TYPES,
                         help="Teacher model class type")
     TeacherStudentDistill.add_args(parser)
 def add_arguments(parser: argparse.ArgumentParser):
     add_parse_args(parser)
     TeacherStudentDistill.add_args(parser)
     parser.add_argument(
         "--teacher_max_seq_len",
         type=int,
         default=128,
         help="Max sentence \
                          length for teacher data loading",
     )
Exemplo n.º 3
0
 def add_arguments(parser: argparse.ArgumentParser):
     add_parse_args(parser)
     TeacherStudentDistill.add_args(parser)
     parser.add_argument(
         "--labeled_train_file",
         default="labeled.txt",
         type=str,
         help="The file name containing the labeled training examples")
     parser.add_argument(
         "--unlabeled_train_file",
         default="unlabeled.txt",
         type=str,
         help="The file name containing the unlabeled training examples")
 def add_arguments(parser: argparse.ArgumentParser):
     add_parse_args(parser)
     TeacherStudentDistill.add_args(parser)
     parser.add_argument(
         "--unlabeled_filename",
         default="unlabeled.txt",
         type=str,
         help="The file name containing the unlabeled training examples",
     )
     parser.add_argument(
         "--parallel_batching",
         action="store_true",
         help="sample labeled/unlabeled batch in parallel",
     )
     parser.add_argument(
         "--teacher_max_seq_len",
         type=int,
         default=128,
         help="Max sentence \
                          length for teacher data loading",
     )
Exemplo n.º 5
0
 def add_arguments(parser: argparse.ArgumentParser):
     add_parse_args(parser)
     TeacherStudentDistill.add_args(parser)
Exemplo n.º 6
0
def do_kd_training(args):
    prepare_output_path(args.output_dir, args.overwrite_output_dir)
    device, n_gpus = setup_backend(args.no_cuda)
    # Set seed
    set_seed(args.seed, n_gpus)
    # prepare data
    processor = TokenClsProcessor(args.data_dir, tag_col=args.tag_col)
    train_ex = processor.get_train_examples()
    dev_ex = processor.get_dev_examples()
    test_ex = processor.get_test_examples()
    vocab = processor.get_vocabulary()
    vocab_size = len(vocab) + 1
    num_labels = len(processor.get_labels()) + 1
    # create an embedder
    embedder_cls = MODEL_TYPE[args.model_type]
    if args.config_file is not None:
        embedder_model = embedder_cls.from_config(vocab_size, num_labels,
                                                  args.config_file)
    else:
        embedder_model = embedder_cls(vocab_size, num_labels)

    # load external word embeddings if present
    if args.embedding_file is not None:
        emb_dict = load_embedding_file(args.embedding_file)
        emb_mat = get_embedding_matrix(emb_dict, vocab)
        emb_mat = torch.tensor(emb_mat, dtype=torch.float)
        embedder_model.load_embeddings(emb_mat)

    classifier = NeuralTagger(embedder_model,
                              word_vocab=vocab,
                              labels=processor.get_labels(),
                              use_crf=args.use_crf,
                              device=device,
                              n_gpus=n_gpus)

    train_batch_size = args.b * max(1, n_gpus)
    train_dataset = classifier.convert_to_tensors(
        train_ex,
        max_seq_length=args.max_sentence_length,
        max_word_length=args.max_word_length)

    teacher = TransformerTokenClassifier.load_model(
        model_path=args.teacher_model_path, model_type=args.teacher_model_type)
    teacher.to(device, n_gpus)
    teacher_dataset = teacher.convert_to_tensors(train_ex,
                                                 args.max_sentence_length,
                                                 False)

    train_dataset = ParallelDataset(train_dataset, teacher_dataset)

    train_sampler = RandomSampler(train_dataset)
    train_dl = DataLoader(train_dataset,
                          sampler=train_sampler,
                          batch_size=train_batch_size)

    if dev_ex is not None:
        dev_dataset = classifier.convert_to_tensors(
            dev_ex,
            max_seq_length=args.max_sentence_length,
            max_word_length=args.max_word_length)
        dev_sampler = SequentialSampler(dev_dataset)
        dev_dl = DataLoader(dev_dataset,
                            sampler=dev_sampler,
                            batch_size=args.b)

    if test_ex is not None:
        test_dataset = classifier.convert_to_tensors(
            test_ex,
            max_seq_length=args.max_sentence_length,
            max_word_length=args.max_word_length)
        test_sampler = SequentialSampler(test_dataset)
        test_dl = DataLoader(test_dataset,
                             sampler=test_sampler,
                             batch_size=args.b)
    if args.lr is not None:
        opt = classifier.get_optimizer(lr=args.lr)

    distiller = TeacherStudentDistill(teacher, args.kd_temp, args.kd_dist_w,
                                      args.kd_student_w, args.kd_loss_fn)
    classifier.train(train_dl,
                     dev_dl,
                     test_dl,
                     epochs=args.e,
                     batch_size=args.b,
                     logging_steps=args.logging_steps,
                     save_steps=args.save_steps,
                     save_path=args.output_dir,
                     optimizer=opt if opt is not None else None,
                     distiller=distiller)
    classifier.save_model(args.output_dir)
def do_kd_pseudo_training(args):
    prepare_output_path(args.output_dir, args.overwrite_output_dir)
    device, n_gpus = setup_backend(args.no_cuda)
    # Set seed
    args.seed = set_seed(args.seed, n_gpus)
    # prepare data
    processor = TokenClsProcessor(
        args.data_dir, tag_col=args.tag_col, ignore_token=args.ignore_token
    )
    train_labeled_ex = processor.get_train_examples(filename=args.train_filename)
    train_unlabeled_ex = processor.get_train_examples(filename=args.unlabeled_filename)
    dev_ex = processor.get_dev_examples(filename=args.dev_filename)
    test_ex = processor.get_test_examples(filename=args.test_filename)
    vocab = processor.get_vocabulary(train_labeled_ex + train_unlabeled_ex + dev_ex + test_ex)
    vocab_size = len(vocab) + 1
    num_labels = len(processor.get_labels()) + 1
    # create an embedder
    embedder_cls = MODEL_TYPE[args.model_type]
    if args.config_file is not None:
        embedder_model = embedder_cls.from_config(vocab_size, num_labels, args.config_file)
    else:
        embedder_model = embedder_cls(vocab_size, num_labels)

    # load external word embeddings if present
    if args.embedding_file is not None:
        emb_dict = load_embedding_file(args.embedding_file, dim=embedder_model.word_embedding_dim)
        emb_mat = get_embedding_matrix(emb_dict, vocab)
        emb_mat = torch.tensor(emb_mat, dtype=torch.float)
        embedder_model.load_embeddings(emb_mat)

    classifier = NeuralTagger(
        embedder_model,
        word_vocab=vocab,
        labels=processor.get_labels(),
        use_crf=args.use_crf,
        device=device,
        n_gpus=n_gpus,
    )

    train_batch_size = args.b * max(1, n_gpus)
    train_labeled_dataset = classifier.convert_to_tensors(
        train_labeled_ex,
        max_seq_length=args.max_sentence_length,
        max_word_length=args.max_word_length,
    )
    train_unlabeled_dataset = classifier.convert_to_tensors(
        train_unlabeled_ex,
        max_seq_length=args.max_sentence_length,
        max_word_length=args.max_word_length,
        include_labels=False,
    )

    if args.parallel_batching:
        # # concat labeled+unlabeled dataset
        # train_dataset = ConcatTensorDataset(train_labeled_dataset, [train_unlabeled_dataset])
        # match sizes of labeled/unlabeled train data for parallel batching
        larger_ds, smaller_ds = (
            (train_labeled_dataset, train_unlabeled_dataset)
            if len(train_labeled_dataset) > len(train_unlabeled_dataset)
            else (train_unlabeled_dataset, train_labeled_dataset)
        )
        concat_smaller_ds = smaller_ds
        while len(concat_smaller_ds) < len(larger_ds):
            concat_smaller_ds = ConcatTensorDataset(concat_smaller_ds, [smaller_ds])
        if len(concat_smaller_ds[0]) == 4:
            train_unlabeled_dataset = concat_smaller_ds
        else:
            train_labeled_dataset = concat_smaller_ds
    else:
        train_dataset = CombinedTensorDataset([train_labeled_dataset, train_unlabeled_dataset])

    # load saved teacher args if exist
    if os.path.exists(args.teacher_model_path + os.sep + "training_args.bin"):
        t_args = torch.load(args.teacher_model_path + os.sep + "training_args.bin")
        t_device, t_n_gpus = setup_backend(t_args.no_cuda)
        teacher = TransformerTokenClassifier.load_model(
            model_path=args.teacher_model_path,
            model_type=args.teacher_model_type,
            config_name=t_args.config_name,
            tokenizer_name=t_args.tokenizer_name,
            do_lower_case=t_args.do_lower_case,
            output_path=t_args.output_dir,
            device=t_device,
            n_gpus=t_n_gpus,
        )
    else:
        teacher = TransformerTokenClassifier.load_model(
            model_path=args.teacher_model_path, model_type=args.teacher_model_type
        )
        teacher.to(device, n_gpus)

    teacher_labeled_dataset = teacher.convert_to_tensors(train_labeled_ex, args.teacher_max_seq_len)
    teacher_unlabeled_dataset = teacher.convert_to_tensors(
        train_unlabeled_ex, args.teacher_max_seq_len, False
    )

    if args.parallel_batching:
        # # concat teacher labeled+unlabeled dataset
        # teacher_dataset = ConcatTensorDataset(teacher_labeled_dataset, [teacher_unlabeled_dataset])
        # match sizes of labeled/unlabeled teacher train data for parallel batching
        larger_ds, smaller_ds = (
            (teacher_labeled_dataset, teacher_unlabeled_dataset)
            if len(teacher_labeled_dataset) > len(teacher_unlabeled_dataset)
            else (teacher_unlabeled_dataset, teacher_labeled_dataset)
        )
        concat_smaller_ds = smaller_ds
        while len(concat_smaller_ds) < len(larger_ds):
            concat_smaller_ds = ConcatTensorDataset(concat_smaller_ds, [smaller_ds])
        if len(concat_smaller_ds[0]) == 4:
            teacher_unlabeled_dataset = concat_smaller_ds
        else:
            teacher_labeled_dataset = concat_smaller_ds

        train_all_dataset = ParallelDataset(
            train_labeled_dataset,
            teacher_labeled_dataset,
            train_unlabeled_dataset,
            teacher_unlabeled_dataset,
        )

        train_all_sampler = RandomSampler(train_all_dataset)
        # this way must use same batch size for both labeled/unlabeled sets
        train_dl = DataLoader(
            train_all_dataset, sampler=train_all_sampler, batch_size=train_batch_size
        )

    else:
        teacher_dataset = CombinedTensorDataset(
            [teacher_labeled_dataset, teacher_unlabeled_dataset]
        )

        train_dataset = ParallelDataset(train_dataset, teacher_dataset)
        train_sampler = RandomSampler(train_dataset)
        train_dl = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_batch_size)

    if dev_ex is not None:
        dev_dataset = classifier.convert_to_tensors(
            dev_ex, max_seq_length=args.max_sentence_length, max_word_length=args.max_word_length
        )
        dev_sampler = SequentialSampler(dev_dataset)
        dev_dl = DataLoader(dev_dataset, sampler=dev_sampler, batch_size=args.b)

    if test_ex is not None:
        test_dataset = classifier.convert_to_tensors(
            test_ex, max_seq_length=args.max_sentence_length, max_word_length=args.max_word_length
        )
        test_sampler = SequentialSampler(test_dataset)
        test_dl = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.b)
    if args.lr is not None:
        opt = classifier.get_optimizer(lr=args.lr)

    distiller = TeacherStudentDistill(
        teacher, args.kd_temp, args.kd_dist_w, args.kd_student_w, args.kd_loss_fn
    )

    classifier.train(
        train_dl,
        dev_dl,
        test_dl,
        epochs=args.e,
        batch_size=args.b,
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        save_path=args.output_dir,
        optimizer=opt if opt is not None else None,
        best_result_file=args.best_result_file,
        distiller=distiller,
        word_dropout=args.word_dropout,
    )

    classifier.save_model(args.output_dir)
Exemplo n.º 8
0
    def train_pseudo(
        self,
        labeled_data_set: DataLoader,
        unlabeled_data_set: DataLoader,
        distiller: TeacherStudentDistill,
        dev_data_set: DataLoader = None,
        test_data_set: DataLoader = None,
        batch_size_l: int = 8,
        batch_size_ul: int = 8,
        epochs: int = 100,
        optimizer=None,
        max_grad_norm: float = 5.0,
        logging_steps: int = 50,
        save_steps: int = 100,
        save_path: str = None,
        save_best: bool = False,
    ):
        """
        Train a tagging model

        Args:
            train_data_set (DataLoader): train examples dataloader. If distiller object is
            provided train examples should contain a tuple of student/teacher data examples.
            dev_data_set (DataLoader, optional): dev examples dataloader. Defaults to None.
            test_data_set (DataLoader, optional): test examples dataloader. Defaults to None.
            batch_size_l (int, optional): batch size for the labeled dataset. Defaults to 8.
            batch_size_ul (int, optional): batch size for the unlabeled dataset. Defaults to 8.
            epochs (int, optional): num of epochs to train. Defaults to 100.
            optimizer (fn, optional): optimizer function. Defaults to default model optimizer.
            max_grad_norm (float, optional): max gradient norm. Defaults to 5.0.
            logging_steps (int, optional): number of steps between logging. Defaults to 50.
            save_steps (int, optional): number of steps between model saves. Defaults to 100.
            save_path (str, optional): model output path. Defaults to None.
            save_best (str, optional): wether to save model when result is best on dev set
            distiller (TeacherStudentDistill, optional): KD model for training the model using
            a teacher model. Defaults to None.
        """
        if optimizer is None:
            optimizer = self.get_optimizer()
        train_batch_size_l = batch_size_l * max(1, self.n_gpus)
        train_batch_size_ul = batch_size_ul * max(1, self.n_gpus)
        logger.info("***** Running training *****")
        logger.info("  Num labeled examples = %d",
                    len(labeled_data_set.dataset))
        logger.info("  Num unlabeled examples = %d",
                    len(unlabeled_data_set.dataset))
        logger.info("  Instantaneous labeled batch size per GPU/CPU = %d",
                    batch_size_l)
        logger.info("  Instantaneous unlabeled batch size per GPU/CPU = %d",
                    batch_size_ul)
        logger.info("  Total batch size labeled= %d", train_batch_size_l)
        logger.info("  Total batch size unlabeled= %d", train_batch_size_ul)
        global_step = 0
        self.model.zero_grad()
        avg_loss = 0
        iter_l = iter(labeled_data_set)
        iter_ul = iter(unlabeled_data_set)
        epoch_l = 0
        epoch_ul = 0
        s_idx = -1
        best_dev = 0
        best_test = 0
        while True:
            logger.info("labeled epoch=%d, unlabeled epoch=%d", epoch_l,
                        epoch_ul)
            loss_labeled = 0
            loss_unlabeled = 0
            try:
                batch_l = next(iter_l)
                s_idx += 1
            except StopIteration:
                iter_l = iter(labeled_data_set)
                epoch_l += 1
                batch_l = next(iter_l)
                s_idx = 0
                avg_loss = 0
            try:
                batch_ul = next(iter_ul)
            except StopIteration:
                iter_ul = iter(unlabeled_data_set)
                epoch_ul += 1
                batch_ul = next(iter_ul)
            if epoch_ul > epochs:
                logger.info("Done")
                return
            self.model.train()
            batch_l, t_batch_l = batch_l[:2]
            batch_ul, t_batch_ul = batch_ul[:2]
            t_batch_l = tuple(t.to(self.device) for t in t_batch_l)
            t_batch_ul = tuple(t.to(self.device) for t in t_batch_ul)
            t_logits = distiller.get_teacher_logits(t_batch_l)
            t_logits_ul = distiller.get_teacher_logits(t_batch_ul)
            batch_l = tuple(t.to(self.device) for t in batch_l)
            batch_ul = tuple(t.to(self.device) for t in batch_ul)
            inputs = self.batch_mapper(batch_l)
            inputs_ul = self.batch_mapper(batch_ul)
            logits = self.model(**inputs)
            logits_ul = self.model(**inputs_ul)
            t_labels = torch.argmax(F.log_softmax(t_logits_ul, dim=2), dim=2)
            if self.use_crf:
                loss_labeled = -1.0 * self.crf(
                    logits, inputs["labels"], mask=inputs["mask"] != 0.0)
                loss_unlabeled = -1.0 * self.crf(
                    logits_ul, t_labels, mask=inputs_ul["mask"] != 0.0)
            else:
                loss_fn = CrossEntropyLoss(ignore_index=0)
                loss_labeled = loss_fn(logits.view(-1, self.num_labels),
                                       inputs["labels"].view(-1))
                loss_unlabeled = loss_fn(logits_ul.view(-1, self.num_labels),
                                         t_labels.view(-1))

            if self.n_gpus > 1:
                loss_labeled = loss_labeled.mean()
                loss_unlabeled = loss_unlabeled.mean()

            # add distillation loss
            loss_labeled = distiller.distill_loss(loss_labeled, logits,
                                                  t_logits)
            loss_unlabeled = distiller.distill_loss(loss_unlabeled, logits_ul,
                                                    t_logits_ul)

            # sum labeled and unlabeled losses
            loss = loss_labeled + loss_unlabeled
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_grad_norm)
            optimizer.step()
            # self.model.zero_grad()
            optimizer.zero_grad()
            global_step += 1
            avg_loss += loss.item()
            if global_step % logging_steps == 0:
                if s_idx != 0:
                    logger.info(" global_step = %s, average loss = %s",
                                global_step, avg_loss / s_idx)
                dev = self._get_eval(dev_data_set, "dev")
                test = self._get_eval(test_data_set, "test")
                if dev > best_dev:
                    best_dev = dev
                    best_test = test
                    if save_path is not None and save_best:
                        self.save_model(save_path)
                logger.info("Best result: dev= %s, test= %s", str(best_dev),
                            str(best_test))
            if save_path is not None and global_step % save_steps == 0:
                self.save_model(save_path)
Exemplo n.º 9
0
    def train(
        self,
        train_data_set: DataLoader,
        dev_data_set: DataLoader = None,
        test_data_set: DataLoader = None,
        epochs: int = 3,
        batch_size: int = 8,
        optimizer=None,
        max_grad_norm: float = 5.0,
        logging_steps: int = 50,
        save_steps: int = 100,
        save_path: str = None,
        distiller: TeacherStudentDistill = None,
    ):
        """
        Train a tagging model

        Args:
            train_data_set (DataLoader): train examples dataloader. If distiller object is
            provided train examples should contain a tuple of student/teacher data examples.
            dev_data_set (DataLoader, optional): dev examples dataloader. Defaults to None.
            test_data_set (DataLoader, optional): test examples dataloader. Defaults to None.
            epochs (int, optional): num of epochs to train. Defaults to 3.
            batch_size (int, optional): batch size. Defaults to 8.
            optimizer (fn, optional): optimizer function. Defaults to default model optimizer.
            max_grad_norm (float, optional): max gradient norm. Defaults to 5.0.
            logging_steps (int, optional): number of steps between logging. Defaults to 50.
            save_steps (int, optional): number of steps between model saves. Defaults to 100.
            save_path (str, optional): model output path. Defaults to None.
            distiller (TeacherStudentDistill, optional): KD model for training the model using
            a teacher model. Defaults to None.
        """
        if optimizer is None:
            optimizer = self.get_optimizer()
        train_batch_size = batch_size * max(1, self.n_gpus)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_data_set.dataset))
        logger.info("  Num Epochs = %d", epochs)
        logger.info("  Instantaneous batch size per GPU/CPU = %d", batch_size)
        logger.info("  Total batch size = %d", train_batch_size)
        global_step = 0
        self.model.zero_grad()
        epoch_it = trange(epochs, desc="Epoch")
        for _ in epoch_it:
            step_it = tqdm(train_data_set, desc="Train iteration")
            avg_loss = 0
            for step, batch in enumerate(step_it):
                self.model.train()
                if distiller:
                    batch, t_batch = batch[:2]
                    t_batch = tuple(t.to(self.device) for t in t_batch)
                    t_logits = distiller.get_teacher_logits(t_batch)
                batch = tuple(t.to(self.device) for t in batch)
                inputs = self.batch_mapper(batch)
                logits = self.model(**inputs)
                if self.use_crf:
                    loss = -1.0 * self.crf(
                        logits, inputs["labels"], mask=inputs["mask"] != 0.0)
                else:
                    loss_fn = CrossEntropyLoss(ignore_index=0)
                    loss = loss_fn(logits.view(-1, self.num_labels),
                                   inputs["labels"].view(-1))
                if self.n_gpus > 1:
                    loss = loss.mean()

                # add distillation loss if activated
                if distiller:
                    loss = distiller.distill_loss(loss, logits, t_logits)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               max_grad_norm)

                optimizer.step()
                # self.model.zero_grad()
                optimizer.zero_grad()
                global_step += 1
                avg_loss += loss.item()
                if global_step % logging_steps == 0:
                    if step != 0:
                        logger.info(" global_step = %s, average loss = %s",
                                    global_step, avg_loss / step)
                    self._get_eval(dev_data_set, "dev")
                    self._get_eval(test_data_set, "test")
                if save_path is not None and global_step % save_steps == 0:
                    self.save_model(save_path)
Exemplo n.º 10
0
    def train(
        self,
        train_data_set: DataLoader,
        dev_data_set: DataLoader = None,
        test_data_set: DataLoader = None,
        epochs: int = 3,
        batch_size: int = 8,
        optimizer=None,
        max_grad_norm: float = 5.0,
        logging_steps: int = 50,
        save_steps: int = 100,
        save_path: str = None,
        distiller: TeacherStudentDistill = None,
        best_result_file: str = None,
        word_dropout: float = 0,
    ):
        """
        Train a tagging model

        Args:
            train_data_set (DataLoader): train examples dataloader.
                - If distiller object is provided train examples should contain a tuple of
                  student/teacher data examples.
            dev_data_set (DataLoader, optional): dev examples dataloader. Defaults to None.
            test_data_set (DataLoader, optional): test examples dataloader. Defaults to None.
            epochs (int, optional): num of epochs to train. Defaults to 3.
            batch_size (int, optional): batch size. Defaults to 8.
            optimizer (fn, optional): optimizer function. Defaults to default model optimizer.
            max_grad_norm (float, optional): max gradient norm. Defaults to 5.0.
            logging_steps (int, optional): number of steps between logging. Defaults to 50.
            save_steps (int, optional): number of steps between model saves. Defaults to 100.
            save_path (str, optional): model output path. Defaults to None.
            distiller (TeacherStudentDistill, optional): KD model for training the model using
            a teacher model. Defaults to None.
            best_result_file (str, optional): path to save best dev results when it's updated.
            word_dropout (float, optional): whole-word (-> oov) dropout rate. Defaults to 0.
        """
        if optimizer is None:
            optimizer = self.get_optimizer()
        train_batch_size = batch_size * max(1, self.n_gpus)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_data_set.dataset))
        logger.info("  Num Epochs = %d", epochs)
        logger.info("  Instantaneous batch size per GPU/CPU = %d", batch_size)
        logger.info("  Total batch size = %d", train_batch_size)
        global_step = 0
        best_dev = 0
        dev_test = 0
        self.model.zero_grad()
        epoch_it = trange(epochs, desc="Epoch")
        for epoch in epoch_it:
            step_it = tqdm(train_data_set, desc="Train iteration")
            avg_loss = 0

            for step, batches in enumerate(step_it):
                self.model.train()

                batch, t_batch = (batches, []) if not distiller else (batches[:2])
                batch = tuple(t.to(self.device) for t in batch)
                inputs = self.batch_mapper(batch)
                logits = self.model(**inputs)

                if distiller:
                    t_batch = tuple(t.to(self.device) for t in t_batch)
                    t_logits = distiller.get_teacher_logits(t_batch)
                    valid_positions = (
                        t_batch[3] != 0.0
                    )  # TODO: implement method to get only valid logits from the model itself
                    valid_t_logits = {}
                    max_seq_len = logits.shape[1]
                    for i in range(len(logits)):  # each example in batch
                        valid_logit_i = t_logits[i][valid_positions[i]]
                        valid_t_logits[i] = (
                            valid_logit_i
                            if valid_logit_i.shape[0] <= max_seq_len
                            else valid_logit_i[:][:max_seq_len]
                        )  # cut to max len

                    # prepare teacher labels for non-labeled examples
                    t_labels_dict = {}
                    for i in range(len(valid_t_logits.keys())):
                        t_labels_dict[i] = torch.argmax(
                            F.log_softmax(valid_t_logits[i], dim=-1), dim=-1
                        )

                # pseudo labeling
                for i, is_labeled in enumerate(inputs["is_labeled"]):
                    if not is_labeled:
                        t_labels_i = t_labels_dict[i]
                        # add the padded teacher label:
                        inputs["labels"][i] = torch.cat(
                            (
                                t_labels_i,
                                torch.zeros([max_seq_len - len(t_labels_i)], dtype=torch.long).to(
                                    self.device
                                ),
                            ),
                            0,
                        )

                # apply word dropout to the input
                if word_dropout != 0:
                    tokens = inputs["words"]
                    tokens = np.array(tokens.detach().cpu())
                    word_probs = np.random.random(tokens.shape)
                    drop_indices = np.where(
                        (word_probs > word_dropout) & (tokens != 0)
                    )  # ignore padding indices
                    inputs["words"][drop_indices[0], drop_indices[1]] = self.word_vocab.oov_id

                # loss
                if self.use_crf:
                    loss = -1.0 * self.crf(logits, inputs["labels"], mask=inputs["mask"] != 0.0)
                else:
                    loss_fn = CrossEntropyLoss(ignore_index=0)
                    loss = loss_fn(logits.view(-1, self.num_labels), inputs["labels"].view(-1))

                # for idcnn training - add dropout penalty loss
                module = self.model.module if self.n_gpus > 1 else self.model
                if isinstance(module, IDCNN) and module.drop_penalty != 0:
                    logits_no_drop = self.model(**inputs, no_dropout=True)
                    sub = logits.sub(logits_no_drop)
                    drop_loss = torch.div(torch.sum(torch.pow(sub, 2)), 2)
                    loss += module.drop_penalty * drop_loss

                if self.n_gpus > 1:
                    loss = loss.mean()

                # add distillation loss if activated
                if distiller:
                    # filter masked student logits (no padding)
                    valid_s_logits = {}
                    valid_s_positions = inputs["mask"] != 0.0
                    for i in range(len(logits)):
                        valid_s_logit_i = logits[i][valid_s_positions[i]]
                        valid_s_logits[i] = valid_s_logit_i
                    loss = distiller.distill_loss_dict(loss, valid_s_logits, valid_t_logits)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                avg_loss += loss.item()
                if global_step % logging_steps == 0:
                    if step != 0:
                        logger.info(
                            " global_step = %s, average loss = %s", global_step, avg_loss / step
                        )
                        best_dev, dev_test = self.update_best_model(
                            dev_data_set,
                            test_data_set,
                            best_dev,
                            dev_test,
                            best_result_file,
                            avg_loss / step,
                            epoch,
                            save_path=None,
                        )
                if save_steps != 0 and save_path is not None and global_step % save_steps == 0:
                    self.save_model(save_path)
        self.update_best_model(
            dev_data_set,
            test_data_set,
            best_dev,
            dev_test,
            best_result_file,
            "end_training",
            "end_training",
            save_path=save_path + "/best_dev",
        )