Example #1
0
def main():
    try:
        os.mkdir(args.snapshot_directory)
    except:
        pass

    images = []
    files = os.listdir(args.dataset_path)
    for filename in files:
        image = np.load(os.path.join(args.dataset_path, filename))
        image = image / 255 * 2.0 - 1.0
        images.append(image)

    images = np.vstack(images)
    images = images.transpose((0, 3, 1, 2)).astype(np.float32)
    train_dev_split = 0.9
    num_images = images.shape[0]
    num_train_images = int(num_images * train_dev_split)
    num_dev_images = num_images - num_train_images
    images_train = images[:args.batch_size]
    images_dev = images[args.batch_size:]

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cp

    hyperparams = HyperParameters()
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.layer_normalization_enabled = args.layer_normalization
    hyperparams.pixel_n = args.pixel_n
    hyperparams.chz_channels = args.chz_channels
    hyperparams.inference_channels_downsampler_x = args.channels_downsampler_x
    hyperparams.pixel_sigma_i = args.initial_pixel_sigma
    hyperparams.pixel_sigma_f = args.final_pixel_sigma
    hyperparams.chrz_size = (32, 32)
    hyperparams.save(args.snapshot_directory)
    hyperparams.print()

    model = Model(hyperparams, snapshot_directory=args.snapshot_directory)
    if using_gpu:
        model.to_gpu()

    optimizer = AdamOptimizer(model.parameters,
                              lr_i=args.initial_lr,
                              lr_f=args.final_lr)
    optimizer.print()

    sigma_t = hyperparams.pixel_sigma_i
    pixel_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                        sigma_t**2,
                        dtype="float32")
    pixel_ln_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                           math.log(sigma_t**2),
                           dtype="float32")
    num_pixels = images.shape[1] * images.shape[2] * images.shape[3]

    figure = plt.figure(figsize=(20, 4))
    axis_1 = figure.add_subplot(1, 5, 1)
    axis_2 = figure.add_subplot(1, 5, 2)
    axis_3 = figure.add_subplot(1, 5, 3)
    axis_4 = figure.add_subplot(1, 5, 4)
    axis_5 = figure.add_subplot(1, 5, 5)

    for iteration in range(args.training_steps):
        x = to_gpu(images_train)
        loss_kld = 0

        z_t_params_array, r_final = model.generate_z_params_and_x_from_posterior(
            x)
        for params in z_t_params_array:
            mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params
            kld = draw.nn.functions.gaussian_kl_divergence(
                mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)
            loss_kld += cf.sum(kld)

        mean_x_enc = r_final
        negative_log_likelihood = draw.nn.functions.gaussian_negative_log_likelihood(
            x, mean_x_enc, pixel_var, pixel_ln_var)
        loss_nll = cf.sum(negative_log_likelihood)
        loss_mse = cf.mean_squared_error(mean_x_enc, x)

        loss_nll /= args.batch_size
        loss_kld /= args.batch_size
        loss = loss_nll + loss_kld
        loss = loss_nll
        model.cleargrads()
        loss.backward()
        optimizer.update(iteration)

        sf = hyperparams.pixel_sigma_f
        si = hyperparams.pixel_sigma_i
        sigma_t = max(sf + (si - sf) * (1.0 - iteration / hyperparams.pixel_n),
                      sf)

        pixel_var[...] = sigma_t**2
        pixel_ln_var[...] = math.log(sigma_t**2)

        model.serialize(args.snapshot_directory)
        print(
            "\033[2KIteration {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e} - sigma_t: {:.6f}"
            .format(iteration + 1,
                    float(loss_nll.data) / num_pixels, float(loss_mse.data),
                    float(loss_kld.data), optimizer.learning_rate, sigma_t))

        if iteration % 10 == 0:
            axis_1.imshow(make_uint8(x[0]))
            axis_2.imshow(make_uint8(mean_x_enc.data[0]))

            x_dev = images_dev[random.choice(range(num_dev_images))]
            axis_3.imshow(make_uint8(x_dev))

            with chainer.using_config("train", False), chainer.using_config(
                    "enable_backprop", False):
                x_dev = to_gpu(x_dev)[None, ...]
                _, r_final = model.generate_z_params_and_x_from_posterior(
                    x_dev)
                mean_x_enc = r_final
                axis_4.imshow(make_uint8(mean_x_enc.data[0]))

                mean_x_d = model.generate_image(batch_size=1, xp=xp)
                axis_5.imshow(make_uint8(mean_x_d[0]))

            plt.pause(0.01)
Example #2
0
def main():
    try:
        os.mkdir(args.snapshot_directory)
    except:
        pass

    comm = chainermn.create_communicator()
    device = comm.intra_rank
    cuda.get_device(device).use()
    xp = cp

    images = []
    files = os.listdir(args.dataset_path)
    files.sort()
    subset_size = int(math.ceil(len(files) / comm.size))
    files = deque(files)
    files.rotate(-subset_size * comm.rank)
    files = list(files)[:subset_size]
    for filename in files:
        image = np.load(os.path.join(args.dataset_path, filename))
        image = image / 256
        images.append(image)

    print(comm.rank, files)

    images = np.vstack(images)
    images = images.transpose((0, 3, 1, 2)).astype(np.float32)
    train_dev_split = 0.9
    num_images = images.shape[0]
    num_train_images = int(num_images * train_dev_split)
    num_dev_images = num_images - num_train_images
    images_train = images[:num_train_images]

    # To avoid OpenMPI bug
    # multiprocessing.set_start_method("forkserver")
    # p = multiprocessing.Process(target=print, args=("", ))
    # p.start()
    # p.join()

    hyperparams = HyperParameters()
    hyperparams.chz_channels = args.chz_channels
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_share_upsampler = args.generator_share_upsampler
    hyperparams.generator_downsampler_channels = args.generator_downsampler_channels
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.inference_downsampler_channels = args.inference_downsampler_channels
    hyperparams.batch_normalization_enabled = args.enable_batch_normalization
    hyperparams.use_gru = args.use_gru
    hyperparams.no_backprop_diff_xr = args.no_backprop_diff_xr

    if comm.rank == 0:
        hyperparams.save(args.snapshot_directory)
        hyperparams.print()

    if args.use_gru:
        model = GRUModel(hyperparams,
                         snapshot_directory=args.snapshot_directory)
    else:
        model = LSTMModel(hyperparams,
                          snapshot_directory=args.snapshot_directory)
    model.to_gpu()

    optimizer = AdamOptimizer(model.parameters,
                              lr_i=args.initial_lr,
                              lr_f=args.final_lr,
                              beta_1=args.adam_beta1,
                              communicator=comm)
    if comm.rank == 0:
        optimizer.print()

    num_pixels = images.shape[1] * images.shape[2] * images.shape[3]

    dataset = draw.data.Dataset(images_train)
    iterator = draw.data.Iterator(dataset, batch_size=args.batch_size)

    num_updates = 0

    for iteration in range(args.training_steps):
        mean_kld = 0
        mean_nll = 0
        mean_mse = 0
        start_time = time.time()

        for batch_index, data_indices in enumerate(iterator):
            x = dataset[data_indices]
            x += np.random.uniform(0, 1 / 256, size=x.shape)
            x = to_gpu(x)

            z_t_param_array, x_param, r_t_array = model.sample_z_and_x_params_from_posterior(
                x)

            loss_kld = 0
            for params in z_t_param_array:
                mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params
                kld = draw.nn.functions.gaussian_kl_divergence(
                    mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)
                loss_kld += cf.sum(kld)

            loss_sse = 0
            for r_t in r_t_array:
                loss_sse += cf.sum(cf.squared_error(r_t, x))

            mu_x, ln_var_x = x_param

            loss_nll = cf.gaussian_nll(x, mu_x, ln_var_x)

            loss_nll /= args.batch_size
            loss_kld /= args.batch_size
            loss_sse /= args.batch_size
            loss = args.loss_beta * loss_nll + loss_kld + args.loss_alpha * loss_sse

            model.cleargrads()
            loss.backward(loss_scale=optimizer.loss_scale())
            optimizer.update(num_updates, loss_value=float(loss.array))

            num_updates += 1
            mean_kld += float(loss_kld.data)
            mean_nll += float(loss_nll.data)
            mean_mse += float(loss_sse.data) / num_pixels / (
                hyperparams.generator_generation_steps - 1)

            printr(
                "Iteration {}: Batch {} / {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e}"
                .format(
                    iteration + 1, batch_index + 1, len(iterator),
                    float(loss_nll.data) / num_pixels + math.log(256.0),
                    float(loss_sse.data) / num_pixels /
                    (hyperparams.generator_generation_steps - 1),
                    float(loss_kld.data), optimizer.learning_rate))

            if comm.rank == 0 and batch_index > 0 and batch_index % 100 == 0:
                model.serialize(args.snapshot_directory)

        if comm.rank == 0:
            model.serialize(args.snapshot_directory)

        if comm.rank == 0:
            elapsed_time = time.time() - start_time
            print(
                "\r\033[2KIteration {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e} - elapsed_time: {:.3f} min"
                .format(
                    iteration + 1,
                    mean_nll / len(iterator) / num_pixels + math.log(256.0),
                    mean_mse / len(iterator), mean_kld / len(iterator),
                    optimizer.learning_rate, elapsed_time / 60))