def main(): try: os.mkdir(args.snapshot_directory) except: pass images = [] files = os.listdir(args.dataset_path) for filename in files: image = np.load(os.path.join(args.dataset_path, filename)) image = image / 255 * 2.0 - 1.0 images.append(image) images = np.vstack(images) images = images.transpose((0, 3, 1, 2)).astype(np.float32) train_dev_split = 0.9 num_images = images.shape[0] num_train_images = int(num_images * train_dev_split) num_dev_images = num_images - num_train_images images_train = images[:args.batch_size] images_dev = images[args.batch_size:] xp = np using_gpu = args.gpu_device >= 0 if using_gpu: cuda.get_device(args.gpu_device).use() xp = cp hyperparams = HyperParameters() hyperparams.generator_share_core = args.generator_share_core hyperparams.generator_share_prior = args.generator_share_prior hyperparams.generator_generation_steps = args.generation_steps hyperparams.inference_share_core = args.inference_share_core hyperparams.inference_share_posterior = args.inference_share_posterior hyperparams.layer_normalization_enabled = args.layer_normalization hyperparams.pixel_n = args.pixel_n hyperparams.chz_channels = args.chz_channels hyperparams.inference_channels_downsampler_x = args.channels_downsampler_x hyperparams.pixel_sigma_i = args.initial_pixel_sigma hyperparams.pixel_sigma_f = args.final_pixel_sigma hyperparams.chrz_size = (32, 32) hyperparams.save(args.snapshot_directory) hyperparams.print() model = Model(hyperparams, snapshot_directory=args.snapshot_directory) if using_gpu: model.to_gpu() optimizer = AdamOptimizer(model.parameters, lr_i=args.initial_lr, lr_f=args.final_lr) optimizer.print() sigma_t = hyperparams.pixel_sigma_i pixel_var = xp.full((args.batch_size, 3) + hyperparams.image_size, sigma_t**2, dtype="float32") pixel_ln_var = xp.full((args.batch_size, 3) + hyperparams.image_size, math.log(sigma_t**2), dtype="float32") num_pixels = images.shape[1] * images.shape[2] * images.shape[3] figure = plt.figure(figsize=(20, 4)) axis_1 = figure.add_subplot(1, 5, 1) axis_2 = figure.add_subplot(1, 5, 2) axis_3 = figure.add_subplot(1, 5, 3) axis_4 = figure.add_subplot(1, 5, 4) axis_5 = figure.add_subplot(1, 5, 5) for iteration in range(args.training_steps): x = to_gpu(images_train) loss_kld = 0 z_t_params_array, r_final = model.generate_z_params_and_x_from_posterior( x) for params in z_t_params_array: mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params kld = draw.nn.functions.gaussian_kl_divergence( mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p) loss_kld += cf.sum(kld) mean_x_enc = r_final negative_log_likelihood = draw.nn.functions.gaussian_negative_log_likelihood( x, mean_x_enc, pixel_var, pixel_ln_var) loss_nll = cf.sum(negative_log_likelihood) loss_mse = cf.mean_squared_error(mean_x_enc, x) loss_nll /= args.batch_size loss_kld /= args.batch_size loss = loss_nll + loss_kld loss = loss_nll model.cleargrads() loss.backward() optimizer.update(iteration) sf = hyperparams.pixel_sigma_f si = hyperparams.pixel_sigma_i sigma_t = max(sf + (si - sf) * (1.0 - iteration / hyperparams.pixel_n), sf) pixel_var[...] = sigma_t**2 pixel_ln_var[...] = math.log(sigma_t**2) model.serialize(args.snapshot_directory) print( "\033[2KIteration {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e} - sigma_t: {:.6f}" .format(iteration + 1, float(loss_nll.data) / num_pixels, float(loss_mse.data), float(loss_kld.data), optimizer.learning_rate, sigma_t)) if iteration % 10 == 0: axis_1.imshow(make_uint8(x[0])) axis_2.imshow(make_uint8(mean_x_enc.data[0])) x_dev = images_dev[random.choice(range(num_dev_images))] axis_3.imshow(make_uint8(x_dev)) with chainer.using_config("train", False), chainer.using_config( "enable_backprop", False): x_dev = to_gpu(x_dev)[None, ...] _, r_final = model.generate_z_params_and_x_from_posterior( x_dev) mean_x_enc = r_final axis_4.imshow(make_uint8(mean_x_enc.data[0])) mean_x_d = model.generate_image(batch_size=1, xp=xp) axis_5.imshow(make_uint8(mean_x_d[0])) plt.pause(0.01)
def main(): try: os.mkdir(args.snapshot_directory) except: pass comm = chainermn.create_communicator() device = comm.intra_rank cuda.get_device(device).use() xp = cp images = [] files = os.listdir(args.dataset_path) files.sort() subset_size = int(math.ceil(len(files) / comm.size)) files = deque(files) files.rotate(-subset_size * comm.rank) files = list(files)[:subset_size] for filename in files: image = np.load(os.path.join(args.dataset_path, filename)) image = image / 256 images.append(image) print(comm.rank, files) images = np.vstack(images) images = images.transpose((0, 3, 1, 2)).astype(np.float32) train_dev_split = 0.9 num_images = images.shape[0] num_train_images = int(num_images * train_dev_split) num_dev_images = num_images - num_train_images images_train = images[:num_train_images] # To avoid OpenMPI bug # multiprocessing.set_start_method("forkserver") # p = multiprocessing.Process(target=print, args=("", )) # p.start() # p.join() hyperparams = HyperParameters() hyperparams.chz_channels = args.chz_channels hyperparams.generator_generation_steps = args.generation_steps hyperparams.generator_share_core = args.generator_share_core hyperparams.generator_share_prior = args.generator_share_prior hyperparams.generator_share_upsampler = args.generator_share_upsampler hyperparams.generator_downsampler_channels = args.generator_downsampler_channels hyperparams.inference_share_core = args.inference_share_core hyperparams.inference_share_posterior = args.inference_share_posterior hyperparams.inference_downsampler_channels = args.inference_downsampler_channels hyperparams.batch_normalization_enabled = args.enable_batch_normalization hyperparams.use_gru = args.use_gru hyperparams.no_backprop_diff_xr = args.no_backprop_diff_xr if comm.rank == 0: hyperparams.save(args.snapshot_directory) hyperparams.print() if args.use_gru: model = GRUModel(hyperparams, snapshot_directory=args.snapshot_directory) else: model = LSTMModel(hyperparams, snapshot_directory=args.snapshot_directory) model.to_gpu() optimizer = AdamOptimizer(model.parameters, lr_i=args.initial_lr, lr_f=args.final_lr, beta_1=args.adam_beta1, communicator=comm) if comm.rank == 0: optimizer.print() num_pixels = images.shape[1] * images.shape[2] * images.shape[3] dataset = draw.data.Dataset(images_train) iterator = draw.data.Iterator(dataset, batch_size=args.batch_size) num_updates = 0 for iteration in range(args.training_steps): mean_kld = 0 mean_nll = 0 mean_mse = 0 start_time = time.time() for batch_index, data_indices in enumerate(iterator): x = dataset[data_indices] x += np.random.uniform(0, 1 / 256, size=x.shape) x = to_gpu(x) z_t_param_array, x_param, r_t_array = model.sample_z_and_x_params_from_posterior( x) loss_kld = 0 for params in z_t_param_array: mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params kld = draw.nn.functions.gaussian_kl_divergence( mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p) loss_kld += cf.sum(kld) loss_sse = 0 for r_t in r_t_array: loss_sse += cf.sum(cf.squared_error(r_t, x)) mu_x, ln_var_x = x_param loss_nll = cf.gaussian_nll(x, mu_x, ln_var_x) loss_nll /= args.batch_size loss_kld /= args.batch_size loss_sse /= args.batch_size loss = args.loss_beta * loss_nll + loss_kld + args.loss_alpha * loss_sse model.cleargrads() loss.backward(loss_scale=optimizer.loss_scale()) optimizer.update(num_updates, loss_value=float(loss.array)) num_updates += 1 mean_kld += float(loss_kld.data) mean_nll += float(loss_nll.data) mean_mse += float(loss_sse.data) / num_pixels / ( hyperparams.generator_generation_steps - 1) printr( "Iteration {}: Batch {} / {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e}" .format( iteration + 1, batch_index + 1, len(iterator), float(loss_nll.data) / num_pixels + math.log(256.0), float(loss_sse.data) / num_pixels / (hyperparams.generator_generation_steps - 1), float(loss_kld.data), optimizer.learning_rate)) if comm.rank == 0 and batch_index > 0 and batch_index % 100 == 0: model.serialize(args.snapshot_directory) if comm.rank == 0: model.serialize(args.snapshot_directory) if comm.rank == 0: elapsed_time = time.time() - start_time print( "\r\033[2KIteration {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e} - elapsed_time: {:.3f} min" .format( iteration + 1, mean_nll / len(iterator) / num_pixels + math.log(256.0), mean_mse / len(iterator), mean_kld / len(iterator), optimizer.learning_rate, elapsed_time / 60))