def main(args): # Setting device_id = args.device_id conf = args.conf path = conf.data_path B = conf.batch_size R = conf.n_rays L = conf.layers D = conf.depth feature_size = conf.feature_size ctx = get_extension_context('cudnn', device_id=device_id) nn.set_default_context(ctx) # Dataset ds = DTUMVSDataSource(path, R, shuffle=True) # Monitor monitor_path = "/".join(args.model_load_path.split("/")[0:-1]) monitor = Monitor(monitor_path) monitor_image = MonitorImage( f"Rendered image synthesis", monitor, interval=1) # Load model nn.load_parameters(args.model_load_path) # Render pose = ds.poses[conf.valid_index:conf.valid_index+1, ...] intrinsic = ds.intrinsics[conf.valid_index:conf.valid_index+1, ...] mask_obj = ds.masks[conf.valid_index:conf.valid_index+1, ...] image = render(pose, intrinsic, mask_obj, conf) monitor_image.add(conf.valid_index, image)
def morph(args): # Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn threshold = args.truncation_threshold # Model nn.load_parameters(args.model_load_path) z = nn.Variable([batch_size, latent]) alpha = nn.Variable.from_numpy_array(np.zeros([1, 1])) beta = (nn.Variable.from_numpy_array(np.ones([1, 1])) - alpha) y_fake_a = nn.Variable([batch_size]) y_fake_b = nn.Variable([batch_size]) x_fake = generator(z, [y_fake_a, y_fake_b], maps=maps, n_classes=n_classes, test=True, sn=not_sn, coefs=[alpha, beta]).apply(persistent=True) b, c, h, w = x_fake.shape # Monitor monitor = Monitor(args.monitor_path) name = "Morphed Image {} {}".format(args.from_class_id, args.to_class_id) monitor_image = MonitorImage(name, monitor, interval=1, num_images=1, normalize_method=normalize_method) # Morph images = [] z_data = resample(batch_size, latent, threshold) z.d = z_data for i in range(args.n_morphs): alpha.d = 1.0 * i / args.n_morphs y_fake_a.d = generate_one_class(args.from_class_id, batch_size) y_fake_b.d = generate_one_class(args.to_class_id, batch_size) x_fake.forward(clear_buffer=True) monitor_image.add(i, x_fake.d)
def generate(args): # Load model nn.load_parameters(args.model_load_path) # Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) nn.set_default_context(ctx) # Input b, c, h, w = 1, 3, args.image_size, args.image_size x_real_a = nn.Variable([b, c, h, w]) x_real_b = nn.Variable([b, c, h, w]) one = nn.Variable.from_numpy_array(np.ones((1, 1, 1, 1)) * 0.5) # Model maps = args.maps # content/style (domain A) x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a") x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a") # content/style (domain B) x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b") x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b") # generate over domains and reconstruction of content and style (domain A) z_style_a = F.randn( shape=x_style_a.shape) if not args.example_guided else x_style_a z_style_a = z_style_a.apply(persistent=True) x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a") # generate over domains and reconstruction of content and style (domain B) z_style_b = F.randn( shape=x_style_b.shape) if not args.example_guided else x_style_b z_style_b = z_style_b.apply(persistent=True) x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b") # Monitor suffix = "Stochastic" if not args.example_guided else "Example-guided" monitor = Monitor(args.monitor_path) monitor_image_a = MonitorImage("Fake Image B to A {} Valid".format(suffix), monitor, interval=1) monitor_image_b = MonitorImage("Fake Image A to B {} Valid".format(suffix), monitor, interval=1) # DataIterator di_a = munit_data_iterator(args.img_path_a, args.batch_size) di_b = munit_data_iterator(args.img_path_b, args.batch_size) # Generate all # generate (A -> B) if args.example_guided: x_real_b.d = di_b.next()[0] for i in range(di_a.size): x_real_a.d = di_a.next()[0] images = [] images.append(x_real_a.d.copy()) for _ in range(args.num_repeats): x_fake_b.forward(clear_buffer=True) images.append(x_fake_b.d.copy()) monitor_image_b.add(i, np.concatenate(images, axis=3)) # generate (B -> A) if args.example_guided: x_real_a.d = di_a.next()[0] for i in range(di_b.size): x_real_b.d = di_b.next()[0] images = [] images.append(x_real_b.d.copy()) for _ in range(args.num_repeats): x_fake_a.forward(clear_buffer=True) images.append(x_fake_a.d.copy()) monitor_image_a.add(i, np.concatenate(images, axis=3))
def generate(args): # Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn threshold = args.truncation_threshold # Model nn.load_parameters(args.model_load_path) z = nn.Variable([batch_size, latent]) y_fake = nn.Variable([batch_size]) x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\ .apply(persistent=True) # Generate All if args.generate_all: # Monitor monitor = Monitor(args.monitor_path) name = "Generated Image Tile All" monitor_image = MonitorImageTile(name, monitor, interval=1, num_images=args.batch_size, normalize_method=normalize_method) # Generate images for all classes for class_id in range(args.n_classes): # Generate z_data = resample(batch_size, latent, threshold) y_data = generate_one_class(class_id, batch_size) z.d = z_data y_fake.d = y_data x_fake.forward(clear_buffer=True) monitor_image.add(class_id, x_fake.d) return # Generate Indivisually monitor = Monitor(args.monitor_path) name = "Generated Image Tile {}".format( args.class_id) if args.class_id != -1 else "Generated Image Tile" monitor_image_tile = MonitorImageTile(name, monitor, interval=1, num_images=args.batch_size, normalize_method=normalize_method) name = "Generated Image {}".format( args.class_id) if args.class_id != -1 else "Generated Image" monitor_image = MonitorImage(name, monitor, interval=1, num_images=args.batch_size, normalize_method=normalize_method) z_data = resample(batch_size, latent, threshold) y_data = generate_random_class(n_classes, batch_size) if args.class_id == -1 else \ generate_one_class(args.class_id, batch_size) z.d = z_data y_fake.d = y_data x_fake.forward(clear_buffer=True) monitor_image.add(0, x_fake.d) monitor_image_tile.add(0, x_fake.d)
def make_monitor_image(name): return MonitorImage(name, monitor, interval=1, normalize_method=lambda x: (x + 1.0) * 127.5)
def main(args): # Setting device_id = args.device_id conf = args.conf path = args.conf.data_path B = conf.batch_size R = conf.n_rays L = conf.layers D = conf.depth feature_size = conf.feature_size # Dataset ds = DTUMVSDataSource(path, R, shuffle=True) di = data_iterator_dtumvs(ds, B) camloc = nn.Variable([B, 3]) raydir = nn.Variable([B, R, 3]) alpha = nn.Variable.from_numpy_array(conf.alpha) color_gt = nn.Variable([B, R, 3]) mask_obj = nn.Variable([B, R, 1]) # Monitor interval = di.size monitor_path = create_monitor_path(conf.data_path, args.monitor_path) monitor = Monitor(monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=interval) monitor_mhit = MonitorSeries("Hit count", monitor, interval=1) monitor_color_loss = MonitorSeries( "Training color loss", monitor, interval=interval) monitor_mask_loss = MonitorSeries( "Training mask loss", monitor, interval=interval) monitor_eikonal_loss = MonitorSeries( "Training eikonal loss", monitor, interval=interval) monitor_time = MonitorTimeElapsed( "Training time", monitor, interval=interval) monitor_image = MonitorImage("Rendered image", monitor, interval=1) # Solver solver = S.Adam(conf.learning_rate) loss, color_loss, mask_loss, eikonal_loss, mask_hit = \ idr_loss(camloc, raydir, alpha, color_gt, mask_obj, conf) solver.set_parameters(nn.get_parameters()) # Training loop for i in range(conf.train_epoch): ds.change_sampling_idx() # Validate if i % conf.valid_epoch_interval == 0 and not args.skip_val: def validate(i): pose_ = ds.poses[conf.valid_index:conf.valid_index+1, ...] intrinsic_ = ds.intrinsics[conf.valid_index:conf.valid_index+1, ...] mask_obj_ = ds.masks[conf.valid_index:conf.valid_index+1, ...] image = render(pose_, intrinsic_, mask_obj_, conf) monitor_image.add(i, image) nn.save_parameters(f"{monitor_path}/model_{i:05d}.h5") validate(i) # Train for j in range(di.size): # Feed data color_, mask_, intrinsic_, pose_, xy_ = di.next() color_gt.d = color_ mask_obj.d = mask_ raydir_, camloc_ = generate_raydir_camloc(pose_, intrinsic_, xy_) raydir.d = raydir_ camloc.d = camloc_ # Network loss.forward() solver.zero_grad() loss.backward(clear_buffer=True) solver.update() # Monitor t = i * di.size + j monitor_mhit.add(t, np.sum(mask_hit.d)) monitor_loss.add(t, loss.d) monitor_color_loss.add(t, color_loss.d) monitor_mask_loss.add(t, mask_loss.d) monitor_eikonal_loss.add(t, eikonal_loss.d) monitor_time.add(t) # Decay if i in conf.alpha_decay: alpha.d = alpha.d * 2.0 if i in conf.lr_decay: solver.set_learning_rate(solver.learning_rate() * 0.5) validate(i)
def match(args): # Context extension_module = "cudnn" ctx = get_extension_context(extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = 1 image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn threshold = args.truncation_threshold # Model (SAGAN) nn.load_parameters(args.model_load_path) z = nn.Variable([batch_size, latent]) y_fake = nn.Variable([batch_size]) x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\ .apply(persistent=True) # Model (Inception model) from nnp file nnp = NnpLoader(args.nnp_inception_model_load_path) x, h = get_input_and_output(nnp, batch_size, args.variable_name) # DataIterator for a given class_id di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path, batch_size=batch_size, n_classes=args.n_classes, noise=False, class_id=args.class_id) # Monitor monitor = Monitor(args.monitor_path) name = "Matched Image {}".format(args.class_id) monitor_image = MonitorImage(name, monitor, interval=1, num_images=batch_size, normalize_method=lambda x: (x + 1.) / 2. * 255.) name = "Matched Image Tile {}".format(args.class_id) monitor_image_tile = MonitorImageTile(name, monitor, interval=1, num_images=batch_size + args.top_n, normalize_method=lambda x: (x + 1.) / 2. * 255.) # Generate and p(h|x).forward # generate z_data = resample(batch_size, latent, threshold) y_data = generate_one_class(args.class_id, batch_size) z.d = z_data y_fake.d = y_data x_fake.forward(clear_buffer=True) # p(h|x).forward x_fake_d = x_fake.d.copy() x_fake_d = preprocess( x_fake_d, (args.image_size, args.image_size), args.nnp_preprocess) x.d = x_fake_d h.forward(clear_buffer=True) h_fake_d = h.d.copy() # Feature matching norm2_list = [] x_data_list = [] x_data_list.append(x_fake.d) for i in range(di.size): # forward for real data x_d, _ = di.next() x_data_list.append(x_d) x_d = preprocess( x_d, (args.image_size, args.image_size), args.nnp_preprocess) x.d = x_d h.forward(clear_buffer=True) h_real_d = h.d.copy() # norm computation axis = tuple(np.arange(1, len(h.shape)).tolist()) norm2 = np.sum((h_real_d - h_fake_d) ** 2.0, axis=axis) norm2_list.append(norm2) # Save top-n images argmins = np.argsort(norm2_list) for i in range(args.top_n): monitor_image.add(i, x_data_list[i]) matched_images = np.concatenate(x_data_list) monitor_image_tile.add(0, matched_images)
def train(args): # Create Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank mpi_local_rank = comm.local_rank device_id = mpi_local_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # Input b, c, h, w = args.batch_size, 3, args.image_size, args.image_size x_real_a = nn.Variable([b, c, h, w]) x_real_b = nn.Variable([b, c, h, w]) # Model # workaround for starting with the same model among devices. np.random.seed(412) maps = args.maps # within-domain reconstruction (domain A) x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a") x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a") x_recon_a = decoder(x_content_a, x_style_a, name="decoder-a") # within-domain reconstruction (domain B) x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b") x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b") x_recon_b = decoder(x_content_b, x_style_b, name="decoder-b") # generate over domains and reconstruction of content and style (domain A) z_style_a = F.randn(shape=x_style_a.shape) x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a") x_content_rec_b = content_encoder(x_fake_a, maps, name="content-encoder-a") x_style_rec_a = style_encoder(x_fake_a, maps, name="style-encoder-a") # generate over domains and reconstruction of content and style (domain B) z_style_b = F.randn(shape=x_style_b.shape) x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b") x_content_rec_a = content_encoder(x_fake_b, maps, name="content-encoder-b") x_style_rec_b = style_encoder(x_fake_b, maps, name="style-encoder-b") # discriminate (domain A) p_x_fake_a_list = discriminators(x_fake_a) p_x_real_a_list = discriminators(x_real_a) p_x_fake_b_list = discriminators(x_fake_b) p_x_real_b_list = discriminators(x_real_b) # Loss # within-domain reconstruction loss_recon_x_a = recon_loss(x_recon_a, x_real_a).apply(persistent=True) loss_recon_x_b = recon_loss(x_recon_b, x_real_b).apply(persistent=True) # content and style reconstruction loss_recon_x_style_a = recon_loss(x_style_rec_a, z_style_a).apply(persistent=True) loss_recon_x_content_b = recon_loss(x_content_rec_b, x_content_b).apply(persistent=True) loss_recon_x_style_b = recon_loss(x_style_rec_b, z_style_b).apply(persistent=True) loss_recon_x_content_a = recon_loss(x_content_rec_a, x_content_a).apply(persistent=True) # adversarial def f(x, y): return x + y loss_gen_a = reduce(f, [lsgan_loss(p_f) for p_f in p_x_fake_a_list]).apply(persistent=True) loss_dis_a = reduce(f, [ lsgan_loss(p_f, p_r) for p_f, p_r in zip(p_x_fake_a_list, p_x_real_a_list) ]).apply(persistent=True) loss_gen_b = reduce(f, [lsgan_loss(p_f) for p_f in p_x_fake_b_list]).apply(persistent=True) loss_dis_b = reduce(f, [ lsgan_loss(p_f, p_r) for p_f, p_r in zip(p_x_fake_b_list, p_x_real_b_list) ]).apply(persistent=True) # loss for generator-related models loss_gen = loss_gen_a + loss_gen_b \ + args.lambda_x * (loss_recon_x_a + loss_recon_x_b) \ + args.lambda_c * (loss_recon_x_content_a + loss_recon_x_content_b) \ + args.lambda_s * (loss_recon_x_style_a + loss_recon_x_style_b) # loss for discriminators loss_dis = loss_dis_a + loss_dis_b # Solver lr_g, lr_d, beta1, beta2 = args.lr_g, args.lr_d, args.beta1, args.beta2 # solver for generator-related models solver_gen = S.Adam(lr_g, beta1, beta2) with nn.parameter_scope("generator"): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) # solver for discriminators solver_dis = S.Adam(lr_d, beta1, beta2) with nn.parameter_scope("discriminators"): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor monitor = Monitor(args.monitor_path) # time monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10) # reconstruction monitor_loss_recon_x_a = MonitorSeries("Recon Loss Image A", monitor, interval=10) monitor_loss_recon_x_content_b = MonitorSeries("Recon Loss Content B", monitor, interval=10) monitor_loss_recon_x_style_a = MonitorSeries("Recon Loss Style A", monitor, interval=10) monitor_loss_recon_x_b = MonitorSeries("Recon Loss Image B", monitor, interval=10) monitor_loss_recon_x_content_a = MonitorSeries("Recon Loss Content A", monitor, interval=10) monitor_loss_recon_x_style_b = MonitorSeries("Recon Loss Style B", monitor, interval=10) # adversarial monitor_loss_gen_a = MonitorSeries("Gen Loss A", monitor, interval=10) monitor_loss_dis_a = MonitorSeries("Dis Loss A", monitor, interval=10) monitor_loss_gen_b = MonitorSeries("Gen Loss B", monitor, interval=10) monitor_loss_dis_b = MonitorSeries("Dis Loss B", monitor, interval=10) monitor_losses = [ # reconstruction (monitor_loss_recon_x_a, loss_recon_x_a), (monitor_loss_recon_x_content_b, loss_recon_x_content_b), (monitor_loss_recon_x_style_a, loss_recon_x_style_a), (monitor_loss_recon_x_b, loss_recon_x_b), (monitor_loss_recon_x_content_a, loss_recon_x_content_a), (monitor_loss_recon_x_style_b, loss_recon_x_style_b), # adaversarial (monitor_loss_gen_a, loss_gen_a), (monitor_loss_dis_a, loss_dis_a), (monitor_loss_gen_b, loss_gen_b), (monitor_loss_dis_b, loss_dis_b) ] # image monitor_image_a = MonitorImage("Fake Image B to A Train", monitor, interval=1) monitor_image_b = MonitorImage("Fake Image A to B Train", monitor, interval=1) monitor_images = [ (monitor_image_a, x_fake_a), (monitor_image_b, x_fake_b), ] # DataIterator rng_a = np.random.RandomState(device_id) rng_b = np.random.RandomState(device_id + n_devices) di_a = munit_data_iterator(args.img_path_a, args.batch_size, rng=rng_a) di_b = munit_data_iterator(args.img_path_b, args.batch_size, rng=rng_b) # Train for i in range(args.max_iter // n_devices): ii = i * n_devices # Train generator-related models x_data_a, x_data_b = di_a.next()[0], di_b.next()[0] x_real_a.d, x_real_b.d = x_data_a, x_data_b solver_gen.zero_grad() loss_gen.forward(clear_no_need_grad=True) loss_gen.backward(clear_buffer=True) comm.all_reduce([w.grad for w in params_gen.values()]) solver_gen.weight_decay(args.weight_decay_rate) solver_gen.update() # Train discriminators x_data_a, x_data_b = di_a.next()[0], di_b.next()[0] x_real_a.d, x_real_b.d = x_data_a, x_data_b x_fake_a.need_grad, x_fake_b.need_grad = False, False solver_dis.zero_grad() loss_dis.forward(clear_no_need_grad=True) loss_dis.backward(clear_buffer=True) comm.all_reduce([w.grad for w in params_dis.values()]) solver_dis.weight_decay(args.weight_decay_rate) solver_dis.update() x_fake_a.need_grad, x_fake_b.need_grad = True, True # LR schedule if (i + 1) % (args.lr_decay_at_every // n_devices) == 0: lr_d = solver_dis.learning_rate() * args.lr_decay_rate lr_g = solver_gen.learning_rate() * args.lr_decay_rate solver_dis.set_learning_rate(lr_d) solver_gen.set_learning_rate(lr_g) if mpi_local_rank == 0: # Monitor monitor_time.add(ii) for mon, loss in monitor_losses: mon.add(ii, loss.d) # Save if (i + 1) % (args.model_save_interval // n_devices) == 0: for mon, x in monitor_images: mon.add(ii, x.d) nn.save_parameters( os.path.join(args.monitor_path, "param_{:05d}.h5".format(i))) if mpi_local_rank == 0: # Monitor for mon, loss in monitor_losses: mon.add(ii, loss.d) # Save for mon, x in monitor_images: mon.add(ii, x.d) nn.save_parameters( os.path.join(args.monitor_path, "param_{:05d}.h5".format(i)))
def train_single_scale(args, index, model, reals, prev_models, Zs, noise_amps, monitor): # prepare log monitors monitor_train_d_real = MonitorSeries('train_d_real%d' % index, monitor) monitor_train_d_fake = MonitorSeries('train_d_fake%d' % index, monitor) monitor_train_g_fake = MonitorSeries('train_g_fake%d' % index, monitor) monitor_train_g_rec = MonitorSeries('train_g_rec%d' % index, monitor) monitor_image_g = MonitorImage('image_g_%d' % index, monitor, interval=1, num_images=1, normalize_method=denormalize) real = reals[index] ch, w, h = real.shape[1], real.shape[2], real.shape[3] # training loop for epoch in range(args.niter): d_real_error_history = [] d_fake_error_history = [] g_fake_error_history = [] g_rec_error_history = [] if index == 0: z_opt = np.random.normal(0.0, 1.0, size=(1, 1, w, h)) noise_ = np.random.normal(0.0, 1.0, size=(1, 1, w, h)) else: z_opt = np.zeros((1, ch, w, h)) noise_ = np.random.normal(0.0, 1.0, size=(1, ch, w, h)) # discriminator training loop for d_step in range(args.d_steps): # previous outputs if d_step == 0 and epoch == 0: if index == 0: prev = np.zeros_like(noise_) z_prev = np.zeros_like(z_opt) args.noise_amp = 1 else: prev = _draw_concat(args, index, prev_models, Zs, reals, noise_amps, 'rand') z_prev = _draw_concat(args, index, prev_models, Zs, reals, noise_amps, 'rec') rmse = np.sqrt(np.mean((real - z_prev)**2)) args.noise_amp = args.noise_amp_init * rmse else: prev = _draw_concat(args, index, prev_models, Zs, reals, noise_amps, 'rand') # input noise if index == 0: noise = noise_ else: noise = args.noise_amp * noise_ + prev fake_error, real_error = model.update_d(epoch, noise, prev) # accumulate errors for logging d_real_error_history.append(real_error) d_fake_error_history.append(fake_error) # generator training loop for g_step in range(args.g_steps): noise_rec = args.noise_amp * z_opt + z_prev fake_error, rec_error = model.update_g(epoch, noise, prev, noise_rec, z_prev) # accumulate errors for logging g_fake_error_history.append(fake_error) g_rec_error_history.append(rec_error) # save errors monitor_train_d_real.add(epoch, np.mean(d_real_error_history)) monitor_train_d_fake.add(epoch, np.mean(d_fake_error_history)) monitor_train_g_fake.add(epoch, np.mean(g_fake_error_history)) monitor_train_g_rec.add(epoch, np.mean(g_rec_error_history)) # save generated image monitor_image_g.add(epoch, model.generate(noise, prev)) return z_opt