def compute_frechet_inception_distance(z, y_fake, x_fake, x, y, args, di=None): h_fakes = [] h_reals = [] for i in range(args.max_iter): logger.info("Compute at {}-th batch".format(i)) # Generate z.d = np.random.randn(args.batch_size, args.latent) y_fake.d = generate_random_class(args.n_classes, args.batch_size) x_fake.forward(clear_buffer=True) # Predict for fake x_fake_d = x_fake.d.copy() x_fake_d = preprocess( x_fake_d, (args.image_size, args.image_size), args.nnp_preprocess) x.d = x_fake_d y.forward(clear_buffer=True) h_fakes.append(y.d.copy().squeeze()) # Predict for real x_d, _ = di.next() x_d = preprocess( x_d, (args.image_size, args.image_size), args.nnp_preprocess) x.d = x_d y.forward(clear_buffer=True) h_reals.append(y.d.copy().squeeze()) h_fakes = np.concatenate(h_fakes) h_reals = np.concatenate(h_reals) # FID score ave_h_real = np.mean(h_reals, axis=0) ave_h_fake = np.mean(h_fakes, axis=0) cov_h_real = np.cov(h_reals, rowvar=False) cov_h_fake = np.cov(h_fakes, rowvar=False) score = np.sum((ave_h_real - ave_h_fake) ** 2) \ + np.trace(cov_h_real + cov_h_fake - 2.0 * sqrtm(np.dot(cov_h_real, cov_h_fake))) return score
def compute_inception_score(z, y_fake, x_fake, x, y, args): preds = [] for i in range(args.max_iter): logger.info("Compute at {}-th batch".format(i)) # Generate z.d = np.random.randn(args.batch_size, args.latent) y_fake.d = generate_random_class(args.n_classes, args.batch_size) x_fake.forward(clear_buffer=True) # Predict x_fake_d = x_fake.d.copy() x_fake_d = preprocess( x_fake_d, (args.image_size, args.image_size), args.nnp_preprocess) x.d = x_fake_d y.forward(clear_buffer=True) preds.append(y.d.copy()) p_yx = np.concatenate(preds) # Score p_y = np.mean(p_yx, axis=0) kld = np.sum(p_yx * (np.log(p_yx) - np.log(p_y)), axis=1) score = np.exp(np.mean(kld)) return score
def generate(args): # Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn threshold = args.truncation_threshold # Model nn.load_parameters(args.model_load_path) z = nn.Variable([batch_size, latent]) y_fake = nn.Variable([batch_size]) x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\ .apply(persistent=True) # Generate All if args.generate_all: # Monitor monitor = Monitor(args.monitor_path) name = "Generated Image Tile All" monitor_image = MonitorImageTile(name, monitor, interval=1, num_images=args.batch_size, normalize_method=normalize_method) # Generate images for all classes for class_id in range(args.n_classes): # Generate z_data = resample(batch_size, latent, threshold) y_data = generate_one_class(class_id, batch_size) z.d = z_data y_fake.d = y_data x_fake.forward(clear_buffer=True) monitor_image.add(class_id, x_fake.d) return # Generate Indivisually monitor = Monitor(args.monitor_path) name = "Generated Image Tile {}".format( args.class_id) if args.class_id != -1 else "Generated Image Tile" monitor_image_tile = MonitorImageTile(name, monitor, interval=1, num_images=args.batch_size, normalize_method=normalize_method) name = "Generated Image {}".format( args.class_id) if args.class_id != -1 else "Generated Image" monitor_image = MonitorImage(name, monitor, interval=1, num_images=args.batch_size, normalize_method=normalize_method) z_data = resample(batch_size, latent, threshold) y_data = generate_random_class(n_classes, batch_size) if args.class_id == -1 else \ generate_one_class(args.class_id, batch_size) z.d = z_data y_fake.d = y_data x_fake.forward(clear_buffer=True) monitor_image.add(0, x_fake.d) monitor_image_tile.add(0, x_fake.d)
def train(args): # Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = comm.local_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn # Model # workaround to start with the same weights in the distributed system. np.random.seed(412) # generator loss z = nn.Variable([batch_size, latent]) y_fake = nn.Variable([batch_size]) x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, sn=not_sn).apply(persistent=True) p_fake = discriminator(x_fake, y_fake, maps=maps // 16, n_classes=n_classes, sn=not_sn) loss_gen = gan_loss(p_fake) # discriminator loss y_real = nn.Variable([batch_size]) x_real = nn.Variable([batch_size, 3, image_size, image_size]) p_real = discriminator(x_real, y_real, maps=maps // 16, n_classes=n_classes, sn=not_sn) loss_dis = gan_loss(p_fake, p_real) # generator with fixed value for test z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent)) y_test = nn.Variable.from_numpy_array( generate_random_class(n_classes, batch_size)) x_test = generator(z_test, y_test, maps=maps, n_classes=n_classes, test=True, sn=not_sn) # Solver solver_gen = S.Adam(args.lrg, args.beta1, args.beta2) solver_dis = S.Adam(args.lrd, args.beta1, args.beta2) with nn.parameter_scope("generator"): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) with nn.parameter_scope("discriminator"): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor if comm.rank == 0: monitor = Monitor(args.monitor_path) monitor_loss_gen = MonitorSeries( "Generator Loss", monitor, interval=10) monitor_loss_dis = MonitorSeries( "Discriminator Loss", monitor, interval=10) monitor_time = MonitorTimeElapsed( "Training Time", monitor, interval=10) monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor, num_images=args.batch_size, interval=1, normalize_method=normalize_method) monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor, num_images=args.batch_size, interval=1, normalize_method=normalize_method) # DataIterator rng = np.random.RandomState(device_id) di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path, args.batch_size, n_classes=args.n_classes, rng=rng) # Train loop for i in range(args.max_iter): # Train discriminator x_fake.need_grad = False # no need for discriminator backward solver_dis.zero_grad() for _ in range(args.accum_grad): # feed x_real and y_real x_data, y_data = di.next() x_real.d, y_real.d = x_data, y_data.flatten() # feed z and y_fake z_data = np.random.randn(args.batch_size, args.latent) y_data = generate_random_class(args.n_classes, args.batch_size) z.d, y_fake.d = z_data, y_data loss_dis.forward(clear_no_need_grad=True) loss_dis.backward( 1.0 / (args.accum_grad * n_devices), clear_buffer=True) comm.all_reduce([v.grad for v in params_dis.values()]) solver_dis.update() # Train genrator x_fake.need_grad = True # need for generator backward solver_gen.zero_grad() for _ in range(args.accum_grad): z_data = np.random.randn(args.batch_size, args.latent) y_data = generate_random_class(args.n_classes, args.batch_size) z.d, y_fake.d = z_data, y_data loss_gen.forward(clear_no_need_grad=True) loss_gen.backward( 1.0 / (args.accum_grad * n_devices), clear_buffer=True) comm.all_reduce([v.grad for v in params_gen.values()]) solver_gen.update() # Synchronize by averaging the weights over devices using allreduce if i % args.sync_weight_every_itr == 0: weights = [v.data for v in nn.get_parameters().values()] comm.all_reduce(weights, division=True, inplace=True) # Save model and image if i % args.save_interval == 0 and comm.rank == 0: x_test.forward(clear_buffer=True) nn.save_parameters(os.path.join( args.monitor_path, "params_{}.h5".format(i))) monitor_image_tile_train.add(i, x_fake.d) monitor_image_tile_test.add(i, x_test.d) # Monitor if comm.rank == 0: monitor_loss_gen.add(i, loss_gen.d.copy()) monitor_loss_dis.add(i, loss_dis.d.copy()) monitor_time.add(i) if comm.rank == 0: x_test.forward(clear_buffer=True) nn.save_parameters(os.path.join( args.monitor_path, "params_{}.h5".format(i))) monitor_image_tile_train.add(i, x_fake.d) monitor_image_tile_test.add(i, x_test.d)