def test_savefun_and_writer_exclusive(self): # savefun and writer arguments cannot be specified together. def savefun(*args, **kwargs): assert False writer = extensions.snapshot_writers.SimpleWriter() with pytest.raises(TypeError): extensions.snapshot(savefun=savefun, writer=writer) trainer = mock.MagicMock() with pytest.raises(TypeError): extensions.snapshot_object(trainer, savefun=savefun, writer=writer)
def test_clean_up_tempdir(self): snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat') snapshot(self.trainer) left_tmps = [ fn for fn in os.listdir('.') if fn.startswith('tmpmyfile.dat') ] self.assertEqual(len(left_tmps), 0)
def test_save_file(self): w = extensions.snapshot_writers.SimpleWriter() snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat', writer=w) snapshot(self.trainer) self.assertTrue(os.path.exists('myfile.dat'))
def test_on_error(self): class TheOnlyError(Exception): pass @training.make_extension(trigger=(1, 'iteration'), priority=100) def exception_raiser(trainer): raise TheOnlyError() self.trainer.extend(exception_raiser) snapshot = extensions.snapshot_object(self.trainer, self.filename, snapshot_on_error=True) self.trainer.extend(snapshot) self.assertFalse(os.path.exists(self.filename)) with self.assertRaises(TheOnlyError): self.trainer.run() self.assertTrue(os.path.exists(self.filename))
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() assert_config(config) output.mkdir(exist_ok=True, parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model predictor = create_predictor(config.network) model = Model( loss_config=config.loss, predictor=predictor, local_padding_size=config.dataset.local_padding_size, ) if config.train.weight_initializer is not None: init_weights(model, name=config.train.weight_initializer) device = torch.device("cuda") model.to(device) # dataset _create_iterator = partial( create_iterator, batch_size=config.train.batchsize, eval_batch_size=config.train.eval_batchsize, num_processes=config.train.num_processes, use_multithread=config.train.use_multithread, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True, for_eval=False) test_iter = _create_iterator(datasets["test"], for_train=False, for_eval=False) eval_iter = _create_iterator(datasets["eval"], for_train=False, for_eval=True) valid_iter = None if datasets["valid"] is not None: valid_iter = _create_iterator(datasets["valid"], for_train=False, for_eval=True) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer cp: Dict[str, Any] = copy(config.train.optimizer) n = cp.pop("name").lower() optimizer: Optimizer if n == "adam": optimizer = optim.Adam(model.parameters(), **cp) elif n == "sgd": optimizer = optim.SGD(model.parameters(), **cp) else: raise ValueError(n) # updater if not config.train.use_amp: updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) else: updater = AmpUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_eval = (config.train.eval_iteration, "iteration") trigger_snapshot = (config.train.snapshot_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) shift_ext = None if config.train.linear_shift is not None: shift_ext = extensions.LinearShift(**config.train.linear_shift) if config.train.step_shift is not None: shift_ext = extensions.StepShift(**config.train.step_shift) if shift_ext is not None: trainer.extend(shift_ext) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) generator = Generator( config=config, predictor=predictor, use_gpu=True, max_batch_size=(config.train.eval_batchsize if config.train.eval_batchsize is not None else config.train.batchsize), use_fast_inference=False, ) generate_evaluator = GenerateEvaluator( generator=generator, time_length=config.dataset.time_length_evaluate, local_padding_time_length=config.dataset. local_padding_time_length_evaluate, ) ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device) trainer.extend(ext, name="eval", trigger=trigger_eval) if valid_iter is not None: ext = extensions.Evaluator(valid_iter, generate_evaluator, device=device) trainer.extend(ext, name="valid", trigger=trigger_eval) if config.train.stop_iteration is not None: saving_model_num = int(config.train.stop_iteration / config.train.eval_iteration / 10) else: saving_model_num = 10 ext = extensions.snapshot_object( predictor, filename="predictor_{.updater.iteration}.pth", n_retains=saving_model_num, ) trainer.extend( ext, trigger=LowValueTrigger("eval/main/mcd", trigger=trigger_eval), ) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.observe_lr(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) trainer.extend(TensorboardReport(writer=SummaryWriter(Path(output))), trigger=trigger_log) if config.project.category is not None: ext = WandbReport( config_dict=config.to_dict(), project_category=config.project.category, project_name=config.project.name, output_dir=output.joinpath("wandb"), ) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) ext = extensions.snapshot_object( trainer, filename="trainer_{.updater.iteration}.pth", n_retains=1, autoload=True, ) trainer.extend(ext, trigger=trigger_snapshot) return trainer
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(exist_ok=True, parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model device = torch.device("cuda") predictor = create_predictor(config.network) model = Model( model_config=config.model, predictor=predictor, local_padding_length=config.dataset.local_padding_length, ) init_weights(model, "orthogonal") model.to(device) # dataset _create_iterator = partial( create_iterator, batch_size=config.train.batchsize, eval_batch_size=config.train.eval_batchsize, num_processes=config.train.num_processes, use_multithread=config.train.use_multithread, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True, for_eval=False) test_iter = _create_iterator(datasets["test"], for_train=False, for_eval=False) eval_iter = _create_iterator(datasets["eval"], for_train=False, for_eval=True) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer cp: Dict[str, Any] = copy(config.train.optimizer) n = cp.pop("name").lower() optimizer: Optimizer if n == "adam": optimizer = optim.Adam(model.parameters(), **cp) elif n == "sgd": optimizer = optim.SGD(model.parameters(), **cp) else: raise ValueError(n) # updater use_amp = config.train.use_amp if config.train.use_amp is not None else amp_exist if use_amp: updater = AmpUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) else: updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_eval = (config.train.eval_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) writer = SummaryWriter(Path(output)) # # error at randint # sample_data = datasets["train"][0] # writer.add_graph( # model, # input_to_model=( # sample_data["wave"].unsqueeze(0).to(device), # sample_data["local"].unsqueeze(0).to(device), # sample_data["speaker_id"].unsqueeze(0).to(device) # if predictor.with_speaker # else None, # ), # ) if config.train.multistep_shift is not None: trainer.extend( extensions.MultistepShift(**config.train.multistep_shift)) if config.train.step_shift is not None: trainer.extend(extensions.StepShift(**config.train.step_shift)) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) generator = Generator( config=config, noise_schedule_config=NoiseScheduleModelConfig(start=1e-4, stop=0.05, num=50), predictor=predictor, sampling_rate=config.dataset.sampling_rate, use_gpu=True, ) generate_evaluator = GenerateEvaluator( generator=generator, local_padding_time_second=config.dataset. evaluate_local_padding_time_second, ) ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device) trainer.extend(ext, name="eval", trigger=trigger_eval) if config.train.stop_iteration is not None: saving_model_num = int(config.train.stop_iteration / config.train.eval_iteration / 10) else: saving_model_num = 10 ext = extensions.snapshot_object( predictor, filename="predictor_{.updater.iteration}.pth", n_retains=saving_model_num, ) trainer.extend( ext, trigger=LowValueTrigger("eval/main/mcd", trigger=trigger_eval), ) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.observe_lr(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) trainer.extend(ext, trigger=TensorboardReport(writer=writer)) if config.project.category is not None: ext = WandbReport( config_dict=config.to_dict(), project_category=config.project.category, project_name=config.project.name, output_dir=output.joinpath("wandb"), ) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) ext = extensions.snapshot_object( trainer, filename="trainer_{.updater.iteration}.pth", n_retains=1, autoload=True, ) trainer.extend(ext, trigger=trigger_eval) return trainer
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(exist_ok=True, parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model networks = create_network(config.network) model = Model(model_config=config.model, networks=networks) if config.train.weight_initializer is not None: init_weights(model, name=config.train.weight_initializer) device = torch.device("cuda") if config.train.use_gpu else torch.device( "cpu") model.to(device) # dataset _create_iterator = partial( create_iterator, batch_size=config.train.batch_size, eval_batch_size=config.train.eval_batch_size, num_processes=config.train.num_processes, use_multithread=config.train.use_multithread, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True) test_iter = _create_iterator(datasets["test"], for_train=False) eval_iter = _create_iterator(datasets["eval"], for_train=False) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer optimizer = make_optimizer(config_dict=config.train.optimizer, model=model) # updater if not config.train.use_amp: updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) else: updater = AmpUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_eval = (config.train.eval_iteration, "iteration") trigger_snapshot = (config.train.snapshot_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) if config.train.stop_iteration is not None: saving_model_num = int(config.train.stop_iteration / config.train.eval_iteration / 10) else: saving_model_num = 10 ext = extensions.snapshot_object( networks.predictor, filename="predictor_{.updater.iteration}.pth", n_retains=saving_model_num, ) trainer.extend( ext, trigger=LowValueTrigger("test/main/loss", trigger=trigger_eval), ) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.observe_lr(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) ext = TensorboardReport(writer=SummaryWriter(Path(output))) trainer.extend(ext, trigger=trigger_log) if config.project.category is not None: ext = WandbReport( config_dict=config.to_dict(), project_category=config.project.category, project_name=config.project.name, output_dir=output.joinpath("wandb"), ) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) ext = extensions.snapshot_object( trainer, filename="trainer_{.updater.iteration}.pth", n_retains=1, autoload=True, ) trainer.extend(ext, trigger=trigger_snapshot) return trainer
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(exist_ok=True, parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model predictor = create_predictor(config.network) model = Model(model_config=config.model, predictor=predictor) if config.train.weight_initializer is not None: init_weights(model, name=config.train.weight_initializer) device = torch.device("cuda") model.to(device) # dataset _create_iterator = partial( create_iterator, batch_size=config.train.batch_size, num_processes=config.train.num_processes, use_multithread=config.train.use_multithread, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True) test_iter = _create_iterator(datasets["test"], for_train=False) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer cp: Dict[str, Any] = copy(config.train.optimizer) n = cp.pop("name").lower() optimizer: Optimizer if n == "adam": optimizer = optim.Adam(model.parameters(), **cp) elif n == "sgd": optimizer = optim.SGD(model.parameters(), **cp) else: raise ValueError(n) # updater updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, converter=list_concat, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_eval = (config.train.snapshot_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) writer = SummaryWriter(Path(output)) sample_data = datasets["train"][0] writer.add_graph( model, input_to_model=( [sample_data["f0"].to(device)], [sample_data["phoneme"].to(device)], [sample_data["phoneme_list"].to(device)], ([sample_data["speaker_id"].to(device)] if predictor.with_speaker else None), ), ) ext = extensions.Evaluator(test_iter, model, converter=list_concat, device=device) trainer.extend(ext, name="test", trigger=trigger_log) if config.train.stop_iteration is not None: saving_model_num = int(config.train.stop_iteration / config.train.snapshot_iteration / 10) else: saving_model_num = 10 ext = extensions.snapshot_object( predictor, filename="predictor_{.updater.iteration}.pth", n_retains=saving_model_num, ) trainer.extend( ext, trigger=LowValueTrigger("test/main/loss", trigger=trigger_eval), ) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.observe_lr(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) ext = TensorboardReport(writer=writer) trainer.extend(ext, trigger=trigger_log) if config.project.category is not None: ext = WandbReport( config_dict=config.to_dict(), project_category=config.project.category, project_name=config.project.name, output_dir=output.joinpath("wandb"), ) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) ext = extensions.snapshot_object( trainer, filename="trainer_{.updater.iteration}.pth", n_retains=1, autoload=True, ) trainer.extend(ext, trigger=trigger_eval) return trainer
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model device = torch.device("cuda") networks = create_network(config.network) model = Model( model_config=config.model, networks=networks, local_padding_length=config.dataset.local_padding_length, ) model.to(device) if config.model.discriminator_input_type is not None: discriminator_model = DiscriminatorModel( model_config=config.model, networks=networks, local_padding_length=config.dataset.local_padding_length, ) discriminator_model.to(device) else: discriminator_model = None # dataset def _create_iterator(dataset, for_train: bool): return MultiprocessIterator( dataset, config.train.batchsize, repeat=for_train, shuffle=for_train, n_processes=config.train.num_processes, dataset_timeout=300, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True) test_iter = _create_iterator(datasets["test"], for_train=False) test_eval_iter = _create_iterator(datasets["test_eval"], for_train=False) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer optimizer = create_optimizer(config.train.optimizer, model) if config.train.discriminator_optimizer is not None: discriminator_optimizer = create_optimizer( config.train.discriminator_optimizer, discriminator_model) else: discriminator_optimizer = None # updater updater = Updater( iterator=train_iter, optimizer=optimizer, discriminator_model=discriminator_model, model=model, discriminator_optimizer=discriminator_optimizer, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_snapshot = (config.train.snapshot_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) if config.train.step_shift is not None: trainer.extend(extensions.StepShift(**config.train.step_shift)) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) if discriminator_model is not None: ext = extensions.Evaluator(test_iter, discriminator_model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) generator = Generator(config=config, predictor=networks.predictor, use_gpu=True) generate_evaluator = GenerateEvaluator( generator=generator, time_length=config.dataset.evaluate_time_second, local_padding_time_length=config.dataset. evaluate_local_padding_time_second, ) ext = extensions.Evaluator(test_eval_iter, generate_evaluator, device=device) trainer.extend(ext, name="eval", trigger=trigger_snapshot) ext = extensions.snapshot_object( networks.predictor, filename="predictor_{.updater.iteration}.pth") trainer.extend(ext, trigger=trigger_snapshot) # ext = extensions.snapshot_object( # trainer, filename="trainer_{.updater.iteration}.pth" # ) # trainer.extend(ext, trigger=trigger_snapshot) # if networks.discriminator is not None: # ext = extensions.snapshot_object( # networks.discriminator, filename="discriminator_{.updater.iteration}.pth" # ) # trainer.extend(ext, trigger=trigger_snapshot) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) ext = TensorboardReport(writer=SummaryWriter(Path(output))) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if discriminator_model is not None: (output / "discriminator_struct.txt").write_text( repr(discriminator_model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) return trainer
def train_phase(predictor, train, valid, args): print('# classes:', train.n_classes) print('# samples:') print('-- train:', len(train)) print('-- valid:', len(valid)) # setup dataset iterators train_iter = iterators.MultiprocessIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=True) # setup a model class_weight = None # NOTE: please set if you have.. lossfun = partial(softmax_cross_entropy, normalize=False, class_weight=class_weight) device = torch.device(args.gpu) model = Classifier(predictor, lossfun=lossfun) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.iteration, 'iteration'), out=args.out) frequency = max(args.iteration // 20, 1) if args.frequency == -1 else max(1, args.frequency) stop_trigger = triggers.EarlyStoppingTrigger( monitor='validation/main/loss', max_trigger=(args.iteration, 'iteration'), check_trigger=(frequency, 'iteration'), patients=np.inf if args.pinfall == -1 else max(1, args.pinfall)) trainer = training.Trainer(updater, stop_trigger, out=args.out) # setup a visualizer transforms = { 'x': lambda x: x, 'y': lambda x: np.argmax(x, axis=0), 't': lambda x: x } cmap = np.array([[0, 0, 0], [0, 0, 1]]) cmaps = {'x': None, 'y': cmap, 't': cmap} clims = {'x': 'minmax', 'y': None, 't': None} visualizer = ImageVisualizer(transforms=transforms, cmaps=cmaps, clims=clims) # setup a validator valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png') trainer.extend(Validator(valid_iter, model, valid_file, visualizer=visualizer, n_vis=20, device=args.gpu), trigger=(frequency, 'iteration')) # trainer.extend(DumpGraph(model, 'main/loss')) trainer.extend(extensions.snapshot( filename='snapshot_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object( predictor, 'predictor_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) log_keys = [ 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy' ] trainer.extend(LogReport(keys=log_keys)) # setup log ploter if extensions.PlotReport.available(): for plot_key in ['loss', 'accuracy']: plot_keys = [ key for key in log_keys if key.split('/')[-1].startswith(plot_key) ] trainer.extend( extensions.PlotReport(plot_keys, 'iteration', file_name=plot_key + '.png', trigger=(frequency, 'iteration'))) trainer.extend( PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=100)) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) # train trainer.run()
def add_snapshot_object(target, name): ext = extensions.snapshot_object(target, filename=name + "_{.updater.iteration}.pth") trainer.extend(ext, trigger=trigger_snapshot)
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(parents=True) with (output / 'config.yaml').open(mode='w') as f: yaml.safe_dump(config.to_dict(), f) # model networks = create_network(config.network) model = Model(model_config=config.model, networks=networks) device = torch.device('cuda') model.to(device) # dataset def _create_iterator(dataset, for_train: bool): return MultiprocessIterator( dataset, config.train.batchsize, repeat=for_train, shuffle=for_train, n_processes=config.train.num_processes, dataset_timeout=60, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets['train'], for_train=True) test_iter = _create_iterator(datasets['test'], for_train=False) train_test_iter = _create_iterator(datasets['train_test'], for_train=False) warnings.simplefilter('error', MultiprocessIterator.TimeoutWarning) # optimizer cp: Dict[str, Any] = copy(config.train.optimizer) n = cp.pop('name').lower() if n == 'adam': optimizer = optim.Adam(model.parameters(), **cp) elif n == 'sgd': optimizer = optim.SGD(model.parameters(), **cp) else: raise ValueError(n) # updater updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) # trainer trigger_log = (config.train.log_iteration, 'iteration') trigger_snapshot = (config.train.snapshot_iteration, 'iteration') trigger_stop = ( config.train.stop_iteration, 'iteration') if config.train.stop_iteration is not None else None trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name='test', trigger=trigger_log) ext = extensions.Evaluator(train_test_iter, model, device=device) trainer.extend(ext, name='train', trigger=trigger_log) ext = extensions.snapshot_object( networks.predictor, filename='predictor_{.updater.iteration}.npz') trainer.extend(ext, trigger=trigger_snapshot) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend(extensions.PrintReport( ['iteration', 'main/loss', 'test/main/loss']), trigger=trigger_log) ext = TensorboardReport(writer=SummaryWriter(Path(output))) trainer.extend(ext, trigger=trigger_log) (output / 'struct.txt').write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) return trainer
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(exist_ok=True, parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model networks = create_network(config.network) model = Model(config=config.model, networks=networks) init_orthogonal(model) device = torch.device("cuda") model.to(device) # dataset _create_iterator = partial( create_iterator, batch_size=config.train.batch_size, eval_batch_size=config.train.eval_batch_size, num_processes=config.train.num_processes, use_multithread=config.train.use_multithread, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True, for_eval=False) test_iter = _create_iterator(datasets["test"], for_train=False, for_eval=False) valid_iter = None if datasets["valid"] is not None: valid_iter = _create_iterator(datasets["valid"], for_train=False, for_eval=True) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer cp: Dict[str, Any] = copy(config.train.optimizer) n = cp.pop("name").lower() optimizer: Optimizer if n == "adam": optimizer = optim.Adam(model.parameters(), **cp) elif n == "sgd": optimizer = optim.SGD(model.parameters(), **cp) elif n == "ranger": optimizer = Ranger(model.parameters(), **cp) else: raise ValueError(n) # updater if not config.train.use_amp: updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) else: updater = AmpUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_eval = (config.train.eval_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) if valid_iter is not None: ext = extensions.Evaluator(valid_iter, model, device=device) trainer.extend(ext, name="valid", trigger=trigger_eval) if config.train.stop_iteration is not None: saving_model_num = int(config.train.stop_iteration / config.train.eval_iteration / 10) else: saving_model_num = 10 for field in dataclasses.fields(Networks): ext = extensions.snapshot_object( getattr(networks, field.name), filename=field.name + "_{.updater.iteration}.pth", n_retains=saving_model_num, ) trainer.extend( ext, trigger=HighValueTrigger( ("valid/main/phoneme_accuracy" if valid_iter is not None else "test/main/phoneme_accuracy"), trigger=trigger_eval, ), ) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) ext = TensorboardReport(writer=SummaryWriter(Path(output))) trainer.extend(ext, trigger=trigger_log) if config.project.category is not None: ext = WandbReport( config_dict=config.to_dict(), project_category=config.project.category, project_name=config.project.name, output_dir=output.joinpath("wandb"), ) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) ext = extensions.snapshot_object( trainer, filename="trainer_{.updater.iteration}.pth", n_retains=1, autoload=True, ) trainer.extend(ext, trigger=trigger_eval) return trainer
def train_phase(generator, train, valid, args): print('# samples:') print('-- train:', len(train)) print('-- valid:', len(valid)) # setup dataset iterators train_iter = iterators.SerialIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=True) # setup a model model = Regressor(generator, activation=torch.tanh, lossfun=F.l1_loss, accfun=F.l1_loss) discriminator = build_discriminator() discriminator.save_args(os.path.join(args.out, 'discriminator.json')) device = torch.device(args.gpu) model.to(device) discriminator.to(device) # setup an optimizer optimizer_G = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta, 0.999), weight_decay=max(args.decay, 0)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta, 0.999), weight_decay=max(args.decay, 0)) # setup a trainer updater = DCGANUpdater( iterator=train_iter, optimizer={ 'gen': optimizer_G, 'dis': optimizer_D, }, model={ 'gen': model, 'dis': discriminator, }, alpha=args.alpha, device=args.gpu, ) frequency = max(args.iteration//80, 1) if args.frequency == -1 else max(1, args.frequency) stop_trigger = triggers.EarlyStoppingTrigger(monitor='validation/main/loss', max_trigger=(args.iteration, 'iteration'), check_trigger=(frequency, 'iteration'), patients=np.inf if args.pinfall == -1 else max(1, args.pinfall)) trainer = training.Trainer(updater, stop_trigger, out=args.out) # shift lr trainer.extend( extensions.LinearShift('lr', (args.lr, 0.0), (args.iteration//2, args.iteration), optimizer=optimizer_G)) trainer.extend( extensions.LinearShift('lr', (args.lr, 0.0), (args.iteration//2, args.iteration), optimizer=optimizer_D)) # setup a visualizer transforms = {'x': lambda x: x, 'y': lambda x: x, 't': lambda x: x} clims = {'x': (-1., 1.), 'y': (-1., 1.), 't': (-1., 1.)} visualizer = ImageVisualizer(transforms=transforms, cmaps=None, clims=clims) # setup a validator valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png') trainer.extend(Validator(valid_iter, model, valid_file, visualizer=visualizer, n_vis=20, device=args.gpu), trigger=(frequency, 'iteration')) # trainer.extend(DumpGraph('loss_gen', filename='generative_loss.dot')) # trainer.extend(DumpGraph('loss_cond', filename='conditional_loss.dot')) # trainer.extend(DumpGraph('loss_dis', filename='discriminative_loss.dot')) trainer.extend(extensions.snapshot(filename='snapshot_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object(generator, 'generator_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) trainer.extend(extensions.snapshot_object(discriminator, 'discriminator_iter_{.updater.iteration:08}.pth'), trigger=(frequency, 'iteration')) log_keys = ['loss_gen', 'loss_cond', 'loss_dis', 'validation/main/accuracy'] trainer.extend(LogReport(keys=log_keys, trigger=(100, 'iteration'))) # setup log ploter if extensions.PlotReport.available(): for plot_key in ['loss', 'accuracy']: plot_keys = [key for key in log_keys if key.split('/')[-1].startswith(plot_key)] trainer.extend( extensions.PlotReport(plot_keys, 'iteration', file_name=plot_key + '.png', trigger=(frequency, 'iteration')) ) trainer.extend(PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=1)) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) # train trainer.run()