Esempio n. 1
0
def main(args):

    # If output_dir not provided, a folder will be generated in pwd
    if not args.output_dir:
        args.output_dir = os.path.join(
            "./results",
            f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",
        )
        os.makedirs(args.output_dir)
    model = SummarizationTrainer(args)
    if args.PreLoad_Model:
        checkpoint = torch.load(args.path)
        model.load_state_dict(checkpoint['state_dict'])
        print("model loaded with custom weights from {}".format(args.path))

    trainer = generic_train(model, args)
    import pdb
    pdb.set_trace()
    checkpoint_path = "/Users/byronwallace/code/RoboSum/weights/pl_title_/pl_title_2048.ckpt"
    model = model.load_from_checkpoint(checkpoint_path)
    trainer.test(model)

    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
        #checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
        #model = model.load_from_checkpoint(checkpoints[-1])
        checkpoints = os.path.join(args.output_dir,
                                   (args.output_dir.split('/')[-1] + ".ckpt"))
        print(checkpoints)
        model = model.load_from_checkpoint(checkpoints)
        trainer.test(model)
Esempio n. 2
0
def main():
    parser = argparse.ArgumentParser()
    add_generic_args(parser, os.getcwd())
    parser = GLUETransformer.add_model_specific_args(parser, os.getcwd())
    args = parser.parse_args()

    # If output_dir not provided, a folder will be generated in pwd
    if args.output_dir is None:
        args.output_dir = os.path.join(
            "./results",
            f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",
        )
        os.makedirs(args.output_dir)

    model = GLUETransformer(args)
    trainer = generic_train(model, args)

    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        checkpoints = list(
            sorted(
                glob.glob(os.path.join(args.output_dir,
                                       "checkpointepoch=*.ckpt"),
                          recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1])
        return trainer.test(model)
Esempio n. 3
0
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))

    # summarization model
    model: SummarizationModule = SummarizationModule(args)

    dataset = Path(args.data_dir).name
    if (args.logger_name == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name,
                             project=f"hf_{dataset}")

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric,
                                                  args.early_stopping_patience)
    else:
        es_callback = False

    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(args.output_dir,
                                                    model.val_metric,
                                                    args.save_top_k),
        early_stopping_callback=es_callback,
        logger=logger,
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(
        sorted(
            glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                      recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
    return model
Esempio n. 4
0
def main(args):

    # If output_dir not provided, a folder will be generated in pwd
    if not args.output_dir:
        args.output_dir = os.path.join(
            "./results",
            f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",
        )
        os.makedirs(args.output_dir)
    model = SummarizationTrainer(args)
    trainer = generic_train(model, args)

    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
        checkpoints = list(
            sorted(
                glob.glob(os.path.join(args.output_dir,
                                       "checkpointepoch=*.ckpt"),
                          recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1])
        trainer.test(model)
Esempio n. 5
0
def main(args):

    # If output_dir not provided, a folder will be generated in pwd
    if not args.output_dir:
        args.output_dir = os.path.join("./results", f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",)
        os.makedirs(args.output_dir)
    model = SummarizationTrainer(args)
    trainer = generic_train(model, args)

    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
        # checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
        # model = model.load_from_checkpoint(checkpoints[-1])
        # trainer.test(model)
        checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1]).cuda()
        m = {}
        m1 = {}
        m2 = {}

        for line in open('/dccstor/tuhinstor/tuhin/NQ-amr-qc/val.source'):
            q = line.split('-----------')[0].rstrip()
            supp = line.split('-----------')[1].lstrip()
            m[q] = line.strip()
        for x,y in zip(open('/dccstor/tuhinstor/tuhin/NQ-rank/val.source'),open('/dccstor/tuhinstor/tuhin/NQ-amr-qc/val.source')):
            q1 = x.split('-----------')[0].rstrip()
            q2 = y.split('-----------')[0].rstrip()
            m2[q2] = q1
        
        for line in open('/dccstor/tuhinstor/tuhin/newdata/data1/gold_tokens2.jsonl'):
            line = json.loads(line.strip())
            m1[line['q']] = line['cand']
        count = 0
        f = open('/dccstor/tuhinstor/tuhin/likelihood_amr_qc.txt','w')
        for line in open('/dccstor/tuhinstor/tuhin/NQ-amr-qc/val.source'):
            q = line.split('-----------')[0].rstrip()
            source = m[q]
            logger.info("Doing "+q)
            la_candidates = []
            cand = []
            batch_arr = []
            for c in m1[m2[q]]:
                cand = c[2]
                if cand=='':
                    continue
                # for cand_sent in nltk.sent_tokenize(cand): # , 'parent': cand,
                batch_arr.append({'text': cand,'target_ids': model.tokenizer.batch_encode_plus([cand], max_length=1024, return_tensors='pt')['input_ids'].cuda()})
            source_id = model.tokenizer.batch_encode_plus([source], max_length=1024, return_tensors='pt')['input_ids']
            ans = model.likelihood(batch_arr,source_id.cuda(),torch.cuda.LongTensor([[1]*len(source_id.cpu().tolist()[0])]))
            try:
                if ans[-1]=='\n':
                     f.write(q+' [SEP] '+ans)
                else:
                    f.write(q+' [SEP] '+ans+'\n')
            except:
                print("Failed for ",q)
Esempio n. 6
0
def main(args=None, model=None) -> GenerativeQAModule:

    Path(args.output_dir).mkdir(exist_ok=True)

    # named_actors = []
    # args.actor_handles = named_actors
    # assert args.actor_handles == named_actors

    if model is None:
        model: GenerativeQAModule = GenerativeQAModule(args)

    dataset = Path(args.data_dir).name

    data_module = Seq2SeqDataModule(model.tokenizer, args)

    if (args.logger_name == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        training_logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger
        project = os.environ.get("WANDB_PROJECT", dataset)
        training_logger = WandbLogger(name=model.output_dir.name,
                                      project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger
        training_logger = WandbLogger(name=model.output_dir.name,
                                      project=f"hf_{dataset}")

    elif args.logger_name == "tb-logs":
        from pytorch_lightning.loggers import TensorBoardLogger
        training_logger = TensorBoardLogger('tb_logs', name='my_model')

    es_callback = (get_early_stopping_callback(model.val_metric,
                                               args.early_stopping_patience)
                   if args.early_stopping_patience >= 0 else False)

    trainer: pl.Trainer = generic_train(
        model,
        args,
        data_module,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(args.output_dir,
                                                    model.val_metric),
        early_stopping_callback=es_callback,
        logger=training_logger,
        accelerator=CustomAccel() if args.gpus > 1 else None,
        profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")

    if not args.do_predict:
        return model

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
    return model
Esempio n. 7
0
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    check_output_dir(args, expected_items=3)
    if model is None:
        if "summarization" in args.task:
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
    dataset = Path(args.data_dir).name
    if (
        args.logger_name == "default"
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        from pytorch_lightning.loggers import CSVLogger
        logger = CSVLogger('chen_logs',name = 'SCHWEIGEN')  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False

    lower_is_better = args.val_metric == "loss"
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric, args.save_top_k, lower_is_better
        ),
        early_stopping_callback=es_callback,
        logger=logger,
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
    return model
Esempio n. 8
0
def main(args):

    # If output_dir not provided, a folder will be generated in pwd
    if not args.output_dir:
        args.output_dir = os.path.join(
            "./results",
            f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",
        )
        os.makedirs(args.output_dir)
    model = SummarizationTrainer(args)
    if args.checkpoint_model:
        model = model.load_from_checkpoint(args.checkpoint_model)
        logger.info("args.data_dir: %s", args.data_dir)
        model.dataset_kwargs: dict = dict(
            data_dir=args.data_dir,
            max_source_length=args.max_source_length,
            max_target_length=args.max_target_length,
        )
        model.hparams = args
    #trainer = generic_train(model, args)

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric,
                                                  args.early_stopping_patience)
    else:
        es_callback = False

    trainer = generic_train(model,
                            args,
                            checkpoint_callback=get_checkpoint_callback(
                                args.output_dir, model.val_metric),
                            early_stopping_callback=es_callback)

    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
        if args.checkpoint_model:
            trainer.test(model)
        else:
            checkpoints = list(
                sorted(
                    glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                              recursive=True)))
            if checkpoints:
                print('Loading weights from {}'.format(checkpoints[-1]))
                model = model.load_from_checkpoint(checkpoints[-1])
                model.dataset_kwargs: dict = dict(
                    data_dir=args.data_dir,
                    max_source_length=args.max_source_length,
                    max_target_length=args.max_target_length,
                )
                model.hparams = args
            trainer.test(model)
Esempio n. 9
0
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if model is None:
        if args.task == "summarization":
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)

    dataset = Path(args.data_dir).name
    if (args.logger == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name, project=dataset)

    elif args.logger == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name,
                             project=f"hf_{dataset}")
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(args.output_dir,
                                                    model.val_metric),
        logger=logger,
        # TODO: early stopping callback seems messed up
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(
        sorted(
            glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                      recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
    trainer.test(
        model
    )  # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
    return model
def main(args=None, model=None) -> GenerativeQAModule:

    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
    parser = GenerativeQAModule.add_retriever_specific_args(parser)

    args = args or parser.parse_args()

    Path(args.output_dir).mkdir(exist_ok=True)
    if model is None:
        model: GenerativeQAModule = GenerativeQAModule(args)

    dataset = Path(args.data_dir).name
    if (args.logger_name == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name,
                             project=f"hf_{dataset}")

    es_callback = (get_early_stopping_callback(model.val_metric,
                                               args.early_stopping_patience)
                   if args.early_stopping_patience >= 0 else False)

    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(args.output_dir,
                                                    model.val_metric),
        early_stopping_callback=es_callback,
        logger=logger,
        accelerator=CustomAccel() if args.gpus > 1 else None,
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")

    if not args.do_predict:
        return model

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
    return model
Esempio n. 11
0
def main(args):

    # If output_dir not provided, a folder will be generated in pwd
    if not args.output_dir:
        args.output_dir = os.path.join(
            "./results",
            f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",
        )
        os.makedirs(args.output_dir)
    model = SummarizationTrainer(args)
    sd = model.model.state_dict()
    shorter_pos_embeds = sd['model.encoder.embed_positions.weight']
    new_config = model.config
    new_config.max_position_embeddings = 3076
    new_model = BartForConditionalGeneration(new_config)
    correctly_shaped_pos_weight = new_model.model.encoder.embed_positions.weight.cuda(
    )
    correctly_shaped_pos_weight[:shorter_pos_embeds.
                                shape[0]] = shorter_pos_embeds.cuda()
    correctly_shaped_pos_weight[shorter_pos_embeds.
                                shape[0]:2052] = shorter_pos_embeds.cuda()
    correctly_shaped_pos_weight[2052:] = shorter_pos_embeds.cuda()
    sd['model.decoder.embed_positions.weight'] = correctly_shaped_pos_weight
    sd['model.encoder.embed_positions.weight'] = correctly_shaped_pos_weight
    new_model.load_state_dict(sd, strict=True)
    model.model = new_model.cuda()
    trainer = generic_train(model, args)

    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
        checkpoints = list(
            sorted(
                glob.glob(os.path.join(args.output_dir,
                                       "checkpointepoch=*.ckpt"),
                          recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1])
        trainer.test(model)
Esempio n. 12
0
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
    exp_dir = ckpt_path.parent
    if dest_dir is None:
        dest_dir = exp_dir
    clash = list(dest_dir.glob("test_generations*"))
    if clash:
        print(f"SKIPPING to avoid overwriting {clash}")
    ckpt = torch.load(ckpt_path, map_location="cpu")
    if "hparams" in ckpt:
        args = argparse.Namespace(**ckpt["hparams"])
    else:
        args = argparse.Namespace(**pickle_load(exp_dir / "hparams.pkl"))
    args.resume_from_checkpoint = str(ckpt_path)
    args.do_train = False
    args.output_dir = str(dest_dir)
    args.n_gpu = 1
    args.eval_batch_size = 16
    Path(args.output_dir).mkdir(exist_ok=True)
    model = create_module(args)
    trainer: pl.Trainer = generic_train(model, args, early_stopping_callback=False)
    trainer.test(model)
Esempio n. 13
0
            help="Overwrite the cached training and evaluation sets")

        return parser


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser = add_program_args(parser)
    parser = GLUETransformer.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args()

    # If output_dir not provided, a folder will be generated in pwd
    if args.output_dir is None:
        args.output_dir = os.path.join(
            "./results",
            f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",
        )
        os.makedirs(args.output_dir)

    wandb_logger = WandbLogger(project="transformers")
    early_stopping = EarlyStopping(monitor='val_loss', patience=3)

    model = GLUETransformer(args)

    generic_train(model,
                  args,
                  logger=wandb_logger,
                  early_stop_callback=early_stopping)
Esempio n. 14
0
def main(args, model=None) -> SummarizationModule:
    print(args)
    Path(args.output_dir).mkdir(exist_ok=True)

    if model is None:
        if "summarization" in args.task:
            ### Define BART model
            # Config from "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json
            # Vocab modified to 50265 to be consistent with facebook/bart-large default
            config = BartConfig(**json.load(open(args.config_path, "r")))
            config.fp16 = args.fp16

            if args.distill:  # if distilling, start from finetuned checkpoint
                if Path(args.data_dir).name == "cnn_dm":
                    checkpoint = 'facebook/bart-large-cnn'
                else:
                    checkpoint = 'facebook/bart-large-xsum'
            else:
                checkpoint = 'facebook/bart-large'  #Start from pretrained checkpoint otherwise

            if args.resume_from_checkpoint:
                print(
                    "Resuming from checkpoint, make sure checkpoint is finetuned for best results"
                )
                if ".ckpt" in args.resume_from_checkpoint:
                    checkpoint = args.resume_from_checkpoint
                    if args.distill:  # set resume from checkpoint to None (state dict is different)
                        args.resume_from_checkpoint = None
                else:
                    checkpoints = list(
                        sorted(
                            glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                                      recursive=True)))
                    if len(checkpoints) > 0:  #No checkpoints available
                        checkpoint = checkpoints[-1]
                        args.resume_from_checkpoint = checkpoint
                    else:
                        args.resume_from_checkpoint = None
                        print("No valid checkpoint to resume from. Using ",
                              checkpoint)

            print("Loading BART model checkpoint using ", checkpoint)
            model = BartForConditionalGeneration.from_pretrained(checkpoint,
                                                                 config=config)

            if args.distill == "sft":
                model = distill_sft(model)

            tokenizer = BartTokenizer.from_pretrained(
                'facebook/bart-large'
            )  # Downloads vocab and merges file automatically
            model: SummarizationModule = SummarizationModule(
                args, model=model, config=config, tokenizer=tokenizer)
        else:
            raise ValueError("Translation not supported at this time")
            model: SummarizationModule = TranslationModule(args)
    dataset = Path(args.data_dir).name
    if (args.logger_name == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name,
                             project=f"hf_{dataset}")

    # if args.early_stopping_patience >= 0:
    #     extra_callbacks = [get_early_stopping_callback(f"val_{model.val_metric}", args.early_stopping_patience)]
    # else:
    #     extra_callbacks = []
    extra_callbacks = [
        CheckpointEveryNSteps(args.output_dir, args.max_steps - 1)
    ]

    lower_is_better = args.val_metric == "loss"

    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(args.output_dir,
                                                    model.val_metric,
                                                    args.save_top_k,
                                                    lower_is_better),
        extra_callbacks=extra_callbacks,
        logger=logger,
    )

    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if args.do_predict and not args.do_train:
        # Testing from a checkpoint
        trainer.test(model)
    elif args.do_predict and args.do_train:
        # test() without a model tests using the best checkpoint automatically
        model.hparams.test_checkpoint = ""
        checkpoints = list(
            sorted(
                glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                          recursive=True)))
        if checkpoints:
            model.hparams.test_checkpoint = checkpoints[-1]
            trainer.resume_from_checkpoint = checkpoints[-1]
        trainer.logger.log_hyperparams(model.hparams)
        trainer.test()
    return model
Esempio n. 15
0
def main(args=None, model=None) -> GenerativeQAModule:
    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
    parser = GenerativeQAModule.add_retriever_specific_args(parser)
    args = args or parser.parse_args()

    Path(args.output_dir).mkdir(exist_ok=True)
    Path(args.output_dir + "/dpr_ctx_checkpoint").mkdir(
        exist_ok=True)  # save dpr_context encoder seprately for the future use
    print(args.shard_dir)
    if os.path.exists(
            args.shard_dir
    ):  # we do not need previous kb shards used in dataset re-conding and re-indexing
        shutil.rmtree(args.shard_dir)
    Path(args.shard_dir).mkdir(exist_ok=True)

    if os.path.exists(
            args.cache_dir
    ):  # we do not need previous cache files used in dataset re-conding and re-indexing
        shutil.rmtree(args.cache_dir)
    Path(args.cache_dir).mkdir(exist_ok=True)

    named_actors = []
    if args.distributed_retriever == "ray" and args.gpus > 1:
        if not is_ray_available():
            raise RuntimeError("Please install Ray to use the Ray "
                               "distributed retriever.")
        # Connect to an existing Ray cluster.
        try:
            ray.init(address=args.ray_address)
        except (ConnectionError, ValueError):
            logger.warning(
                "Connection to Ray cluster failed. Make sure a Ray"
                "cluster is running by either using Ray's cluster "
                "launcher (`ray up`) or by manually starting Ray on "
                "each node via `ray start --head` for the head node "
                "and `ray start --address='<ip address>:6379'` for "
                "additional nodes. See "
                "https://docs.ray.io/en/master/cluster/index.html "
                "for more info.")
            raise

        # Create Ray actors only for rank 0.
        if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"]
                == 0) and ("NODE_RANK" not in os.environ
                           or os.environ["NODE_RANK"] == 0):
            remote_cls = ray.remote(RayRetriever)
            named_actors = [
                remote_cls.options(
                    name="retrieval_worker_{}".format(i)).remote()
                for i in range(args.num_retrieval_workers)
            ]
        else:
            logger.info(
                "Getting named actors for NODE_RANK {}, LOCAL_RANK {}".format(
                    os.environ["NODE_RANK"], os.environ["LOCAL_RANK"]))
            named_actors = [
                ray.get_actor("retrieval_worker_{}".format(i))
                for i in range(args.num_retrieval_workers)
            ]
    args.actor_handles = named_actors
    assert args.actor_handles == named_actors

    if model is None:
        model: GenerativeQAModule = GenerativeQAModule(args)

    dataset = Path(args.data_dir).name
    if (args.logger_name == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        training_logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        training_logger = WandbLogger(name=model.output_dir.name,
                                      project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        training_logger = WandbLogger(name=model.output_dir.name,
                                      project=f"hf_{dataset}")

    es_callback = (get_early_stopping_callback(model.val_metric,
                                               args.early_stopping_patience)
                   if args.early_stopping_patience >= 0 else False)

    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(args.output_dir,
                                                    model.val_metric),
        early_stopping_callback=es_callback,
        logger=training_logger,
        profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
    )

    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
    return model
Esempio n. 16
0
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
        if args.task == "summarization":
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)

    # add atomic relation tokens
    if args.atomic:
        print("Special tokens are added.")

        additional_tokens_list = [
            "AtLocation",
            "CapableOf",
            "Causes",
            "CausesDesire",
            "CreatedBy",
            "DefinedAs",
            "DesireOf",
            "Desires",
            "HasA",
            "HasFirstSubevent",
            "HasLastSubevent",
            "HasPainCharacter",
            "HasPainIntensity",
            "HasPrerequisite",
            "HasProperty",
            "HasSubEvent",
            "HasSubevent",
            "HinderedBy",
            "InheritsFrom",
            "InstanceOf",
            "IsA",
            "LocatedNear",
            "LocationOfAction",
            "MadeOf",
            "MadeUpOf",
            "MotivatedByGoal",
            "NotCapableOf",
            "NotDesires",
            "NotHasA",
            "NotHasProperty",
            "NotIsA",
            "NotMadeOf",
            "ObjectUse",
            "PartOf",
            "ReceivesAction",
            "RelatedTo",
            "SymbolOf",
            "UsedFor",
            "isAfter",
            "isBefore",
            "isFilledBy",
            "oEffect",
            "oReact",
            "oWant",
            "xAttr",
            "xEffect",
            "xIntent",
            "xNeed",
            "xReact",
            "xReason",
            "xWant",
        ]

        num_added_toks = model.tokenizer.add_tokens(additional_tokens_list)
        model.model.resize_token_embeddings(len(model.tokenizer))

    dataset = Path(args.data_dir).name
    if (
        args.logger_name == "default"
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name, project=dataset)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(
            name=model.output_dir.name, project=f"hf_{dataset}")

    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric),
        logger=logger,
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(
        sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)

    trainer.test(model)
    return model
Esempio n. 17
0
def main(args, model=None) -> SummarizationModule:

    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError(f"Output directory ({args.output_dir}) already exists and is not empty.")

    if model is None:
        model = SummarizationModule(args)

    #add unlikelihood parameters - with logr weights
    set_ul_params(model, args)

    dataset = Path(args.data_dir).name
    if (
        args.logger_name == "default"
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False

    trainer = None

    if args.do_train:

        lower_is_better = args.val_metric == "loss"
        save_top_k = args.max_epochs

        trainer: pl.Trainer = generic_train(
            model,
            args,
            logging_callback=Seq2SeqLoggingCallback(),
            checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric, save_top_k, lower_is_better),
            early_stopping_callback=es_callback,
            logger=logger,
        )
        pickle_save(model.hparams, model.output_dir / "hparams.pkl")

        #now write loss logs into the same directory
        with open(os.path.join(args.output_dir, 'loss_logs.json'), 'w') as f:
            f.write(json.dumps(model.losses, indent=2))

    if args.do_generate:

        if args.generate_epoch > -1:
            model = BartForConditionalGeneration.from_pretrained(join(args.output_dir, f'best_tfmr-{args.generate_epoch}'))
        else:
            print("********* using fresh model *********")
            args.generate_epoch = 'no-train'
            model = BartForConditionalGeneration.from_pretrained(args.model_name_or_path)
 
        tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path)
        abstracts = list(open(join(args.data_dir, f'{args.generate_input_prefix}.source')).readlines())
        pls = list(open(join(args.data_dir, f'{args.generate_input_prefix}.target')).readlines())
        dois = list(open(join(args.data_dir, f'{args.generate_input_prefix}.doi')).readlines())

        if args.generate_start_index == 'none' and args.generate_end_index != 'none':
            abstracts = abstracts[:int(args.generate_end_index)]
            pls = pls[:int(args.generate_end_index)]
            dois = dois[:int(args.generate_end_index)]
        elif args.generate_start_index != 'none' and args.generate_end_index == 'none':
            abstracts = abstracts[int(args.generate_start_index):]
            pls = pls[int(args.generate_start_index):]
            dois = dois[int(args.generate_start_index):]
        elif args.generate_start_index != 'none' and args.generate_end_index != 'none':
            abstracts = abstracts[int(args.generate_start_index):int(args.generate_end_index)]
            pls = pls[int(args.generate_start_index):int(args.generate_end_index)]
            dois = dois[int(args.generate_start_index):int(args.generate_end_index)]

        abstracts_final = []
        dois_final = []
        pls_final = []
        gen_final = []
 
        batch = tokenizer(abstracts, padding='max_length', max_length=args.max_source_length, truncation=True, return_tensors='pt')
        input_ids = batch['input_ids']

        fname_prefix = f'gen_{args.decode_method}_{args.generate_input_prefix}_{args.generate_epoch}_{args.generate_start_index}-{args.generate_end_index}'
        fname_text = fname_prefix + '_text_only.txt'

        logs_list = []
        for i,d,a,p in zip(range(len(dois)), dois, abstracts, pls):
            ids = input_ids[i]

            logs = None
            if args.decode_method=='greedy':
                gen_ids = model.generate(ids.unsqueeze(0), 
                                         do_sample=False,
                                         max_length=args.max_target_length, 
                                         early_stopping=False, 
                                         num_return_sequences=1, 
                                         decoder_start_token_id=model.config.pad_token_id)
            elif args.decode_method=='beam':
                gen_ids = model.generate(ids.unsqueeze(0), 
                                         do_sample=False,
                                         num_beams=args.decode_num_beams,
                                         max_length=args.max_target_length, 
                                         early_stopping=False, 
                                         num_return_sequences=1, 
                                         decoder_start_token_id=model.config.pad_token_id)
            else:
                gen_ids = model.generate(ids.unsqueeze(0),
                                         do_sample=True,
                                         top_p=args.decode_p,
                                         max_length=args.max_target_length, 
                                         early_stopping=False, 
                                         num_return_sequences=1, 
                                         decoder_start_token_id=model.config.pad_token_id)
            
            gen_text = tokenizer.decode(gen_ids.squeeze(0), skip_special_tokens=True, clean_up_tokenization_spaces=False)
             
            dois_final.append(dois[i])
            abstracts_final.append(abstracts[i])
            pls_final.append(pls[i])
            gen_final.append(gen_text)

            if logs is not None:
                logs_list.append(logs)

            with open(join(args.output_dir, fname_text), 'a+') as f:
                f.write(gen_text + '\n----------------------------------------\n')
                f.flush()
            
            print(gen_text + '\n----------------------------------------\n')

        output = [{'doi': d.strip(), 'abstract': a.strip(), 'pls': p.strip(), 'gen': g.strip()} for d,a,p,g in zip(dois_final, abstracts_final, pls_final, gen_final)]

        fname_json = fname_prefix + '.json'
        open(join(args.output_dir, fname_json), 'w').write(json.dumps(output, indent=2))

        if len(logs_list) > 0:
            fname_logs = fname_prefix + '_log.json'
            open(join(args.output_dir, fname_logs), 'w').write(json.dumps(logs_list, indent=2))

    return model
        labels = batch[1]

        logits = self(tokens)
        loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1))
        out_label_ids = labels.detach().cpu().numpy()
        out_pred = logits.detach().cpu().numpy()
        return {"val_loss": loss, "pred": out_pred, "target": out_label_ids}


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_generic_args(parser)
    parser = KimCNN.add_model_specific_args(parser, os.getcwd())
    args = parser.parse_args()
    args.embed_num = 512
    args.embed_dim = 768
    args.class_num = 2
    args.kernel_num = 3
    args.kernel_sizes = [2, 3, 4]
    args.dropout = 0.5
    args.static = True

    model = KimCNN(args)
    trainer = generic_train(model, args)
    print("Training done")
    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1], **args.__dict__)
        trainer.test(model)
Esempio n. 19
0
def main(args):

    # If output_dir not provided, a folder will be generated in pwd
    if not args.output_dir:
        args.output_dir = os.path.join("./results", f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",)
        os.makedirs(args.output_dir)
    model = SummarizationTrainer(args)
    trainer = generic_train(model, args)

    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
        # checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
        # model = model.load_from_checkpoint(checkpoints[-1])
        # trainer.test(model)
        checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1]).cuda()
        m = {}
        m1 = {}
        m2 = {}
        ansq = {}

        for line,line1 in zip(open('/dccstor/tuhinstor/tuhin/NQ-rank1/val.source'),open('/dccstor/tuhinstor/tuhin/NQ-rank1/val.target')):
            q = line.split('-----------')[0].rstrip()
            supp = line.split('-----------')[1].lstrip()
            m[q] = line.strip()
            ansq[q] = line1.strip()
        for line in open('/dccstor/tuhinstor/tuhin/newdata/data1/gold_tokens2.jsonl'):
            line = json.loads(line.strip())
            m1[line['q']] = line['cand']
        count = 0
        fw = open('/dccstor/tuhinstor/tuhin/likelihood_ranked_amr_val.txt','w')
        cou = 1
        for line in open('/dccstor/tuhinstor/tuhin/NQ-rank1/val.source'):
            q = line.split('-----------')[0].rstrip()
            source = m[q]
            logger.info("Doing "+q)
            la_candidates = []
            cand = []
            batch_arr = []
            for c in m1[q]:
                cand = c[2]
                if cand=='':
                    continue
                for cand_sent in nltk.sent_tokenize(cand): # ,
                    batch_arr.append({'text': cand_sent, 'parent': cand,'target_ids': model.tokenizer.batch_encode_plus([cand_sent], max_length=1024, return_tensors='pt')['input_ids'].cuda()})
            source_id = model.tokenizer.batch_encode_plus([source], max_length=1024, return_tensors='pt')['input_ids']
            ans = model.likelihood(batch_arr,source_id.cuda(),torch.cuda.LongTensor([[1]*len(source_id.cpu().tolist()[0])]))
            ans = sorted(ans,key=lambda tup: tup[0])
            countp = {}
            for res in ans:
                if countp[res[2]]<3:
                    if ansq[q] in res[2]:
                        fw.write(str(cou)+'\t'+ansq[q]+'\n')
                        ansq[q] = "garbage"+ansq[q]
                    else:
                        fw.write(str(cou)+'\t'+res[1])
                    if res[2] not in countp:
                        countp[res[2]]=1
                    else:
                        countp[res[2]]= countp[res[2]]+1
            cou = cou+1