Exemplo n.º 1
0
    def train(
        self,
        train_data_set: DataLoader,
        dev_data_set: DataLoader = None,
        test_data_set: DataLoader = None,
        epochs: int = 3,
        batch_size: int = 8,
        optimizer=None,
        max_grad_norm: float = 5.0,
        logging_steps: int = 50,
        save_steps: int = 100,
        save_path: str = None,
        distiller: TeacherStudentDistill = None,
    ):
        """
        Train a tagging model

        Args:
            train_data_set (DataLoader): train examples dataloader. If distiller object is
            provided train examples should contain a tuple of student/teacher data examples.
            dev_data_set (DataLoader, optional): dev examples dataloader. Defaults to None.
            test_data_set (DataLoader, optional): test examples dataloader. Defaults to None.
            epochs (int, optional): num of epochs to train. Defaults to 3.
            batch_size (int, optional): batch size. Defaults to 8.
            optimizer (fn, optional): optimizer function. Defaults to default model optimizer.
            max_grad_norm (float, optional): max gradient norm. Defaults to 5.0.
            logging_steps (int, optional): number of steps between logging. Defaults to 50.
            save_steps (int, optional): number of steps between model saves. Defaults to 100.
            save_path (str, optional): model output path. Defaults to None.
            distiller (TeacherStudentDistill, optional): KD model for training the model using
            a teacher model. Defaults to None.
        """
        if optimizer is None:
            optimizer = self.get_optimizer()
        train_batch_size = batch_size * max(1, self.n_gpus)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_data_set.dataset))
        logger.info("  Num Epochs = %d", epochs)
        logger.info("  Instantaneous batch size per GPU/CPU = %d", batch_size)
        logger.info("  Total batch size = %d", train_batch_size)
        global_step = 0
        self.model.zero_grad()
        epoch_it = trange(epochs, desc="Epoch")
        for _ in epoch_it:
            step_it = tqdm(train_data_set, desc="Train iteration")
            avg_loss = 0
            for step, batch in enumerate(step_it):
                self.model.train()
                if distiller:
                    batch, t_batch = batch[:2]
                    t_batch = tuple(t.to(self.device) for t in t_batch)
                    t_logits = distiller.get_teacher_logits(t_batch)
                batch = tuple(t.to(self.device) for t in batch)
                inputs = self.batch_mapper(batch)
                logits = self.model(**inputs)
                if self.use_crf:
                    loss = -1.0 * self.crf(
                        logits, inputs["labels"], mask=inputs["mask"] != 0.0)
                else:
                    loss_fn = CrossEntropyLoss(ignore_index=0)
                    loss = loss_fn(logits.view(-1, self.num_labels),
                                   inputs["labels"].view(-1))
                if self.n_gpus > 1:
                    loss = loss.mean()

                # add distillation loss if activated
                if distiller:
                    loss = distiller.distill_loss(loss, logits, t_logits)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               max_grad_norm)

                optimizer.step()
                # self.model.zero_grad()
                optimizer.zero_grad()
                global_step += 1
                avg_loss += loss.item()
                if global_step % logging_steps == 0:
                    if step != 0:
                        logger.info(" global_step = %s, average loss = %s",
                                    global_step, avg_loss / step)
                    self._get_eval(dev_data_set, "dev")
                    self._get_eval(test_data_set, "test")
                if save_path is not None and global_step % save_steps == 0:
                    self.save_model(save_path)
Exemplo n.º 2
0
    def train_pseudo(
        self,
        labeled_data_set: DataLoader,
        unlabeled_data_set: DataLoader,
        distiller: TeacherStudentDistill,
        dev_data_set: DataLoader = None,
        test_data_set: DataLoader = None,
        batch_size_l: int = 8,
        batch_size_ul: int = 8,
        epochs: int = 100,
        optimizer=None,
        max_grad_norm: float = 5.0,
        logging_steps: int = 50,
        save_steps: int = 100,
        save_path: str = None,
        save_best: bool = False,
    ):
        """
        Train a tagging model

        Args:
            train_data_set (DataLoader): train examples dataloader. If distiller object is
            provided train examples should contain a tuple of student/teacher data examples.
            dev_data_set (DataLoader, optional): dev examples dataloader. Defaults to None.
            test_data_set (DataLoader, optional): test examples dataloader. Defaults to None.
            batch_size_l (int, optional): batch size for the labeled dataset. Defaults to 8.
            batch_size_ul (int, optional): batch size for the unlabeled dataset. Defaults to 8.
            epochs (int, optional): num of epochs to train. Defaults to 100.
            optimizer (fn, optional): optimizer function. Defaults to default model optimizer.
            max_grad_norm (float, optional): max gradient norm. Defaults to 5.0.
            logging_steps (int, optional): number of steps between logging. Defaults to 50.
            save_steps (int, optional): number of steps between model saves. Defaults to 100.
            save_path (str, optional): model output path. Defaults to None.
            save_best (str, optional): wether to save model when result is best on dev set
            distiller (TeacherStudentDistill, optional): KD model for training the model using
            a teacher model. Defaults to None.
        """
        if optimizer is None:
            optimizer = self.get_optimizer()
        train_batch_size_l = batch_size_l * max(1, self.n_gpus)
        train_batch_size_ul = batch_size_ul * max(1, self.n_gpus)
        logger.info("***** Running training *****")
        logger.info("  Num labeled examples = %d",
                    len(labeled_data_set.dataset))
        logger.info("  Num unlabeled examples = %d",
                    len(unlabeled_data_set.dataset))
        logger.info("  Instantaneous labeled batch size per GPU/CPU = %d",
                    batch_size_l)
        logger.info("  Instantaneous unlabeled batch size per GPU/CPU = %d",
                    batch_size_ul)
        logger.info("  Total batch size labeled= %d", train_batch_size_l)
        logger.info("  Total batch size unlabeled= %d", train_batch_size_ul)
        global_step = 0
        self.model.zero_grad()
        avg_loss = 0
        iter_l = iter(labeled_data_set)
        iter_ul = iter(unlabeled_data_set)
        epoch_l = 0
        epoch_ul = 0
        s_idx = -1
        best_dev = 0
        best_test = 0
        while True:
            logger.info("labeled epoch=%d, unlabeled epoch=%d", epoch_l,
                        epoch_ul)
            loss_labeled = 0
            loss_unlabeled = 0
            try:
                batch_l = next(iter_l)
                s_idx += 1
            except StopIteration:
                iter_l = iter(labeled_data_set)
                epoch_l += 1
                batch_l = next(iter_l)
                s_idx = 0
                avg_loss = 0
            try:
                batch_ul = next(iter_ul)
            except StopIteration:
                iter_ul = iter(unlabeled_data_set)
                epoch_ul += 1
                batch_ul = next(iter_ul)
            if epoch_ul > epochs:
                logger.info("Done")
                return
            self.model.train()
            batch_l, t_batch_l = batch_l[:2]
            batch_ul, t_batch_ul = batch_ul[:2]
            t_batch_l = tuple(t.to(self.device) for t in t_batch_l)
            t_batch_ul = tuple(t.to(self.device) for t in t_batch_ul)
            t_logits = distiller.get_teacher_logits(t_batch_l)
            t_logits_ul = distiller.get_teacher_logits(t_batch_ul)
            batch_l = tuple(t.to(self.device) for t in batch_l)
            batch_ul = tuple(t.to(self.device) for t in batch_ul)
            inputs = self.batch_mapper(batch_l)
            inputs_ul = self.batch_mapper(batch_ul)
            logits = self.model(**inputs)
            logits_ul = self.model(**inputs_ul)
            t_labels = torch.argmax(F.log_softmax(t_logits_ul, dim=2), dim=2)
            if self.use_crf:
                loss_labeled = -1.0 * self.crf(
                    logits, inputs["labels"], mask=inputs["mask"] != 0.0)
                loss_unlabeled = -1.0 * self.crf(
                    logits_ul, t_labels, mask=inputs_ul["mask"] != 0.0)
            else:
                loss_fn = CrossEntropyLoss(ignore_index=0)
                loss_labeled = loss_fn(logits.view(-1, self.num_labels),
                                       inputs["labels"].view(-1))
                loss_unlabeled = loss_fn(logits_ul.view(-1, self.num_labels),
                                         t_labels.view(-1))

            if self.n_gpus > 1:
                loss_labeled = loss_labeled.mean()
                loss_unlabeled = loss_unlabeled.mean()

            # add distillation loss
            loss_labeled = distiller.distill_loss(loss_labeled, logits,
                                                  t_logits)
            loss_unlabeled = distiller.distill_loss(loss_unlabeled, logits_ul,
                                                    t_logits_ul)

            # sum labeled and unlabeled losses
            loss = loss_labeled + loss_unlabeled
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_grad_norm)
            optimizer.step()
            # self.model.zero_grad()
            optimizer.zero_grad()
            global_step += 1
            avg_loss += loss.item()
            if global_step % logging_steps == 0:
                if s_idx != 0:
                    logger.info(" global_step = %s, average loss = %s",
                                global_step, avg_loss / s_idx)
                dev = self._get_eval(dev_data_set, "dev")
                test = self._get_eval(test_data_set, "test")
                if dev > best_dev:
                    best_dev = dev
                    best_test = test
                    if save_path is not None and save_best:
                        self.save_model(save_path)
                logger.info("Best result: dev= %s, test= %s", str(best_dev),
                            str(best_test))
            if save_path is not None and global_step % save_steps == 0:
                self.save_model(save_path)
Exemplo n.º 3
0
    def train(
        self,
        train_data_set: DataLoader,
        dev_data_set: DataLoader = None,
        test_data_set: DataLoader = None,
        epochs: int = 3,
        batch_size: int = 8,
        optimizer=None,
        max_grad_norm: float = 5.0,
        logging_steps: int = 50,
        save_steps: int = 100,
        save_path: str = None,
        distiller: TeacherStudentDistill = None,
        best_result_file: str = None,
        word_dropout: float = 0,
    ):
        """
        Train a tagging model

        Args:
            train_data_set (DataLoader): train examples dataloader.
                - If distiller object is provided train examples should contain a tuple of
                  student/teacher data examples.
            dev_data_set (DataLoader, optional): dev examples dataloader. Defaults to None.
            test_data_set (DataLoader, optional): test examples dataloader. Defaults to None.
            epochs (int, optional): num of epochs to train. Defaults to 3.
            batch_size (int, optional): batch size. Defaults to 8.
            optimizer (fn, optional): optimizer function. Defaults to default model optimizer.
            max_grad_norm (float, optional): max gradient norm. Defaults to 5.0.
            logging_steps (int, optional): number of steps between logging. Defaults to 50.
            save_steps (int, optional): number of steps between model saves. Defaults to 100.
            save_path (str, optional): model output path. Defaults to None.
            distiller (TeacherStudentDistill, optional): KD model for training the model using
            a teacher model. Defaults to None.
            best_result_file (str, optional): path to save best dev results when it's updated.
            word_dropout (float, optional): whole-word (-> oov) dropout rate. Defaults to 0.
        """
        if optimizer is None:
            optimizer = self.get_optimizer()
        train_batch_size = batch_size * max(1, self.n_gpus)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_data_set.dataset))
        logger.info("  Num Epochs = %d", epochs)
        logger.info("  Instantaneous batch size per GPU/CPU = %d", batch_size)
        logger.info("  Total batch size = %d", train_batch_size)
        global_step = 0
        best_dev = 0
        dev_test = 0
        self.model.zero_grad()
        epoch_it = trange(epochs, desc="Epoch")
        for epoch in epoch_it:
            step_it = tqdm(train_data_set, desc="Train iteration")
            avg_loss = 0

            for step, batches in enumerate(step_it):
                self.model.train()

                batch, t_batch = (batches, []) if not distiller else (batches[:2])
                batch = tuple(t.to(self.device) for t in batch)
                inputs = self.batch_mapper(batch)
                logits = self.model(**inputs)

                if distiller:
                    t_batch = tuple(t.to(self.device) for t in t_batch)
                    t_logits = distiller.get_teacher_logits(t_batch)
                    valid_positions = (
                        t_batch[3] != 0.0
                    )  # TODO: implement method to get only valid logits from the model itself
                    valid_t_logits = {}
                    max_seq_len = logits.shape[1]
                    for i in range(len(logits)):  # each example in batch
                        valid_logit_i = t_logits[i][valid_positions[i]]
                        valid_t_logits[i] = (
                            valid_logit_i
                            if valid_logit_i.shape[0] <= max_seq_len
                            else valid_logit_i[:][:max_seq_len]
                        )  # cut to max len

                    # prepare teacher labels for non-labeled examples
                    t_labels_dict = {}
                    for i in range(len(valid_t_logits.keys())):
                        t_labels_dict[i] = torch.argmax(
                            F.log_softmax(valid_t_logits[i], dim=-1), dim=-1
                        )

                # pseudo labeling
                for i, is_labeled in enumerate(inputs["is_labeled"]):
                    if not is_labeled:
                        t_labels_i = t_labels_dict[i]
                        # add the padded teacher label:
                        inputs["labels"][i] = torch.cat(
                            (
                                t_labels_i,
                                torch.zeros([max_seq_len - len(t_labels_i)], dtype=torch.long).to(
                                    self.device
                                ),
                            ),
                            0,
                        )

                # apply word dropout to the input
                if word_dropout != 0:
                    tokens = inputs["words"]
                    tokens = np.array(tokens.detach().cpu())
                    word_probs = np.random.random(tokens.shape)
                    drop_indices = np.where(
                        (word_probs > word_dropout) & (tokens != 0)
                    )  # ignore padding indices
                    inputs["words"][drop_indices[0], drop_indices[1]] = self.word_vocab.oov_id

                # loss
                if self.use_crf:
                    loss = -1.0 * self.crf(logits, inputs["labels"], mask=inputs["mask"] != 0.0)
                else:
                    loss_fn = CrossEntropyLoss(ignore_index=0)
                    loss = loss_fn(logits.view(-1, self.num_labels), inputs["labels"].view(-1))

                # for idcnn training - add dropout penalty loss
                module = self.model.module if self.n_gpus > 1 else self.model
                if isinstance(module, IDCNN) and module.drop_penalty != 0:
                    logits_no_drop = self.model(**inputs, no_dropout=True)
                    sub = logits.sub(logits_no_drop)
                    drop_loss = torch.div(torch.sum(torch.pow(sub, 2)), 2)
                    loss += module.drop_penalty * drop_loss

                if self.n_gpus > 1:
                    loss = loss.mean()

                # add distillation loss if activated
                if distiller:
                    # filter masked student logits (no padding)
                    valid_s_logits = {}
                    valid_s_positions = inputs["mask"] != 0.0
                    for i in range(len(logits)):
                        valid_s_logit_i = logits[i][valid_s_positions[i]]
                        valid_s_logits[i] = valid_s_logit_i
                    loss = distiller.distill_loss_dict(loss, valid_s_logits, valid_t_logits)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                avg_loss += loss.item()
                if global_step % logging_steps == 0:
                    if step != 0:
                        logger.info(
                            " global_step = %s, average loss = %s", global_step, avg_loss / step
                        )
                        best_dev, dev_test = self.update_best_model(
                            dev_data_set,
                            test_data_set,
                            best_dev,
                            dev_test,
                            best_result_file,
                            avg_loss / step,
                            epoch,
                            save_path=None,
                        )
                if save_steps != 0 and save_path is not None and global_step % save_steps == 0:
                    self.save_model(save_path)
        self.update_best_model(
            dev_data_set,
            test_data_set,
            best_dev,
            dev_test,
            best_result_file,
            "end_training",
            "end_training",
            save_path=save_path + "/best_dev",
        )