コード例 #1
0
ファイル: exp.py プロジェクト: CMUSTRUDEL/DIRTY
def train(args):
    config = json.loads(_jsonnet.evaluate_file(args["CONFIG_FILE"]))

    if args["--extra-config"]:
        extra_config = args["--extra-config"]
        extra_config = json.loads(extra_config)
        config = util.update(config, extra_config)

    # dataloaders
    batch_size = config["train"]["batch_size"]
    train_set = Dataset(
        config["data"]["train_file"],
        config["data"],
        percent=float(args["--percent"]),
    )
    dev_set = Dataset(config["data"]["dev_file"], config["data"])
    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        collate_fn=Dataset.collate_fn,
        num_workers=16,
        pin_memory=True,
    )
    val_loader = DataLoader(
        dev_set,
        batch_size=batch_size,
        collate_fn=Dataset.collate_fn,
        num_workers=8,
        pin_memory=True,
    )

    # model
    model = TypeReconstructionModel(config)

    wandb_logger = WandbLogger(name=args["--expname"],
                               project="dire",
                               log_model=True)
    wandb_logger.log_hyperparams(config)
    resume_from_checkpoint = (args["--eval-ckpt"]
                              if args["--eval-ckpt"] else args["--resume"])
    if resume_from_checkpoint == "":
        resume_from_checkpoint = None
    trainer = pl.Trainer(
        max_epochs=config["train"]["max_epoch"],
        logger=wandb_logger,
        gpus=1 if args["--cuda"] else None,
        auto_select_gpus=True,
        gradient_clip_val=1,
        callbacks=[
            EarlyStopping(
                monitor="val_retype_acc"
                if config["data"]["retype"] else "val_rename_acc",
                mode="max",
                patience=config["train"]["patience"],
            )
        ],
        check_val_every_n_epoch=config["train"]["check_val_every_n_epoch"],
        progress_bar_refresh_rate=10,
        accumulate_grad_batches=config["train"]["grad_accum_step"],
        resume_from_checkpoint=resume_from_checkpoint,
    )
    if args["--eval-ckpt"]:
        # HACK: necessary to make pl test work for IterableDataset
        Dataset.__len__ = lambda self: 1000000
        test_set = Dataset(config["data"]["test_file"], config["data"])
        test_loader = DataLoader(
            test_set,
            batch_size=config["test"]["batch_size"],
            collate_fn=Dataset.collate_fn,
            num_workers=8,
            pin_memory=True,
        )
        trainer.test(model,
                     test_dataloaders=test_loader,
                     ckpt_path=args["--eval-ckpt"])
    else:
        trainer.fit(model, train_loader, val_loader)
コード例 #2
0
 def build(cls, config):
     params = util.update(cls.default_params(), config)
     model = cls(params)
     return model
コード例 #3
0
    def build(cls, config):
        params = util.update(XfmrSequentialEncoder.default_params(), config)

        return cls(params)