Example #1
0
def main():

    os.environ['CHAINER_SEED'] = str(args.seed)
    logging.info('chainer seed = ' + os.environ['CHAINER_SEED'])

    _mkdir(args.snapshot_directory)
    _mkdir(args.log_directory)

    meter_train = Meter()
    meter_train.load(args.snapshot_directory)

    #==============================================================================
    # Dataset
    #==============================================================================
    def read_npy_files(directory):
        filenames = []
        files = os.listdir(os.path.join(directory, "images"))
        for filename in files:
            if filename.endswith(".npy"):
                filenames.append(filename)
        filenames.sort()
        
        dataset_images = []
        dataset_viewpoints = []
        for i in range(len(filenames)):
            images_npy_path = os.path.join(directory, "images", filenames[i])
            viewpoints_npy_path = os.path.join(directory, "viewpoints", filenames[i])
            tmp_images = np.load(images_npy_path)
            tmp_viewpoints = np.load(viewpoints_npy_path)
        
            assert tmp_images.shape[0] == tmp_viewpoints.shape[0]
            
            dataset_images.extend(tmp_images)
            dataset_viewpoints.extend(tmp_viewpoints)
        dataset_images = np.array(dataset_images)
        dataset_viewpoints = np.array(dataset_viewpoints)

        dataset = list()
        for i in range(len(dataset_images)):
            item = {'image':dataset_images[i],'viewpoint':dataset_viewpoints[i]}
            dataset.append(item)
        
        return dataset

    def read_files(directory):
        filenames = []
        files = os.listdir(directory)
        
        for filename in files:
            if filename.endswith(".h5"):
                filenames.append(filename)
        filenames.sort()
        
        dataset_images = []
        dataset_viewpoints = []
        for i in range(len(filenames)):
            F = h5py.File(os.path.join(directory,filenames[i]))
            tmp_images = list(F["images"])
            tmp_viewpoints = list(F["viewpoints"])
            
            dataset_images.extend(tmp_images)
            dataset_viewpoints.extend(tmp_viewpoints)
        
        dataset_images = np.array(dataset_images)
        dataset_viewpoints = np.array(dataset_viewpoints)

        dataset = list()
        for i in range(len(dataset_images)):
            item = {'image':dataset_images[i],'viewpoint':dataset_viewpoints[i]}
            dataset.append(item)
        
        return dataset
    
    dataset_train = read_files(args.train_dataset_directory)
    # ipdb.set_trace()
    if args.test_dataset_directory is not None:
        dataset_test = read_files(args.test_dataset_directory)
    
    # ipdb.set_trace()
    
    #==============================================================================
    # Hyperparameters
    #==============================================================================
    hyperparams = HyperParameters()
    hyperparams.num_layers = args.generation_steps
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.h_channels = args.h_channels
    hyperparams.z_channels = args.z_channels
    hyperparams.u_channels = args.u_channels
    hyperparams.r_channels = args.r_channels
    hyperparams.image_size = (args.image_size, args.image_size)
    hyperparams.representation_architecture = args.representation_architecture
    hyperparams.pixel_sigma_annealing_steps = args.pixel_sigma_annealing_steps
    hyperparams.initial_pixel_sigma = args.initial_pixel_sigma
    hyperparams.final_pixel_sigma = args.final_pixel_sigma

    hyperparams.save(args.snapshot_directory)
    print(hyperparams, "\n")

    #==============================================================================
    # Model
    #==============================================================================
    model = Model(hyperparams)
    model.load(args.snapshot_directory, meter_train.epoch)
    
    #==============================================================================
    # Pixel-variance annealing
    #==============================================================================
    variance_scheduler = PixelVarianceScheduler(
        sigma_start=args.initial_pixel_sigma,
        sigma_end=args.final_pixel_sigma,
        final_num_updates=args.pixel_sigma_annealing_steps)
    variance_scheduler.load(args.snapshot_directory)
    print(variance_scheduler, "\n")

    pixel_log_sigma = np.full(
        (args.batch_size, 3) + hyperparams.image_size,
        math.log(variance_scheduler.standard_deviation),
        dtype="float32")

    #==============================================================================
    # Selecting the GPU
    #==============================================================================
    # xp = np
    # gpu_device = args.gpu_device
    # using_gpu = gpu_device >= 0
    # if using_gpu:
    #     cuda.get_device(gpu_device).use()
    #     xp = cp

    # devices = tuple([chainer.get_device(f"@cupy:{gpu}") for gpu in args.gpu_devices])
    # if any(device.xp is chainerx for device in devices):
    #     sys.stderr.write("Cannot support ChainerX devices.")
    #     sys.exit(1)

    ngpu = args.ngpu
    using_gpu = ngpu > 0
    xp=cp
    if ngpu == 1:
        gpu_id = 0
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(gpu_id).use()
        model.to_gpu()  # Copy the model to the GPU
        logging.info('single gpu calculation.')
    elif ngpu > 1:
        gpu_id = 0
        devices = {'main': gpu_id}
        for gid in six.moves.xrange(1, ngpu):
            devices['sub_%d' % gid] = gid
        logging.info('multi gpu calculation (#gpus = %d).' % ngpu)
        logging.info('batch size is automatically increased (%d -> %d)' % (
            args.batch_size, args.batch_size * args.ngpu))
    else:
        gpu_id = -1
        logging.info('cpu calculation')

    #==============================================================================
    # Logging
    #==============================================================================
    csv = DataFrame()
    csv.load(args.log_directory)

    #==============================================================================
    # Optimizer
    #==============================================================================
    initial_training_step=0
    # lr = compute_lr_at_step(initial_training_step) # function in GQN AdamOptimizer
    
    optimizer = chainer.optimizers.Adam(beta1=0.9, beta2=0.99, eps=1e-8) #lr is needed originally
    optimizer.setup(model)
    # optimizer = AdamOptimizer(
    #     model.parameters,
    #     initial_lr=args.initial_lr,
    #     final_lr=args.final_lr,
    #     initial_training_step=variance_scheduler.training_step)
    # )
    print(optimizer, "\n")


    #==============================================================================
    # Training iterations
    #==============================================================================
    if ngpu>1:

        train_iters = [
            chainer.iterators.MultiprocessIterator(dataset_train, args.batch_size, n_processes=args.number_processes, order_sampler=chainer.iterators.ShuffleOrderSampler()) for i in chainer.datasets.split_dataset_n_random(dataset_train, len(devices))
        ]
        updater = CustomParallelUpdater(train_iters, optimizer, devices, converter=chainer.dataset.concat_examples, pixel_log_sigma=pixel_log_sigma)
    
    elif ngpu==1:
        
        train_iters = chainer.iterators.SerialIterator(dataset_train,args.batch_size,shuffle=True)
        updater = CustomUpdater(train_iters, optimizer, device=0, converter=chainer.dataset.concat_examples, pixel_log_sigma=pixel_log_sigma)
        
    else:
        raise NotImplementedError('Implement for single gpu or cpu')
    
    trainer = chainer.training.Trainer(updater,(args.epochs,'epoch'),args.snapshot_directory)
    
    trainer.extend(AnnealLearningRate(
                                    initial_lr=args.initial_lr,
                                    final_lr=args.final_lr,
                                    annealing_steps=args.pixel_sigma_annealing_steps,
                                    optimizer=optimizer),
                                    trigger=(1,'iteration'))

    # add information per epoch with report?
    # add learning rate annealing, snapshot saver, evaluator
    trainer.extend(extensions.LogReport())
    
    trainer.extend(extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}', 
                                    savefun=chainer.serializers.save_hdf5, 
                                    target=optimizer.target),
                                    trigger=(args.report_interval_iters,'epoch'))
    
    trainer.extend(extensions.ProgressBar())
    reports = ['epoch', 'main/loss', 'main/bits_per_pixel', 'main/NLL', 'main/MSE']
    #Validation
    if args.test_dataset_directory is not None:
        test_iters = chainer.iterators.SerialIterator(
            dataset_test,args.batch_size*6, repeat=False, shuffle=False)

        trainer.extend(Validation(
            test_iters, chainer.dataset.concat_examples, optimizer.target, variance_scheduler,device=0))

        reports.append('validation/main/bits_per_pixel')
        reports.append('validation/main/NLL')
        reports.append('validation/main/MSE')
    reports.append('elapsed_time')

    trainer.extend(
        extensions.PrintReport(reports), trigger=(args.report_interval_iters, 'iteration')) 

    # np.random.seed(args.seed)
    # cp.random.seed(args.seed)

    trainer.run()
Example #2
0
def main():
    try:
        os.mkdir(args.snapshot_directory)
    except:
        pass

    np.random.seed(0)

    xp = np
    device_gpu = args.gpu_device
    device_cpu = -1
    using_gpu = device_gpu >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_directory)

    hyperparams = HyperParameters()
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.generator_share_upsampler = args.generator_share_upsampler
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.h_channels = args.h_channels
    hyperparams.z_channels = args.z_channels
    hyperparams.u_channels = args.u_channels
    hyperparams.image_size = (args.image_size, args.image_size)
    hyperparams.representation_channels = args.representation_channels
    hyperparams.representation_architecture = args.representation_architecture
    hyperparams.pixel_n = args.pixel_n
    hyperparams.pixel_sigma_i = args.initial_pixel_variance
    hyperparams.pixel_sigma_f = args.final_pixel_variance
    hyperparams.save(args.snapshot_directory)
    print(hyperparams)

    model = Model(hyperparams,
                  snapshot_directory=args.snapshot_directory,
                  optimized=args.optimized)
    if using_gpu:
        model.to_gpu()

    scheduler = Scheduler(sigma_start=args.initial_pixel_variance,
                          sigma_end=args.final_pixel_variance,
                          final_num_updates=args.pixel_n,
                          snapshot_directory=args.snapshot_directory)
    print(scheduler)

    optimizer = AdamOptimizer(model.parameters,
                              mu_i=args.initial_lr,
                              mu_f=args.final_lr,
                              initial_training_step=scheduler.num_updates)
    print(optimizer)

    pixel_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                        scheduler.pixel_variance**2,
                        dtype="float32")
    pixel_ln_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                           math.log(scheduler.pixel_variance**2),
                           dtype="float32")

    representation_shape = (args.batch_size,
                            hyperparams.representation_channels,
                            args.image_size // 4, args.image_size // 4)

    fig = plt.figure(figsize=(9, 3))
    axis_data = fig.add_subplot(1, 3, 1)
    axis_data.set_title("Data")
    axis_data.axis("off")
    axis_reconstruction = fig.add_subplot(1, 3, 2)
    axis_reconstruction.set_title("Reconstruction")
    axis_reconstruction.axis("off")
    axis_generation = fig.add_subplot(1, 3, 3)
    axis_generation.set_title("Generation")
    axis_generation.axis("off")

    current_training_step = 0
    for iteration in range(args.training_iterations):
        mean_kld = 0
        mean_nll = 0
        mean_mse = 0
        mean_elbo = 0
        total_num_batch = 0
        start_time = time.time()

        for subset_index, subset in enumerate(dataset):
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for batch_index, data_indices in enumerate(iterator):
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)

                total_views = images.shape[1]

                # Sample number of views
                num_views = random.choice(range(1, total_views + 1))
                observation_view_indices = list(range(total_views))
                random.shuffle(observation_view_indices)
                observation_view_indices = observation_view_indices[:num_views]
                query_index = random.choice(range(total_views))

                if num_views > 0:
                    observation_images = preprocess_images(
                        images[:, observation_view_indices])
                    observation_query = viewpoints[:, observation_view_indices]
                    representation = model.compute_observation_representation(
                        observation_images, observation_query)
                else:
                    representation = xp.zeros(representation_shape,
                                              dtype="float32")
                    representation = chainer.Variable(representation)

                # Sample query
                query_index = random.choice(range(total_views))
                query_images = preprocess_images(images[:, query_index])
                query_viewpoints = viewpoints[:, query_index]

                # Transfer to gpu if necessary
                query_images = to_device(query_images, device_gpu)
                query_viewpoints = to_device(query_viewpoints, device_gpu)

                z_t_param_array, mean_x = model.sample_z_and_x_params_from_posterior(
                    query_images, query_viewpoints, representation)

                # Compute loss
                ## KL Divergence
                loss_kld = 0
                for params in z_t_param_array:
                    mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params
                    kld = gqn.functions.gaussian_kl_divergence(
                        mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)
                    loss_kld += cf.sum(kld)

                ## Negative log-likelihood of generated image
                loss_nll = cf.sum(
                    gqn.functions.gaussian_negative_log_likelihood(
                        query_images, mean_x, pixel_var, pixel_ln_var))

                # Calculate the average loss value
                loss_nll = loss_nll / args.batch_size
                loss_kld = loss_kld / args.batch_size

                loss = loss_nll / scheduler.pixel_variance + loss_kld

                model.cleargrads()
                loss.backward()
                optimizer.update(current_training_step)

                loss_nll = float(loss_nll.data) + math.log(256.0)
                loss_kld = float(loss_kld.data)

                elbo = -(loss_nll + loss_kld)

                loss_mse = float(
                    cf.mean_squared_error(query_images, mean_x).data)

                printr(
                    "Iteration {}: Subset {} / {}: Batch {} / {} - elbo: {:.2f} - loss: nll: {:.2f} mse: {:.6e} kld: {:.5f} - lr: {:.4e} - pixel_variance: {:.5f} - step: {}  "
                    .format(iteration + 1,
                            subset_index + 1, len(dataset), batch_index + 1,
                            len(iterator), elbo, loss_nll, loss_mse, loss_kld,
                            optimizer.learning_rate, scheduler.pixel_variance,
                            current_training_step))

                scheduler.step(iteration, current_training_step)
                pixel_var[...] = scheduler.pixel_variance**2
                pixel_ln_var[...] = math.log(scheduler.pixel_variance**2)

                total_num_batch += 1
                current_training_step += 1
                mean_kld += loss_kld
                mean_nll += loss_nll
                mean_mse += loss_mse
                mean_elbo += elbo

            model.serialize(args.snapshot_directory)

            # Visualize
            if args.with_visualization:
                axis_data.imshow(make_uint8(query_images[0]),
                                 interpolation="none")
                axis_reconstruction.imshow(make_uint8(mean_x.data[0]),
                                           interpolation="none")

                with chainer.no_backprop_mode():
                    generated_x = model.generate_image(
                        query_viewpoints[None, 0], representation[None, 0])
                    axis_generation.imshow(make_uint8(generated_x[0]),
                                           interpolation="none")
                plt.pause(1e-8)

        elapsed_time = time.time() - start_time
        print(
            "\033[2KIteration {} - elbo: {:.2f} - loss: nll: {:.2f} mse: {:.6e} kld: {:.5f} - lr: {:.4e} - pixel_variance: {:.5f} - step: {} - time: {:.3f} min"
            .format(iteration + 1, mean_elbo / total_num_batch,
                    mean_nll / total_num_batch, mean_mse / total_num_batch,
                    mean_kld / total_num_batch, optimizer.learning_rate,
                    scheduler.pixel_variance, current_training_step,
                    elapsed_time / 60))
        model.serialize(args.snapshot_directory)
Example #3
0
def main():
    _mkdir(args.snapshot_directory)
    _mkdir(args.log_directory)

    meter_train = Meter()
    meter_train.load(args.snapshot_directory)

    #==============================================================================
    # Workaround to fix OpenMPI bug
    #==============================================================================
    multiprocessing.set_start_method("forkserver")
    p = multiprocessing.Process(target=print, args=("", ))
    p.start()
    p.join()

    #==============================================================================
    # Selecting the GPU
    #==============================================================================
    comm = chainermn.create_communicator()
    device = comm.intra_rank
    cuda.get_device(device).use()

    def _print(*args):
        if comm.rank == 0:
            print(*args)

    _print("Using {} GPUs".format(comm.size))

    #==============================================================================
    # Dataset
    #==============================================================================
    dataset_train = Dataset(args.train_dataset_directory)
    dataset_test = None
    if args.test_dataset_directory is not None:
        dataset_test = Dataset(args.test_dataset_directory)

    #==============================================================================
    # Hyperparameters
    #==============================================================================
    hyperparams = HyperParameters()
    hyperparams.num_layers = args.generation_steps
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.h_channels = args.h_channels
    hyperparams.z_channels = args.z_channels
    hyperparams.u_channels = args.u_channels
    hyperparams.r_channels = args.r_channels
    hyperparams.image_size = (args.image_size, args.image_size)
    hyperparams.representation_architecture = args.representation_architecture
    hyperparams.pixel_sigma_annealing_steps = args.pixel_sigma_annealing_steps
    hyperparams.initial_pixel_sigma = args.initial_pixel_sigma
    hyperparams.final_pixel_sigma = args.final_pixel_sigma
    _print(hyperparams, "\n")

    if comm.rank == 0:
        hyperparams.save(args.snapshot_directory)

    #==============================================================================
    # Model
    #==============================================================================
    model = Model(hyperparams)
    model.load(args.snapshot_directory, meter_train.epoch)
    model.to_gpu()

    #==============================================================================
    # Pixel-variance annealing
    #==============================================================================
    variance_scheduler = PixelVarianceScheduler(
        sigma_start=args.initial_pixel_sigma,
        sigma_end=args.final_pixel_sigma,
        final_num_updates=args.pixel_sigma_annealing_steps)
    variance_scheduler.load(args.snapshot_directory)
    _print(variance_scheduler, "\n")

    pixel_log_sigma = cp.full(
        (args.batch_size, 3) + hyperparams.image_size,
        math.log(variance_scheduler.standard_deviation),
        dtype="float32")

    #==============================================================================
    # Logging
    #==============================================================================
    csv = DataFrame()
    csv.load(args.log_directory)

    #==============================================================================
    # Optimizer
    #==============================================================================
    optimizer = AdamOptimizer(
        model.parameters,
        initial_lr=args.initial_lr,
        final_lr=args.final_lr,
        initial_training_step=variance_scheduler.training_step)
    _print(optimizer, "\n")

    #==============================================================================
    # Algorithms
    #==============================================================================
    def encode_scene(images, viewpoints):
        # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
        images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)

        # Sample number of views
        total_views = images.shape[1]
        num_views = random.choice(range(1, total_views + 1))

        # Sample views
        observation_view_indices = list(range(total_views))
        random.shuffle(observation_view_indices)
        observation_view_indices = observation_view_indices[:num_views]

        observation_images = preprocess_images(
            images[:, observation_view_indices])

        observation_query = viewpoints[:, observation_view_indices]
        representation = model.compute_observation_representation(
            observation_images, observation_query)

        # Sample query view
        query_index = random.choice(range(total_views))
        query_images = preprocess_images(images[:, query_index])
        query_viewpoints = viewpoints[:, query_index]

        # Transfer to gpu if necessary
        query_images = cuda.to_gpu(query_images)
        query_viewpoints = cuda.to_gpu(query_viewpoints)

        return representation, query_images, query_viewpoints

    def estimate_ELBO(query_images, z_t_param_array, pixel_mean,
                      pixel_log_sigma):
        # KL Diverge, pixel_ln_varnce
        kl_divergence = 0
        for params_t in z_t_param_array:
            mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params_t
            normal_q = chainer.distributions.Normal(
                mean_z_q, log_scale=ln_var_z_q)
            normal_p = chainer.distributions.Normal(
                mean_z_p, log_scale=ln_var_z_p)
            kld_t = chainer.kl_divergence(normal_q, normal_p)
            kl_divergence += cf.sum(kld_t)
        kl_divergence = kl_divergence / args.batch_size

        # Negative log-likelihood of generated image
        batch_size = query_images.shape[0]
        num_pixels_per_batch = np.prod(query_images.shape[1:])
        normal = chainer.distributions.Normal(
            query_images, log_scale=pixel_log_sigma)

        log_px = cf.sum(normal.log_prob(pixel_mean)) / batch_size
        negative_log_likelihood = -log_px

        # Empirical ELBO
        ELBO = log_px - kl_divergence

        # https://arxiv.org/abs/1604.08772 Section.2
        # https://www.reddit.com/r/MachineLearning/comments/56m5o2/discussion_calculation_of_bitsdims/
        bits_per_pixel = -(ELBO / num_pixels_per_batch - np.log(256)) / np.log(
            2)

        return ELBO, bits_per_pixel, negative_log_likelihood, kl_divergence

    #==============================================================================
    # Training iterations
    #==============================================================================
    dataset_size = len(dataset_train)
    random.seed(0)
    np.random.seed(0)
    cp.random.seed(0)

    for epoch in range(args.epochs):
        _print("Epoch {}/{}:".format(
            epoch + 1,
            args.epochs,
        ))
        meter_train.next_epoch()

        subset_indices = list(range(len(dataset_train.subset_filenames)))
        subset_size_per_gpu = len(subset_indices) // comm.size
        if len(subset_indices) % comm.size != 0:
            subset_size_per_gpu += 1

        for subset_loop in range(subset_size_per_gpu):
            random.shuffle(subset_indices)
            subset_index = subset_indices[comm.rank]
            subset = dataset_train.read(subset_index)
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for batch_index, data_indices in enumerate(iterator):
                #------------------------------------------------------------------------------
                # Scene encoder
                #------------------------------------------------------------------------------
                # images.shape: (batch, views, height, width, channels)
                images, viewpoints = subset[data_indices]
                representation, query_images, query_viewpoints = encode_scene(
                    images, viewpoints)

                #------------------------------------------------------------------------------
                # Compute empirical ELBO
                #------------------------------------------------------------------------------
                # Compute distribution parameterws
                (z_t_param_array,
                 pixel_mean) = model.sample_z_and_x_params_from_posterior(
                     query_images, query_viewpoints, representation)

                # Compute ELBO
                (ELBO, bits_per_pixel, negative_log_likelihood,
                 kl_divergence) = estimate_ELBO(query_images, z_t_param_array,
                                                pixel_mean, pixel_log_sigma)

                #------------------------------------------------------------------------------
                # Update parameters
                #------------------------------------------------------------------------------
                loss = -ELBO
                model.cleargrads()
                loss.backward()
                optimizer.update(meter_train.num_updates)

                #------------------------------------------------------------------------------
                # Logging
                #------------------------------------------------------------------------------
                with chainer.no_backprop_mode():
                    mean_squared_error = cf.mean_squared_error(
                        query_images, pixel_mean)
                meter_train.update(
                    ELBO=float(ELBO.data),
                    bits_per_pixel=float(bits_per_pixel.data),
                    negative_log_likelihood=float(
                        negative_log_likelihood.data),
                    kl_divergence=float(kl_divergence.data),
                    mean_squared_error=float(mean_squared_error.data))

                #------------------------------------------------------------------------------
                # Annealing
                #------------------------------------------------------------------------------
                variance_scheduler.update(meter_train.num_updates)
                pixel_log_sigma[...] = math.log(
                    variance_scheduler.standard_deviation)

            if subset_loop % 100 == 0:
                _print("    Subset {}/{}:".format(
                    subset_loop + 1,
                    subset_size_per_gpu,
                    dataset_size,
                ))
                _print("        {}".format(meter_train))
                _print("        lr: {} - sigma: {}".format(
                    optimizer.learning_rate,
                    variance_scheduler.standard_deviation))

        #------------------------------------------------------------------------------
        # Validation
        #------------------------------------------------------------------------------
        meter_test = None
        if dataset_test is not None:
            meter_test = Meter()
            batch_size_test = args.batch_size * 6
            subset_indices_test = list(
                range(len(dataset_test.subset_filenames)))
            pixel_log_sigma_test = cp.full(
                (batch_size_test, 3) + hyperparams.image_size,
                math.log(variance_scheduler.standard_deviation),
                dtype="float32")

            subset_size_per_gpu = len(subset_indices_test) // comm.size

            with chainer.no_backprop_mode():
                for subset_loop in range(subset_size_per_gpu):
                    subset_index = subset_indices_test[subset_loop * comm.size
                                                       + comm.rank]
                    subset = dataset_train.read(subset_index)
                    iterator = gqn.data.Iterator(
                        subset, batch_size=batch_size_test)

                    for data_indices in iterator:
                        images, viewpoints = subset[data_indices]

                        # Scene encoder
                        representation, query_images, query_viewpoints = encode_scene(
                            images, viewpoints)

                        # Compute empirical ELBO
                        (z_t_param_array, pixel_mean
                         ) = model.sample_z_and_x_params_from_posterior(
                             query_images, query_viewpoints, representation)
                        (ELBO, bits_per_pixel, negative_log_likelihood,
                         kl_divergence) = estimate_ELBO(
                             query_images, z_t_param_array, pixel_mean,
                             pixel_log_sigma_test)
                        mean_squared_error = cf.mean_squared_error(
                            query_images, pixel_mean)

                        # Logging
                        meter_test.update(
                            ELBO=float(ELBO.data),
                            bits_per_pixel=float(bits_per_pixel.data),
                            negative_log_likelihood=float(
                                negative_log_likelihood.data),
                            kl_divergence=float(kl_divergence.data),
                            mean_squared_error=float(mean_squared_error.data))

            meter = meter_test.allreduce(comm)

            _print("    Test:")
            _print("        {} - done in {:.3f} min".format(
                meter,
                meter.elapsed_time,
            ))

            model.save(args.snapshot_directory, meter_train.epoch)
            variance_scheduler.save(args.snapshot_directory)
            meter_train.save(args.snapshot_directory)
            csv.save(args.log_directory)

            _print("Epoch {} done in {:.3f} min".format(
                epoch + 1,
                meter_train.epoch_elapsed_time,
            ))
            _print("    {}".format(meter_train))
            _print("    lr: {} - sigma: {} - training_steps: {}".format(
                optimizer.learning_rate,
                variance_scheduler.standard_deviation,
                meter_train.num_updates,
            ))
            _print("    Time elapsed: {:.3f} min".format(
                meter_train.elapsed_time))
def main():
    try:
        os.mkdir(args.snapshot_path)
    except:
        pass

    comm = chainermn.create_communicator()
    device = comm.intra_rank
    print("device", device, "/", comm.size)
    cuda.get_device(device).use()
    xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    hyperparams = HyperParameters()
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.channels_chz = args.channels_chz
    hyperparams.generator_channels_u = args.channels_u
    hyperparams.inference_channels_map_x = args.channels_map_x
    hyperparams.pixel_n = args.pixel_n
    hyperparams.pixel_sigma_i = args.initial_pixel_sigma
    hyperparams.pixel_sigma_f = args.final_pixel_sigma
    if comm.rank == 0:
        hyperparams.save(args.snapshot_path)
        hyperparams.print()

    model = Model(hyperparams, snapshot_directory=args.snapshot_path)
    model.to_gpu()

    optimizer = Optimizer(
        model.parameters,
        communicator=comm,
        mu_i=args.initial_lr,
        mu_f=args.final_lr)
    if comm.rank == 0:
        optimizer.print()

    dataset_mean, dataset_std = dataset.load_mean_and_std()

    if comm.rank == 0:
        np.save(os.path.join(args.snapshot_path, "mean.npy"), dataset_mean)
        np.save(os.path.join(args.snapshot_path, "std.npy"), dataset_std)

    # avoid division by zero
    dataset_std += 1e-12

    sigma_t = hyperparams.pixel_sigma_i
    pixel_var = xp.full(
        (args.batch_size, 3) + hyperparams.image_size,
        sigma_t**2,
        dtype="float32")
    pixel_ln_var = xp.full(
        (args.batch_size, 3) + hyperparams.image_size,
        math.log(sigma_t**2),
        dtype="float32")

    random.seed(0)
    subset_indices = list(range(len(dataset.subset_filenames)))

    current_training_step = 0
    for iteration in range(args.training_iterations):
        mean_kld = 0
        mean_nll = 0
        total_batch = 0
        subset_size_per_gpu = len(subset_indices) // comm.size
        start_time = time.time()

        for subset_loop in range(subset_size_per_gpu):
            random.shuffle(subset_indices)
            subset_index = subset_indices[comm.rank]
            subset = dataset.read(subset_index)
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for batch_index, data_indices in enumerate(iterator):
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # preprocessing
                images = (images - dataset_mean) / dataset_std

                # (batch, views, height, width, channels) ->  (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3))

                total_views = images.shape[1]

                # sample number of views
                num_views = random.choice(range(total_views))
                query_index = random.choice(range(total_views))

                if current_training_step == 0 and num_views == 0:
                    num_views = 1  # avoid OpenMPI error

                if num_views > 0:
                    r = model.compute_observation_representation(
                        images[:, :num_views], viewpoints[:, :num_views])
                else:
                    r = xp.zeros(
                        (args.batch_size, hyperparams.channels_r) +
                        hyperparams.chrz_size,
                        dtype="float32")
                    r = chainer.Variable(r)

                query_images = images[:, query_index]
                query_viewpoints = viewpoints[:, query_index]
                # transfer to gpu
                query_images = to_gpu(query_images)
                query_viewpoints = to_gpu(query_viewpoints)

                h0_gen, c0_gen, u_0, h0_enc, c0_enc = model.generate_initial_state(
                    args.batch_size, xp)

                loss_kld = 0

                hl_enc = h0_enc
                cl_enc = c0_enc
                hl_gen = h0_gen
                cl_gen = c0_gen
                ul_enc = u_0

                xq = model.inference_downsampler.downsample(query_images)

                for l in range(model.generation_steps):
                    inference_core = model.get_inference_core(l)
                    inference_posterior = model.get_inference_posterior(l)
                    generation_core = model.get_generation_core(l)
                    generation_piror = model.get_generation_prior(l)

                    h_next_enc, c_next_enc = inference_core.forward_onestep(
                        hl_gen, hl_enc, cl_enc, xq, query_viewpoints, r)

                    mean_z_q = inference_posterior.compute_mean_z(hl_enc)
                    ln_var_z_q = inference_posterior.compute_ln_var_z(hl_enc)
                    ze_l = cf.gaussian(mean_z_q, ln_var_z_q)

                    mean_z_p = generation_piror.compute_mean_z(hl_gen)
                    ln_var_z_p = generation_piror.compute_ln_var_z(hl_gen)

                    h_next_gen, c_next_gen, u_next_enc = generation_core.forward_onestep(
                        hl_gen, cl_gen, ul_enc, ze_l, query_viewpoints, r)

                    kld = gqn.nn.chainer.functions.gaussian_kl_divergence(
                        mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)

                    loss_kld += cf.sum(kld)

                    hl_gen = h_next_gen
                    cl_gen = c_next_gen
                    ul_enc = u_next_enc
                    hl_enc = h_next_enc
                    cl_enc = c_next_enc

                mean_x = model.generation_observation.compute_mean_x(ul_enc)
                negative_log_likelihood = gqn.nn.chainer.functions.gaussian_negative_log_likelihood(
                    query_images, mean_x, pixel_var, pixel_ln_var)
                loss_nll = cf.sum(negative_log_likelihood)

                loss_nll /= args.batch_size
                loss_kld /= args.batch_size
                loss = loss_nll + loss_kld

                model.cleargrads()
                loss.backward()
                optimizer.update(current_training_step)

                if comm.rank == 0:
                    printr(
                        "Iteration {}: Subset {} / {}: Batch {} / {} - loss: nll: {:.3f} kld: {:.3f} - lr: {:.4e} - sigma_t: {:.6f}".
                        format(iteration + 1, subset_loop * comm.size + 1,
                               len(dataset), batch_index + 1,
                               len(subset) // args.batch_size,
                               float(loss_nll.data), float(loss_kld.data),
                               optimizer.learning_rate, sigma_t))

                sf = hyperparams.pixel_sigma_f
                si = hyperparams.pixel_sigma_i
                sigma_t = max(
                    sf + (si - sf) *
                    (1.0 - current_training_step / hyperparams.pixel_n), sf)

                pixel_var[...] = sigma_t**2
                pixel_ln_var[...] = math.log(sigma_t**2)

                total_batch += 1
                current_training_step += comm.size
                # current_training_step += 1
                mean_kld += float(loss_kld.data)
                mean_nll += float(loss_nll.data)

            if comm.rank == 0:
                model.serialize(args.snapshot_path)

        if comm.rank == 0:
            elapsed_time = time.time() - start_time
            print(
                "\033[2KIteration {} - loss: nll: {:.3f} kld: {:.3f} - lr: {:.4e} - sigma_t: {:.6f} - step: {} - elapsed_time: {:.3f} min".
                format(iteration + 1, mean_nll / total_batch,
                       mean_kld / total_batch, optimizer.learning_rate,
                       sigma_t, current_training_step, elapsed_time / 60))
            model.serialize(args.snapshot_path)
Example #5
0
def main():
    try:
        os.mkdir(args.snapshot_directory)
    except:
        pass

    images = []
    files = os.listdir(args.dataset_path)
    for filename in files:
        image = np.load(os.path.join(args.dataset_path, filename))
        image = image / 255 * 2.0 - 1.0
        images.append(image)

    images = np.vstack(images)
    images = images.transpose((0, 3, 1, 2)).astype(np.float32)
    train_dev_split = 0.9
    num_images = images.shape[0]
    num_train_images = int(num_images * train_dev_split)
    num_dev_images = num_images - num_train_images
    images_train = images[:args.batch_size]
    images_dev = images[args.batch_size:]

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cp

    hyperparams = HyperParameters()
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.layer_normalization_enabled = args.layer_normalization
    hyperparams.pixel_n = args.pixel_n
    hyperparams.chz_channels = args.chz_channels
    hyperparams.inference_channels_downsampler_x = args.channels_downsampler_x
    hyperparams.pixel_sigma_i = args.initial_pixel_sigma
    hyperparams.pixel_sigma_f = args.final_pixel_sigma
    hyperparams.chrz_size = (32, 32)
    hyperparams.save(args.snapshot_directory)
    hyperparams.print()

    model = Model(hyperparams, snapshot_directory=args.snapshot_directory)
    if using_gpu:
        model.to_gpu()

    optimizer = AdamOptimizer(model.parameters,
                              lr_i=args.initial_lr,
                              lr_f=args.final_lr)
    optimizer.print()

    sigma_t = hyperparams.pixel_sigma_i
    pixel_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                        sigma_t**2,
                        dtype="float32")
    pixel_ln_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                           math.log(sigma_t**2),
                           dtype="float32")
    num_pixels = images.shape[1] * images.shape[2] * images.shape[3]

    figure = plt.figure(figsize=(20, 4))
    axis_1 = figure.add_subplot(1, 5, 1)
    axis_2 = figure.add_subplot(1, 5, 2)
    axis_3 = figure.add_subplot(1, 5, 3)
    axis_4 = figure.add_subplot(1, 5, 4)
    axis_5 = figure.add_subplot(1, 5, 5)

    for iteration in range(args.training_steps):
        x = to_gpu(images_train)
        loss_kld = 0

        z_t_params_array, r_final = model.generate_z_params_and_x_from_posterior(
            x)
        for params in z_t_params_array:
            mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params
            kld = draw.nn.functions.gaussian_kl_divergence(
                mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)
            loss_kld += cf.sum(kld)

        mean_x_enc = r_final
        negative_log_likelihood = draw.nn.functions.gaussian_negative_log_likelihood(
            x, mean_x_enc, pixel_var, pixel_ln_var)
        loss_nll = cf.sum(negative_log_likelihood)
        loss_mse = cf.mean_squared_error(mean_x_enc, x)

        loss_nll /= args.batch_size
        loss_kld /= args.batch_size
        loss = loss_nll + loss_kld
        loss = loss_nll
        model.cleargrads()
        loss.backward()
        optimizer.update(iteration)

        sf = hyperparams.pixel_sigma_f
        si = hyperparams.pixel_sigma_i
        sigma_t = max(sf + (si - sf) * (1.0 - iteration / hyperparams.pixel_n),
                      sf)

        pixel_var[...] = sigma_t**2
        pixel_ln_var[...] = math.log(sigma_t**2)

        model.serialize(args.snapshot_directory)
        print(
            "\033[2KIteration {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e} - sigma_t: {:.6f}"
            .format(iteration + 1,
                    float(loss_nll.data) / num_pixels, float(loss_mse.data),
                    float(loss_kld.data), optimizer.learning_rate, sigma_t))

        if iteration % 10 == 0:
            axis_1.imshow(make_uint8(x[0]))
            axis_2.imshow(make_uint8(mean_x_enc.data[0]))

            x_dev = images_dev[random.choice(range(num_dev_images))]
            axis_3.imshow(make_uint8(x_dev))

            with chainer.using_config("train", False), chainer.using_config(
                    "enable_backprop", False):
                x_dev = to_gpu(x_dev)[None, ...]
                _, r_final = model.generate_z_params_and_x_from_posterior(
                    x_dev)
                mean_x_enc = r_final
                axis_4.imshow(make_uint8(mean_x_enc.data[0]))

                mean_x_d = model.generate_image(batch_size=1, xp=xp)
                axis_5.imshow(make_uint8(mean_x_d[0]))

            plt.pause(0.01)
Example #6
0
def main():
    ##############################################
    # To avoid OpenMPI bug
    multiprocessing.set_start_method("forkserver")
    p = multiprocessing.Process(target=print, args=("", ))
    p.start()
    p.join()
    ##############################################

    try:
        os.mkdir(args.snapshot_directory)
    except:
        pass

    comm = chainermn.create_communicator()
    device = comm.intra_rank
    print("device", device, "/", comm.size)
    cuda.get_device(device).use()
    xp = cupy

    dataset = gqn.data.Dataset(args.dataset_directory)

    hyperparams = HyperParameters()
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.generator_u_channels = args.u_channels
    hyperparams.generator_share_upsampler = args.generator_share_upsampler
    hyperparams.generator_subpixel_convolution_enabled = args.generator_subpixel_convolution_enabled
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.inference_downsampler_channels = args.inference_downsampler_channels
    hyperparams.h_channels = args.h_channels
    hyperparams.z_channels = args.z_channels
    hyperparams.representation_channels = args.representation_channels
    hyperparams.pixel_n = args.pixel_n
    hyperparams.pixel_sigma_i = args.initial_pixel_variance
    hyperparams.pixel_sigma_f = args.final_pixel_variance
    if comm.rank == 0:
        hyperparams.save(args.snapshot_directory)
        print(hyperparams)

    model = Model(hyperparams, snapshot_directory=args.snapshot_directory)
    model.to_gpu()

    optimizer = AdamOptimizer(model.parameters,
                              communicator=comm,
                              mu_i=args.initial_lr,
                              mu_f=args.final_lr)
    if comm.rank == 0:
        print(optimizer)

    scheduler = Scheduler(sigma_start=args.initial_pixel_variance,
                          sigma_end=args.final_pixel_variance,
                          final_num_updates=args.pixel_n)
    if comm.rank == 0:
        print(scheduler)

    pixel_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                        scheduler.pixel_variance**2,
                        dtype="float32")
    pixel_ln_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                           math.log(scheduler.pixel_variance**2),
                           dtype="float32")
    num_pixels = hyperparams.image_size[0] * hyperparams.image_size[1] * 3

    random.seed(0)
    subset_indices = list(range(len(dataset.subset_filenames)))

    representation_shape = (
        args.batch_size,
        hyperparams.representation_channels) + hyperparams.chrz_size

    current_training_step = 0
    for iteration in range(args.training_iterations):
        mean_kld = 0
        mean_nll = 0
        mean_mse = 0
        mean_elbo = 0
        total_num_batch = 0
        subset_size_per_gpu = len(subset_indices) // comm.size
        if len(subset_indices) % comm.size != 0:
            subset_size_per_gpu += 1
        start_time = time.time()

        for subset_loop in range(subset_size_per_gpu):
            random.shuffle(subset_indices)
            subset_index = subset_indices[comm.rank]
            subset = dataset.read(subset_index)
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for batch_index, data_indices in enumerate(iterator):
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # (batch, views, height, width, channels) ->  (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images = images / 255.0
                images += np.random.uniform(
                    0, 1.0 / 256.0, size=images.shape).astype(np.float32)

                total_views = images.shape[1]

                # Sample observations
                num_views = random.choice(range(total_views + 1))
                if current_training_step == 0 and num_views == 0:
                    num_views = 1  # avoid OpenMPI error

                observation_view_indices = list(range(total_views))
                random.shuffle(observation_view_indices)
                observation_view_indices = observation_view_indices[:num_views]

                if num_views > 0:
                    representation = model.compute_observation_representation(
                        images[:, observation_view_indices],
                        viewpoints[:, observation_view_indices])
                else:
                    representation = xp.zeros(representation_shape,
                                              dtype="float32")
                    representation = chainer.Variable(representation)

                # Sample query
                query_index = random.choice(range(total_views))
                query_images = images[:, query_index]
                query_viewpoints = viewpoints[:, query_index]

                # Transfer to gpu
                query_images = to_gpu(query_images)
                query_viewpoints = to_gpu(query_viewpoints)

                z_t_param_array, mean_x = model.sample_z_and_x_params_from_posterior(
                    query_images, query_viewpoints, representation)

                # Compute loss
                ## KL Divergence
                loss_kld = chainer.Variable(xp.zeros((), dtype=xp.float32))
                for params in z_t_param_array:
                    mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params
                    kld = gqn.functions.gaussian_kl_divergence(
                        mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)
                    loss_kld += cf.sum(kld)

                ##Negative log-likelihood of generated image
                loss_nll = cf.sum(
                    gqn.functions.gaussian_negative_log_likelihood(
                        query_images, mean_x, pixel_var, pixel_ln_var))

                # Calculate the average loss value
                loss_nll = loss_nll / args.batch_size
                loss_kld = loss_kld / args.batch_size

                loss = (loss_nll / scheduler.pixel_variance) + loss_kld

                model.cleargrads()
                loss.backward()
                optimizer.update(current_training_step)

                loss_nll = float(loss_nll.data) + math.log(256.0)
                loss_kld = float(loss_kld.data)

                elbo = -(loss_nll + loss_kld)

                loss_mse = float(
                    cf.mean_squared_error(query_images, mean_x).data)

                if comm.rank == 0:
                    printr(
                        "Iteration {}: Subset {} / {}: Batch {} / {} - elbo: {:.2f} - loss: nll: {:.2f} mse: {:.5f} kld: {:.5f} - lr: {:.4e} - pixel_variance: {:.5f} - step: {}  "
                        .format(iteration + 1, subset_loop + 1,
                                subset_size_per_gpu, batch_index + 1,
                                len(iterator), elbo, loss_nll, loss_mse,
                                loss_kld, optimizer.learning_rate,
                                scheduler.pixel_variance,
                                current_training_step))

                total_num_batch += 1
                current_training_step += comm.size
                mean_kld += loss_kld
                mean_nll += loss_nll
                mean_mse += loss_mse
                mean_elbo += elbo

                scheduler.step(current_training_step)
                pixel_var[...] = scheduler.pixel_variance**2
                pixel_ln_var[...] = math.log(scheduler.pixel_variance**2)

            if comm.rank == 0:
                model.serialize(args.snapshot_directory)

        if comm.rank == 0:
            elapsed_time = time.time() - start_time
            mean_elbo /= total_num_batch
            mean_nll /= total_num_batch
            mean_mse /= total_num_batch
            mean_kld /= total_num_batch
            print(
                "\033[2KIteration {} - elbo: {:.2f} - loss: nll: {:.2f} mse: {:.5f} kld: {:.5f} - lr: {:.4e} - pixel_variance: {:.5f} - step: {} - time: {:.3f} min"
                .format(iteration + 1, mean_elbo, mean_nll, mean_mse, mean_kld,
                        optimizer.learning_rate, scheduler.pixel_variance,
                        current_training_step, elapsed_time / 60))
            model.serialize(args.snapshot_directory)
Example #7
0
def main():
    try:
        os.mkdir(args.snapshot_directory)
    except:
        pass

    comm = chainermn.create_communicator()
    device = comm.intra_rank
    cuda.get_device(device).use()
    xp = cp

    images = []
    files = os.listdir(args.dataset_path)
    files.sort()
    subset_size = int(math.ceil(len(files) / comm.size))
    files = deque(files)
    files.rotate(-subset_size * comm.rank)
    files = list(files)[:subset_size]
    for filename in files:
        image = np.load(os.path.join(args.dataset_path, filename))
        image = image / 256
        images.append(image)

    print(comm.rank, files)

    images = np.vstack(images)
    images = images.transpose((0, 3, 1, 2)).astype(np.float32)
    train_dev_split = 0.9
    num_images = images.shape[0]
    num_train_images = int(num_images * train_dev_split)
    num_dev_images = num_images - num_train_images
    images_train = images[:num_train_images]

    # To avoid OpenMPI bug
    # multiprocessing.set_start_method("forkserver")
    # p = multiprocessing.Process(target=print, args=("", ))
    # p.start()
    # p.join()

    hyperparams = HyperParameters()
    hyperparams.chz_channels = args.chz_channels
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_share_upsampler = args.generator_share_upsampler
    hyperparams.generator_downsampler_channels = args.generator_downsampler_channels
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.inference_downsampler_channels = args.inference_downsampler_channels
    hyperparams.batch_normalization_enabled = args.enable_batch_normalization
    hyperparams.use_gru = args.use_gru
    hyperparams.no_backprop_diff_xr = args.no_backprop_diff_xr

    if comm.rank == 0:
        hyperparams.save(args.snapshot_directory)
        hyperparams.print()

    if args.use_gru:
        model = GRUModel(hyperparams,
                         snapshot_directory=args.snapshot_directory)
    else:
        model = LSTMModel(hyperparams,
                          snapshot_directory=args.snapshot_directory)
    model.to_gpu()

    optimizer = AdamOptimizer(model.parameters,
                              lr_i=args.initial_lr,
                              lr_f=args.final_lr,
                              beta_1=args.adam_beta1,
                              communicator=comm)
    if comm.rank == 0:
        optimizer.print()

    num_pixels = images.shape[1] * images.shape[2] * images.shape[3]

    dataset = draw.data.Dataset(images_train)
    iterator = draw.data.Iterator(dataset, batch_size=args.batch_size)

    num_updates = 0

    for iteration in range(args.training_steps):
        mean_kld = 0
        mean_nll = 0
        mean_mse = 0
        start_time = time.time()

        for batch_index, data_indices in enumerate(iterator):
            x = dataset[data_indices]
            x += np.random.uniform(0, 1 / 256, size=x.shape)
            x = to_gpu(x)

            z_t_param_array, x_param, r_t_array = model.sample_z_and_x_params_from_posterior(
                x)

            loss_kld = 0
            for params in z_t_param_array:
                mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params
                kld = draw.nn.functions.gaussian_kl_divergence(
                    mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)
                loss_kld += cf.sum(kld)

            loss_sse = 0
            for r_t in r_t_array:
                loss_sse += cf.sum(cf.squared_error(r_t, x))

            mu_x, ln_var_x = x_param

            loss_nll = cf.gaussian_nll(x, mu_x, ln_var_x)

            loss_nll /= args.batch_size
            loss_kld /= args.batch_size
            loss_sse /= args.batch_size
            loss = args.loss_beta * loss_nll + loss_kld + args.loss_alpha * loss_sse

            model.cleargrads()
            loss.backward(loss_scale=optimizer.loss_scale())
            optimizer.update(num_updates, loss_value=float(loss.array))

            num_updates += 1
            mean_kld += float(loss_kld.data)
            mean_nll += float(loss_nll.data)
            mean_mse += float(loss_sse.data) / num_pixels / (
                hyperparams.generator_generation_steps - 1)

            printr(
                "Iteration {}: Batch {} / {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e}"
                .format(
                    iteration + 1, batch_index + 1, len(iterator),
                    float(loss_nll.data) / num_pixels + math.log(256.0),
                    float(loss_sse.data) / num_pixels /
                    (hyperparams.generator_generation_steps - 1),
                    float(loss_kld.data), optimizer.learning_rate))

            if comm.rank == 0 and batch_index > 0 and batch_index % 100 == 0:
                model.serialize(args.snapshot_directory)

        if comm.rank == 0:
            model.serialize(args.snapshot_directory)

        if comm.rank == 0:
            elapsed_time = time.time() - start_time
            print(
                "\r\033[2KIteration {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e} - elapsed_time: {:.3f} min"
                .format(
                    iteration + 1,
                    mean_nll / len(iterator) / num_pixels + math.log(256.0),
                    mean_mse / len(iterator), mean_kld / len(iterator),
                    optimizer.learning_rate, elapsed_time / 60))
Example #8
0
def main():
    try:
        os.mkdir(args.snapshot_path)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    hyperparams = HyperParameters()
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.pixel_n = args.pixel_n
    hyperparams.pixel_sigma_i = args.initial_pixel_sigma
    hyperparams.pixel_sigma_f = args.final_pixel_sigma
    hyperparams.save(args.snapshot_path)
    hyperparams.print()

    model = Model(hyperparams, snapshot_directory=args.snapshot_path)
    if using_gpu:
        model.to_gpu()

    optimizer = Optimizer(model.parameters,
                          mu_i=args.initial_lr,
                          mu_f=args.final_lr)
    optimizer.print()

    if args.with_visualization:
        figure = gqn.imgplot.figure()
        axis1 = gqn.imgplot.image()
        axis2 = gqn.imgplot.image()
        axis3 = gqn.imgplot.image()
        figure.add(axis1, 0, 0, 1 / 3, 1)
        figure.add(axis2, 1 / 3, 0, 1 / 3, 1)
        figure.add(axis3, 2 / 3, 0, 1 / 3, 1)
        plot = gqn.imgplot.window(
            figure, (500 * 3, 500),
            "Query image / Reconstructed image / Generated image")
        plot.show()

    sigma_t = hyperparams.pixel_sigma_i
    pixel_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                        sigma_t**2,
                        dtype="float32")
    pixel_ln_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                           math.log(sigma_t**2),
                           dtype="float32")

    dataset_mean, dataset_std = dataset.load_mean_and_std()

    np.save(os.path.join(args.snapshot_path, "mean.npy"), dataset_mean)
    np.save(os.path.join(args.snapshot_path, "std.npy"), dataset_std)

    # avoid division by zero
    dataset_std += 1e-12

    current_training_step = 0
    for iteration in range(args.training_iterations):
        mean_kld = 0
        mean_nll = 0
        total_batch = 0

        for subset_index, subset in enumerate(dataset):
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for batch_index, data_indices in enumerate(iterator):
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # preprocessing
                images = (images - dataset_mean) / dataset_std

                # (batch, views, height, width, channels) ->  (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3))

                total_views = images.shape[1]

                # sample number of views
                num_views = random.choice(range(total_views))
                query_index = random.choice(range(total_views))

                if num_views > 0:
                    r = model.compute_observation_representation(
                        images[:, :num_views], viewpoints[:, :num_views])
                else:
                    r = xp.zeros((args.batch_size, hyperparams.channels_r) +
                                 hyperparams.chrz_size,
                                 dtype="float32")
                    r = chainer.Variable(r)

                query_images = images[:, query_index]
                query_viewpoints = viewpoints[:, query_index]

                # transfer to gpu
                query_images = to_gpu(query_images)
                query_viewpoints = to_gpu(query_viewpoints)

                h0_gen, c0_gen, u_0, h0_enc, c0_enc = model.generate_initial_state(
                    args.batch_size, xp)

                loss_kld = 0

                hl_enc = h0_enc
                cl_enc = c0_enc
                hl_gen = h0_gen
                cl_gen = c0_gen
                ul_enc = u_0

                xq = model.inference_downsampler.downsample(query_images)

                for l in range(model.generation_steps):
                    inference_core = model.get_inference_core(l)
                    inference_posterior = model.get_inference_posterior(l)
                    generation_core = model.get_generation_core(l)
                    generation_piror = model.get_generation_prior(l)

                    h_next_enc, c_next_enc = inference_core.forward_onestep(
                        hl_gen, hl_enc, cl_enc, xq, query_viewpoints, r)

                    mean_z_q = inference_posterior.compute_mean_z(hl_enc)
                    ln_var_z_q = inference_posterior.compute_ln_var_z(hl_enc)
                    ze_l = cf.gaussian(mean_z_q, ln_var_z_q)

                    mean_z_p = generation_piror.compute_mean_z(hl_gen)
                    ln_var_z_p = generation_piror.compute_ln_var_z(hl_gen)

                    h_next_gen, c_next_gen, u_next_enc = generation_core.forward_onestep(
                        hl_gen, cl_gen, ul_enc, ze_l, query_viewpoints, r)

                    kld = gqn.nn.chainer.functions.gaussian_kl_divergence(
                        mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)

                    loss_kld += cf.sum(kld)

                    hl_gen = h_next_gen
                    cl_gen = c_next_gen
                    ul_enc = u_next_enc
                    hl_enc = h_next_enc
                    cl_enc = c_next_enc

                mean_x = model.generation_observation.compute_mean_x(ul_enc)
                negative_log_likelihood = gqn.nn.chainer.functions.gaussian_negative_log_likelihood(
                    query_images, mean_x, pixel_var, pixel_ln_var)
                loss_nll = cf.sum(negative_log_likelihood)

                loss_nll /= args.batch_size
                loss_kld /= args.batch_size
                loss = loss_nll + loss_kld

                model.cleargrads()
                loss.backward()
                optimizer.update(current_training_step)

                if args.with_visualization and plot.closed() is False:
                    axis1.update(
                        make_uint8(query_images[0], dataset_mean, dataset_std))
                    axis2.update(
                        make_uint8(mean_x.data[0], dataset_mean, dataset_std))

                    with chainer.no_backprop_mode():
                        generated_x = model.generate_image(
                            query_viewpoints[None, 0], r[None, 0], xp)
                        axis3.update(
                            make_uint8(generated_x[0], dataset_mean,
                                       dataset_std))

                printr(
                    "Iteration {}: Subset {} / {}: Batch {} / {} - loss: nll: {:.3f} kld: {:.3f} - lr: {:.4e} - sigma_t: {:.6f}"
                    .format(iteration + 1,
                            subset_index + 1, len(dataset), batch_index + 1,
                            len(iterator), float(loss_nll.data),
                            float(loss_kld.data), optimizer.learning_rate,
                            sigma_t))

                sf = hyperparams.pixel_sigma_f
                si = hyperparams.pixel_sigma_i
                sigma_t = max(
                    sf + (si - sf) *
                    (1.0 - current_training_step / hyperparams.pixel_n), sf)

                pixel_var[...] = sigma_t**2
                pixel_ln_var[...] = math.log(sigma_t**2)

                total_batch += 1
                current_training_step += 1
                mean_kld += float(loss_kld.data)
                mean_nll += float(loss_nll.data)

            model.serialize(args.snapshot_path)

        print(
            "\033[2KIteration {} - loss: nll: {:.3f} kld: {:.3f} - lr: {:.4e} - sigma_t: {:.6f} - step: {}"
            .format(iteration + 1, mean_nll / total_batch,
                    mean_kld / total_batch, optimizer.learning_rate, sigma_t,
                    current_training_step))
Example #9
0
def main():
    _mkdir(args.snapshot_directory)
    _mkdir(args.log_directory)

    meter_train = Meter()
    meter_train.load(args.snapshot_directory)

    #==============================================================================
    # Selecting the GPU
    #==============================================================================
    xp = np
    gpu_device = args.gpu_device
    using_gpu = gpu_device >= 0
    if using_gpu:
        cuda.get_device(gpu_device).use()
        xp = cp

    #==============================================================================
    # Dataset
    #==============================================================================
    dataset_train = Dataset(args.train_dataset_directory)
    dataset_test = None
    if args.test_dataset_directory is not None:
        dataset_test = Dataset(args.test_dataset_directory)

    #==============================================================================
    # Hyperparameters
    #==============================================================================
    hyperparams = HyperParameters()
    hyperparams.num_layers = args.generation_steps
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.h_channels = args.h_channels
    hyperparams.z_channels = args.z_channels
    hyperparams.u_channels = args.u_channels
    hyperparams.r_channels = args.r_channels
    hyperparams.image_size = (args.image_size, args.image_size)
    hyperparams.representation_architecture = args.representation_architecture
    hyperparams.pixel_sigma_annealing_steps = args.pixel_sigma_annealing_steps
    hyperparams.initial_pixel_sigma = args.initial_pixel_sigma
    hyperparams.final_pixel_sigma = args.final_pixel_sigma

    hyperparams.save(args.snapshot_directory)
    print(hyperparams, "\n")

    #==============================================================================
    # Model
    #==============================================================================
    model = Model(hyperparams)
    model.load(args.snapshot_directory, meter_train.epoch)
    if using_gpu:
        model.to_gpu()

    #==============================================================================
    # Pixel-variance annealing
    #==============================================================================
    variance_scheduler = PixelVarianceScheduler(
        sigma_start=args.initial_pixel_sigma,
        sigma_end=args.final_pixel_sigma,
        final_num_updates=args.pixel_sigma_annealing_steps)
    variance_scheduler.load(args.snapshot_directory)
    print(variance_scheduler, "\n")

    pixel_log_sigma = xp.full(
        (args.batch_size, 3) + hyperparams.image_size,
        math.log(variance_scheduler.standard_deviation),
        dtype="float32")

    #==============================================================================
    # Logging
    #==============================================================================
    csv = DataFrame()
    csv.load(args.log_directory)

    #==============================================================================
    # Optimizer
    #==============================================================================
    optimizer = AdamOptimizer(
        model.parameters,
        initial_lr=args.initial_lr,
        final_lr=args.final_lr,
        initial_training_step=variance_scheduler.training_step)
    print(optimizer, "\n")

    #==============================================================================
    # Visualization
    #==============================================================================
    fig = plt.figure(figsize=(9, 6))
    axes_train = [
        fig.add_subplot(2, 3, 1),
        fig.add_subplot(2, 3, 2),
        fig.add_subplot(2, 3, 3),
    ]
    axes_train[0].set_title("Training Data")
    axes_train[0].axis("off")
    axes_train[1].set_title("Reconstruction")
    axes_train[1].axis("off")
    axes_train[2].set_title("Generation")
    axes_train[2].axis("off")
    axes_test = [
        fig.add_subplot(2, 3, 4),
        fig.add_subplot(2, 3, 5),
        fig.add_subplot(2, 3, 6),
    ]
    axes_test[0].set_title("Validation Data")
    axes_test[0].axis("off")
    axes_test[1].set_title("Reconstruction")
    axes_test[1].axis("off")
    axes_test[2].set_title("Generation")
    axes_test[2].axis("off")

    #==============================================================================
    # Algorithms
    #==============================================================================
    def encode_scene(images, viewpoints):
        # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
        images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)

        # Sample number of views
        total_views = images.shape[1]
        num_views = random.choice(range(1, total_views + 1))

        # Sample views
        observation_view_indices = list(range(total_views))
        random.shuffle(observation_view_indices)
        observation_view_indices = observation_view_indices[:num_views]

        observation_images = preprocess_images(
            images[:, observation_view_indices])

        observation_query = viewpoints[:, observation_view_indices]
        representation = model.compute_observation_representation(
            observation_images, observation_query)

        # Sample query view
        query_index = random.choice(range(total_views))
        query_images = preprocess_images(images[:, query_index])
        query_viewpoints = viewpoints[:, query_index]

        # Transfer to gpu if necessary
        query_images = to_device(query_images, gpu_device)
        query_viewpoints = to_device(query_viewpoints, gpu_device)

        return representation, query_images, query_viewpoints

    def estimate_ELBO(query_images, z_t_param_array, pixel_mean,
                      pixel_log_sigma):
        # KL Diverge, pixel_ln_varnce
        kl_divergence = 0
        for params_t in z_t_param_array:
            mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params_t
            normal_q = chainer.distributions.Normal(
                mean_z_q, log_scale=ln_var_z_q)
            normal_p = chainer.distributions.Normal(
                mean_z_p, log_scale=ln_var_z_p)
            kld_t = chainer.kl_divergence(normal_q, normal_p)
            kl_divergence += cf.sum(kld_t)
        kl_divergence = kl_divergence / args.batch_size

        # Negative log-likelihood of generated image
        batch_size = query_images.shape[0]
        num_pixels_per_batch = np.prod(query_images.shape[1:])
        normal = chainer.distributions.Normal(
            query_images, log_scale=pixel_log_sigma)

        log_px = cf.sum(normal.log_prob(pixel_mean)) / batch_size
        negative_log_likelihood = -log_px

        # Empirical ELBO
        ELBO = log_px - kl_divergence

        # https://arxiv.org/abs/1604.08772 Section.2
        # https://www.reddit.com/r/MachineLearning/comments/56m5o2/discussion_calculation_of_bitsdims/
        bits_per_pixel = -(ELBO / num_pixels_per_batch - np.log(256)) / np.log(
            2)

        return ELBO, bits_per_pixel, negative_log_likelihood, kl_divergence

    #==============================================================================
    # Training iterations
    #==============================================================================
    dataset_size = len(dataset_train)
    np.random.seed(0)
    cp.random.seed(0)
    start_training = True

    for epoch in range(meter_train.epoch, args.epochs):
        print("Epoch {}/{}:".format(
            epoch + 1,
            args.epochs,
        ))
        meter_train.next_epoch()

        for subset_index, subset in enumerate(dataset_train):
            iterator = Iterator(subset, batch_size=args.batch_size)

            for batch_index, data_indices in enumerate(iterator):
                #------------------------------------------------------------------------------
                # Scene encoder
                #------------------------------------------------------------------------------
                # images.shape: (batch, views, height, width, channels)
                images, viewpoints = subset[data_indices]
                representation, query_images, query_viewpoints = encode_scene(
                    images, viewpoints)

                #------------------------------------------------------------------------------
                # Compute empirical ELBO
                #------------------------------------------------------------------------------
                # Compute distribution parameterws
                (z_t_param_array
                 ) = model.sample_z_and_x_params_from_posterior(
                     query_images, query_viewpoints, representation)

                # # Compute ELBO
                # (ELBO, bits_per_pixel, negative_log_likelihood,
                #  kl_divergence) = estimate_ELBO(query_images, z_t_param_array,
                #                                 pixel_mean, pixel_log_sigma)

                #------------------------------------------------------------------------------
                # Update parameters
                #------------------------------------------------------------------------------
                loss = -ELBO
                model.cleargrads()
                loss.backward()
                # if start_training: 
                #     g = chainer.computational_graph.build_computational_graph(pixel_mean)
                #     with open(os.path.join(args.snapshot_directory,'cg.dot'), 'w') as o:
                #         o.write(g.dump())
                #     start_training = False
                # exit()
                optimizer.update(meter_train.num_updates)

                #------------------------------------------------------------------------------
                # Logging
                #------------------------------------------------------------------------------
                with chainer.no_backprop_mode():
                    mean_squared_error = cf.mean_squared_error(
                        query_images, pixel_mean)
                meter_train.update(
                    ELBO=float(ELBO.data),
                    bits_per_pixel=float(bits_per_pixel.data),
                    negative_log_likelihood=float(
                        negative_log_likelihood.data),
                    kl_divergence=float(kl_divergence.data),
                    mean_squared_error=float(mean_squared_error.data))

                #------------------------------------------------------------------------------
                # Annealing
                #------------------------------------------------------------------------------
                variance_scheduler.update(meter_train.num_updates)
                pixel_log_sigma[...] = math.log(
                    variance_scheduler.standard_deviation)

            if subset_index % 100 == 0:
                print("    Subset {}/{}:".format(
                    subset_index + 1,
                    dataset_size,
                ))
                print("        {}".format(meter_train))
                print("        lr: {} - sigma: {}".format(
                    optimizer.learning_rate,
                    variance_scheduler.standard_deviation))

        #------------------------------------------------------------------------------
        # Visualization
        #------------------------------------------------------------------------------
        if args.visualize:
            axes_train[0].imshow(
                make_uint8(query_images[0]), interpolation="none")
            axes_train[1].imshow(
                make_uint8(pixel_mean.data[0]), interpolation="none")

            with chainer.no_backprop_mode():
                generated_x = model.generate_image(query_viewpoints[None, 0],
                                                   representation[None, 0])
                axes_train[2].imshow(
                    make_uint8(generated_x[0]), interpolation="none")

        #------------------------------------------------------------------------------
        # Validation
        #------------------------------------------------------------------------------
        meter_test = None
        if dataset_test is not None:
            meter_test = Meter()
            batch_size_test = args.batch_size * 6
            pixel_log_sigma_test = xp.full(
                (batch_size_test, 3) + hyperparams.image_size,
                math.log(variance_scheduler.standard_deviation),
                dtype="float32")

            with chainer.no_backprop_mode():
                for subset in dataset_test:
                    iterator = Iterator(subset, batch_size=batch_size_test)
                    for data_indices in iterator:
                        images, viewpoints = subset[data_indices]

                        # Scene encoder
                        representation, query_images, query_viewpoints = encode_scene(
                            images, viewpoints)

                        # Compute empirical ELBO
                        (z_t_param_array, pixel_mean
                         ) = model.sample_z_and_x_params_from_posterior(
                             query_images, query_viewpoints, representation)
                        (ELBO, bits_per_pixel, negative_log_likelihood,
                         kl_divergence) = estimate_ELBO(
                             query_images, z_t_param_array, pixel_mean,
                             pixel_log_sigma_test)
                        mean_squared_error = cf.mean_squared_error(
                            query_images, pixel_mean)

                        # Logging
                        meter_test.update(
                            ELBO=float(ELBO.data),
                            bits_per_pixel=float(bits_per_pixel.data),
                            negative_log_likelihood=float(
                                negative_log_likelihood.data),
                            kl_divergence=float(kl_divergence.data),
                            mean_squared_error=float(mean_squared_error.data))

            print("    Test:")
            print("        {} - done in {:.3f} min".format(
                meter_test,
                meter_test.elapsed_time,
            ))

            if args.visualize:
                axes_test[0].imshow(
                    make_uint8(query_images[0]), interpolation="none")
                axes_test[1].imshow(
                    make_uint8(pixel_mean.data[0]), interpolation="none")

                with chainer.no_backprop_mode():
                    generated_x = model.generate_image(
                        query_viewpoints[None, 0], representation[None, 0])
                    axes_test[2].imshow(
                        make_uint8(generated_x[0]), interpolation="none")

        if args.visualize:
            plt.pause(1e-10)

        csv.append(epoch, meter_train, meter_test)

        #------------------------------------------------------------------------------
        # Snapshot
        #------------------------------------------------------------------------------
        model.save(args.snapshot_directory, epoch)
        variance_scheduler.save(args.snapshot_directory)
        meter_train.save(args.snapshot_directory)
        csv.save(args.log_directory)

        print("Epoch {} done in {:.3f} min".format(
            epoch + 1,
            meter_train.epoch_elapsed_time,
        ))
        print("    {}".format(meter_train))
        print("    lr: {} - sigma: {} - training_steps: {}".format(
            optimizer.learning_rate,
            variance_scheduler.standard_deviation,
            meter_train.num_updates,
        ))
        print("    Time elapsed: {:.3f} min".format(meter_train.elapsed_time))