def train(args): config = json.loads(_jsonnet.evaluate_file(args["CONFIG_FILE"])) if args["--extra-config"]: extra_config = args["--extra-config"] extra_config = json.loads(extra_config) config = util.update(config, extra_config) # dataloaders batch_size = config["train"]["batch_size"] train_set = Dataset( config["data"]["train_file"], config["data"], percent=float(args["--percent"]), ) dev_set = Dataset(config["data"]["dev_file"], config["data"]) train_loader = DataLoader( train_set, batch_size=batch_size, collate_fn=Dataset.collate_fn, num_workers=16, pin_memory=True, ) val_loader = DataLoader( dev_set, batch_size=batch_size, collate_fn=Dataset.collate_fn, num_workers=8, pin_memory=True, ) # model model = TypeReconstructionModel(config) wandb_logger = WandbLogger(name=args["--expname"], project="dire", log_model=True) wandb_logger.log_hyperparams(config) resume_from_checkpoint = (args["--eval-ckpt"] if args["--eval-ckpt"] else args["--resume"]) if resume_from_checkpoint == "": resume_from_checkpoint = None trainer = pl.Trainer( max_epochs=config["train"]["max_epoch"], logger=wandb_logger, gpus=1 if args["--cuda"] else None, auto_select_gpus=True, gradient_clip_val=1, callbacks=[ EarlyStopping( monitor="val_retype_acc" if config["data"]["retype"] else "val_rename_acc", mode="max", patience=config["train"]["patience"], ) ], check_val_every_n_epoch=config["train"]["check_val_every_n_epoch"], progress_bar_refresh_rate=10, accumulate_grad_batches=config["train"]["grad_accum_step"], resume_from_checkpoint=resume_from_checkpoint, ) if args["--eval-ckpt"]: # HACK: necessary to make pl test work for IterableDataset Dataset.__len__ = lambda self: 1000000 test_set = Dataset(config["data"]["test_file"], config["data"]) test_loader = DataLoader( test_set, batch_size=config["test"]["batch_size"], collate_fn=Dataset.collate_fn, num_workers=8, pin_memory=True, ) trainer.test(model, test_dataloaders=test_loader, ckpt_path=args["--eval-ckpt"]) else: trainer.fit(model, train_loader, val_loader)
def build(cls, config): params = util.update(cls.default_params(), config) model = cls(params) return model
def build(cls, config): params = util.update(XfmrSequentialEncoder.default_params(), config) return cls(params)