def load_marcua_dataset(): for expid, (prop, answers, truth) in enumerate(data_iterator()): cur_wid = 0 wids = {} rowid = 0 for ans in answers: # normalize worker ids if ans[1] not in wids: wids[ans[1]] = cur_wid cur_wid += 1 row = list(ans) row[1] = wids[ans[1]] row.append(rowid) rowid += 1 row.insert(0, prop) row.insert(0, expid) row.append(truth) yield row
def main(): # Args args = get_args() save_args(args) # Context ctx = extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) nn.set_auto_forward(True) # Data Itrator di = data_iterator(args.img_path, args.batch_size, imsize=(args.imsize, args.imsize), num_samples=args.train_samples, dataset_name=args.dataset_name) # Model generator = Generator(use_bn=args.use_bn, last_act=args.last_act, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward) discriminator = Discriminator(use_ln=args.use_ln, alpha=args.leaky_alpha, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward) # Solver solver_gen = S.Adam(alpha=args.learning_rate, beta1=args.beta1, beta2=args.beta2) solver_dis = S.Adam(alpha=args.learning_rate, beta1=args.beta1, beta2=args.beta2) # Monitor monitor = Monitor(args.monitor_path) monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10) monitor_loss_dis = MonitorSeries("Discriminator Loss", monitor, interval=10) monitor_p_fake = MonitorSeries("Fake Probability", monitor, interval=10) monitor_p_real = MonitorSeries("Real Probability", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training Time per Resolution", monitor, interval=1) monitor_image_tile = MonitorImageTileWithName("Image Tile", monitor, num_images=4, normalize_method=lambda x: (x + 1.) / 2.) # TODO: use argment resolution_list = [4, 8, 16, 32, 64, 128] channel_list = [512, 512, 256, 128, 64, 32] trainer = Trainer(di, generator, discriminator, solver_gen, solver_dis, args.monitor_path, monitor_loss_gen, monitor_loss_dis, monitor_p_fake, monitor_p_real, monitor_time, monitor_image_tile, resolution_list, channel_list, n_latent=args.latent, n_critic=args.critic, save_image_interval=args.save_image_interval, hyper_sphere=args.hyper_sphere, l2_fake_weight=args.l2_fake_weight) # TODO: use images per resolution? trainer.train(args.epoch_per_resolution)
def main(): # Args args = get_args() # Context ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) logger.info(ctx) nn.set_default_context(ctx) nn.set_auto_forward(True) # Monitor monitor = Monitor(args.monitor_path) # Validation logger.info("Start validation") num_images = args.valid_samples num_batches = num_images // args.batch_size # DataIterator di = data_iterator(args.img_path, args.batch_size, imsize=(args.imsize, args.imsize), num_samples=args.valid_samples, dataset_name=args.dataset_name) # generator gen = load_gen(args.model_load_path, use_bn=args.use_bn, last_act=args.last_act, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward) # compute metric if args.validation_metric == "ms-ssim": logger.info("Multi Scale SSIM") monitor_time = MonitorTimeElapsed("MS-SSIM-ValidationTime", monitor, interval=1) monitor_metric = MonitorSeries("MS-SSIM", monitor, interval=1) from ms_ssim import compute_metric score = compute_metric(gen, args.batch_size, num_images, args.latent, args.hyper_sphere) monitor_time.add(0) monitor_metric.add(0, score) elif args.validation_metric == "swd": logger.info("Sliced Wasserstein Distance") monitor_time = MonitorTimeElapsed("SWD-ValidationTime", monitor, interval=1) monitor_metric = MonitorSeries("SWD", monitor, interval=1) nhoods_per_image = 128 nhood_size = 7 level_list = [128, 64, 32, 16] # TODO: use argument dir_repeats = 4 dirs_per_repeat = 128 from sliced_wasserstein import compute_metric score = compute_metric(di, gen, args.latent, num_batches, nhoods_per_image, nhood_size, level_list, dir_repeats, dirs_per_repeat, args.hyper_sphere) monitor_time.add(0) monitor_metric.add(0, score) # averaged in the log else: logger.info("Set `validation-metric` as either `ms-ssim` or `swd`.") logger.info(score) logger.info("End validation")
def train(args): # Context ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) aug_list = args.aug_list # Model scope_gen = "Generator" scope_dis = "Discriminator" # generator loss z = nn.Variable([args.batch_size, args.latent, 1, 1]) x_fake = Generator(z, scope_name=scope_gen, img_size=args.image_size) p_fake = Discriminator([augment(xf, aug_list) for xf in x_fake], label="fake", scope_name=scope_dis) lossG = loss_gen(p_fake) # discriminator loss x_real = nn.Variable( [args.batch_size, 3, args.image_size, args.image_size]) x_real_aug = augment(x_real, aug_list) p_real, rec_imgs, part = Discriminator( x_real_aug, label="real", scope_name=scope_dis) lossD_fake = loss_dis_fake(p_fake) lossD_real = loss_dis_real(p_real, rec_imgs, part, x_real_aug) lossD = lossD_fake + lossD_real # generator with fixed latent values for test # Use train=True even in an inference phase z_test = nn.Variable.from_numpy_array( np.random.randn(args.batch_size, args.latent, 1, 1)) x_test = Generator(z_test, scope_name=scope_gen, train=True, img_size=args.image_size)[0] # Exponential Moving Average (EMA) model # Use train=True even in an inference phase scope_gen_ema = "Generator_EMA" x_test_ema = Generator(z_test, scope_name=scope_gen_ema, train=True, img_size=args.image_size)[0] copy_params(scope_gen, scope_gen_ema) update_ema_var = make_ema_updater(scope_gen_ema, scope_gen, 0.999) # Solver solver_gen = S.Adam(args.lr, beta1=0.5) solver_dis = S.Adam(args.lr, beta1=0.5) with nn.parameter_scope(scope_gen): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) with nn.parameter_scope(scope_dis): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor monitor = Monitor(args.monitor_path) monitor_loss_gen = MonitorSeries( "Generator Loss", monitor, interval=10) monitor_loss_dis_real = MonitorSeries( "Discriminator Loss Real", monitor, interval=10) monitor_loss_dis_fake = MonitorSeries( "Discriminator Loss Fake", monitor, interval=10) monitor_time = MonitorTimeElapsed( "Training Time", monitor, interval=10) monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor, num_images=args.batch_size, interval=1, normalize_method=lambda x: (x + 1.) / 2.) monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor, num_images=args.batch_size, interval=1, normalize_method=lambda x: (x + 1.) / 2.) monitor_image_tile_test_ema = MonitorImageTile("Image Tile Test EMA", monitor, num_images=args.batch_size, interval=1, normalize_method=lambda x: (x + 1.) / 2.) # Data Iterator rng = np.random.RandomState(141) di = data_iterator(args.img_path, args.batch_size, imsize=(args.image_size, args.image_size), num_samples=args.train_samples, rng=rng) # Train loop for i in range(args.max_iter): # Train discriminator x_fake[0].need_grad = False # no need backward to generator x_fake[1].need_grad = False # no need backward to generator solver_dis.zero_grad() x_real.d = di.next()[0] z.d = np.random.randn(args.batch_size, args.latent, 1, 1) lossD.forward() lossD.backward() solver_dis.update() # Train generator x_fake[0].need_grad = True # need backward to generator x_fake[1].need_grad = True # need backward to generator solver_gen.zero_grad() lossG.forward() lossG.backward() solver_gen.update() # Update EMA model update_ema_var.forward() # Monitor monitor_loss_gen.add(i, lossG.d) monitor_loss_dis_real.add(i, lossD_real.d) monitor_loss_dis_fake.add(i, lossD_fake.d) monitor_time.add(i) # Save if (i+1) % args.save_interval == 0: with nn.parameter_scope(scope_gen): nn.save_parameters(os.path.join( args.monitor_path, "Gen_iter{}.h5".format(i+1))) with nn.parameter_scope(scope_gen_ema): nn.save_parameters(os.path.join( args.monitor_path, "GenEMA_iter{}.h5".format(i+1))) with nn.parameter_scope(scope_dis): nn.save_parameters(os.path.join( args.monitor_path, "Dis_iter{}.h5".format(i+1))) if (i+1) % args.test_interval == 0: x_test.forward(clear_buffer=True) x_test_ema.forward(clear_buffer=True) monitor_image_tile_train.add(i+1, x_fake[0]) monitor_image_tile_test.add(i+1, x_test) monitor_image_tile_test_ema.add(i+1, x_test_ema) # Last x_test.forward(clear_buffer=True) x_test_ema.forward(clear_buffer=True) monitor_image_tile_train.add(args.max_iter, x_fake[0]) monitor_image_tile_test.add(args.max_iter, x_test) monitor_image_tile_test_ema.add(args.max_iter, x_test_ema) with nn.parameter_scope(scope_gen): nn.save_parameters(os.path.join(args.monitor_path, "Gen_iter{}.h5".format(args.max_iter))) with nn.parameter_scope(scope_gen_ema): nn.save_parameters(os.path.join(args.monitor_path, "GenEMA_iter{}.h5".format(args.max_iter))) with nn.parameter_scope(scope_dis): nn.save_parameters(os.path.join(args.monitor_path, "Dis_iter{}.h5".format(args.max_iter)))