Exemple #1
0
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))

    # summarization model
    model: SummarizationModule = SummarizationModule(args)

    dataset = Path(args.data_dir).name
    if (args.logger_name == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name,
                             project=f"hf_{dataset}")

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric,
                                                  args.early_stopping_patience)
    else:
        es_callback = False

    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(args.output_dir,
                                                    model.val_metric,
                                                    args.save_top_k),
        early_stopping_callback=es_callback,
        logger=logger,
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(
        sorted(
            glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                      recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
    return model
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    check_output_dir(args, expected_items=3)
    if model is None:
        if "summarization" in args.task:
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
    dataset = Path(args.data_dir).name
    if (
        args.logger_name == "default"
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        from pytorch_lightning.loggers import CSVLogger
        logger = CSVLogger('chen_logs',name = 'SCHWEIGEN')  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False

    lower_is_better = args.val_metric == "loss"
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric, args.save_top_k, lower_is_better
        ),
        early_stopping_callback=es_callback,
        logger=logger,
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
    return model
Exemple #3
0
def main(args):

    # If output_dir not provided, a folder will be generated in pwd
    if not args.output_dir:
        args.output_dir = os.path.join(
            "./results",
            f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",
        )
        os.makedirs(args.output_dir)
    model = SummarizationTrainer(args)
    if args.checkpoint_model:
        model = model.load_from_checkpoint(args.checkpoint_model)
        logger.info("args.data_dir: %s", args.data_dir)
        model.dataset_kwargs: dict = dict(
            data_dir=args.data_dir,
            max_source_length=args.max_source_length,
            max_target_length=args.max_target_length,
        )
        model.hparams = args
    #trainer = generic_train(model, args)

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric,
                                                  args.early_stopping_patience)
    else:
        es_callback = False

    trainer = generic_train(model,
                            args,
                            checkpoint_callback=get_checkpoint_callback(
                                args.output_dir, model.val_metric),
                            early_stopping_callback=es_callback)

    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
        if args.checkpoint_model:
            trainer.test(model)
        else:
            checkpoints = list(
                sorted(
                    glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                              recursive=True)))
            if checkpoints:
                print('Loading weights from {}'.format(checkpoints[-1]))
                model = model.load_from_checkpoint(checkpoints[-1])
                model.dataset_kwargs: dict = dict(
                    data_dir=args.data_dir,
                    max_source_length=args.max_source_length,
                    max_target_length=args.max_target_length,
                )
                model.hparams = args
            trainer.test(model)
Exemple #4
0
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if model is None:
        if args.task == "summarization":
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)

    dataset = Path(args.data_dir).name
    if (args.logger == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name, project=dataset)

    elif args.logger == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name,
                             project=f"hf_{dataset}")
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(args.output_dir,
                                                    model.val_metric),
        logger=logger,
        # TODO: early stopping callback seems messed up
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(
        sorted(
            glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                      recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
    trainer.test(
        model
    )  # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
    return model
Exemple #5
0
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
        if args.task == "summarization":
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)

    # add atomic relation tokens
    if args.atomic:
        print("Special tokens are added.")

        additional_tokens_list = [
            "AtLocation",
            "CapableOf",
            "Causes",
            "CausesDesire",
            "CreatedBy",
            "DefinedAs",
            "DesireOf",
            "Desires",
            "HasA",
            "HasFirstSubevent",
            "HasLastSubevent",
            "HasPainCharacter",
            "HasPainIntensity",
            "HasPrerequisite",
            "HasProperty",
            "HasSubEvent",
            "HasSubevent",
            "HinderedBy",
            "InheritsFrom",
            "InstanceOf",
            "IsA",
            "LocatedNear",
            "LocationOfAction",
            "MadeOf",
            "MadeUpOf",
            "MotivatedByGoal",
            "NotCapableOf",
            "NotDesires",
            "NotHasA",
            "NotHasProperty",
            "NotIsA",
            "NotMadeOf",
            "ObjectUse",
            "PartOf",
            "ReceivesAction",
            "RelatedTo",
            "SymbolOf",
            "UsedFor",
            "isAfter",
            "isBefore",
            "isFilledBy",
            "oEffect",
            "oReact",
            "oWant",
            "xAttr",
            "xEffect",
            "xIntent",
            "xNeed",
            "xReact",
            "xReason",
            "xWant",
        ]

        num_added_toks = model.tokenizer.add_tokens(additional_tokens_list)
        model.model.resize_token_embeddings(len(model.tokenizer))

    dataset = Path(args.data_dir).name
    if (
        args.logger_name == "default"
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name, project=dataset)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(
            name=model.output_dir.name, project=f"hf_{dataset}")

    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric),
        logger=logger,
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(
        sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)

    trainer.test(model)
    return model
Exemple #6
0
def main(args, model=None) -> SummarizationModule:

    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError(f"Output directory ({args.output_dir}) already exists and is not empty.")

    if model is None:
        model = SummarizationModule(args)

    #add unlikelihood parameters - with logr weights
    set_ul_params(model, args)

    dataset = Path(args.data_dir).name
    if (
        args.logger_name == "default"
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False

    trainer = None

    if args.do_train:

        lower_is_better = args.val_metric == "loss"
        save_top_k = args.max_epochs

        trainer: pl.Trainer = generic_train(
            model,
            args,
            logging_callback=Seq2SeqLoggingCallback(),
            checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric, save_top_k, lower_is_better),
            early_stopping_callback=es_callback,
            logger=logger,
        )
        pickle_save(model.hparams, model.output_dir / "hparams.pkl")

        #now write loss logs into the same directory
        with open(os.path.join(args.output_dir, 'loss_logs.json'), 'w') as f:
            f.write(json.dumps(model.losses, indent=2))

    if args.do_generate:

        if args.generate_epoch > -1:
            model = BartForConditionalGeneration.from_pretrained(join(args.output_dir, f'best_tfmr-{args.generate_epoch}'))
        else:
            print("********* using fresh model *********")
            args.generate_epoch = 'no-train'
            model = BartForConditionalGeneration.from_pretrained(args.model_name_or_path)
 
        tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path)
        abstracts = list(open(join(args.data_dir, f'{args.generate_input_prefix}.source')).readlines())
        pls = list(open(join(args.data_dir, f'{args.generate_input_prefix}.target')).readlines())
        dois = list(open(join(args.data_dir, f'{args.generate_input_prefix}.doi')).readlines())

        if args.generate_start_index == 'none' and args.generate_end_index != 'none':
            abstracts = abstracts[:int(args.generate_end_index)]
            pls = pls[:int(args.generate_end_index)]
            dois = dois[:int(args.generate_end_index)]
        elif args.generate_start_index != 'none' and args.generate_end_index == 'none':
            abstracts = abstracts[int(args.generate_start_index):]
            pls = pls[int(args.generate_start_index):]
            dois = dois[int(args.generate_start_index):]
        elif args.generate_start_index != 'none' and args.generate_end_index != 'none':
            abstracts = abstracts[int(args.generate_start_index):int(args.generate_end_index)]
            pls = pls[int(args.generate_start_index):int(args.generate_end_index)]
            dois = dois[int(args.generate_start_index):int(args.generate_end_index)]

        abstracts_final = []
        dois_final = []
        pls_final = []
        gen_final = []
 
        batch = tokenizer(abstracts, padding='max_length', max_length=args.max_source_length, truncation=True, return_tensors='pt')
        input_ids = batch['input_ids']

        fname_prefix = f'gen_{args.decode_method}_{args.generate_input_prefix}_{args.generate_epoch}_{args.generate_start_index}-{args.generate_end_index}'
        fname_text = fname_prefix + '_text_only.txt'

        logs_list = []
        for i,d,a,p in zip(range(len(dois)), dois, abstracts, pls):
            ids = input_ids[i]

            logs = None
            if args.decode_method=='greedy':
                gen_ids = model.generate(ids.unsqueeze(0), 
                                         do_sample=False,
                                         max_length=args.max_target_length, 
                                         early_stopping=False, 
                                         num_return_sequences=1, 
                                         decoder_start_token_id=model.config.pad_token_id)
            elif args.decode_method=='beam':
                gen_ids = model.generate(ids.unsqueeze(0), 
                                         do_sample=False,
                                         num_beams=args.decode_num_beams,
                                         max_length=args.max_target_length, 
                                         early_stopping=False, 
                                         num_return_sequences=1, 
                                         decoder_start_token_id=model.config.pad_token_id)
            else:
                gen_ids = model.generate(ids.unsqueeze(0),
                                         do_sample=True,
                                         top_p=args.decode_p,
                                         max_length=args.max_target_length, 
                                         early_stopping=False, 
                                         num_return_sequences=1, 
                                         decoder_start_token_id=model.config.pad_token_id)
            
            gen_text = tokenizer.decode(gen_ids.squeeze(0), skip_special_tokens=True, clean_up_tokenization_spaces=False)
             
            dois_final.append(dois[i])
            abstracts_final.append(abstracts[i])
            pls_final.append(pls[i])
            gen_final.append(gen_text)

            if logs is not None:
                logs_list.append(logs)

            with open(join(args.output_dir, fname_text), 'a+') as f:
                f.write(gen_text + '\n----------------------------------------\n')
                f.flush()
            
            print(gen_text + '\n----------------------------------------\n')

        output = [{'doi': d.strip(), 'abstract': a.strip(), 'pls': p.strip(), 'gen': g.strip()} for d,a,p,g in zip(dois_final, abstracts_final, pls_final, gen_final)]

        fname_json = fname_prefix + '.json'
        open(join(args.output_dir, fname_json), 'w').write(json.dumps(output, indent=2))

        if len(logs_list) > 0:
            fname_logs = fname_prefix + '_log.json'
            open(join(args.output_dir, fname_logs), 'w').write(json.dumps(logs_list, indent=2))

    return model