Esempio n. 1
0
def main(config_file, update, _run, _log, _config):
    working_dir = _run.observers[0].dir
    # load the config file
    config = yaml_load(config_file)
    recursive_update(config, update)
    yaml_dump(config, path.join(working_dir, 'config.yaml'))
    _config = config
    print(_config)
    print(working_dir)
    dataset = _config['dataset']
    dataset['device'] = update['device']
    # load the dataset and the vocab
    train_loader, dev_loader, test_loader, vocab = load_data(**dataset)
    vocab_size = len(vocab.itos)

    # model
    _config['hidden']['features'][0] = vocab_size

    # trainer batch
    test_sample = _config['trainer_batch']['test_sample']
    _config['trainer_batch']['test_sample'] = 1

    config = extend_config_reference(_config)
    trainer = config['trainer']
    trainer['evaluate_interval'] = len(
        train_loader) * trainer['evaluate_interval']
    trainer['save_checkpoint_interval'] = trainer['evaluate_interval']
    trainer['base_dir'] = working_dir
    yaml_dump(trainer, path.join(working_dir, 'trainer.yaml'))
    trainer['train_iterator'] = train_loader
    trainer['dev_iterator'] = dev_loader
    trainer['test_iterator'] = None
    callback = EvaluationCallback(working_dir,
                                  vocab,
                                  corpus_dir=path.join(dataset['data_dir'],
                                                       'corpus'),
                                  **config['callback'])
    trainer['callbacks'] = callback
    trainer['logger'] = _log

    print(config)
    trainer = Trainer.from_config(trainer)
    _log.info("model architecture")
    print(trainer.trainer_batch.model)

    # train the model
    trainer.train()

    # testing and save results
    trainer.dev_iterator = test_loader
    trainer.trainer_batch.test_sample = test_sample  # test using many samples, but not in development dataset
    trainer.restore_from_basedir(best=True)
    stat = trainer._evaluate_epoch().get_dict()
    callback.evaluate_topic_coherence()  # topic coherence of best checkpoint
    stat.update(callback.get_dict())
    yaml_dump(stat, path.join(working_dir, 'result.yaml'))
    _log.info('test result of best evaluation {}'.format(stat))
Esempio n. 2
0
    def run_training_batch(self, batch, batch_idx):
        """

        :param batch: dict; contains three keys: input_ids, attention_mask, decoder_input_ids
            Example for 'batch':
                batch: {'input_ids': tensor([[  0,  36, 230,  ...,   8,  41,   2]]),
                'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]]),
                'decoder_input_ids': tensor([[    0,   287,    10,  2107,   111, 10468,   226, 47385, 11579,  1012,
                                                2156,     5,  5302, 47385,   281, 47385, 10003,   255, 47385,   347,
                                                111,  2107, 47385,   574, 47385,  1000, 47385,   398, 47385,   245,
                                                16,    10,   205,  1374, 12576,   479,   646,  1000,  1215,  3388,
                                                510,   742,    85,   128,   579,    65,     9,     5,   357,  3092,
                                                23,    63,  1836,    11,     5,  3555,   111,   672,  2156, 26180,
                                                47385,   642,   111,  3547,  4120,   479,   646,  1000,  1215,  3388,
                                                510,   742,  7192,  8806, 10262,  3444,  7951,  2170,  1318,     2]])}
        :param batch_idx: number of batch
        :return:
        """
        # load tokenizer
        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
        # load config for GSM
        config = yaml_load(f"{self.default_root_dir}/data/config/gsm.yaml")
        # load dict
        dictionary = Dictionary.load(datapath('dict-www-cnndm-unigram'))
        # remove [SEP]
        sep_list = [
            '[SEP_0]', '[SEP_1]', '[SEP_2]', '[SEP_3]', '[SEP_4]', '[SEP_5]',
            '[SEP_6]', '[SEP_7]', '[SEP_8]', '[SEP_9]', '<S_SEP>'
        ]
        # vocab size for topic modeling
        vocab_size = len(dictionary)
        # model
        config['hidden']['features'][0] = vocab_size

        # trainer batch
        config['trainer_batch']['test_sample'] = 1
        config = extend_config_reference(config)
        gsm_trainer = config['GSMtrainer']
        gsm_trainer[
            'base_dir'] = f"{self.default_root_dir}/log/bart-large-cnn-finetune"
        gsm_trainer = GSMTrainer.from_config(gsm_trainer)

        # number of topics
        K = config['gsmtopic']['k']

        # yaml_dump(gsm_trainer,
        #           os.path.join(f"{self.default_root_dir}/log/bart-large-cnn-finetune", "gsm_trainer.yaml"))

        # -----------------------------------------
        # Topic Modeling - GSM
        # -----------------------------------------
        batch_size = batch['input_ids'].size()[0]

        docs = []
        for batch_num in range(batch_size):
            # extract the batch_sentence
            batch_sentence = tokenizer.decode(
                batch['input_ids'][batch_num].tolist(),
                skip_special_tokens=True)
            # change to lowercase and split to list
            batch_sentence_list = batch_sentence.split(" ")
            # remove [SEP]
            batch_sentence_list_nosep = [
                item for item in batch_sentence_list if item not in sep_list
            ]
            text = ' '.join([x for x in batch_sentence_list_nosep])
            fine_text = text.replace(' ##', '').lower()
            batch_sentence = re.sub(r'[^\w\s]', '', fine_text)
            # batch_sentence: change to the cleaned news for topic modeling
            # change to training data format in topic modeling
            gsm_data_bow = dictionary.doc2bow(batch_sentence.split(" "))
            docs.append(gsm_data_bow)

        # gsm_data: data for topic modeling
        gsm_data = DataLoader(DocDataset(docs, len(dictionary), device='cuda'),
                              batch_size=config['dataset']['batch_size'],
                              drop_last=False,
                              num_workers=0)

        gsm_trainer.__dict__['train_iterator'] = gsm_data

        gsm_loss, gsm_p = gsm_trainer.co_train(vocab_size, training=True)

        del gsm_data

        # track grad norms
        grad_norm_dic = {}

        # track all metrics for callbacks
        batch_callback_metrics = []

        # track metrics to log
        batch_log_metrics = []

        if batch is None:
            return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)

        # Batch start events
        with self.profiler.profile('on_batch_start'):
            # callbacks
            self.on_batch_start()
            # hooks
            if self.is_function_implemented('on_batch_start'):
                response = self.get_model().on_batch_start(batch)
                if response == -1:
                    return AttributeDict(signal=-1,
                                         grad_norm_dic=grad_norm_dic)

        splits = [batch]
        if self.truncated_bptt_steps is not None:
            model_ref = self.get_model()
            with self.profiler.profile('tbptt_split_batch'):
                splits = model_ref.tbptt_split_batch(batch,
                                                     self.truncated_bptt_steps)

        self.hiddens = None
        for split_idx, split_batch in enumerate(splits):
            self.split_idx = split_idx

            for opt_idx, optimizer in self._get_optimizers_iterable():
                # make sure only the gradients of the current optimizer's parameters are calculated
                # in the training step to prevent dangling gradients in multiple-optimizer setup.
                if len(self.optimizers) > 1:
                    for param in self.get_model().parameters():
                        param.requires_grad = False
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.requires_grad = True

                # -------------------
                # calculate loss
                # -------------------
                beta = 0.01
                opt_closure_result = self.optimizer_closure(
                    split_batch,
                    batch_idx,
                    opt_idx,
                    optimizer,
                    self.hiddens,
                    gsm_p,  # topic distribution
                    gsm_loss,  # loss for topic modeling
                    K,  # number of topics
                    beta,
                )

                # ------------------------------
                # POST forward bookkeeping
                # ------------------------------
                batch_callback_metrics.append(
                    opt_closure_result.training_step_output.callback_metrics)
                batch_log_metrics.append(
                    opt_closure_result.training_step_output.log_metrics)

                self.add_progress_bar_metrics(
                    opt_closure_result.training_step_output.pbar_on_batch_end)

                # track hiddens
                self.hiddens = opt_closure_result.hiddens

                # check if loss or model weights are nan
                if self.terminate_on_nan:
                    self.detect_nan_tensors(opt_closure_result.loss)

                # track total loss for logging (avoid mem leaks)
                self.batch_loss_value.append(opt_closure_result.loss)

                # ------------------------------
                # BACKWARD PASS
                # ------------------------------
                # gradient update with accumulated gradients
                if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
                    # backward
                    grad_norm_dic = self.run_batch_backward_pass(
                        split_batch, batch_idx, opt_idx, optimizer)

                    # calculate running loss for display
                    self.running_loss.append(self.batch_loss_value.mean())

                    # reset for next set of accumulated grads
                    self.batch_loss_value.reset()

        # Batch end events
        with self.profiler.profile('on_batch_end'):
            # callbacks
            self.on_batch_end()
            # model hooks
            if self.is_function_implemented('on_batch_end'):
                self.get_model().on_batch_end()

        # collapse all metrics into one dict
        batch_log_metrics = {
            k: v
            for d in batch_log_metrics for k, v in d.items()
        }

        # track all metrics for callbacks
        self.callback_metrics.update(
            {k: v
             for d in batch_callback_metrics for k, v in d.items()})

        result = AttributeDict(
            signal=0,
            grad_norm_dic=grad_norm_dic,
            batch_log_metrics=batch_log_metrics,
            training_step_output_for_epoch_end=opt_closure_result.
            training_step_output_for_epoch_end)
        return result
Esempio n. 3
0
    def _evaluate(self,
                  model: LightningModule,
                  dataloaders: List[DataLoader],
                  max_batches: Union[int, List[int]],
                  test_mode: bool = False):
        """Run evaluation code.
        Args:
            model: The model to evaluate.
            dataloaders: A list of PyTorch dataloaders.
            max_batches: An integer or list of integers with length of the number of dataloaders. Each
                entry is the number of batches to process in the corresponding dataloader.
            test_mode:
        """
        # enable eval mode
        model.zero_grad()
        model.eval()

        # copy properties for forward overrides
        self.copy_trainer_model_properties(model)

        # bookkeeping
        outputs = []

        # convert max_batches to list
        if isinstance(max_batches, int):
            max_batches = [max_batches] * len(dataloaders)

        # run validation
        for dataloader_idx, dataloader in enumerate(dataloaders):
            dl_outputs = []

            # on TPU we have to wrap it under the ParallelLoader
            if self.use_tpu:
                device = xm.xla_device(self.tpu_id)
                dataloader = xla_pl.ParallelLoader(dataloader, [device])
                dataloader = dataloader.per_device_loader(device)

            # each dataloader has a max num batches
            dl_max_batches = max_batches[dataloader_idx]

            # load tokenizer
            tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')

            # load dict
            dictionary = Dictionary.load(datapath('dict-www-cnndm-unigram'))

            # vocab size for topic modeling
            vocab_size = len(dictionary)

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when on fast_dev_run (sets max_batch=1)
                if batch_idx >= dl_max_batches:
                    break

                # -----------------
                # Topic Modeling
                # -----------------
                # load config for GSM
                config = yaml_load(
                    f"{self.default_root_dir}/data/config/gsm.yaml")

                # remove [SEP]
                sep_list = [
                    '[SEP_0]', '[SEP_1]', '[SEP_2]', '[SEP_3]', '[SEP_4]',
                    '[SEP_5]', '[SEP_6]', '[SEP_7]', '[SEP_8]', '[SEP_9]'
                ]

                # model
                config['hidden']['features'][0] = vocab_size

                # trainer batch
                config['trainer_batch']['test_sample'] = 1
                config = extend_config_reference(config)
                gsm_trainer = config['GSMtrainer']
                gsm_trainer[
                    'base_dir'] = f"{self.default_root_dir}/log/bart-large-cnn-finetune"
                gsm_trainer = GSMTrainer.from_config(gsm_trainer)

                # -----------------------------------------
                # Topic Modeling - GSM
                # -----------------------------------------
                batch_size = batch['input_ids'].size()[0]

                docs = []
                for batch_num in range(batch_size):
                    # extract the batch_sentence
                    batch_sentence = tokenizer.decode(
                        batch['input_ids'][batch_num].tolist(),
                        skip_special_tokens=True)
                    # change to lowercase and split to list
                    batch_sentence_list = batch_sentence.split(" ")
                    # remove [SEP]
                    batch_sentence_list_nosep = [
                        item for item in batch_sentence_list
                        if item not in sep_list
                    ]
                    text = ' '.join([x for x in batch_sentence_list_nosep])
                    fine_text = text.replace(' ##', '').lower()
                    batch_sentence = re.sub(r'[^\w\s]', '', fine_text)
                    # batch_sentence: change to the cleaned news for topic modeling
                    # change to training data format in topic modeling
                    gsm_data_bow = dictionary.doc2bow(
                        batch_sentence.split(" "))
                    docs.append(gsm_data_bow)
                # gsm_data: data for topic modeling
                gsm_data = DataLoader(
                    DocDataset(docs, len(dictionary), device='cuda'),
                    batch_size=config['dataset']['batch_size'],
                    drop_last=False,
                    num_workers=0)

                gsm_trainer.__dict__['train_iterator'] = gsm_data

                gsm_loss, gsm_p = gsm_trainer.co_train(vocab_size,
                                                       training=False)

                del gsm_data

                topic_p = Variable(gsm_p.data, requires_grad=False)
                # tm_loss = Variable(gsm_loss.data, requires_grad=False)

                # disable gradients to save memory
                torch.set_grad_enabled(False)

                # callbacks
                if test_mode:
                    self.on_test_batch_start()
                else:
                    self.on_validation_batch_start()

                # -----------------
                # RUN EVALUATION STEP
                # -----------------
                beta = 0.01
                if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
                    with torch.cuda.amp.autocast():
                        output = self.evaluation_forward(
                            model, batch, batch_idx, dataloader_idx, topic_p,
                            gsm_loss, beta, test_mode)
                else:
                    output = self.evaluation_forward(model, batch, batch_idx,
                                                     dataloader_idx, topic_p,
                                                     gsm_loss, beta, test_mode)

                # on dp / ddp2 might still want to do something with the batch parts
                if test_mode:
                    if self.is_overridden('test_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('test_step_end'):
                            output = model_ref.test_step_end(output)
                    self.on_test_batch_end()
                else:
                    if self.is_overridden('validation_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('validation_step_end'):
                            output = model_ref.validation_step_end(output)
                    self.on_validation_batch_end()

                # track outputs for collation
                dl_outputs.append(output)

                # enable gradients to save memory
                torch.set_grad_enabled(True)

            outputs.append(dl_outputs)

        eval_results = {}

        # with a single dataloader don't pass an array
        if len(dataloaders) == 1:
            outputs = outputs[0]

        # give model a chance to do something with the outputs (and method defined)
        if isinstance(
                model,
            (LightningDistributedDataParallel, LightningDataParallel)):
            model = model.module

        if test_mode:
            if self.is_overridden('test_end', model=model):
                eval_results = model.test_end(outputs)
                rank_zero_warn(
                    'Method `test_end` was deprecated in v0.7 and will be removed in v1.0.'
                    ' Use `test_epoch_end` instead.', DeprecationWarning)

            elif self.is_overridden('test_epoch_end', model=model):
                eval_results = model.test_epoch_end(outputs)

        else:
            if self.is_overridden('validation_end', model=model):
                eval_results = model.validation_end(outputs)
                rank_zero_warn(
                    'Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.'
                    ' Use `validation_epoch_end` instead.', DeprecationWarning)

            elif self.is_overridden('validation_epoch_end', model=model):
                eval_results = model.validation_epoch_end(outputs)

        # enable train mode again
        model.train()

        return eval_results
Esempio n. 4
0
def generate_summaries(
    examples: list,
    out_file: str,
    model_name: str,
    batch_size: int = 8,
    device: str = DEFAULT_DEVICE,
    fp16=True,
    task="summarization",
    decoder_start_token_id=None,
    finetune_flag: int = 0,
    checkpoint_path: str = "",
    **gen_kwargs,
) -> None:
    fout = Path(out_file).open("w", encoding="utf-8")

    # initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # if our goal is to evaluate the original checkpoint
    if finetune_flag < 1:
        # initialize the model checkpoints
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
    # if our goal is to evaluate our fine-tuned checkpoint
    else:
        # load the finetuned checkpoints
        model = AutoModelForSeq2SeqLM.from_pretrained(
            f"{checkpoint_path}/best_tfmr").to(device)

    if fp16:
        model = model.half()
    if decoder_start_token_id is None:
        decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None)

    # update config with summarization specific params
    use_task_specific_params(model, task)

    for batch in tqdm(list(chunks(examples, batch_size))):
        batch = tokenizer(batch,
                          return_tensors="pt",
                          truncation=True,
                          padding="max_length").to(device)
        input_ids, attention_mask = trim_batch(
            **batch, pad_token_id=tokenizer.pad_token_id)

        # -----------------------------------------
        # Topic Modeling - GSM
        # -----------------------------------------
        docs = []
        # load dict
        dictionary = Dictionary.load(datapath('dict-www-cnndm-unigram'))
        # remove [SEP]
        sep_list = [
            '[SEP_0]', '[SEP_1]', '[SEP_2]', '[SEP_3]', '[SEP_4]', '[SEP_5]',
            '[SEP_6]', '[SEP_7]', '[SEP_8]', '[SEP_9]'
        ]
        # vocab size for topic modeling
        vocab_size = len(dictionary)
        # load config for GSM
        config = yaml_load(f"data/config/gsm.yaml")
        # model
        config['hidden']['features'][0] = vocab_size

        # trainer batch
        config['trainer_batch']['test_sample'] = 1
        config = extend_config_reference(config)
        gsm_trainer = config['GSMtrainer']
        gsm_trainer['base_dir'] = f"log/bart-large-cnn-finetune"
        gsm_trainer = GSMTrainer.from_config(gsm_trainer)

        total_sample = len(batch['input_ids'])

        for batch_num in range(total_sample):
            # extract the batch_sentence
            batch_sentence = tokenizer.decode(
                batch['input_ids'][batch_num].tolist(),
                skip_special_tokens=True)
            # change to lowercase and split to list
            batch_sentence_list = batch_sentence.split(" ")
            # remove [SEP]
            batch_sentence_list_nosep = [
                item for item in batch_sentence_list if item not in sep_list
            ]
            text = ' '.join([x for x in batch_sentence_list_nosep])
            fine_text = text.replace(' ##', '').lower()
            batch_sentence = re.sub(r'[^\w\s]', '', fine_text)
            # batch_sentence: change to the cleaned news for topic modeling
            # change to training data format in topic modeling
            gsm_data_bow = dictionary.doc2bow(batch_sentence.split(" "))
            docs.append(gsm_data_bow)

        # gsm_data: data for topic modeling
        gsm_data = DataLoader(DocDataset(docs, len(dictionary), device='cuda'),
                              batch_size=config['dataset']['batch_size'],
                              drop_last=False,
                              num_workers=0)

        gsm_trainer.__dict__['train_iterator'] = gsm_data

        gsm_loss, gsm_p = gsm_trainer.co_train(vocab_size=vocab_size,
                                               training=False)

        del gsm_data

        topic_p = gsm_p.cuda()

        summaries = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_start_token_id=decoder_start_token_id,
            topic_p=topic_p,
            **gen_kwargs,
        )
        dec = tokenizer.batch_decode(summaries,
                                     skip_special_tokens=True,
                                     clean_up_tokenization_spaces=False)
        for hypothesis in dec:
            fout.write(hypothesis + "\n")
            fout.flush()