def main(config_file: str): parser = argparse.ArgumentParser() parser.add_argument("overrides", nargs="*", default=[]) args = parser.parse_args() parsed = Path(config_file) initialize(config_dir=str(parsed.parent), strict=False) cfg = compose(parsed.name, overrides=args.overrides) logger.info(f"Training with the following config:\n{cfg.pretty()}") # we want to pass in dictionaries as OmegaConf doesn't play nicely with # loggers and doesn't allow non-native types module = NetworkLightningModule(OmegaConf.to_container(cfg, resolve=True)) trainer = Trainer(**OmegaConf.to_container(cfg.pl_trainer, resolve=True)) trainer.fit( module, train_dataloader=DataLoader( Dataset.from_data_dir(cfg.dataset.train.dir_path, transform=True), shuffle=True, batch_size=cfg.dataset.train.batch_size, num_workers=cfg.dataset.train.num_workers, ), val_dataloaders=DataLoader( Dataset.from_data_dir(cfg.dataset.val.dir_path), batch_size=cfg.dataset.val.batch_size, num_workers=cfg.dataset.val.num_workers, ), ) if cfg.train.run_test: trainer.test(test_dataloaders=DataLoader( Dataset.from_data_dir(cfg.dataset.test.dir_path), batch_size=cfg.datset.train.batch_size, num_workers=cfg.dataset.test.num_workers, ))
def main(cfg: DictConfig): logger.info(f'Training with the following config:\n{cfg.pretty()}') if cfg.train.resume_checkpoint: module = NetworkLightningModule.load_from_checkpoint( hydra.utils.to_absolute_path(cfg.train.resume_checkpoint)) else: module = NetworkLightningModule({ 'board_size': cfg.network.board_size, 'in_channels': cfg.network.in_channels, 'residual_channels': cfg.network.residual_channels, 'residual_layers': cfg.network.residual_layers, 'learning_rate': cfg.train.learning_rate, }) trainer = Trainer( max_epochs=cfg.train.max_epochs, gpus=cfg.train.gpus, early_stop_callback=cfg.train.early_stop, distributed_backend='ddp' if cfg.train.gpus is not None and cfg.train.gpus > 1 else None, train_percent_check=cfg.train.train_percent, val_percent_check=cfg.train.val_percent, ) trainer.fit( module, train_dataloader=DataLoader( Dataset.from_data_dir(hydra.utils.to_absolute_path( cfg.dataset.train_dir), transform=True), shuffle=True, batch_size=cfg.train.batch_size, num_workers=cfg.train.n_data_workers, ), val_dataloaders=DataLoader( Dataset.from_data_dir( hydra.utils.to_absolute_path(cfg.dataset.val_dir)), batch_size=cfg.train.batch_size, num_workers=cfg.train.n_data_workers, ), test_dataloaders=DataLoader( Dataset.from_data_dir( hydra.utils.to_absolute_path(cfg.dataset.test_dir)), batch_size=cfg.train.batch_size, num_workers=cfg.train.n_data_workers, ), ) if cfg.train.run_test: trainer.test()
def test_train(tmp_path): module = NetworkLightningModule({ 'board_size': 19, 'in_channels': 18, 'residual_channels': 1, 'residual_layers': 1, 'learning_rate': 0.05, }) trainer = Trainer(fast_dev_run=True, default_save_path=tmp_path) train_dataset = Dataset.from_data_dir('test-data', transform=True) dataset = Dataset.from_data_dir('test-data') trainer.fit( module, train_dataloader=DataLoader(train_dataset, batch_size=2, shuffle=True), val_dataloaders=DataLoader(dataset, batch_size=2), test_dataloaders=DataLoader(dataset, batch_size=2), )
def test_go_dataset(filenames: List[str], length: int, transform: bool): view = Dataset(filenames, transform) assert len(view) == length random_idx = random.randrange(0, len(view)) planes, moves, outcome = view[random_idx] assert planes.size() == (18, 19, 19) assert moves.item() in list(range(19 * 19 + 1)) assert moves.dtype == torch.int64 assert outcome.item() in (-1, 1)
def test_go_dataset(filenames: List[str], length: int, transform: bool): dataset = Dataset(filenames, transform) assert len(dataset) == length for i in range(len(dataset)): planes, moves, outcome = dataset[i] assert planes.size() == (18, 19, 19) assert moves.item() in list(range(19 * 19 + 1)) assert moves.dtype == torch.int64 assert outcome.item() in (-1, 1) assert outcome.dtype == torch.float32
def test_train(): module = NetworkLightningModule({ 'board_size': 19, 'in_channels': 18, 'residual_channels': 1, 'residual_layers': 1, }) trainer = Trainer(fast_dev_run=True) dataset = Dataset.from_data_dir('test-data') trainer.fit( module, train_dataloader=DataLoader(dataset, batch_size=2), val_dataloaders=DataLoader(dataset, batch_size=2), test_dataloaders=DataLoader(dataset, batch_size=2), )