def generate( model_dir: Path, model_iteration: Optional[int], model_config: Optional[Path], output_dir: Path, use_gpu: bool, ): if model_config is None: model_config = model_dir / "config.yaml" output_dir.mkdir(exist_ok=True) save_arguments(output_dir / "arguments.yaml", generate, locals()) config = Config.from_dict(yaml.safe_load(model_config.open())) model_path = _get_predictor_model_path( model_dir=model_dir, iteration=model_iteration, ) generator = Generator( config=config, predictor=model_path, use_gpu=use_gpu, ) dataset = create_dataset(config.dataset)["test"] for data in tqdm(dataset, desc="generate"): target = data["target"] output = generator.generate(data["feature"])
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(exist_ok=True, parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model networks = create_network(config.network) model = Model(model_config=config.model, networks=networks) if config.train.weight_initializer is not None: init_weights(model, name=config.train.weight_initializer) device = torch.device("cuda") if config.train.use_gpu else torch.device( "cpu") model.to(device) # dataset _create_iterator = partial( create_iterator, batch_size=config.train.batch_size, eval_batch_size=config.train.eval_batch_size, num_processes=config.train.num_processes, use_multithread=config.train.use_multithread, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True) test_iter = _create_iterator(datasets["test"], for_train=False) eval_iter = _create_iterator(datasets["eval"], for_train=False) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer optimizer = make_optimizer(config_dict=config.train.optimizer, model=model) # updater if not config.train.use_amp: updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) else: updater = AmpUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_eval = (config.train.eval_iteration, "iteration") trigger_snapshot = (config.train.snapshot_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) if config.train.stop_iteration is not None: saving_model_num = int(config.train.stop_iteration / config.train.eval_iteration / 10) else: saving_model_num = 10 ext = extensions.snapshot_object( networks.predictor, filename="predictor_{.updater.iteration}.pth", n_retains=saving_model_num, ) trainer.extend( ext, trigger=LowValueTrigger("test/main/loss", trigger=trigger_eval), ) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.observe_lr(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) ext = TensorboardReport(writer=SummaryWriter(Path(output))) trainer.extend(ext, trigger=trigger_log) if config.project.category is not None: ext = WandbReport( config_dict=config.to_dict(), project_category=config.project.category, project_name=config.project.name, output_dir=output.joinpath("wandb"), ) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) ext = extensions.snapshot_object( trainer, filename="trainer_{.updater.iteration}.pth", n_retains=1, autoload=True, ) trainer.extend(ext, trigger=trigger_snapshot) return trainer
def test_equal_base_config_and_reconstructed(train_config_path: Path): with train_config_path.open() as f: d = yaml.load(f, SafeLoader) base = Config.from_dict(d) base_re = Config.from_dict(base.to_dict()) assert base == base_re
def test_to_dict(train_config_path: Path): with train_config_path.open() as f: d = yaml.load(f, SafeLoader) Config.from_dict(d).to_dict()