def train_phase(predictor, train, valid, args): # setup iterators train_iter = iterators.SerialIterator(train, args.batchsize) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=False) # setup a model device = torch.device(args.gpu) model = Classifier(predictor) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu)) # trainer.extend(DumpGraph(model, 'main/loss')) frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) trainer.extend(extensions.LogReport()) if args.plot and extensions.PlotReport.available(): trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) trainer.extend( extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time' ])) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) trainer.run() torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
def _setup(self, stream=None, delete_flush=False): self.logreport = mock.MagicMock(spec=extensions.LogReport( ['epoch'], trigger=(1, 'iteration'), log_name=None)) if stream is None: self.stream = mock.MagicMock() if delete_flush: del self.stream.flush else: self.stream = stream self.report = extensions.PrintReport(['epoch'], log_report=self.logreport, out=self.stream) self.trainer = testing.get_trainer_with_mock_updater( stop_trigger=(1, 'iteration')) self.trainer.extend(self.logreport) self.trainer.extend(self.report) self.logreport.log = [{'epoch': 0}]
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() assert_config(config) output.mkdir(exist_ok=True, parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model predictor = create_predictor(config.network) model = Model( loss_config=config.loss, predictor=predictor, local_padding_size=config.dataset.local_padding_size, ) if config.train.weight_initializer is not None: init_weights(model, name=config.train.weight_initializer) device = torch.device("cuda") model.to(device) # dataset _create_iterator = partial( create_iterator, batch_size=config.train.batchsize, eval_batch_size=config.train.eval_batchsize, num_processes=config.train.num_processes, use_multithread=config.train.use_multithread, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True, for_eval=False) test_iter = _create_iterator(datasets["test"], for_train=False, for_eval=False) eval_iter = _create_iterator(datasets["eval"], for_train=False, for_eval=True) valid_iter = None if datasets["valid"] is not None: valid_iter = _create_iterator(datasets["valid"], for_train=False, for_eval=True) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer cp: Dict[str, Any] = copy(config.train.optimizer) n = cp.pop("name").lower() optimizer: Optimizer if n == "adam": optimizer = optim.Adam(model.parameters(), **cp) elif n == "sgd": optimizer = optim.SGD(model.parameters(), **cp) else: raise ValueError(n) # updater if not config.train.use_amp: updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) else: updater = AmpUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_eval = (config.train.eval_iteration, "iteration") trigger_snapshot = (config.train.snapshot_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) shift_ext = None if config.train.linear_shift is not None: shift_ext = extensions.LinearShift(**config.train.linear_shift) if config.train.step_shift is not None: shift_ext = extensions.StepShift(**config.train.step_shift) if shift_ext is not None: trainer.extend(shift_ext) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) generator = Generator( config=config, predictor=predictor, use_gpu=True, max_batch_size=(config.train.eval_batchsize if config.train.eval_batchsize is not None else config.train.batchsize), use_fast_inference=False, ) generate_evaluator = GenerateEvaluator( generator=generator, time_length=config.dataset.time_length_evaluate, local_padding_time_length=config.dataset. local_padding_time_length_evaluate, ) ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device) trainer.extend(ext, name="eval", trigger=trigger_eval) if valid_iter is not None: ext = extensions.Evaluator(valid_iter, generate_evaluator, device=device) trainer.extend(ext, name="valid", trigger=trigger_eval) if config.train.stop_iteration is not None: saving_model_num = int(config.train.stop_iteration / config.train.eval_iteration / 10) else: saving_model_num = 10 ext = extensions.snapshot_object( predictor, filename="predictor_{.updater.iteration}.pth", n_retains=saving_model_num, ) trainer.extend( ext, trigger=LowValueTrigger("eval/main/mcd", trigger=trigger_eval), ) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.observe_lr(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) trainer.extend(TensorboardReport(writer=SummaryWriter(Path(output))), trigger=trigger_log) if config.project.category is not None: ext = WandbReport( config_dict=config.to_dict(), project_category=config.project.category, project_name=config.project.name, output_dir=output.joinpath("wandb"), ) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) ext = extensions.snapshot_object( trainer, filename="trainer_{.updater.iteration}.pth", n_retains=1, autoload=True, ) trainer.extend(ext, trigger=trigger_snapshot) return trainer
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(exist_ok=True, parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model device = torch.device("cuda") predictor = create_predictor(config.network) model = Model( model_config=config.model, predictor=predictor, local_padding_length=config.dataset.local_padding_length, ) init_weights(model, "orthogonal") model.to(device) # dataset _create_iterator = partial( create_iterator, batch_size=config.train.batchsize, eval_batch_size=config.train.eval_batchsize, num_processes=config.train.num_processes, use_multithread=config.train.use_multithread, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True, for_eval=False) test_iter = _create_iterator(datasets["test"], for_train=False, for_eval=False) eval_iter = _create_iterator(datasets["eval"], for_train=False, for_eval=True) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer cp: Dict[str, Any] = copy(config.train.optimizer) n = cp.pop("name").lower() optimizer: Optimizer if n == "adam": optimizer = optim.Adam(model.parameters(), **cp) elif n == "sgd": optimizer = optim.SGD(model.parameters(), **cp) else: raise ValueError(n) # updater use_amp = config.train.use_amp if config.train.use_amp is not None else amp_exist if use_amp: updater = AmpUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) else: updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_eval = (config.train.eval_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) writer = SummaryWriter(Path(output)) # # error at randint # sample_data = datasets["train"][0] # writer.add_graph( # model, # input_to_model=( # sample_data["wave"].unsqueeze(0).to(device), # sample_data["local"].unsqueeze(0).to(device), # sample_data["speaker_id"].unsqueeze(0).to(device) # if predictor.with_speaker # else None, # ), # ) if config.train.multistep_shift is not None: trainer.extend( extensions.MultistepShift(**config.train.multistep_shift)) if config.train.step_shift is not None: trainer.extend(extensions.StepShift(**config.train.step_shift)) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) generator = Generator( config=config, noise_schedule_config=NoiseScheduleModelConfig(start=1e-4, stop=0.05, num=50), predictor=predictor, sampling_rate=config.dataset.sampling_rate, use_gpu=True, ) generate_evaluator = GenerateEvaluator( generator=generator, local_padding_time_second=config.dataset. evaluate_local_padding_time_second, ) ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device) trainer.extend(ext, name="eval", trigger=trigger_eval) if config.train.stop_iteration is not None: saving_model_num = int(config.train.stop_iteration / config.train.eval_iteration / 10) else: saving_model_num = 10 ext = extensions.snapshot_object( predictor, filename="predictor_{.updater.iteration}.pth", n_retains=saving_model_num, ) trainer.extend( ext, trigger=LowValueTrigger("eval/main/mcd", trigger=trigger_eval), ) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.observe_lr(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) trainer.extend(ext, trigger=TensorboardReport(writer=writer)) if config.project.category is not None: ext = WandbReport( config_dict=config.to_dict(), project_category=config.project.category, project_name=config.project.name, output_dir=output.joinpath("wandb"), ) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) ext = extensions.snapshot_object( trainer, filename="trainer_{.updater.iteration}.pth", n_retains=1, autoload=True, ) trainer.extend(ext, trigger=trigger_eval) return trainer
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(exist_ok=True, parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model networks = create_network(config.network) model = Model(model_config=config.model, networks=networks) if config.train.weight_initializer is not None: init_weights(model, name=config.train.weight_initializer) device = torch.device("cuda") if config.train.use_gpu else torch.device( "cpu") model.to(device) # dataset _create_iterator = partial( create_iterator, batch_size=config.train.batch_size, eval_batch_size=config.train.eval_batch_size, num_processes=config.train.num_processes, use_multithread=config.train.use_multithread, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True) test_iter = _create_iterator(datasets["test"], for_train=False) eval_iter = _create_iterator(datasets["eval"], for_train=False) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer optimizer = make_optimizer(config_dict=config.train.optimizer, model=model) # updater if not config.train.use_amp: updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) else: updater = AmpUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_eval = (config.train.eval_iteration, "iteration") trigger_snapshot = (config.train.snapshot_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) if config.train.stop_iteration is not None: saving_model_num = int(config.train.stop_iteration / config.train.eval_iteration / 10) else: saving_model_num = 10 ext = extensions.snapshot_object( networks.predictor, filename="predictor_{.updater.iteration}.pth", n_retains=saving_model_num, ) trainer.extend( ext, trigger=LowValueTrigger("test/main/loss", trigger=trigger_eval), ) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.observe_lr(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) ext = TensorboardReport(writer=SummaryWriter(Path(output))) trainer.extend(ext, trigger=trigger_log) if config.project.category is not None: ext = WandbReport( config_dict=config.to_dict(), project_category=config.project.category, project_name=config.project.name, output_dir=output.joinpath("wandb"), ) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) ext = extensions.snapshot_object( trainer, filename="trainer_{.updater.iteration}.pth", n_retains=1, autoload=True, ) trainer.extend(ext, trigger=trigger_snapshot) return trainer
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(exist_ok=True, parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model predictor = create_predictor(config.network) model = Model(model_config=config.model, predictor=predictor) if config.train.weight_initializer is not None: init_weights(model, name=config.train.weight_initializer) device = torch.device("cuda") model.to(device) # dataset _create_iterator = partial( create_iterator, batch_size=config.train.batch_size, num_processes=config.train.num_processes, use_multithread=config.train.use_multithread, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True) test_iter = _create_iterator(datasets["test"], for_train=False) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer cp: Dict[str, Any] = copy(config.train.optimizer) n = cp.pop("name").lower() optimizer: Optimizer if n == "adam": optimizer = optim.Adam(model.parameters(), **cp) elif n == "sgd": optimizer = optim.SGD(model.parameters(), **cp) else: raise ValueError(n) # updater updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, converter=list_concat, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_eval = (config.train.snapshot_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) writer = SummaryWriter(Path(output)) sample_data = datasets["train"][0] writer.add_graph( model, input_to_model=( [sample_data["f0"].to(device)], [sample_data["phoneme"].to(device)], [sample_data["phoneme_list"].to(device)], ([sample_data["speaker_id"].to(device)] if predictor.with_speaker else None), ), ) ext = extensions.Evaluator(test_iter, model, converter=list_concat, device=device) trainer.extend(ext, name="test", trigger=trigger_log) if config.train.stop_iteration is not None: saving_model_num = int(config.train.stop_iteration / config.train.snapshot_iteration / 10) else: saving_model_num = 10 ext = extensions.snapshot_object( predictor, filename="predictor_{.updater.iteration}.pth", n_retains=saving_model_num, ) trainer.extend( ext, trigger=LowValueTrigger("test/main/loss", trigger=trigger_eval), ) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.observe_lr(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) ext = TensorboardReport(writer=writer) trainer.extend(ext, trigger=trigger_log) if config.project.category is not None: ext = WandbReport( config_dict=config.to_dict(), project_category=config.project.category, project_name=config.project.name, output_dir=output.joinpath("wandb"), ) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) ext = extensions.snapshot_object( trainer, filename="trainer_{.updater.iteration}.pth", n_retains=1, autoload=True, ) trainer.extend(ext, trigger=trigger_eval) return trainer
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model device = torch.device("cuda") networks = create_network(config.network) model = Model( model_config=config.model, networks=networks, local_padding_length=config.dataset.local_padding_length, ) model.to(device) if config.model.discriminator_input_type is not None: discriminator_model = DiscriminatorModel( model_config=config.model, networks=networks, local_padding_length=config.dataset.local_padding_length, ) discriminator_model.to(device) else: discriminator_model = None # dataset def _create_iterator(dataset, for_train: bool): return MultiprocessIterator( dataset, config.train.batchsize, repeat=for_train, shuffle=for_train, n_processes=config.train.num_processes, dataset_timeout=300, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True) test_iter = _create_iterator(datasets["test"], for_train=False) test_eval_iter = _create_iterator(datasets["test_eval"], for_train=False) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer optimizer = create_optimizer(config.train.optimizer, model) if config.train.discriminator_optimizer is not None: discriminator_optimizer = create_optimizer( config.train.discriminator_optimizer, discriminator_model) else: discriminator_optimizer = None # updater updater = Updater( iterator=train_iter, optimizer=optimizer, discriminator_model=discriminator_model, model=model, discriminator_optimizer=discriminator_optimizer, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_snapshot = (config.train.snapshot_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) if config.train.step_shift is not None: trainer.extend(extensions.StepShift(**config.train.step_shift)) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) if discriminator_model is not None: ext = extensions.Evaluator(test_iter, discriminator_model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) generator = Generator(config=config, predictor=networks.predictor, use_gpu=True) generate_evaluator = GenerateEvaluator( generator=generator, time_length=config.dataset.evaluate_time_second, local_padding_time_length=config.dataset. evaluate_local_padding_time_second, ) ext = extensions.Evaluator(test_eval_iter, generate_evaluator, device=device) trainer.extend(ext, name="eval", trigger=trigger_snapshot) ext = extensions.snapshot_object( networks.predictor, filename="predictor_{.updater.iteration}.pth") trainer.extend(ext, trigger=trigger_snapshot) # ext = extensions.snapshot_object( # trainer, filename="trainer_{.updater.iteration}.pth" # ) # trainer.extend(ext, trigger=trigger_snapshot) # if networks.discriminator is not None: # ext = extensions.snapshot_object( # networks.discriminator, filename="discriminator_{.updater.iteration}.pth" # ) # trainer.extend(ext, trigger=trigger_snapshot) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) ext = TensorboardReport(writer=SummaryWriter(Path(output))) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if discriminator_model is not None: (output / "discriminator_struct.txt").write_text( repr(discriminator_model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) return trainer
def train_phase(predictor, train, valid, args): # visualize plt.rcParams['font.size'] = 18 plt.figure(figsize=(13, 5)) ax = sns.scatterplot(x=train.x.ravel(), y=train.y.ravel(), color='blue', s=55, alpha=0.3) ax.plot(train.x.ravel(), train.t.ravel(), color='red', linewidth=2) ax.set_xlabel('x') ax.set_ylabel('y') ax.set_xlim(-10, 10) ax.set_ylim(-15, 15) plt.legend(['Ground-truth', 'Observation']) plt.title('Training data set') plt.tight_layout() plt.savefig(os.path.join(args.out, 'train_dataset.png')) plt.close() # setup iterators train_iter = iterators.SerialIterator(train, args.batchsize, shuffle=True) valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=False) # setup a model device = torch.device(args.gpu) lossfun = noised_mean_squared_error accfun = lambda y, t: F.l1_loss(y[0], t) model = Regressor(predictor, lossfun=lossfun, accfun=accfun) model.to(device) # setup an optimizer optimizer = torch.optim.Adam(model.parameters(), weight_decay=max(args.decay, 0)) # setup a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu)) # trainer.extend(DumpGraph(model, 'main/loss')) frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) trainer.extend(extensions.LogReport()) if args.plot and extensions.PlotReport.available(): trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) trainer.extend( extensions.PlotReport( ['main/predictor/sigma', 'validation/main/predictor/sigma'], 'epoch', file_name='sigma.png')) trainer.extend( extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'main/predictor/sigma', 'validation/main/predictor/sigma', 'elapsed_time' ])) trainer.extend(extensions.ProgressBar()) if args.resume: trainer.load_state_dict(torch.load(args.resume)) trainer.run() torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
def create_trainer( config_dict: Dict[str, Any], output: Path, dataset_dir: Optional[Path], ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model device = torch.device("cuda") networks = create_network(config.network) generator_model = GeneratorModel(model_config=config.model, networks=networks).to(device) moving_generator_model = deepcopy(generator_model).to(device) discriminator_model = DiscriminatorModel(model_config=config.model, networks=networks).to(device) # dataset def _create_iterator(dataset, for_train: bool): return MultiprocessIterator( dataset, config.train.batchsize, repeat=for_train, shuffle=for_train, n_processes=config.train.num_processes, dataset_timeout=60 * 15, ) datasets = create_dataset(config.dataset, dataset_dir=dataset_dir) train_iter = _create_iterator(datasets["train"], for_train=True) test_iter = _create_iterator(datasets["test"], for_train=False) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer style_transfer_optimizer = create_optimizer( config=config.train.style_transfer_optimizer, model=networks.style_transfer) mapping_network_optimizer = create_optimizer( config=config.train.mapping_network_optimizer, model=networks.mapping_network) style_encoder_optimizer = create_optimizer( config=config.train.style_encoder_optimizer, model=networks.style_encoder) discriminator_optimizer = create_optimizer( config=config.train.discriminator_optimizer, model=networks.discriminator) # updater updater = Updater( iterator=train_iter, optimizer=dict( style_transfer=style_transfer_optimizer, mapping_network=mapping_network_optimizer, style_encoder=style_encoder_optimizer, discriminator=discriminator_optimizer, ), model=dict( generator=generator_model, discriminator=discriminator_model, moving_generator=moving_generator_model, ), moving_average_rate=config.train.moving_average_rate, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_snapshot = (config.train.snapshot_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) def eval_func(**kwargs): generator_model.forward_with_latent(**kwargs) generator_model.forward_with_reference(**kwargs) discriminator_model.forward_with_latent(**kwargs) discriminator_model.forward_with_reference(**kwargs) moving_generator_model.forward_with_latent(**kwargs) moving_generator_model.forward_with_reference(**kwargs) ext = extensions.Evaluator( test_iter, target=dict( generator=generator_model, discriminator=discriminator_model, moving_generator=moving_generator_model, ), eval_func=eval_func, device=device, ) trainer.extend(ext, name="test", trigger=trigger_log) def add_snapshot_object(target, name): ext = extensions.snapshot_object(target, filename=name + "_{.updater.iteration}.pth") trainer.extend(ext, trigger=trigger_snapshot) add_snapshot_object(networks.style_transfer, "style_transfer") add_snapshot_object(networks.mapping_network, "mapping_network") add_snapshot_object(networks.style_encoder, "style_encoder") trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport([ "iteration", "generator/latent/loss", "test/generator/latent/loss" ]), trigger=trigger_log, ) if config.train.model_config_linear_shift is not None: ext = ObjectLinearShift(target=config.model, **config.train.model_config_linear_shift) trainer.extend( ext, trigger=(1, "iteration"), ) ext = TensorboardReport(writer=SummaryWriter(Path(output))) trainer.extend(ext, trigger=trigger_log) (output / "generator_struct.txt").write_text(repr(generator_model)) (output / "discriminator_struct.txt").write_text(repr(discriminator_model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) return trainer
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(parents=True) with (output / 'config.yaml').open(mode='w') as f: yaml.safe_dump(config.to_dict(), f) # model networks = create_network(config.network) model = Model(model_config=config.model, networks=networks) device = torch.device('cuda') model.to(device) # dataset def _create_iterator(dataset, for_train: bool): return MultiprocessIterator( dataset, config.train.batchsize, repeat=for_train, shuffle=for_train, n_processes=config.train.num_processes, dataset_timeout=60, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets['train'], for_train=True) test_iter = _create_iterator(datasets['test'], for_train=False) train_test_iter = _create_iterator(datasets['train_test'], for_train=False) warnings.simplefilter('error', MultiprocessIterator.TimeoutWarning) # optimizer cp: Dict[str, Any] = copy(config.train.optimizer) n = cp.pop('name').lower() if n == 'adam': optimizer = optim.Adam(model.parameters(), **cp) elif n == 'sgd': optimizer = optim.SGD(model.parameters(), **cp) else: raise ValueError(n) # updater updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) # trainer trigger_log = (config.train.log_iteration, 'iteration') trigger_snapshot = (config.train.snapshot_iteration, 'iteration') trigger_stop = ( config.train.stop_iteration, 'iteration') if config.train.stop_iteration is not None else None trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name='test', trigger=trigger_log) ext = extensions.Evaluator(train_test_iter, model, device=device) trainer.extend(ext, name='train', trigger=trigger_log) ext = extensions.snapshot_object( networks.predictor, filename='predictor_{.updater.iteration}.npz') trainer.extend(ext, trigger=trigger_snapshot) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend(extensions.PrintReport( ['iteration', 'main/loss', 'test/main/loss']), trigger=trigger_log) ext = TensorboardReport(writer=SummaryWriter(Path(output))) trainer.extend(ext, trigger=trigger_log) (output / 'struct.txt').write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) return trainer
def create_trainer( config_dict: Dict[str, Any], output: Path, ): # config config = Config.from_dict(config_dict) config.add_git_info() output.mkdir(exist_ok=True, parents=True) with (output / "config.yaml").open(mode="w") as f: yaml.safe_dump(config.to_dict(), f) # model networks = create_network(config.network) model = Model(config=config.model, networks=networks) init_orthogonal(model) device = torch.device("cuda") model.to(device) # dataset _create_iterator = partial( create_iterator, batch_size=config.train.batch_size, eval_batch_size=config.train.eval_batch_size, num_processes=config.train.num_processes, use_multithread=config.train.use_multithread, ) datasets = create_dataset(config.dataset) train_iter = _create_iterator(datasets["train"], for_train=True, for_eval=False) test_iter = _create_iterator(datasets["test"], for_train=False, for_eval=False) valid_iter = None if datasets["valid"] is not None: valid_iter = _create_iterator(datasets["valid"], for_train=False, for_eval=True) warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning) # optimizer cp: Dict[str, Any] = copy(config.train.optimizer) n = cp.pop("name").lower() optimizer: Optimizer if n == "adam": optimizer = optim.Adam(model.parameters(), **cp) elif n == "sgd": optimizer = optim.SGD(model.parameters(), **cp) elif n == "ranger": optimizer = Ranger(model.parameters(), **cp) else: raise ValueError(n) # updater if not config.train.use_amp: updater = StandardUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) else: updater = AmpUpdater( iterator=train_iter, optimizer=optimizer, model=model, device=device, ) # trainer trigger_log = (config.train.log_iteration, "iteration") trigger_eval = (config.train.eval_iteration, "iteration") trigger_stop = ((config.train.stop_iteration, "iteration") if config.train.stop_iteration is not None else None) trainer = Trainer(updater, stop_trigger=trigger_stop, out=output) ext = extensions.Evaluator(test_iter, model, device=device) trainer.extend(ext, name="test", trigger=trigger_log) if valid_iter is not None: ext = extensions.Evaluator(valid_iter, model, device=device) trainer.extend(ext, name="valid", trigger=trigger_eval) if config.train.stop_iteration is not None: saving_model_num = int(config.train.stop_iteration / config.train.eval_iteration / 10) else: saving_model_num = 10 for field in dataclasses.fields(Networks): ext = extensions.snapshot_object( getattr(networks, field.name), filename=field.name + "_{.updater.iteration}.pth", n_retains=saving_model_num, ) trainer.extend( ext, trigger=HighValueTrigger( ("valid/main/phoneme_accuracy" if valid_iter is not None else "test/main/phoneme_accuracy"), trigger=trigger_eval, ), ) trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend( extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]), trigger=trigger_log, ) ext = TensorboardReport(writer=SummaryWriter(Path(output))) trainer.extend(ext, trigger=trigger_log) if config.project.category is not None: ext = WandbReport( config_dict=config.to_dict(), project_category=config.project.category, project_name=config.project.name, output_dir=output.joinpath("wandb"), ) trainer.extend(ext, trigger=trigger_log) (output / "struct.txt").write_text(repr(model)) if trigger_stop is not None: trainer.extend(extensions.ProgressBar(trigger_stop)) ext = extensions.snapshot_object( trainer, filename="trainer_{.updater.iteration}.pth", n_retains=1, autoload=True, ) trainer.extend(ext, trigger=trigger_eval) return trainer
def main(): parser = argparse.ArgumentParser(description='Chainer example: MNIST') parser.add_argument('--batchsize', '-b', type=int, default=100, help='Number of images in each mini-batch') parser.add_argument('--epoch', '-e', type=int, default=20, help='Number of sweeps over the dataset to train') parser.add_argument('--frequency', '-f', type=int, default=-1, help='Frequency of taking a snapshot') parser.add_argument('--device', '-d', type=str, default='-1', help='Device specifier. Either ChainerX device ' 'specifier or an integer. If non-negative integer, ' 'CuPy arrays with specified device id are used. If ' 'negative integer, NumPy arrays are used') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--resume', '-r', type=str, help='Resume the training from snapshot') parser.add_argument('--autoload', action='store_true', help='Automatically load trainer snapshots in case' ' of preemption or other temporary system failure') parser.add_argument('--unit', '-u', type=int, default=1000, help='Number of units') group = parser.add_argument_group('deprecated arguments') group.add_argument('--gpu', '-g', dest='device', type=int, nargs='?', const=0, help='GPU ID (negative value indicates CPU)') args = parser.parse_args() device = torch.device(args.device) print('Device: {}'.format(device)) print('# Minibatch-size: {}'.format(args.batchsize)) print('# epoch: {}'.format(args.epoch)) print('') # Set up a neural network to train # Classifier reports softmax cross entropy loss and accuracy at every # iteration, which will be used by the PrintReport extension below. model = Classifier(MLP(784, args.unit, 10)) model.to(device) # Setup an optimizer optimizer = torch.optim.Adam(model.parameters()) # Load the MNIST dataset transform = transforms.ToTensor() train = datasets.MNIST('data', train=True, download=True, transform=transform) test = datasets.MNIST('data', train=False, transform=transform) train_iter = pytorch_trainer.iterators.SerialIterator( train, args.batchsize) test_iter = pytorch_trainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) # Set up a trainer updater = training.updaters.StandardUpdater(train_iter, optimizer, model, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) # Evaluate the model with the test dataset for each epoch trainer.extend(extensions.Evaluator(test_iter, model, device=device), call_before_training=True) # Dump a computational graph from 'loss' variable at the first iteration # The "main" refers to the target link of the "main" optimizer. # trainer.extend(extensions.DumpGraph('main/loss')) # Take a snapshot for each specified epoch frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) # Take a snapshot each ``frequency`` epoch, delete old stale # snapshots and automatically load from snapshot files if any # files are already resident at result directory. trainer.extend(extensions.snapshot(n_retains=1, autoload=args.autoload), trigger=(frequency, 'epoch')) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(), call_before_training=True) # Save two plot images to the result dir trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'), call_before_training=True) trainer.extend(extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'), call_before_training=True) # Print selected entries of the log to stdout # Here "main" refers to the target link of the "main" optimizer again, and # "validation" refers to the default name of the Evaluator extension. # Entries other than 'epoch' are reported by the Classifier link, called by # either the updater or the evaluator. trainer.extend(extensions.PrintReport([ 'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time' ]), call_before_training=True) # Print a progress bar to stdout trainer.extend(extensions.ProgressBar()) if args.resume is not None: # Resume from a snapshot (Note: this loaded model is to be # overwritten by --autoload option, autoloading snapshots, if # any snapshots exist in output directory) trainer.load_state_dict(torch.load(args.resume)) # Run the training trainer.run()