コード例 #1
0
ファイル: main.py プロジェクト: chanhee0222/feed2resp
def main():
    config = Config()
    parser = argparse.ArgumentParser()
    add_generic_args(parser, os.getcwd())
    parser = BartSystem.add_model_specific_args(parser, os.getcwd())
    args = parser.parse_args()

    # Some values from Config class needs to be copied to args to work.
    setattr(config, "num_train_epochs", args.num_train_epochs)
    setattr(config, "save_path", args.output_dir)
    setattr(args, "learning_rate", config.lr_F)

    # Create output directory.
    timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    setattr(config, "save_folder", os.path.join(config.save_path, timestamp))
    os.makedirs(os.path.join(config.save_folder, 'ckpts'))
    init_logger(config.save_folder)
    logger = logging.getLogger(__name__)

    model_F = BartSystem(args).to(config.device)
    # Don't use the trainer to fit the model
    args.do_train = False
    # trainer = generic_train(model_F, args)
    if args.output_dir:
        try:
            checkpoints = list(
                sorted(
                    glob.glob(os.path.join(args.output_dir,
                                           "checkpointepoch=*.ckpt"),
                              recursive=True)))
            if checkpoints[-1]:
                BartSystem.load_from_checkpoint(checkpoints[-1])
                logger.info("Load checkpoint sucessfully!")
        except:
            logger.info("Failed to load checkpoint!")

    # train_iters, dev_iters, test_iters, vocab = load_dataset(config)
    train_iters, dev_iters, test_iters = model_F.train_dataloader(
    ), model_F.val_dataloader(), model_F.test_dataloader()
    model_D = Discriminator(config, model_F.tokenizer).to(config.device)

    logger.info(config.discriminator_method)
    # import pdb
    # pdb.set_trace()
    logger.info(model_D)

    train(config, model_F, model_D, train_iters, dev_iters, test_iters)
コード例 #2
0
        parser.add_argument(
            "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
        )

        parser.add_argument(
            "--tags", nargs='+', type=str, help="experiment tags for neptune.ai", default=['FT', 'last-layer']
        )


        return parser


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_generic_args(parser, os.getcwd())
    parser = GLUETransformer.add_model_specific_args(parser, os.getcwd())
    args = parser.parse_args()

    # If output_dir not provided, a folder will be generated in pwd
    if args.output_dir is None:
        args.output_dir = os.path.join("./results", f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",)
        os.makedirs(args.output_dir)

    model = GLUETransformer(args)
    trainer = generic_train(model, args)

    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1])