Esempio n. 1
0
    def learn(self, model: EmmentalModel,
              dataloaders: List[EmmentalDataLoader]) -> None:
        """Learning procedure of emmental MTL.

        Args:
          model: The emmental model that needs to learn.
          dataloaders: A list of dataloaders used to learn the model.
        """
        # Generate the list of dataloaders for learning process
        start_time = time.time()

        train_split = Meta.config["learner_config"]["train_split"]
        if isinstance(train_split, str):
            train_split = [train_split]

        train_dataloaders = [
            dataloader for dataloader in dataloaders
            if dataloader.split in train_split
        ]

        if not train_dataloaders:
            raise ValueError(
                f"Cannot find the specified train_split "
                f'{Meta.config["learner_config"]["train_split"]} in dataloaders.'
            )

        # Set up task_scheduler
        self._set_task_scheduler()

        # Calculate the total number of batches per epoch
        self.n_batches_per_epoch = self.task_scheduler.get_num_batches(
            train_dataloaders)

        # Set up logging manager
        self._set_logging_manager()
        # Set up optimizer
        self._set_optimizer(model)
        # Set up lr_scheduler
        self._set_lr_scheduler(model)

        if Meta.config["learner_config"]["fp16"]:
            try:
                from apex import amp  # type: ignore
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to "
                    "use fp16 training.")
            logger.info(
                f"Modeling training with 16-bit (mixed) precision "
                f"and {Meta.config['learner_config']['fp16_opt_level']} opt level."
            )
            model, self.optimizer = amp.initialize(
                model,
                self.optimizer,
                opt_level=Meta.config["learner_config"]["fp16_opt_level"],
            )

        # Multi-gpu training (after apex fp16 initialization)
        if (Meta.config["learner_config"]["local_rank"] == -1
                and Meta.config["model_config"]["dataparallel"]):
            model._to_dataparallel()

        # Distributed training (after apex fp16 initialization)
        if Meta.config["learner_config"]["local_rank"] != -1:
            model._to_distributed_dataparallel()

        # Set to training mode
        model.train()

        if Meta.config["meta_config"]["verbose"]:
            logger.info("Start learning...")

        self.metrics: Dict[str, float] = dict()
        self._reset_losses()

        # Set gradients of all model parameters to zero
        self.optimizer.zero_grad()

        for epoch_num in range(Meta.config["learner_config"]["n_epochs"]):
            batches = tqdm(
                enumerate(
                    self.task_scheduler.get_batches(train_dataloaders, model)),
                total=self.n_batches_per_epoch,
                disable=(not Meta.config["meta_config"]["verbose"]
                         or Meta.config["learner_config"]["local_rank"]
                         not in [-1, 0]),
                desc=f"Epoch {epoch_num}:",
            )

            for batch_num, batch in batches:
                # Covert single batch into a batch list
                if not isinstance(batch, list):
                    batch = [batch]

                total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num
                batch_size = 0

                for uids, X_dict, Y_dict, task_to_label_dict, data_name, split in batch:
                    batch_size += len(next(iter(Y_dict.values())))

                    # Perform forward pass and calcualte the loss and count
                    uid_dict, loss_dict, prob_dict, gold_dict = model(
                        uids, X_dict, Y_dict, task_to_label_dict)

                    # Update running loss and count
                    for task_name in uid_dict.keys():
                        identifier = f"{task_name}/{data_name}/{split}"
                        self.running_uids[identifier].extend(
                            uid_dict[task_name])
                        self.running_losses[identifier] += (
                            loss_dict[task_name].item() *
                            len(uid_dict[task_name])
                            if len(loss_dict[task_name].size()) == 0 else
                            torch.sum(loss_dict[task_name]).item())
                        self.running_probs[identifier].extend(
                            prob_dict[task_name])
                        self.running_golds[identifier].extend(
                            gold_dict[task_name])

                    # Skip the backward pass if no loss is calcuated
                    if not loss_dict:
                        continue

                    # Calculate the average loss
                    loss = sum([
                        model.weights[task_name] *
                        task_loss if len(task_loss.size()) == 0 else
                        torch.mean(model.weights[task_name] * task_loss)
                        for task_name, task_loss in loss_dict.items()
                    ])

                    # Perform backward pass to calculate gradients
                    if Meta.config["learner_config"]["fp16"]:
                        with amp.scale_loss(loss,
                                            self.optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()  # type: ignore

                if (total_batch_num +
                        1) % Meta.config["learner_config"]["optimizer_config"][
                            "gradient_accumulation_steps"] == 0 or (
                                batch_num + 1 == self.n_batches_per_epoch
                                and epoch_num + 1
                                == Meta.config["learner_config"]["n_epochs"]):
                    # Clip gradient norm
                    if Meta.config["learner_config"]["optimizer_config"][
                            "grad_clip"]:
                        if Meta.config["learner_config"]["fp16"]:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(self.optimizer),
                                Meta.config["learner_config"]
                                ["optimizer_config"]["grad_clip"],
                            )
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(),
                                Meta.config["learner_config"]
                                ["optimizer_config"]["grad_clip"],
                            )

                    # Update the parameters
                    self.optimizer.step()

                    # Set gradients of all model parameters to zero
                    self.optimizer.zero_grad()

                if Meta.config["learner_config"]["local_rank"] in [-1, 0]:
                    self.metrics.update(
                        self._logging(model, dataloaders, batch_size))

                    batches.set_postfix(self.metrics)

                # Update lr using lr scheduler
                self._update_lr_scheduler(model, total_batch_num, self.metrics)

        if Meta.config["learner_config"]["local_rank"] in [-1, 0]:
            model = self.logging_manager.close(model)
        logger.info(
            f"Total learning time: {time.time() - start_time} seconds.")
Esempio n. 2
0
    def learn(self, model: EmmentalModel,
              dataloaders: List[EmmentalDataLoader]) -> None:
        """Learning procedure of emmental MTL.

        Args:
          model: The emmental model that needs to learn.
          dataloaders: A list of dataloaders used to learn the model.
        """
        start_time = time.time()

        # Generate the list of dataloaders for learning process
        train_split = Meta.config["learner_config"]["train_split"]
        if isinstance(train_split, str):
            train_split = [train_split]

        train_dataloaders = [
            dataloader for dataloader in dataloaders
            if dataloader.split in train_split
        ]

        if not train_dataloaders:
            raise ValueError(
                f"Cannot find the specified train_split "
                f'{Meta.config["learner_config"]["train_split"]} in dataloaders.'
            )

        # Set up task_scheduler
        self._set_task_scheduler()

        # Calculate the total number of batches per epoch
        self.n_batches_per_epoch: int = self.task_scheduler.get_num_batches(
            train_dataloaders)
        if self.n_batches_per_epoch == 0:
            logger.info("No batches in training dataloaders, existing...")
            return

        # Set up learning counter
        self._set_learning_counter()
        # Set up logging manager
        self._set_logging_manager()
        # Set up wandb watch model
        if (Meta.config["logging_config"]["writer_config"]["writer"] == "wandb"
                and Meta.config["logging_config"]["writer_config"]
            ["wandb_watch_model"]):
            if Meta.config["logging_config"]["writer_config"][
                    "wandb_model_watch_freq"]:
                wandb.watch(
                    model,
                    log_freq=Meta.config["logging_config"]["writer_config"]
                    ["wandb_model_watch_freq"],
                )
            else:
                wandb.watch(model)
        # Set up optimizer
        self._set_optimizer(model)
        # Set up lr_scheduler
        self._set_lr_scheduler(model)

        if Meta.config["learner_config"]["fp16"]:
            try:
                from apex import amp  # type: ignore
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to "
                    "use fp16 training.")
            logger.info(
                f"Modeling training with 16-bit (mixed) precision "
                f"and {Meta.config['learner_config']['fp16_opt_level']} opt level."
            )
            model, self.optimizer = amp.initialize(
                model,
                self.optimizer,
                opt_level=Meta.config["learner_config"]["fp16_opt_level"],
            )

        # Multi-gpu training (after apex fp16 initialization)
        if (Meta.config["learner_config"]["local_rank"] == -1
                and Meta.config["model_config"]["dataparallel"]):
            model._to_dataparallel()

        # Distributed training (after apex fp16 initialization)
        if Meta.config["learner_config"]["local_rank"] != -1:
            model._to_distributed_dataparallel()

        # Set to training mode
        model.train()

        if Meta.config["meta_config"]["verbose"]:
            logger.info("Start learning...")

        self.metrics: Dict[str, float] = dict()
        self._reset_losses()

        # Set gradients of all model parameters to zero
        self.optimizer.zero_grad()

        batch_iterator = self.task_scheduler.get_batches(
            train_dataloaders, model)
        for epoch_num in range(self.start_epoch, self.end_epoch):
            for train_dataloader in train_dataloaders:
                # Set epoch for distributed sampler
                if isinstance(train_dataloader, DataLoader) and isinstance(
                        train_dataloader.sampler, DistributedSampler):
                    train_dataloader.sampler.set_epoch(epoch_num)
            step_pbar = tqdm(
                range(self.start_step, self.end_step),
                desc=f"Step {self.start_step + 1}/{self.end_step}"
                if self.use_step_base_counter else
                f"Epoch {epoch_num + 1}/{self.end_epoch}",
                disable=not Meta.config["meta_config"]["verbose"]
                or Meta.config["learner_config"]["local_rank"] not in [-1, 0],
            )
            for step_num in step_pbar:
                if self.use_step_base_counter:
                    step_pbar.set_description(
                        f"Step {step_num + 1}/{self.total_steps}")
                    step_pbar.refresh()
                try:
                    batch = next(batch_iterator)
                except StopIteration:
                    batch_iterator = self.task_scheduler.get_batches(
                        train_dataloaders, model)
                    batch = next(batch_iterator)

                # Check if skip the current batch
                if epoch_num < self.start_train_epoch or (
                        epoch_num == self.start_train_epoch
                        and step_num < self.start_train_step):
                    continue

                # Covert single batch into a batch list
                if not isinstance(batch, list):
                    batch = [batch]

                total_step_num = epoch_num * self.n_batches_per_epoch + step_num
                batch_size = 0

                for _batch in batch:
                    batch_size += len(_batch.uids)
                    # Perform forward pass and calcualte the loss and count
                    uid_dict, loss_dict, prob_dict, gold_dict = model(
                        _batch.uids,
                        _batch.X_dict,
                        _batch.Y_dict,
                        _batch.task_to_label_dict,
                        return_probs=Meta.config["learner_config"]
                        ["online_eval"],
                        return_action_outputs=False,
                    )

                    # Update running loss and count
                    for task_name in uid_dict.keys():
                        identifier = f"{task_name}/{_batch.data_name}/{_batch.split}"
                        self.running_uids[identifier].extend(
                            uid_dict[task_name])
                        self.running_losses[identifier] += (
                            loss_dict[task_name].item() *
                            len(uid_dict[task_name])
                            if len(loss_dict[task_name].size()) == 0 else
                            torch.sum(loss_dict[task_name]).item()
                        ) * model.task_weights[task_name]
                        if (Meta.config["learner_config"]["online_eval"]
                                and prob_dict and gold_dict):
                            self.running_probs[identifier].extend(
                                prob_dict[task_name])
                            self.running_golds[identifier].extend(
                                gold_dict[task_name])

                    # Calculate the average loss
                    loss = sum([
                        model.task_weights[task_name] *
                        task_loss if len(task_loss.size()) == 0 else
                        torch.mean(model.task_weights[task_name] * task_loss)
                        for task_name, task_loss in loss_dict.items()
                    ])

                    # Perform backward pass to calculate gradients
                    if Meta.config["learner_config"]["fp16"]:
                        with amp.scale_loss(loss,
                                            self.optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()  # type: ignore

                if (total_step_num +
                        1) % Meta.config["learner_config"]["optimizer_config"][
                            "gradient_accumulation_steps"] == 0 or (
                                step_num + 1 == self.end_step
                                and epoch_num + 1 == self.end_epoch):
                    # Clip gradient norm
                    if Meta.config["learner_config"]["optimizer_config"][
                            "grad_clip"]:
                        if Meta.config["learner_config"]["fp16"]:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(self.optimizer),
                                Meta.config["learner_config"]
                                ["optimizer_config"]["grad_clip"],
                            )
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(),
                                Meta.config["learner_config"]
                                ["optimizer_config"]["grad_clip"],
                            )

                    # Update the parameters
                    self.optimizer.step()

                    # Set gradients of all model parameters to zero
                    self.optimizer.zero_grad()

                if Meta.config["learner_config"]["local_rank"] in [-1, 0]:
                    self.metrics.update(
                        self._logging(model, dataloaders, batch_size))

                    step_pbar.set_postfix(self.metrics)

                # Update lr using lr scheduler
                self._update_lr_scheduler(model, total_step_num, self.metrics)
            step_pbar.close()

        if Meta.config["learner_config"]["local_rank"] in [-1, 0]:
            model = self.logging_manager.close(model)
        logger.info(
            f"Total learning time: {time.time() - start_time} seconds.")
Esempio n. 3
0
def run_model(mode, config, run_config_path=None):
    """
    Main run method for Emmental Bootleg models.
    Args:
        mode: run mode (train, eval, dump_preds, dump_embs)
        config: parsed model config
        run_config_path: original config path (for saving)

    Returns:

    """

    # Set up distributed backend and save configuration files
    setup(config, run_config_path)

    # Load entity symbols
    log_rank_0_info(logger, f"Loading entity symbols...")
    entity_symbols = EntitySymbols.load_from_cache(
        load_dir=os.path.join(config.data_config.entity_dir,
                              config.data_config.entity_map_dir),
        alias_cand_map_file=config.data_config.alias_cand_map,
        alias_idx_file=config.data_config.alias_idx_map,
    )
    # Create tasks
    tasks = [NED_TASK]
    if config.data_config.type_prediction.use_type_pred is True:
        tasks.append(TYPE_PRED_TASK)

    # Create splits for data loaders
    data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT]
    # Slices are for eval so we only split on test/dev
    slice_splits = [DEV_SPLIT, TEST_SPLIT]
    # If doing eval, only run on test data
    if mode in ["eval", "dump_preds", "dump_embs"]:
        data_splits = [TEST_SPLIT]
        slice_splits = [TEST_SPLIT]
        # We only do dumping if weak labels is True
        if mode in ["dump_preds", "dump_embs"]:
            if config.data_config[
                    f"{TEST_SPLIT}_dataset"].use_weak_label is False:
                raise ValueError(
                    f"When calling dump_preds or dump_embs, we require use_weak_label to be True."
                )

    # Gets embeddings that need to be prepped during data prep or in the __get_item__ method
    batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols)
    # Gets dataloaders
    dataloaders = get_dataloaders(
        config,
        tasks,
        data_splits,
        entity_symbols,
        batch_on_the_fly_kg_adj,
    )
    slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols)

    configure_optimizer(config)

    # Create models and add tasks
    if config.model_config.attn_class == "BERTNED":
        log_rank_0_info(logger, f"Starting NED-Base Model")
        assert (config.data_config.type_prediction.use_type_pred is
                False), f"NED-Base does not support type prediction"
        assert (
            config.data_config.word_embedding.use_sent_proj is False
        ), f"NED-Base requires word_embeddings.use_sent_proj to be False"
        model = EmmentalModel(name="NED-Base")
        model.add_tasks(
            ned_task.create_task(config, entity_symbols, slice_datasets))
    else:
        log_rank_0_info(logger, f"Starting Bootleg Model")
        model = EmmentalModel(name="Bootleg")
        # TODO: make this more general for other tasks -- iterate through list of tasks
        # and add task for each
        model.add_task(
            ned_task.create_task(config, entity_symbols, slice_datasets))
        if TYPE_PRED_TASK in tasks:
            model.add_task(
                type_pred_task.create_task(config, entity_symbols,
                                           slice_datasets))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(model)

    # Print param counts
    if mode == "train":
        log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=True,
                                        logger=logger)
        log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}")
        log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=False,
                                        logger=logger)
        log_rank_0_info(logger,
                        f"===> Total Params Without Grad: {total_params}")

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    # Barrier
    if config["learner_config"]["local_rank"] == 0:
        torch.distributed.barrier()

    # Train model
    if mode == "train":
        emmental_learner = EmmentalLearner()
        emmental_learner._set_optimizer(model)
        emmental_learner.learn(model, dataloaders)
        if config.learner_config.local_rank in [0, -1]:
            model.save(f"{emmental.Meta.log_path}/last_model.pth")

    # Multi-gpu DataParallel eval (NOT distributed)
    if mode in ["eval", "dump_embs", "dump_preds"]:
        # This happens inside EmmentalLearner for training
        if (config["learner_config"]["local_rank"] == -1
                and config["model_config"]["dataparallel"]):
            model._to_dataparallel()

    # If just finished training a model or in eval mode, run eval
    if mode in ["train", "eval"]:
        scores = model.score(dataloaders)
        # Save metrics and models
        log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}")
        log_rank_0_info(logger, f"Metrics: {scores}")
        scores["log_path"] = emmental.Meta.log_path
        if config.learner_config.local_rank in [0, -1]:
            write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt",
                          scores)
            eval_utils.write_disambig_metrics_to_csv(
                f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv",
                scores)
        return scores

    # If you want detailed dumps, save model outputs
    assert mode in [
        "dump_preds",
        "dump_embs",
    ], 'Mode must be "dump_preds" or "dump_embs"'
    dump_embs = False if mode != "dump_embs" else True
    assert (
        len(dataloaders) == 1
    ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!"
    final_result_file, final_out_emb_file = None, None
    if config.learner_config.local_rank in [0, -1]:
        # Setup files/folders
        filename = os.path.basename(dataloaders[0].dataset.raw_filename)
        log_rank_0_debug(
            logger,
            f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}",
        )
        sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens(
            os.path.join(config.data_config.data_dir, filename))
        log_rank_0_debug(logger, f"Done collecting sentence to mention map")
        eval_folder = eval_utils.get_eval_folder(filename)
        subeval_folder = os.path.join(eval_folder, "batch_results")
        utils.ensure_dir(subeval_folder)
        # Will keep track of sentences dumped already. These will only be ones with mentions
        all_dumped_sentences = set()
        number_dumped_batches = 0
        total_mentions_seen = 0
        all_result_files = []
        all_out_emb_files = []
        # Iterating over batches of predictions
        for res_i, res_dict in enumerate(
                eval_utils.batched_pred_iter(
                    model,
                    dataloaders[0],
                    config.run_config.eval_accumulation_steps,
                    sentidx2num_mentions,
                )):
            (
                result_file,
                out_emb_file,
                final_sent_idxs,
                mentions_seen,
            ) = eval_utils.disambig_dump_preds(
                res_i,
                total_mentions_seen,
                config,
                res_dict,
                sentidx2num_mentions,
                sent_idx2row,
                subeval_folder,
                entity_symbols,
                dump_embs,
                NED_TASK,
            )
            all_dumped_sentences.update(final_sent_idxs)
            all_result_files.append(result_file)
            all_out_emb_files.append(out_emb_file)
            total_mentions_seen += mentions_seen
            number_dumped_batches += 1

        # Dump the sentences that had no mentions and were not already dumped
        # Assert all remaining sentences have no mentions
        assert all(
            v == 0 for k, v in sentidx2num_mentions.items()
            if k not in all_dumped_sentences
        ), (f"Sentences with mentions were not dumped: "
            f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}"
            )
        empty_sentidx2row = {
            k: v
            for k, v in sent_idx2row.items() if k not in all_dumped_sentences
        }
        empty_resultfile = eval_utils.get_result_file(number_dumped_batches,
                                                      subeval_folder)
        all_result_files.append(empty_resultfile)
        # Dump the outputs
        eval_utils.write_data_labels_single(
            sentidx2row=empty_sentidx2row,
            output_file=empty_resultfile,
            filt_emb_data=None,
            sental2embid={},
            alias_cand_map=entity_symbols.get_alias2qids(),
            qid2eid=entity_symbols.get_qid2eid(),
            result_alias_offset=total_mentions_seen,
            train_in_cands=config.data_config.train_in_candidates,
            max_cands=entity_symbols.max_candidates,
            dump_embs=dump_embs,
        )

        log_rank_0_info(
            logger,
            f"Finished dumping. Merging results across accumulation steps.")
        # Final result files for labels and embeddings
        final_result_file = os.path.join(eval_folder,
                                         config.run_config.result_label_file)
        # Copy labels
        output = open(final_result_file, "wb")
        for file in all_result_files:
            shutil.copyfileobj(open(file, "rb"), output)
        output.close()
        log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}")
        # Try to copy embeddings
        if dump_embs:
            final_out_emb_file = os.path.join(
                eval_folder, config.run_config.result_emb_file)
            log_rank_0_info(
                logger,
                f"Trying to merge numpy embedding arrays. "
                f"If your machine is limited in memory, this may cause OOM errors. "
                f"Is that happens, result files should be saved in {subeval_folder}.",
            )
            all_arrays = []
            for i, npfile in enumerate(all_out_emb_files):
                all_arrays.append(np.load(npfile))
            np.save(final_out_emb_file, np.concatenate(all_arrays))
            log_rank_0_info(
                logger, f"Bootleg embeddings saved at {final_out_emb_file}")

        # Cleanup
        try_rmtree(subeval_folder)
    return final_result_file, final_out_emb_file