def generate(args): # Context ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size # Generator nn.load_parameters(args.model_load_path) z_test = nn.Variable([batch_size, latent]) x_test = generator(z_test, maps=maps, test=True, up=args.up) # Monitor monitor = Monitor(args.monitor_path) monitor_image_tile_test = MonitorImageTile("Image Tile Generated", monitor, num_images=batch_size, interval=1, normalize_method=denormalize) # Generation iteration for i in range(args.num_generation): z_test.d = np.random.randn(batch_size, latent) x_test.forward(clear_buffer=True) monitor_image_tile_test.add(i, x_test)
def generate(args): ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) scope_gen = "Generator" scope_gen_ema = "Generator_EMA" gen_param_path = args.model_load_path + '/Gen_iter100000.h5' gen_ema_param_path = args.model_load_path + '/GenEMA_iter100000.h5' with nn.parameter_scope(scope_gen): nn.load_parameters(gen_param_path) with nn.parameter_scope(scope_gen_ema): nn.load_parameters(gen_ema_param_path) monitor = Monitor(args.monitor_path) monitor_image_tile_test = MonitorImageTile("Image Tile", monitor, num_images=args.batch_size, interval=1, normalize_method=lambda x: (x + 1.) / 2.) monitor_image_tile_test_ema = MonitorImageTile("Image Tile with EMA", monitor, num_images=args.batch_size, interval=1, normalize_method=lambda x: (x + 1.) / 2.) z_test = nn.Variable([args.batch_size, args.latent, 1, 1]) x_test = Generator(z_test, scope_name=scope_gen, train=True, img_size=args.image_size)[0] x_test_ema = Generator(z_test, scope_name=scope_gen_ema, train=True, img_size=args.image_size)[0] z_test.d = np.random.randn(args.batch_size, args.latent, 1, 1) x_test.forward(clear_buffer=True) x_test_ema.forward(clear_buffer=True) monitor_image_tile_test.add(0, x_test) monitor_image_tile_test_ema.add(0, x_test_ema)
def train(data_iterator, monitor, config, comm, args): monitor_train_loss, monitor_train_recon = None, None monitor_val_loss, monitor_val_recon = None, None if comm.rank == 0: monitor_train_loss = MonitorSeries( config['monitor']['train_loss'], monitor, interval=config['train']['logger_step_interval']) monitor_train_recon = MonitorImageTile(config['monitor']['train_recon'], monitor, interval=config['train']['logger_step_interval'], num_images=config['train']['batch_size']) monitor_val_loss = MonitorSeries( config['monitor']['val_loss'], monitor, interval=config['train']['logger_step_interval']) monitor_val_recon = MonitorImageTile(config['monitor']['val_recon'], monitor, interval=config['train']['logger_step_interval'], num_images=config['train']['batch_size']) model = VQVAE(config) if not args.sample_from_pixelcnn: if config['train']['solver'] == 'adam': solver = S.Adam() else: solver = S.momentum() solver.set_learning_rate(config['train']['learning_rate']) train_loader = data_iterator(config, comm, train=True) if config['dataset']['name'] != 'imagenet': val_loader = data_iterator(config, comm, train=False) else: val_loader = None else: solver, train_loader, val_loader = None, None, None if not args.pixelcnn_prior: trainer = VQVAEtrainer(model, solver, train_loader, val_loader, monitor_train_loss, monitor_train_recon, monitor_val_loss, monitor_val_recon, config, comm) num_epochs = config['train']['num_epochs'] else: pixelcnn_model = GatedPixelCNN(config['prior']) trainer = TrainerPrior(model, pixelcnn_model, solver, train_loader, val_loader, monitor_train_loss, monitor_train_recon, monitor_val_loss, monitor_val_recon, config, comm, eval=args.sample_from_pixelcnn) num_epochs = config['prior']['train']['num_epochs'] if os.path.exists(config['model']['checkpoint']) and (args.load_checkpoint or args.sample_from_pixelcnn): checkpoint_path = config['model']['checkpoint'] if not args.pixelcnn_prior else config['prior']['checkpoint'] trainer.load_checkpoint(checkpoint_path, msg='Parameters loaded from {}'.format( checkpoint_path), pixelcnn=args.pixelcnn_prior, load_solver=not args.sample_from_pixelcnn) if args.sample_from_pixelcnn: trainer.random_generate( args.sample_from_pixelcnn, args.sample_save_path) return for epoch in range(num_epochs): trainer.train(epoch) if epoch % config['val']['interval'] == 0 and val_loader != None: trainer.validate(epoch) if comm.rank == 0: if epoch % config['train']['save_param_step_interval'] == 0 or epoch == config['train']['num_epochs']-1: trainer.save_checkpoint( config['model']['saved_models_dir'], epoch, pixelcnn=args.pixelcnn_prior)
def train(args): # Context ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size image_size = args.image_size lambda_ = args.lambda_ # Model # generator loss z = nn.Variable([batch_size, latent]) x_fake = generator(z, maps=maps, up=args.up).apply(persistent=True) p_fake = discriminator(x_fake, maps=maps) loss_gen = gan_loss(p_fake).apply(persistent=True) # discriminator loss p_fake = discriminator(x_fake, maps=maps) x_real = nn.Variable([batch_size, 3, image_size, image_size]) p_real = discriminator(x_real, maps=maps) loss_dis = gan_loss(p_fake, p_real).apply(persistent=True) # gradient penalty eps = F.rand(shape=[batch_size, 1, 1, 1]) x_rmix = eps * x_real + (1.0 - eps) * x_fake p_rmix = discriminator(x_rmix, maps=maps) x_rmix.need_grad = True # Enabling gradient computation for double backward grads = nn.grad([p_rmix], [x_rmix]) l2norms = [F.sum(g**2.0, [1, 2, 3])**0.5 for g in grads] gp = sum([F.mean((l - 1.0)**2.0) for l in l2norms]) loss_dis += lambda_ * gp # generator with fixed value for test z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent)) x_test = generator(z_test, maps=maps, test=True, up=args.up).apply(persistent=True) # Solver solver_gen = S.Adam(args.lrg, args.beta1, args.beta2) solver_dis = S.Adam(args.lrd, args.beta1, args.beta2) with nn.parameter_scope("generator"): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) with nn.parameter_scope("discriminator"): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor monitor = Monitor(args.monitor_path) monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10) monitor_loss_cri = MonitorSeries("Negative Critic Loss", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training Time", monitor, interval=10) monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor, num_images=batch_size, interval=1, normalize_method=denormalize) monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor, num_images=batch_size, interval=1, normalize_method=denormalize) # Data Iterator di = data_iterator_cifar10(batch_size, True) # Train loop for i in range(args.max_iter): # Train discriminator x_fake.need_grad = False # no need backward to generator for _ in range(args.n_critic): solver_dis.zero_grad() x_real.d = di.next()[0] / 127.5 - 1.0 z.d = np.random.randn(batch_size, latent) loss_dis.forward(clear_no_need_grad=True) loss_dis.backward(clear_buffer=True) solver_dis.update() # Train generator x_fake.need_grad = True # need backward to generator solver_gen.zero_grad() z.d = np.random.randn(batch_size, latent) loss_gen.forward(clear_no_need_grad=True) loss_gen.backward(clear_buffer=True) solver_gen.update() # Monitor monitor_loss_gen.add(i, loss_gen.d) monitor_loss_cri.add(i, -loss_dis.d) monitor_time.add(i) # Save if i % args.save_interval == 0: monitor_image_tile_train.add(i, x_fake) monitor_image_tile_test.add(i, x_test) nn.save_parameters( os.path.join(args.monitor_path, "params_{}.h5".format(i))) # Last x_test.forward(clear_buffer=True) nn.save_parameters( os.path.join(args.monitor_path, "params_{}.h5".format(i))) monitor_image_tile_train.add(i, x_fake) monitor_image_tile_test.add(i, x_test)
def generate(args): # Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn threshold = args.truncation_threshold # Model nn.load_parameters(args.model_load_path) z = nn.Variable([batch_size, latent]) y_fake = nn.Variable([batch_size]) x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\ .apply(persistent=True) # Generate All if args.generate_all: # Monitor monitor = Monitor(args.monitor_path) name = "Generated Image Tile All" monitor_image = MonitorImageTile(name, monitor, interval=1, num_images=args.batch_size, normalize_method=normalize_method) # Generate images for all classes for class_id in range(args.n_classes): # Generate z_data = resample(batch_size, latent, threshold) y_data = generate_one_class(class_id, batch_size) z.d = z_data y_fake.d = y_data x_fake.forward(clear_buffer=True) monitor_image.add(class_id, x_fake.d) return # Generate Indivisually monitor = Monitor(args.monitor_path) name = "Generated Image Tile {}".format( args.class_id) if args.class_id != -1 else "Generated Image Tile" monitor_image_tile = MonitorImageTile(name, monitor, interval=1, num_images=args.batch_size, normalize_method=normalize_method) name = "Generated Image {}".format( args.class_id) if args.class_id != -1 else "Generated Image" monitor_image = MonitorImage(name, monitor, interval=1, num_images=args.batch_size, normalize_method=normalize_method) z_data = resample(batch_size, latent, threshold) y_data = generate_random_class(n_classes, batch_size) if args.class_id == -1 else \ generate_one_class(args.class_id, batch_size) z.d = z_data y_fake.d = y_data x_fake.forward(clear_buffer=True) monitor_image.add(0, x_fake.d) monitor_image_tile.add(0, x_fake.d)
def __init__(self, monitor, config, args, comm, few_shot_config): super(Train, self).__init__(monitor, config, args, comm, few_shot_config) # Initialize Monitor self.monitor_train_loss, self.monitor_train_gen = None, None self.monitor_val_loss, self.monitor_val_gen = None, None if comm is not None: if comm.rank == 0: self.monitor_train_gen_loss = MonitorSeries( config['monitor']['train_loss'], monitor, interval=self.config['logger_step_interval']) self.monitor_train_gen = MonitorImageTile( config['monitor']['train_gen'], monitor, interval=self.config['logger_step_interval'], num_images=self.config['batch_size']) self.monitor_train_disc_loss = MonitorSeries( config['monitor']['train_loss'], monitor, interval=self.config['logger_step_interval']) os.makedirs(self.config['saved_weights_dir'], exist_ok=True) self.results_dir = args.results_dir self.save_weights_dir = args.weights_path self.few_shot_config = few_shot_config # Initialize Discriminator self.discriminator = Discriminator(config['discriminator'], self.img_size) self.gen_exp_weight = 0.5**(32 / (10 * 1000)) self.generator_ema = Generator(config['generator'], self.img_size, config['train']['mix_after'], global_scope='GeneratorEMA') # Initialize Solver if 'gen_solver' not in dir(self): if self.config['solver'] == 'Adam': self.gen_solver = S.Adam(beta1=0, beta2=0.99) self.disc_solver = S.Adam(beta1=0, beta2=0.99) else: self.gen_solver = eval('S.' + self.config['solver'])() self.disc_solver = eval('S.' + self.config['solver'])() self.gen_solver.set_learning_rate(self.config['learning_rate']) self.disc_solver.set_learning_rate(self.config['learning_rate']) self.gen_mean_path_length = 0.0 self.dali = args.dali self.args = args # Initialize Dataloader if args.data == 'ffhq': if args.dali: self.train_loader = get_dali_iterator_ffhq( args.dataset_path, config['data'], self.img_size, self.batch_size, self.comm) else: self.train_loader = get_data_iterator_ffhq( args.dataset_path, config['data'], self.batch_size, self.img_size, self.comm) else: print('Dataset not recognized') exit(1) # Start training self.train()
def match(args): # Context extension_module = "cudnn" ctx = get_extension_context(extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = 1 image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn threshold = args.truncation_threshold # Model (SAGAN) nn.load_parameters(args.model_load_path) z = nn.Variable([batch_size, latent]) y_fake = nn.Variable([batch_size]) x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\ .apply(persistent=True) # Model (Inception model) from nnp file nnp = NnpLoader(args.nnp_inception_model_load_path) x, h = get_input_and_output(nnp, batch_size, args.variable_name) # DataIterator for a given class_id di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path, batch_size=batch_size, n_classes=args.n_classes, noise=False, class_id=args.class_id) # Monitor monitor = Monitor(args.monitor_path) name = "Matched Image {}".format(args.class_id) monitor_image = MonitorImage(name, monitor, interval=1, num_images=batch_size, normalize_method=lambda x: (x + 1.) / 2. * 255.) name = "Matched Image Tile {}".format(args.class_id) monitor_image_tile = MonitorImageTile(name, monitor, interval=1, num_images=batch_size + args.top_n, normalize_method=lambda x: (x + 1.) / 2. * 255.) # Generate and p(h|x).forward # generate z_data = resample(batch_size, latent, threshold) y_data = generate_one_class(args.class_id, batch_size) z.d = z_data y_fake.d = y_data x_fake.forward(clear_buffer=True) # p(h|x).forward x_fake_d = x_fake.d.copy() x_fake_d = preprocess( x_fake_d, (args.image_size, args.image_size), args.nnp_preprocess) x.d = x_fake_d h.forward(clear_buffer=True) h_fake_d = h.d.copy() # Feature matching norm2_list = [] x_data_list = [] x_data_list.append(x_fake.d) for i in range(di.size): # forward for real data x_d, _ = di.next() x_data_list.append(x_d) x_d = preprocess( x_d, (args.image_size, args.image_size), args.nnp_preprocess) x.d = x_d h.forward(clear_buffer=True) h_real_d = h.d.copy() # norm computation axis = tuple(np.arange(1, len(h.shape)).tolist()) norm2 = np.sum((h_real_d - h_fake_d) ** 2.0, axis=axis) norm2_list.append(norm2) # Save top-n images argmins = np.argsort(norm2_list) for i in range(args.top_n): monitor_image.add(i, x_data_list[i]) matched_images = np.concatenate(x_data_list) monitor_image_tile.add(0, matched_images)
def train(args): # Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = comm.local_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn # Model # workaround to start with the same weights in the distributed system. np.random.seed(412) # generator loss z = nn.Variable([batch_size, latent]) y_fake = nn.Variable([batch_size]) x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, sn=not_sn).apply(persistent=True) p_fake = discriminator(x_fake, y_fake, maps=maps // 16, n_classes=n_classes, sn=not_sn) loss_gen = gan_loss(p_fake) # discriminator loss y_real = nn.Variable([batch_size]) x_real = nn.Variable([batch_size, 3, image_size, image_size]) p_real = discriminator(x_real, y_real, maps=maps // 16, n_classes=n_classes, sn=not_sn) loss_dis = gan_loss(p_fake, p_real) # generator with fixed value for test z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent)) y_test = nn.Variable.from_numpy_array( generate_random_class(n_classes, batch_size)) x_test = generator(z_test, y_test, maps=maps, n_classes=n_classes, test=True, sn=not_sn) # Solver solver_gen = S.Adam(args.lrg, args.beta1, args.beta2) solver_dis = S.Adam(args.lrd, args.beta1, args.beta2) with nn.parameter_scope("generator"): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) with nn.parameter_scope("discriminator"): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor if comm.rank == 0: monitor = Monitor(args.monitor_path) monitor_loss_gen = MonitorSeries( "Generator Loss", monitor, interval=10) monitor_loss_dis = MonitorSeries( "Discriminator Loss", monitor, interval=10) monitor_time = MonitorTimeElapsed( "Training Time", monitor, interval=10) monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor, num_images=args.batch_size, interval=1, normalize_method=normalize_method) monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor, num_images=args.batch_size, interval=1, normalize_method=normalize_method) # DataIterator rng = np.random.RandomState(device_id) di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path, args.batch_size, n_classes=args.n_classes, rng=rng) # Train loop for i in range(args.max_iter): # Train discriminator x_fake.need_grad = False # no need for discriminator backward solver_dis.zero_grad() for _ in range(args.accum_grad): # feed x_real and y_real x_data, y_data = di.next() x_real.d, y_real.d = x_data, y_data.flatten() # feed z and y_fake z_data = np.random.randn(args.batch_size, args.latent) y_data = generate_random_class(args.n_classes, args.batch_size) z.d, y_fake.d = z_data, y_data loss_dis.forward(clear_no_need_grad=True) loss_dis.backward( 1.0 / (args.accum_grad * n_devices), clear_buffer=True) comm.all_reduce([v.grad for v in params_dis.values()]) solver_dis.update() # Train genrator x_fake.need_grad = True # need for generator backward solver_gen.zero_grad() for _ in range(args.accum_grad): z_data = np.random.randn(args.batch_size, args.latent) y_data = generate_random_class(args.n_classes, args.batch_size) z.d, y_fake.d = z_data, y_data loss_gen.forward(clear_no_need_grad=True) loss_gen.backward( 1.0 / (args.accum_grad * n_devices), clear_buffer=True) comm.all_reduce([v.grad for v in params_gen.values()]) solver_gen.update() # Synchronize by averaging the weights over devices using allreduce if i % args.sync_weight_every_itr == 0: weights = [v.data for v in nn.get_parameters().values()] comm.all_reduce(weights, division=True, inplace=True) # Save model and image if i % args.save_interval == 0 and comm.rank == 0: x_test.forward(clear_buffer=True) nn.save_parameters(os.path.join( args.monitor_path, "params_{}.h5".format(i))) monitor_image_tile_train.add(i, x_fake.d) monitor_image_tile_test.add(i, x_test.d) # Monitor if comm.rank == 0: monitor_loss_gen.add(i, loss_gen.d.copy()) monitor_loss_dis.add(i, loss_dis.d.copy()) monitor_time.add(i) if comm.rank == 0: x_test.forward(clear_buffer=True) nn.save_parameters(os.path.join( args.monitor_path, "params_{}.h5".format(i))) monitor_image_tile_train.add(i, x_fake.d) monitor_image_tile_test.add(i, x_test.d)
def interpolate(args): # Load model nn.load_parameters(args.model_load_path) # Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) nn.set_default_context(ctx) # Input b, c, h, w = 1, 3, args.image_size, args.image_size x_real_a = nn.Variable([b, c, h, w]) x_real_b = nn.Variable([b, c, h, w]) one = nn.Variable.from_numpy_array(np.ones((1, 1, 1, 1)) * 0.5) # Model maps = args.maps # content/style (domain A) x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a") x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a") # content/style (domain B) x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b") x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b") # generate over domains and reconstruction of content and style (domain A) z_style_a = nn.Variable( x_style_a.shape) if not args.example_guided else x_style_a z_style_a = z_style_a.apply(persistent=True) x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a") # generate over domains and reconstruction of content and style (domain B) z_style_b = nn.Variable( x_style_b.shape) if not args.example_guided else x_style_b z_style_b = z_style_b.apply(persistent=True) x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b") # Monitor def file_names(path): return path.split("/")[-1].rstrip("_AB.jpg") suffix = "Stochastic" if not args.example_guided else "Example-guided" monitor = Monitor(args.monitor_path) monitor_image_tile_a = MonitorImageTile( "Fake Image Tile {} B to A {} Interpolation".format( "-".join([file_names(path) for path in args.img_files_b]), suffix), monitor, interval=1, num_images=len(args.img_files_b)) monitor_image_tile_b = MonitorImageTile( "Fake Image Tile {} A to B {} Interpolation".format( "-".join([file_names(path) for path in args.img_files_a]), suffix), monitor, interval=1, num_images=len(args.img_files_a)) # DataIterator di_a = munit_data_iterator(args.img_files_a, b, shuffle=False) di_b = munit_data_iterator(args.img_files_b, b, shuffle=False) rng = np.random.RandomState(args.seed) # Interpolate (A -> B) z_data_0 = [rng.randn(*z_style_a.shape) for j in range(di_a.size)] z_data_1 = [rng.randn(*z_style_a.shape) for j in range(di_a.size)] for i in range(args.num_repeats): r = 1.0 * i / args.num_repeats images = [] for j in range(di_a.size): x_data_a = di_a.next()[0] x_real_a.d = x_data_a z_style_b.d = z_data_0[j] * (1.0 - r) + z_data_1[j] * r x_fake_b.forward(clear_buffer=True) cmp_image = np.concatenate([x_data_a, x_fake_b.d.copy()], axis=3) images.append(cmp_image) images = np.concatenate(images) monitor_image_tile_b.add(i, images) # Interpolate (B -> A) z_data_0 = [rng.randn(*z_style_b.shape) for j in range(di_b.size)] z_data_1 = [rng.randn(*z_style_b.shape) for j in range(di_b.size)] for i in range(args.num_repeats): r = 1.0 * i / args.num_repeats images = [] for j in range(di_b.size): x_data_b = di_b.next()[0] x_real_b.d = x_data_b z_style_a.d = z_data_0[j] * (1.0 - r) + z_data_1[j] * r x_fake_a.forward(clear_buffer=True) cmp_image = np.concatenate([x_data_b, x_fake_a.d.copy()], axis=3) images.append(cmp_image) images = np.concatenate(images) monitor_image_tile_a.add(i, images)
def train(args): # Context ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) aug_list = args.aug_list # Model scope_gen = "Generator" scope_dis = "Discriminator" # generator loss z = nn.Variable([args.batch_size, args.latent, 1, 1]) x_fake = Generator(z, scope_name=scope_gen, img_size=args.image_size) p_fake = Discriminator([augment(xf, aug_list) for xf in x_fake], label="fake", scope_name=scope_dis) lossG = loss_gen(p_fake) # discriminator loss x_real = nn.Variable( [args.batch_size, 3, args.image_size, args.image_size]) x_real_aug = augment(x_real, aug_list) p_real, rec_imgs, part = Discriminator( x_real_aug, label="real", scope_name=scope_dis) lossD_fake = loss_dis_fake(p_fake) lossD_real = loss_dis_real(p_real, rec_imgs, part, x_real_aug) lossD = lossD_fake + lossD_real # generator with fixed latent values for test # Use train=True even in an inference phase z_test = nn.Variable.from_numpy_array( np.random.randn(args.batch_size, args.latent, 1, 1)) x_test = Generator(z_test, scope_name=scope_gen, train=True, img_size=args.image_size)[0] # Exponential Moving Average (EMA) model # Use train=True even in an inference phase scope_gen_ema = "Generator_EMA" x_test_ema = Generator(z_test, scope_name=scope_gen_ema, train=True, img_size=args.image_size)[0] copy_params(scope_gen, scope_gen_ema) update_ema_var = make_ema_updater(scope_gen_ema, scope_gen, 0.999) # Solver solver_gen = S.Adam(args.lr, beta1=0.5) solver_dis = S.Adam(args.lr, beta1=0.5) with nn.parameter_scope(scope_gen): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) with nn.parameter_scope(scope_dis): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor monitor = Monitor(args.monitor_path) monitor_loss_gen = MonitorSeries( "Generator Loss", monitor, interval=10) monitor_loss_dis_real = MonitorSeries( "Discriminator Loss Real", monitor, interval=10) monitor_loss_dis_fake = MonitorSeries( "Discriminator Loss Fake", monitor, interval=10) monitor_time = MonitorTimeElapsed( "Training Time", monitor, interval=10) monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor, num_images=args.batch_size, interval=1, normalize_method=lambda x: (x + 1.) / 2.) monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor, num_images=args.batch_size, interval=1, normalize_method=lambda x: (x + 1.) / 2.) monitor_image_tile_test_ema = MonitorImageTile("Image Tile Test EMA", monitor, num_images=args.batch_size, interval=1, normalize_method=lambda x: (x + 1.) / 2.) # Data Iterator rng = np.random.RandomState(141) di = data_iterator(args.img_path, args.batch_size, imsize=(args.image_size, args.image_size), num_samples=args.train_samples, rng=rng) # Train loop for i in range(args.max_iter): # Train discriminator x_fake[0].need_grad = False # no need backward to generator x_fake[1].need_grad = False # no need backward to generator solver_dis.zero_grad() x_real.d = di.next()[0] z.d = np.random.randn(args.batch_size, args.latent, 1, 1) lossD.forward() lossD.backward() solver_dis.update() # Train generator x_fake[0].need_grad = True # need backward to generator x_fake[1].need_grad = True # need backward to generator solver_gen.zero_grad() lossG.forward() lossG.backward() solver_gen.update() # Update EMA model update_ema_var.forward() # Monitor monitor_loss_gen.add(i, lossG.d) monitor_loss_dis_real.add(i, lossD_real.d) monitor_loss_dis_fake.add(i, lossD_fake.d) monitor_time.add(i) # Save if (i+1) % args.save_interval == 0: with nn.parameter_scope(scope_gen): nn.save_parameters(os.path.join( args.monitor_path, "Gen_iter{}.h5".format(i+1))) with nn.parameter_scope(scope_gen_ema): nn.save_parameters(os.path.join( args.monitor_path, "GenEMA_iter{}.h5".format(i+1))) with nn.parameter_scope(scope_dis): nn.save_parameters(os.path.join( args.monitor_path, "Dis_iter{}.h5".format(i+1))) if (i+1) % args.test_interval == 0: x_test.forward(clear_buffer=True) x_test_ema.forward(clear_buffer=True) monitor_image_tile_train.add(i+1, x_fake[0]) monitor_image_tile_test.add(i+1, x_test) monitor_image_tile_test_ema.add(i+1, x_test_ema) # Last x_test.forward(clear_buffer=True) x_test_ema.forward(clear_buffer=True) monitor_image_tile_train.add(args.max_iter, x_fake[0]) monitor_image_tile_test.add(args.max_iter, x_test) monitor_image_tile_test_ema.add(args.max_iter, x_test_ema) with nn.parameter_scope(scope_gen): nn.save_parameters(os.path.join(args.monitor_path, "Gen_iter{}.h5".format(args.max_iter))) with nn.parameter_scope(scope_gen_ema): nn.save_parameters(os.path.join(args.monitor_path, "GenEMA_iter{}.h5".format(args.max_iter))) with nn.parameter_scope(scope_dis): nn.save_parameters(os.path.join(args.monitor_path, "Dis_iter{}.h5".format(args.max_iter)))
def main(args): from numpy.random import seed seed(46) # Get context. from nnabla.ext_utils import get_extension_context ctx = get_extension_context('cudnn', device_id='0', type_config='float') nn.set_default_context(ctx) # Create CNN network # === TRAIN === # Create input variables. image = nn.Variable([args.batch_size, 3, args.img_height, args.img_width]) label = nn.Variable([args.batch_size, 1, args.img_height, args.img_width]) # Create prediction graph. pred = depth_cnn_model(image, test=False) pred.persistent = True # Create loss function. loss = l1_loss(pred, label) # === VAL === #vimage = nn.Variable([args.batch_size, 3, args.img_height, args.img_width]) #vlabel = nn.Variable([args.batch_size, 1, args.img_height, args.img_width]) #vpred = depth_cnn_model(vimage, test=True) #vloss = l1_loss(vpred, vlabel) # Prepare monitors. monitor = Monitor(os.path.join(args.log_dir, 'nnmonitor')) monitors = { 'train_epoch_loss': MonitorSeries('Train epoch loss', monitor, interval=1), 'train_itr_loss': MonitorSeries('Train itr loss', monitor, interval=100), # 'val_epoch_loss': MonitorSeries('Val epoch loss', monitor, interval=1), 'train_viz': MonitorImageTile('Train images', monitor, interval=1000, num_images=4) } # Create Solver. If training from checkpoint, load the info. if args.optimizer == "adam": solver = S.Adam(alpha=args.learning_rate, beta1=0.9, beta2=0.999) elif args.optimizer == "sgd": solver = S.Momentum(lr=args.learning_rate, momentum=0.9) solver.set_parameters(nn.get_parameters()) # Initialize DataIterator data_dic = prepare_dataloader(args.dataset_path, datatype_list=['train', 'val'], batch_size=args.batch_size, img_size=(args.img_height, args.img_width)) # Training loop. logger.info("Start training!!!") total_itr_index = 0 for epoch in range(1, args.epochs + 1): ## === training === ## total_train_loss = 0 index = 0 while index < data_dic['train']['size']: # Preprocess image.d, label.d = data_dic['train']['itr'].next() loss.forward(clear_no_need_grad=True) # Initialize gradients solver.zero_grad() # Backward execution loss.backward(clear_buffer=True) # Update parameters by computed gradients if args.optimizer == 'sgd': solver.weight_decay(1e-4) solver.update() # Update log index += 1 total_itr_index += 1 total_train_loss += loss.d # Pass to monitor monitors['train_itr_loss'].add(total_itr_index, loss.d) # Visualization pred.forward(clear_buffer=True) train_viz = np.concatenate([ image.d, convert_depth2colormap(label.d), convert_depth2colormap(pred.d) ], axis=3) monitors['train_viz'].add(total_itr_index, train_viz) # Logger logger.info("[{}] {}/{} Train Loss {} ({})".format( epoch, index, data_dic['train']['size'], total_train_loss / index, loss.d)) # Pass training loss to a monitor. train_error = total_train_loss / data_dic['train']['size'] monitors['train_epoch_loss'].add(epoch, train_error) # Save Parameter out_param_file = os.path.join(args.log_dir, 'checkpoint' + str(epoch) + '.h5') nn.save_parameters(out_param_file)