def test_train(): # print options to help debugging # print(' '.join(sys.argv)) # load the dataset dataloader = data.create_dataloader(opt) # create trainer for our model trainer = Pix2PixTrainer(opt) # create tool for counting iterations iter_counter = IterationCounter(opt, len(dataloader)) # create tool for visualization visualizer = Visualizer(opt) for epoch in iter_counter.training_epochs(): iter_counter.record_epoch_start(epoch) for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): iter_counter.record_one_iteration() # Training # train generator if i % opt.D_steps_per_G == 0: trainer.run_generator_one_step(data_i) # train discriminator trainer.run_discriminator_one_step(data_i) # Visualizations if iter_counter.needs_printing(): losses = trainer.get_latest_losses() visualizer.print_current_errors(epoch, iter_counter.epoch_iter, losses, iter_counter.time_per_iter) visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) if iter_counter.needs_displaying(): visuals = OrderedDict([('input_label', data_i['label']), ('synthesized_image', trainer.get_latest_generated()), ('real_image', data_i['image'])]) visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) if iter_counter.needs_saving(): print('saving the latest model (epoch %d, total_steps %d)' % (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') iter_counter.record_current_iter() trainer.update_learning_rate(epoch) iter_counter.record_epoch_end() if epoch % opt.save_epoch_freq == 0 or \ epoch == iter_counter.total_epochs: print('saving the model at the end of epoch %d, iters %d' % (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') trainer.save(epoch) print('Training was successfully finished.')
def build_model_and_get_trainer(config: DotMap, data_loader: DataLoader, strategy: tf.distribute.Strategy) -> Trainer: model_structure = config.model.structure print('Create the model') if model_structure == 'pix2pix': with strategy.scope(): generator = get_generator_model(config) discriminator = get_discriminator_model(config) trainer = Pix2PixTrainer(generator=generator, discriminator=discriminator, data_loader=data_loader, strategy=strategy, config=config) return trainer else: raise ValueError(f"unknown model structure {model_structure}")
import data from util.iter_counter import IterationCounter from util.visualizer import Visualizer from trainers.pix2pix_trainer import Pix2PixTrainer # parse options opt = TrainOptions().parse() # print options to help debugging print(' '.join(sys.argv)) # load the dataset dataloader = data.create_dataloader(opt) # create trainer for our model trainer = Pix2PixTrainer(opt) # create tool for counting iterations iter_counter = IterationCounter(opt, len(dataloader)) # create tool for visualization visualizer = Visualizer(opt) for epoch in iter_counter.training_epochs(): iter_counter.record_epoch_start(epoch) for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): iter_counter.record_one_iteration() # Training # train generator if i % opt.D_steps_per_G == 0:
def train(): # create trainer for our model and freeze necessary model layers opt.niter = opt.niter + 20 # 20 more iterations of training opt.lr = 0.00002 # 1/10th of the original lr trainer = Pix2PixTrainer(opt) # Proceed with training. # load the dataset dataloader = data.create_dataloader(opt) trainer = Pix2PixTrainer(opt) # create tool for counting iterations iter_counter = IterationCounter(opt, len(dataloader)) # create tool for visualization visualizer = Visualizer(opt) for epoch in iter_counter.training_epochs(): iter_counter.record_epoch_start(epoch) for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): iter_counter.record_one_iteration() # Training # train generator if i % opt.D_steps_per_G == 0: trainer.run_generator_one_step(data_i) # train discriminator trainer.run_discriminator_one_step(data_i) # Visualizations if iter_counter.needs_printing(): losses = trainer.get_latest_losses() visualizer.print_current_errors(epoch, iter_counter.epoch_iter, losses, iter_counter.time_per_iter) visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) if iter_counter.needs_displaying(): visuals = OrderedDict([('input_label', data_i['label']), ('synthesized_image', trainer.get_latest_generated()), ('real_image', data_i['image'])]) visualizer.display_current_results( visuals, epoch, iter_counter.total_steps_so_far) if iter_counter.needs_saving(): print('saving the latest model (epoch %d, total_steps %d)' % (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') iter_counter.record_current_iter() trainer.update_learning_rate(epoch) iter_counter.record_epoch_end() if epoch % opt.save_epoch_freq == 0 or \ epoch == iter_counter.total_epochs: print('saving the model at the end of epoch %d, iters %d' % (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') trainer.save(epoch)
parser.add_argument('--save_freq', type=int, default=10) parser.add_argument('--print_loss_freq', type=int, default=40) parser.add_argument('--eval_mode', type=bool, default=False) # data loader parser.add_argument('--workers', type=int, default=4) parser.add_argument('--batch_size', type=int, default=1) # model hyperparameters parser.add_argument('--inner_channels', type=int, default=64) parser.add_argument('--dropout', type=float, default=0.0) parser.add_argument('--norm', type=str, default='batch') # training hyperparameters parser.add_argument('--lr', type=float, default=2e-4) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--decay_ratio', type=float, default=0.5) parser.add_argument('--lamb', type=float, default=10.0) parser.add_argument('--beta1', type=float, default=0.5) parser.add_argument('--beta2', type=float, default=0.999) args = parser.parse_args() trainer = None if args.model == 'pix2pix': trainer = Pix2PixTrainer(args=args) if args.op == 'train': trainer.train() else: trainer.test(args.checkpoint)
def main_worker(gpu, world_size, idx_server, opt): print('Use GPU: {} for training'.format(gpu)) ngpus_per_node = world_size world_size = opt.world_size rank = idx_server * ngpus_per_node + gpu opt.gpu = gpu dist.init_process_group(backend='nccl', init_method=opt.dist_url, world_size=world_size, rank=rank) torch.cuda.set_device(opt.gpu) # load the dataset dataloader = data.create_dataloader(opt, world_size, rank) # create trainer for our model trainer = Pix2PixTrainer(opt) # create tool for counting iterations iter_counter = IterationCounter(opt, len(dataloader), world_size, rank) # create tool for visualization visualizer = Visualizer(opt, rank) for epoch in iter_counter.training_epochs(): # set epoch for data sampler dataloader.sampler.set_epoch(epoch) iter_counter.record_epoch_start(epoch) for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): iter_counter.record_one_iteration() # Training # train generator trainer.run_generator_one_step(data_i) # train discriminator trainer.run_discriminator_one_step(data_i) # Visualizations if iter_counter.needs_printing(): losses = trainer.get_latest_losses() visualizer.print_current_errors(epoch, iter_counter.epoch_iter, losses, iter_counter.time_per_iter) visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) visuals = OrderedDict([('input_label', data_i['label']), ('synthesized_image', trainer.get_latest_generated()), ('real_image', data_i['image'])]) visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) if rank == 0: print('saving the latest model (epoch %d, total_steps %d)' % (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') iter_counter.record_current_iter() trainer.update_learning_rate(epoch) iter_counter.record_epoch_end() if (epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs) and (rank == 0): print('saving the model at the end of epoch %d, iters %d' % (epoch, iter_counter.total_steps_so_far)) trainer.save(epoch) print('Training was successfully finished.')
opt = TrainOptions().parse() # print options to help debugging print(' '.join(sys.argv)) #torch.manual_seed(0) # load the dataset dataloader = data.create_dataloader(opt) len_dataloader = len(dataloader) dataloader.dataset[11] # create tool for counting iterations iter_counter = IterationCounter(opt, len(dataloader)) # create trainer for our model trainer = Pix2PixTrainer(opt, resume_epoch=iter_counter.first_epoch) save_root = os.path.join(os.path.dirname(opt.checkpoints_dir), 'output', opt.name) for epoch in iter_counter.training_epochs(): opt.epoch = epoch if not opt.maskmix: print('inject nothing') elif opt.maskmix and opt.noise_for_mask and epoch > opt.mask_epoch: print('inject noise') else: print('inject mask') print('real_reference_probability is :{}'.format(dataloader.dataset.real_reference_probability)) print('hard_reference_probability is :{}'.format(dataloader.dataset.hard_reference_probability)) iter_counter.record_epoch_start(epoch) for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): iter_counter.record_one_iteration()
def do_train(opt): dataloader = data.create_dataloader(opt) # dataset [CustomDataset] of size 2000 was created # create trainer for our model trainer = Pix2PixTrainer(opt) # Network [SPADEGenerator] was created. Total number of parameters: 92.5 million. To see the architecture, do print(network). # Network [MultiscaleDiscriminator] was created. Total number of parameters: 5.6 million. To see the architecture, do print(network). # Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth # HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value=''))) # create tool for counting iterations iter_counter = IterationCounter(opt, len(dataloader)) # create tool for visualization visualizer = Visualizer(opt) # create web directory ./checkpoints/ipdb_test/web... for epoch in iter_counter.training_epochs(): iter_counter.record_epoch_start(epoch) for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): # data_i = # {'label': tensor([[[[ 0., 0., 0., ..., 0., 0., 0.], # [ 0., 0., 0., ..., 0., 0., 0.], # [ 0., 0., 0., ..., 0., 0., 0.], # ..., # [ 0., 0., 0., ..., 13., 13., 13.], # [ 0., 0., 0., ..., 13., 13., 13.], # [ 0., 0., 0., ..., 13., 13., 13.]]]]), 'instance': tensor([0]), 'image': tensor([[[[-1.0000, -1.0000, -0.9922, ..., 0.5529, 0.5529, 0.5529], # [-1.0000, -1.0000, -0.9922, ..., 0.5529, 0.5529, 0.5529], # [-1.0000, -0.9922, -0.9843, ..., 0.5529, 0.5529, 0.5529], # ..., # [ 0.4118, 0.4275, 0.4118, ..., -0.7490, -0.7333, -0.7020], # [ 0.4196, 0.4039, 0.4196, ..., -0.7020, -0.7804, -0.7255], # [ 0.4039, 0.4196, 0.4588, ..., -0.6784, -0.7333, -0.6941]], # [[-0.9529, -0.9686, -0.9843, ..., 0.5843, 0.5843, 0.5843], # [-0.9529, -0.9686, -0.9843, ..., 0.5843, 0.5843, 0.5843], # [-0.9608, -0.9686, -0.9765, ..., 0.5843, 0.5843, 0.5843], # ..., # [ 0.4431, 0.4588, 0.4431, ..., -0.8510, -0.8353, -0.8039], # [ 0.4510, 0.4353, 0.4510, ..., -0.8039, -0.8824, -0.8275], # [ 0.4353, 0.4510, 0.4902, ..., -0.7725, -0.8275, -0.7882]], # [[-0.9843, -1.0000, -1.0000, ..., 0.6549, 0.6549, 0.6549], # [-0.9843, -1.0000, -1.0000, ..., 0.6549, 0.6549, 0.6549], # [-0.9922, -1.0000, -0.9922, ..., 0.6549, 0.6549, 0.6549], # ..., # [ 0.5294, 0.5451, 0.5294, ..., -0.9216, -0.8980, -0.8667], # [ 0.5373, 0.5216, 0.5373, ..., -0.8824, -0.9529, -0.8980], # [ 0.5216, 0.5373, 0.5765, ..., -0.8667, -0.9216, -0.8824]]]]), 'path': ['../../Celeb_subset/train/images/8516.jpg']} iter_counter.record_one_iteration() # Training # train generator if i % opt.D_steps_per_G == 0: trainer.run_generator_one_step(data_i) # train discriminator trainer.run_discriminator_one_step(data_i) # Visualizations if iter_counter.needs_printing(): losses = trainer.get_latest_losses() visualizer.print_current_errors(epoch, iter_counter.epoch_iter, losses, iter_counter.time_per_iter) visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) if iter_counter.needs_displaying(): visuals = OrderedDict([('input_label', data_i['label']), ('synthesized_image', trainer.get_latest_generated()), ('real_image', data_i['image'])]) visualizer.display_current_results( visuals, epoch, iter_counter.total_steps_so_far) if iter_counter.needs_saving(): print('saving the latest model (epoch %d, total_steps %d)' % (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') iter_counter.record_current_iter() trainer.update_learning_rate(epoch) iter_counter.record_epoch_end() if epoch % opt.save_epoch_freq == 0 or \ epoch == iter_counter.total_epochs: print('saving the model at the end of epoch %d, iters %d' % (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') trainer.save(epoch) print('Training was successfully finished.')
def do_train(opt): # print options to help debugging print(' '.join(sys.argv)) # load the dataset dataloader = data.create_dataloader(opt) # create trainer for our model trainer = Pix2PixTrainer(opt) # create tool for counting iterations iter_counter = IterationCounter(opt, len(dataloader)) # create tool for visualization visualizer = Visualizer(opt) if opt.train_eval: # val_opt = TestOptions().parse() original_flip = opt.no_flip opt.no_flip = True opt.phase = 'test' opt.isTrain = False dataloader_val = data.create_dataloader(opt) val_visualizer = Visualizer(opt) # # create a webpage that summarizes the all results web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) opt.phase = 'train' opt.isTrain = True opt.no_flip = original_flip # process for calculate FID scores from inception import InceptionV3 from fid_score import calculate_fid_given_paths import pathlib # define the inceptionV3 block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[opt.eval_dims] eval_model = InceptionV3([block_idx]).cuda() # load real images distributions on the training set mu_np_root = os.path.join('datasets/train_mu_si',opt.dataset_mode,'m.npy') st_np_root = os.path.join('datasets/train_mu_si',opt.dataset_mode,'s.npy') m0, s0 = np.load(mu_np_root), np.load(st_np_root) # load previous best FID if opt.continue_train: fid_record_dir = os.path.join(opt.checkpoints_dir, opt.name, 'fid.txt') FID_score, _ = np.loadtxt(fid_record_dir, delimiter=',', dtype=float) else: FID_score = 1000 else: FID_score = 1000 for epoch in iter_counter.training_epochs(): iter_counter.record_epoch_start(epoch) for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): iter_counter.record_one_iteration() # Training # train generator if i % opt.D_steps_per_G == 0: trainer.run_generator_one_step(data_i) # train discriminator trainer.run_discriminator_one_step(data_i) # Visualizations if iter_counter.needs_printing(): losses = trainer.get_latest_losses() if opt.train_eval: visualizer.print_current_errors(epoch, iter_counter.epoch_iter, losses, iter_counter.time_per_iter, FID_score) else: visualizer.print_current_errors(epoch, iter_counter.epoch_iter, losses, iter_counter.time_per_iter) visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) # if iter_counter.needs_displaying(): # visuals = OrderedDict([('input_label', data_i['label']), # ('synthesized_image', trainer.get_latest_generated()), # ('real_image', data_i['image'])]) # visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) if iter_counter.needs_saving(): print('saving the latest model (epoch %d, total_steps %d)' % (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') iter_counter.record_current_iter(FID_score) trainer.update_learning_rate(epoch) iter_counter.record_epoch_end() if epoch % opt.eval_epoch_freq == 0 and opt.train_eval: # generate fake image trainer.pix2pix_model.eval() print('start evalidation .... ') if opt.use_vae: flag = True opt.use_vae = False else: flag = False for i, data_i in enumerate(dataloader_val): if data_i['label'].size()[0] != opt.batchSize: if opt.batchSize > 2*data_i['label'].size()[0]: print('batch size is too large') break data_i = repair_data(data_i, opt.batchSize) generated = trainer.pix2pix_model(data_i, mode='inference') img_path = data_i['path'] for b in range(generated.shape[0]): tmp = tensor2im(generated[b]) visuals = OrderedDict([('input_label', data_i['label'][b]), ('synthesized_image', generated[b])]) val_visualizer.save_images(webpage, visuals, img_path[b:b + 1]) webpage.save() trainer.pix2pix_model.train() if flag: opt.use_vae = True # cal fid score fake_path = pathlib.Path(os.path.join(web_dir, 'images/synthesized_image/')) files = list(fake_path.glob('*.jpg')) + list(fake_path.glob('*.png')) m1, s1 = calculate_activation_statistics(files, eval_model, 1, opt.eval_dims, True, images=None) fid_value = calculate_frechet_distance(m0, s0, m1, s1) visualizer.print_eval_fids(epoch, fid_value, FID_score) # save the best model if necessary if fid_value < FID_score: FID_score = fid_value trainer.save('best') if epoch % opt.save_epoch_freq == 0 or \ epoch == iter_counter.total_epochs: print('saving the model at the end of epoch %d, iters %d' % (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') trainer.save(epoch) print('Training was successfully finished.')