Ejemplo n.º 1
0
def train(n_epochs, learning_rate_G, learning_rate_D, batch_size, mid_flag,
          check_num, discriminative):
    beta_G = cfg.TRAIN.ADAM_BETA_G
    beta_D = cfg.TRAIN.ADAM_BETA_D
    n_vox = cfg.CONST.N_VOX
    dim = cfg.NET.DIM
    vox_shape = [n_vox[0], n_vox[1], n_vox[2], dim[-1]]
    com_shape = [n_vox[0], n_vox[1], n_vox[2], 2]
    dim_z = cfg.NET.DIM_Z
    start_vox_size = cfg.NET.START_VOX
    kernel = cfg.NET.KERNEL
    stride = cfg.NET.STRIDE
    dilations = cfg.NET.DILATIONS
    freq = cfg.CHECK_FREQ
    record_vox_num = cfg.RECORD_VOX_NUM

    depvox_gan_model = depvox_gan(batch_size=batch_size,
                                  vox_shape=vox_shape,
                                  com_shape=com_shape,
                                  dim_z=dim_z,
                                  dim=dim,
                                  start_vox_size=start_vox_size,
                                  kernel=kernel,
                                  stride=stride,
                                  dilations=dilations,
                                  discriminative=discriminative,
                                  is_train=True)

    Z_tf, z_part_enc_tf, surf_tf, full_tf, full_gen_tf, surf_dec_tf, full_dec_tf,\
    gen_loss_tf, discrim_loss_tf, recons_ssc_loss_tf, recons_com_loss_tf, recons_sem_loss_tf, encode_loss_tf, refine_loss_tf, summary_tf,\
    part_tf, part_dec_tf, complete_gt_tf, complete_gen_tf, complete_dec_tf, sscnet_tf, scores_tf = depvox_gan_model.build_model()
    global_step = tf.Variable(0, name='global_step', trainable=False)
    config_gpu = tf.ConfigProto()
    config_gpu.gpu_options.allow_growth = True
    sess = tf.Session(config=config_gpu)
    saver = tf.train.Saver(max_to_keep=cfg.SAVER_MAX)

    data_paths = scene_model_id_pair(dataset_portion=cfg.TRAIN.DATASET_PORTION)
    print('---amount of data:', str(len(data_paths)))
    data_process = DataProcess(data_paths, batch_size, repeat=True)

    enc_sscnet_vars = list(
        filter(lambda x: x.name.startswith('enc_ssc'),
               tf.trainable_variables()))
    enc_sdf_vars = list(
        filter(lambda x: x.name.startswith('enc_x'), tf.trainable_variables()))
    dis_sdf_vars = list(
        filter(lambda x: x.name.startswith('dis_x'), tf.trainable_variables()))
    dis_com_vars = list(
        filter(lambda x: x.name.startswith('dis_g'), tf.trainable_variables()))
    dis_sem_vars = list(
        filter(lambda x: x.name.startswith('dis_y'), tf.trainable_variables()))
    gen_com_vars = list(
        filter(lambda x: x.name.startswith('gen_x'), tf.trainable_variables()))
    gen_sem_vars = list(
        filter(lambda x: x.name.startswith('gen_y'), tf.trainable_variables()))
    gen_sdf_vars = list(
        filter(lambda x: x.name.startswith('gen_z'), tf.trainable_variables()))
    refine_vars = list(
        filter(lambda x: x.name.startswith('gen_y_ref'),
               tf.trainable_variables()))

    lr_VAE = tf.placeholder(tf.float32, shape=[])

    # main optimiser
    train_op_pred_sscnet = tf.train.AdamOptimizer(learning_rate_G,
                                                  beta1=beta_G,
                                                  beta2=0.9).minimize(
                                                      recons_ssc_loss_tf,
                                                      var_list=enc_sscnet_vars)
    train_op_pred_com = tf.train.AdamOptimizer(
        learning_rate_G, beta1=beta_G, beta2=0.9).minimize(
            recons_com_loss_tf,
            var_list=enc_sdf_vars + gen_com_vars + gen_sdf_vars)
    train_op_pred_sem = tf.train.AdamOptimizer(
        learning_rate_G, beta1=beta_G, beta2=0.9).minimize(
            recons_sem_loss_tf,
            var_list=enc_sdf_vars + gen_sem_vars + gen_sdf_vars)

    # refine optimiser
    train_op_refine = tf.train.AdamOptimizer(learning_rate_G,
                                             beta1=beta_G,
                                             beta2=0.9).minimize(
                                                 refine_loss_tf,
                                                 var_list=refine_vars)

    if discriminative is True:
        train_op_gen_sdf = tf.train.AdamOptimizer(learning_rate_G,
                                                  beta1=beta_G,
                                                  beta2=0.9).minimize(
                                                      gen_loss_tf,
                                                      var_list=gen_sdf_vars)
        train_op_gen_com = tf.train.AdamOptimizer(learning_rate_G,
                                                  beta1=beta_G,
                                                  beta2=0.9).minimize(
                                                      gen_loss_tf,
                                                      var_list=gen_com_vars)
        train_op_gen_sem = tf.train.AdamOptimizer(
            learning_rate_G, beta1=beta_G,
            beta2=0.9).minimize(gen_loss_tf,
                                var_list=gen_sem_vars + gen_com_vars)
        train_op_dis_sdf = tf.train.AdamOptimizer(learning_rate_D,
                                                  beta1=beta_D,
                                                  beta2=0.9).minimize(
                                                      discrim_loss_tf,
                                                      var_list=dis_sdf_vars)
        train_op_dis_com = tf.train.AdamOptimizer(learning_rate_D,
                                                  beta1=beta_D,
                                                  beta2=0.9).minimize(
                                                      discrim_loss_tf,
                                                      var_list=dis_com_vars)
        train_op_dis_sem = tf.train.AdamOptimizer(learning_rate_D,
                                                  beta1=beta_D,
                                                  beta2=0.9).minimize(
                                                      discrim_loss_tf,
                                                      var_list=dis_sem_vars,
                                                      global_step=global_step)

        Z_tf_sample, comp_tf_sample, full_tf_sample, full_ref_tf_sample, part_tf_sample, scores_tf_sample = depvox_gan_model.samples_generator(
            visual_size=batch_size)

        model_path = cfg.DIR.CHECK_POINT_PATH + '-d'
    else:
        model_path = cfg.DIR.CHECK_POINT_PATH

    writer = tf.summary.FileWriter(cfg.DIR.LOG_PATH, sess.graph_def)
    tf.initialize_all_variables().run(session=sess)

    if mid_flag:
        chckpt_path = model_path + '/checkpoint' + str(check_num)
        saver.restore(sess, chckpt_path)
        Z_var_np_sample = np.load(cfg.DIR.TRAIN_OBJ_PATH +
                                  '/sample_z.npy').astype(np.float32)
        Z_var_np_sample = Z_var_np_sample[:batch_size]
        print('---weights restored')
    else:
        Z_var_np_sample = np.random.normal(size=(batch_size, start_vox_size[0],
                                                 start_vox_size[1],
                                                 start_vox_size[2],
                                                 dim_z)).astype(np.float32)
        np.save(cfg.DIR.TRAIN_OBJ_PATH + '/sample_z.npy', Z_var_np_sample)

    ite = check_num * freq + 1
    cur_epochs = int(ite / int(len(data_paths) / batch_size))

    #training
    for epoch in np.arange(cur_epochs, n_epochs):
        epoch_flag = True
        while epoch_flag:
            print(colored('---Iteration:%d, epoch:%d', 'blue') % (ite, epoch))
            db_inds, epoch_flag = data_process.get_next_minibatch()
            batch_tsdf = data_process.get_tsdf(db_inds)
            batch_surf = data_process.get_surf(db_inds)
            batch_voxel = data_process.get_voxel(db_inds)

            # Evaluation masks
            # NOTICE that the target should never have negative values,
            # otherwise the one-hot coding never works for that region
            if cfg.TYPE_TASK == 'scene':
                """
                space_effective = np.where(batch_voxel > -1, 1, 0) * np.where(batch_tsdf > -1, 1, 0)
                batch_voxel *= space_effective
                batch_tsdf *= space_effective
                # occluded region
                """
                batch_tsdf[batch_tsdf < -1] = 0
                batch_surf[batch_surf < 0] = 0
                batch_voxel[batch_voxel < 0] = 0

            lr = learning_rate(cfg.LEARNING_RATE_V, ite)

            batch_z_var = np.random.normal(size=(batch_size, start_vox_size[0],
                                                 start_vox_size[1],
                                                 start_vox_size[2],
                                                 dim_z)).astype(np.float32)

            # updating for the main network
            is_supervised = True
            if is_supervised is True:
                _ = sess.run(
                    train_op_pred_sscnet,
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                        lr_VAE: lr
                    },
                )
                _, _, _ = sess.run(
                    [train_op_pred_com, train_op_pred_sem, train_op_refine],
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                        lr_VAE: lr
                    },
                )
            gen_com_loss_val, gen_sem_loss_val, z_part_enc_val = sess.run(
                [recons_com_loss_tf, recons_sem_loss_tf, z_part_enc_tf],
                feed_dict={
                    Z_tf: batch_z_var,
                    part_tf: batch_tsdf,
                    surf_tf: batch_surf,
                    full_tf: batch_voxel,
                    lr_VAE: lr
                },
            )

            if discriminative is True:
                discrim_loss_val, gen_loss_val, scores_discrim = sess.run(
                    [discrim_loss_tf, gen_loss_tf, scores_tf],
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                    },
                )
                if scores_discrim[0] - scores_discrim[1] > 0.3:
                    _ = sess.run(
                        train_op_gen_sdf,
                        feed_dict={
                            Z_tf: batch_z_var,
                            part_tf: batch_tsdf,
                            surf_tf: batch_surf,
                            full_tf: batch_voxel,
                            lr_VAE: lr
                        },
                    )
                if scores_discrim[2] - scores_discrim[3] > 0.3:
                    _ = sess.run(
                        train_op_gen_com,
                        feed_dict={
                            Z_tf: batch_z_var,
                            part_tf: batch_tsdf,
                            surf_tf: batch_surf,
                            full_tf: batch_voxel,
                            lr_VAE: lr
                        },
                    )
                if scores_discrim[4] - scores_discrim[5] > 0.3:
                    _ = sess.run(
                        train_op_gen_sem,
                        feed_dict={
                            Z_tf: batch_z_var,
                            part_tf: batch_tsdf,
                            surf_tf: batch_surf,
                            full_tf: batch_voxel,
                            lr_VAE: lr
                        },
                    )
                _ = sess.run(
                    train_op_dis_sdf,
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                    },
                )
                _ = sess.run(
                    train_op_dis_com,
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                    },
                )
                _ = sess.run(
                    train_op_dis_sem,
                    feed_dict={
                        Z_tf: batch_z_var,
                        part_tf: batch_tsdf,
                        surf_tf: batch_surf,
                        full_tf: batch_voxel,
                    },
                )

            print('GAN')
            np.set_printoptions(precision=2)
            print('reconstruct-com loss:', gen_com_loss_val)

            print('reconstruct-sem loss:', gen_sem_loss_val)

            if discriminative is True:
                print(
                    '            gen loss:', "%.2f" % gen_loss_val if
                    ('gen_loss_val' in locals()) else 'None')

                print(
                    '      output discrim:', "%.2f" % discrim_loss_val if
                    ('discrim_loss_val' in locals()) else 'None')

                print(
                    '      scores discrim:',
                    colored("%.2f" % scores_discrim[0], 'green'),
                    colored("%.2f" % scores_discrim[1], 'magenta'),
                    colored("%.2f" % scores_discrim[2], 'green'),
                    colored("%.2f" % scores_discrim[3], 'magenta'),
                    colored("%.2f" % scores_discrim[4], 'green'),
                    colored("%.2f" % scores_discrim[5], 'magenta') if
                    ('scores_discrim' in locals()) else 'None')

            print(
                '     avarage of code:',
                np.mean(np.mean(z_part_enc_val, 4)) if
                ('z_part_enc_val' in locals()) else 'None')

            print(
                '         std of code:',
                np.mean(np.std(z_part_enc_val, 4)) if
                ('z_part_enc_val' in locals()) else 'None')

            if np.mod(ite, freq) == 0:
                if discriminative is True:
                    full_models = sess.run(
                        full_tf_sample,
                        feed_dict={Z_tf_sample: Z_var_np_sample},
                    )
                    full_models_cat = np.argmax(full_models, axis=4)
                    record_vox = full_models_cat[:record_vox_num]
                    np.save(
                        cfg.DIR.TRAIN_OBJ_PATH + '/' + str(ite // freq) +
                        '.npy', record_vox)
                save_path = saver.save(sess,
                                       model_path + '/checkpoint' +
                                       str(ite // freq),
                                       global_step=None)

            ite += 1
Ejemplo n.º 2
0
def evaluate(batch_size, checknum, mode, discriminative):

    n_vox = cfg.CONST.N_VOX
    dim = cfg.NET.DIM
    vox_shape = [n_vox[0], n_vox[1], n_vox[2], dim[4]]
    complete_shape = [n_vox[0], n_vox[1], n_vox[2], 2]
    dim_z = cfg.NET.DIM_Z
    start_vox_size = cfg.NET.START_VOX
    kernel = cfg.NET.KERNEL
    stride = cfg.NET.STRIDE
    dilations = cfg.NET.DILATIONS
    freq = cfg.CHECK_FREQ

    save_path = cfg.DIR.EVAL_PATH
    if discriminative is True:
        model_path = cfg.DIR.CHECK_POINT_PATH + '-d'
    else:
        model_path = cfg.DIR.CHECK_POINT_PATH
    chckpt_path = model_path + '/checkpoint' + str(checknum)

    depvox_gan_model = depvox_gan(batch_size=batch_size,
                                  vox_shape=vox_shape,
                                  complete_shape=complete_shape,
                                  dim_z=dim_z,
                                  dim=dim,
                                  start_vox_size=start_vox_size,
                                  kernel=kernel,
                                  stride=stride,
                                  dilations=dilations,
                                  discriminative=discriminative,
                                  is_train=False)


    Z_tf, z_enc_tf, surf_tf, full_tf, full_gen_tf, surf_dec_tf, full_dec_tf,\
    gen_loss_tf, discrim_loss_tf, recons_ssc_loss_tf, recons_com_loss_tf, recons_sem_loss_tf, encode_loss_tf, refine_loss_tf, summary_tf,\
    part_tf, part_dec_tf, complete_gt_tf, complete_gen_tf, complete_dec_tf, ssc_tf, scores_tf = depvox_gan_model.build_model()
    if discriminative is True:
        Z_tf_sample, comp_tf_sample, surf_tf_sample, full_tf_sample, part_tf_sample, scores_tf_sample = depvox_gan_model.samples_generator(
            visual_size=batch_size)
    sess = tf.InteractiveSession()
    saver = tf.train.Saver()

    # Restore variables from disk.
    saver.restore(sess, chckpt_path)

    print("...Weights restored.")

    if mode == 'recons':
        # evaluation for reconstruction
        voxel_test, surf_test, part_test, num, data_paths = scene_model_id_pair_test(
            dataset_portion=cfg.TRAIN.DATASET_PORTION)

        # Evaluation masks
        if cfg.TYPE_TASK == 'scene':
            # occluded region
            """
            space_effective = np.where(voxel_test > -1, 1, 0) * np.where(
                part_test > -1, 1, 0)
            voxel_test *= space_effective
            part_test *= space_effective
            """
            # occluded region
            part_test[part_test < -1] = 0
            surf_test[surf_test < 0] = 0
            voxel_test[voxel_test < 0] = 0

        num = voxel_test.shape[0]
        print("test voxels loaded")
        for i in np.arange(int(num / batch_size)):
            batch_tsdf = part_test[i * batch_size:i * batch_size + batch_size]
            batch_surf = surf_test[i * batch_size:i * batch_size + batch_size]
            batch_voxel = voxel_test[i * batch_size:i * batch_size +
                                     batch_size]

            batch_pred_surf, batch_pred_full, batch_pred_part, batch_part_enc_Z, batch_complete_gt, batch_pred_complete, batch_ssc = sess.run(
                [
                    surf_dec_tf, full_dec_tf, part_dec_tf, z_enc_tf,
                    complete_gt_tf, complete_dec_tf, ssc_tf
                ],
                feed_dict={
                    part_tf: batch_tsdf,
                    surf_tf: batch_surf,
                    full_tf: batch_voxel
                })

            if i == 0:
                pred_part = batch_pred_part
                pred_surf = batch_pred_surf
                pred_full = batch_pred_full
                pred_ssc = batch_ssc
                part_enc_Z = batch_part_enc_Z
                complete_gt = batch_complete_gt
                pred_complete = batch_pred_complete
            else:
                pred_part = np.concatenate((pred_part, batch_pred_part),
                                           axis=0)
                pred_surf = np.concatenate((pred_surf, batch_pred_surf),
                                           axis=0)
                pred_full = np.concatenate((pred_full, batch_pred_full),
                                           axis=0)
                pred_ssc = np.concatenate((pred_ssc, batch_ssc), axis=0)
                part_enc_Z = np.concatenate((part_enc_Z, batch_part_enc_Z),
                                            axis=0)
                complete_gt = np.concatenate((complete_gt, batch_complete_gt),
                                             axis=0)
                pred_complete = np.concatenate(
                    (pred_complete, batch_pred_complete), axis=0)

        print("forwarded")

        # For visualization
        bin_file = np.uint8(voxel_test)
        bin_file.tofile(save_path + '/scene.bin')

        surface = np.array(part_test)
        if cfg.TYPE_TASK == 'scene':
            surface = np.abs(surface)
            surface *= 10
            pred_part = np.abs(pred_part)
            pred_part *= 10
        elif cfg.TYPE_TASK == 'object':
            surface = np.clip(surface, 0, 1)
            pred_part = np.clip(pred_part, 0, 1)
        surface.astype('uint8').tofile(save_path + '/surface.bin')
        pred_part.astype('uint8').tofile(save_path + '/dec_part.bin')

        depsem_gt = np.multiply(voxel_test, np.clip(surface, 0, 1))
        if cfg.TYPE_TASK == 'scene':
            depsem_gt[depsem_gt < 0] = 0
        depsem_gt.astype('uint8').tofile(save_path + '/depth_seg_scene.bin')

        # decoded
        np.argmax(pred_ssc,
                  axis=4).astype('uint8').tofile(save_path + '/dec_ssc.bin')
        error = np.array(
            np.clip(np.argmax(pred_ssc, axis=4), 0, 1) +
            np.argmax(complete_gt, axis=4) * 2)
        error.astype('uint8').tofile(save_path + '/dec_ssc_error.bin')
        np.argmax(pred_surf,
                  axis=4).astype('uint8').tofile(save_path + '/dec_surf.bin')
        error = np.array(
            np.clip(np.argmax(pred_surf, axis=4), 0, 1) +
            np.argmax(complete_gt, axis=4) * 2)
        error.astype('uint8').tofile(save_path + '/dec_surf_error.bin')
        np.argmax(pred_full,
                  axis=4).astype('uint8').tofile(save_path + '/dec_full.bin')
        error = np.array(
            np.clip(np.argmax(pred_full, axis=4), 0, 1) +
            np.argmax(complete_gt, axis=4) * 2)
        error.astype('uint8').tofile(save_path + '/dec_full_error.bin')
        np.argmax(pred_complete,
                  axis=4).astype('uint8').tofile(save_path +
                                                 '/dec_complete.bin')
        np.argmax(complete_gt,
                  axis=4).astype('uint8').tofile(save_path +
                                                 '/complete_gt.bin')

        # reconstruction and generation from normal distribution evaluation
        # generator from random distribution
        if discriminative is True:
            np.save(save_path + '/decode_z.npy', part_enc_Z)
            sample_times = 10
            for j in np.arange(sample_times):
                Z_var_np_sample = np.random.normal(
                    size=(batch_size, start_vox_size[0], start_vox_size[1],
                          start_vox_size[2], dim_z)).astype(np.float32)

                z_comp_rand, z_surf_rand, z_full_rand, z_part_rand, scores_sample = sess.run(
                    [
                        comp_tf_sample, surf_tf_sample, full_tf_sample,
                        part_tf_sample, scores_tf_sample
                    ],
                    feed_dict={Z_tf_sample: Z_var_np_sample})
                if j == 0:
                    z_comp_rand_all = z_comp_rand
                    z_part_rand_all = z_part_rand
                    z_surf_rand_all = z_surf_rand
                    z_full_rand_all = z_full_rand
                else:
                    z_comp_rand_all = np.concatenate(
                        [z_comp_rand_all, z_comp_rand], axis=0)
                    z_part_rand_all = np.concatenate(
                        [z_part_rand_all, z_part_rand], axis=0)
                    z_surf_rand_all = np.concatenate(
                        [z_surf_rand_all, z_surf_rand], axis=0)
                    z_full_rand_all = np.concatenate(
                        [z_full_rand_all, z_full_rand], axis=0)
                    print(scores_sample)
            Z_var_np_sample.astype('float32').tofile(save_path +
                                                     '/sample_z.bin')
            np.argmax(z_comp_rand_all,
                      axis=4).astype('uint8').tofile(save_path +
                                                     '/gen_comp.bin')
            np.argmax(z_surf_rand_all,
                      axis=4).astype('uint8').tofile(save_path +
                                                     '/gen_surf.bin')
            np.argmax(z_full_rand_all,
                      axis=4).astype('uint8').tofile(save_path +
                                                     '/gen_full.bin')
            if cfg.TYPE_TASK == 'scene':
                z_part_rand_all = np.abs(z_part_rand_all)
                z_part_rand_all *= 10
            elif cfg.TYPE_TASK == 'object':
                z_part_rand_all[z_part_rand_all <= 0.4] = 0
                z_part_rand_all[z_part_rand_all > 0.4] = 1
                z_part_rand = np.squeeze(z_part_rand)
            z_part_rand_all.astype('uint8').tofile(save_path + '/gen_part.bin')

            eigen_shape = False
            if eigen_shape:
                z_U, z_V = pca(np.reshape(part_enc_Z, [
                    200, start_vox_size[0] * start_vox_size[1] *
                    start_vox_size[2] * dim_z
                ]),
                               dim_remain=200)
                z_V = np.reshape(np.transpose(z_V[:, 0:8]), [
                    8, start_vox_size[0], start_vox_size[1], start_vox_size[2],
                    dim_z
                ])
                z_surf_rand, z_full_rand, z_part_rand = sess.run(
                    [surf_tf_sample, full_tf_sample, part_tf_sample],
                    feed_dict={Z_tf_sample: z_V})
                np.argmax(z_surf_rand,
                          axis=4).astype('uint8').tofile(save_path +
                                                         '/gen_surf.bin')
                if cfg.TYPE_TASK == 'scene':
                    z_part_rand = np.abs(z_part_rand)
                    z_part_rand *= 10
                elif cfg.TYPE_TASK == 'object':
                    z_part_rand[z_part_rand <= 0.4] = 0
                    z_part_rand[z_part_rand > 0.4] = 1
                    z_part_rand = np.squeeze(z_part_rand)
                z_part_rand.astype('uint8').tofile(save_path + '/gen_sdf.bin')

        print("voxels saved")

        # numerical evalutation

        # calc_IoU
        # completion
        on_complete_gt = complete_gt
        complete_gen = np.argmax(pred_complete, axis=4)
        on_complete_gen = onehot(complete_gen, 2)
        IoU_comp = np.zeros([2 + 1])
        AP_comp = np.zeros([2 + 1])
        print(colored("Completion", 'cyan'))
        IoU_comp = IoU(on_complete_gt, on_complete_gen, IoU_comp,
                       [vox_shape[0], vox_shape[1], vox_shape[2], 2])

        # depth segmentation
        on_depsem_gt = onehot(depsem_gt, vox_shape[3])
        on_depsem_ssc = np.multiply(
            onehot(np.argmax(pred_ssc, axis=4), vox_shape[3]),
            np.expand_dims(np.clip(surface, 0, 1), -1))
        on_depsem_dec = np.multiply(
            onehot(np.argmax(pred_full, axis=4), vox_shape[3]),
            np.expand_dims(np.clip(surface, 0, 1), -1))
        print(colored("Geometric segmentation", 'cyan'))
        IoU_class = np.zeros([vox_shape[3] + 1])
        IoU_class = IoU(on_depsem_gt, on_depsem_ssc, IoU_class, vox_shape)
        IoU_all = np.expand_dims(IoU_class, axis=1)

        print(colored("Generative segmentation", 'cyan'))
        IoU_class = np.zeros([vox_shape[3] + 1])
        IoU_class = IoU(on_depsem_gt, on_depsem_dec, IoU_class, vox_shape)
        IoU_all = np.expand_dims(IoU_class, axis=1)

        # volume segmentation
        on_surf_gt = onehot(surf_test, vox_shape[3])
        on_full_gt = onehot(voxel_test, vox_shape[3])
        print(colored("Geometric semantic completion", 'cyan'))
        on_pred = onehot(np.argmax(pred_ssc, axis=4), vox_shape[3])
        IoU_class = IoU(on_full_gt, on_pred, IoU_class, vox_shape)
        IoU_all = np.concatenate((IoU_all, np.expand_dims(IoU_class, axis=1)),
                                 axis=1)
        print(colored("Generative semantic completion", 'cyan'))
        on_pred = onehot(np.argmax(pred_surf, axis=4), vox_shape[3])
        IoU_class = IoU(on_full_gt, on_pred, IoU_class, vox_shape)
        IoU_all = np.concatenate((IoU_all, np.expand_dims(IoU_class, axis=1)),
                                 axis=1)
        print(colored("Solid generative semantic completion", 'cyan'))
        on_pred = onehot(np.argmax(pred_full, axis=4), vox_shape[3])
        IoU_class = IoU(on_full_gt, on_pred, IoU_class, vox_shape)
        IoU_all = np.concatenate((IoU_all, np.expand_dims(IoU_class, axis=1)),
                                 axis=1)

        np.savetxt(save_path + '/IoU.csv',
                   np.transpose(IoU_all[1:] * 100),
                   delimiter=" & ",
                   fmt='%2.1f')

    # interpolation evaluation
    if mode == 'interpolate':
        interpolate_num = 8
        #interpolatioin latent vectores
        decode_z = np.load(save_path + '/decode_z.npy')
        print(save_path)
        decode_z = decode_z[20:20 + batch_size]
        for l in np.arange(batch_size):
            for r in np.arange(batch_size):
                if l != r:
                    print l, r
                    base_num_left = l
                    base_num_right = r
                    left = np.reshape(decode_z[base_num_left], [
                        1, start_vox_size[0], start_vox_size[1],
                        start_vox_size[2], dim_z
                    ])
                    right = np.reshape(decode_z[base_num_right], [
                        1, start_vox_size[0], start_vox_size[1],
                        start_vox_size[2], dim_z
                    ])

                    duration = (right - left) / (interpolate_num - 1)
                    # left is the reference sample and Z_np_sample is the remaining samples
                    if base_num_left == 0:
                        Z_np_sample = decode_z[1:]
                    elif base_num_left == batch_size - 1:
                        Z_np_sample = decode_z[:batch_size - 1]
                    else:
                        Z_np_sample_before = np.reshape(
                            decode_z[:base_num_left], [
                                base_num_left, start_vox_size[0],
                                start_vox_size[1], start_vox_size[2], dim_z
                            ])
                        Z_np_sample_after = np.reshape(
                            decode_z[base_num_left + 1:], [
                                batch_size - base_num_left - 1,
                                start_vox_size[0], start_vox_size[1],
                                start_vox_size[2], dim_z
                            ])
                        Z_np_sample = np.concatenate(
                            [Z_np_sample_before, Z_np_sample_after], axis=0)
                    for i in np.arange(interpolate_num):
                        if i == 0:
                            Z = copy.copy(left)
                            interpolate_z = copy.copy(Z)
                        else:
                            Z = Z + duration
                            interpolate_z = np.concatenate([interpolate_z, Z],
                                                           axis=0)

                        # Z_np_sample is used to fill up the batch
                        Z_var_np_sample = np.concatenate([Z, Z_np_sample],
                                                         axis=0)
                        pred_full_rand, pred_part_rand = sess.run(
                            [full_tf_sample, part_tf_sample],
                            feed_dict={Z_tf_sample: Z_var_np_sample})
                        interpolate_vox = np.reshape(pred_full_rand[0], [
                            1, vox_shape[0], vox_shape[1], vox_shape[2],
                            vox_shape[3]
                        ])
                        interpolate_part = np.reshape(pred_part_rand[0], [
                            1, vox_shape[0], vox_shape[1], vox_shape[2],
                            complete_shape[3]
                        ])

                        if i == 0:
                            pred_full = interpolate_vox
                            pred_part = interpolate_part
                        else:
                            pred_full = np.concatenate(
                                [pred_full, interpolate_vox], axis=0)
                            pred_part = np.concatenate(
                                [pred_part, interpolate_part], axis=0)
                    interpolate_z.astype('uint8').tofile(
                        save_path + '/interpolate/interpolation_z' + str(l) +
                        '-' + str(r) + '.bin')

                    full_models_cat = np.argmax(pred_full, axis=4)
                    full_models_cat.astype('uint8').tofile(
                        save_path + '/interpolate/interpolation_f' + str(l) +
                        '-' + str(r) + '.bin')
                    if cfg.TYPE_TASK == 'scene':
                        pred_part = np.abs(pred_part)
                        pred_part[pred_part < 0.2] = 0
                        pred_part[pred_part >= 0.2] = 1
                    elif cfg.TYPE_TASK == 'object':
                        pred_part = np.argmax(pred_part, axis=4)
                    pred_part.astype('uint8').tofile(
                        save_path + '/interpolate/interpolation_p' + str(l) +
                        '-' + str(r) + '.bin')
        print("voxels saved")

    # add noise evaluation
    if mode == 'noise':
        decode_z = np.load(save_path + '/decode_z.npy')
        decode_z = decode_z[:batch_size]
        noise_num = 10
        for base_num in np.arange(batch_size):
            print base_num
            base = np.reshape(decode_z[base_num], [
                1, start_vox_size[0], start_vox_size[1], start_vox_size[2],
                dim_z
            ])
            eps = np.random.normal(size=(noise_num - 1,
                                         dim_z)).astype(np.float32)

            if base_num == 0:
                Z_np_sample = decode_z[1:]
            elif base_num == batch_size - 1:
                Z_np_sample = decode_z[:batch_size - 1]
            else:
                Z_np_sample_before = np.reshape(decode_z[:base_num], [
                    base_num, start_vox_size[0], start_vox_size[1],
                    start_vox_size[2], dim_z
                ])
                Z_np_sample_after = np.reshape(decode_z[base_num + 1:], [
                    batch_size - base_num - 1, start_vox_size[0],
                    start_vox_size[1], start_vox_size[2], dim_z
                ])
                Z_np_sample = np.concatenate(
                    [Z_np_sample_before, Z_np_sample_after], axis=0)

            for c in np.arange(start_vox_size[0]):
                for l in np.arange(start_vox_size[1]):
                    for d in np.arange(start_vox_size[2]):

                        for i in np.arange(noise_num):
                            if i == 0:
                                Z = copy.copy(base)
                                noise_z = copy.copy(Z)
                            else:
                                Z = copy.copy(base)
                                Z[0, c, l, d, :] += eps[i - 1]
                                noise_z = np.concatenate([noise_z, Z], axis=0)
                            Z_var_np_sample = np.concatenate([Z, Z_np_sample],
                                                             axis=0)
                            pred_full_rand = sess.run(
                                full_tf_sample,
                                feed_dict={Z_tf_sample: Z_var_np_sample})
                            """
                            refined_voxs_rand = sess.run(
                                sample_refine_full_tf,
                                feed_dict={
                                    sample_full_tf: pred_full_rand
                                })
                            """
                            noise_vox = np.reshape(pred_full_rand[0], [
                                1, vox_shape[0], vox_shape[1], vox_shape[2],
                                vox_shape[3]
                            ])
                            if i == 0:
                                pred_full = noise_vox
                            else:
                                pred_full = np.concatenate(
                                    [pred_full, noise_vox], axis=0)

                        np.save(
                            save_path + '/noise_z' + str(base_num) + '_' +
                            str(c) + str(l) + str(d) + '.npy', noise_z)

                        full_models_cat = np.argmax(pred_full, axis=4)
                        np.save(
                            save_path + '/noise' + str(base_num) + '_' +
                            str(c) + str(l) + str(d) + '.npy', full_models_cat)

        print("voxels saved")
Ejemplo n.º 3
0
def byproduct(batch_size, checknum):

    n_vox = cfg.CONST.N_VOX
    dim = cfg.NET.DIM
    vox_shape = [n_vox[0], n_vox[1], n_vox[2], dim[4]]
    tsdf_shape = [n_vox[0], n_vox[1], n_vox[2], 3]
    dim_z = cfg.NET.DIM_Z
    start_vox_size = cfg.NET.START_VOX
    kernel = cfg.NET.KERNEL
    stride = cfg.NET.STRIDE
    dilations = cfg.NET.DILATIONS
    freq = cfg.CHECK_FREQ
    discriminative = cfg.NET.DISCRIMINATIVE
    generative = cfg.NET.GENERATIVE

    save_path = cfg.DIR.EVAL_PATH
    chckpt_path = cfg.DIR.CHECK_PT_PATH + str(checknum)

    depvox_gan_model = depvox_gan(
        batch_size=batch_size,
        vox_shape=vox_shape,
        tsdf_shape=tsdf_shape,
        dim_z=dim_z,
        dim=dim,
        start_vox_size=start_vox_size,
        kernel=kernel,
        stride=stride,
        dilations=dilations,
        generative=generative)


    Z_tf, z_tsdf_enc_tf, z_vox_enc_tf, vox_tf, vox_gen_tf, vox_gen_decode_tf, vox_vae_decode_tf, vox_cc_decode_tf,\
    recon_vae_loss_tf, recon_cc_loss_tf, recon_gen_loss_tf, code_encode_loss_tf, gen_loss_tf, discrim_loss_tf,\
    cost_enc_tf, cost_code_tf, cost_gen_tf, cost_discrim_tf, summary_tf,\
    tsdf_tf, tsdf_gen_tf, tsdf_gen_decode_tf, tsdf_vae_decode_tf, tsdf_cc_decode_tf = depvox_gan_model.build_model()
    sess = tf.InteractiveSession()
    saver = tf.train.Saver()

    # Restore variables from disk.
    saver.restore(sess, chckpt_path)

    print("...Weights restored.")

    model_path = cfg.DIR.ROOT_PATH
    models = os.listdir(model_path)
    scene_name_pair = []  # full path of the objs files
    scene_name_pair.extend([(model_path, model_id) for model_id in models])
    num_models = len(scene_name_pair)

    batch_tsdf = np.zeros((batch_size, n_vox[0], n_vox[1], n_vox[2]),
                          dtype=np.float32)
    batch_voxel = np.zeros((batch_size, n_vox[0], n_vox[1], n_vox[2]),
                           dtype=np.float32)

    for i in np.arange(num_models):
        sceneId, model_id = scene_name_pair[i]

        voxel_fn = cfg.DIR.VOXEL_PATH % (model_id)
        voxel_data = np.load(voxel_fn)
        batch_voxel[0, :, :, :] = voxel_data

        tsdf_fn = cfg.DIR.TSDF_PATH % (model_id)
        tsdf_data = np.load(tsdf_fn)
        batch_tsdf[0, :, :, :] = tsdf_data

        # Evaluation masks
        if cfg.TYPE_TASK is 'scene':
            volume_effective = np.clip(
                np.where(batch_voxel > 0, 1, 0) + np.where(
                    batch_tsdf > 0, 1, 0), 0, 1)
            batch_voxel *= volume_effective
            batch_tsdf *= volume_effective

            # batch_tsdf[batch_tsdf > 1] = 0
            # batch_tsdf_test[np.where(batch_voxel_test == 10)] = 1

        batch_pred_voxs, batch_vae_voxs, batch_cc_voxs,\
        batch_pred_tsdf, batch_vae_tsdf, batch_cc_tsdf = sess.run(
            [
                vox_gen_decode_tf, vox_vae_decode_tf, vox_cc_decode_tf,
                tsdf_gen_decode_tf, tsdf_vae_decode_tf, tsdf_cc_decode_tf
            ],
            feed_dict={
                tsdf_tf: batch_tsdf,
                vox_tf: batch_voxel
            })

        batch_pred_tsdf = np.argmax(batch_pred_tsdf, axis=4).astype('float32')
        import ipdb
        ipdb.set_trace()
        # batch_pred_tsdf[batch_tsdf == -1] = -1
        np.save(
            '/media/wangyida/SSD2T/database/SUNCG_Yida/test/depth_tsdf_vae_npy/'
            + models[i], batch_pred_tsdf[0])
Ejemplo n.º 4
0
def evaluate(batch_size, checknum, mode, discriminative):

    n_vox = cfg.CONST.N_VOX
    dim = cfg.NET.DIM
    vox_shape = [n_vox[0], n_vox[1], n_vox[2], dim[4]]
    com_shape = [n_vox[0], n_vox[1], n_vox[2], 2]
    dim_z = cfg.NET.DIM_Z
    start_vox_size = cfg.NET.START_VOX
    kernel = cfg.NET.KERNEL
    stride = cfg.NET.STRIDE
    dilations = cfg.NET.DILATIONS
    freq = cfg.CHECK_FREQ

    save_path = cfg.DIR.EVAL_PATH
    if discriminative is True:
        model_path = cfg.DIR.CHECK_POINT_PATH + '-d'
    else:
        model_path = cfg.DIR.CHECK_POINT_PATH
    chckpt_path = model_path + '/checkpoint' + str(checknum)

    depvox_gan_model = depvox_gan(batch_size=batch_size,
                                  vox_shape=vox_shape,
                                  com_shape=com_shape,
                                  dim_z=dim_z,
                                  dim=dim,
                                  start_vox_size=start_vox_size,
                                  kernel=kernel,
                                  stride=stride,
                                  dilations=dilations,
                                  discriminative=discriminative,
                                  is_train=False)


    Z_tf, z_enc_tf, surf_tf, full_tf, full_gen_tf, surf_dec_tf, full_dec_tf,\
    gen_loss_tf, discrim_loss_tf, recons_ssc_loss_tf, recons_com_loss_tf, recons_sem_loss_tf, encode_loss_tf, refine_loss_tf, summary_tf,\
    part_tf, part_dec_tf, comp_gt_tf, comp_gen_tf, comp_dec_tf, ssc_tf, scores_tf = depvox_gan_model.build_model()
    if discriminative is True:
        Z_tf_samp, comp_tf_samp, surf_tf_samp, full_tf_samp, part_tf_samp, scores_tf_samp = depvox_gan_model.samples_generator(
            visual_size=batch_size)
    sess = tf.InteractiveSession()
    saver = tf.train.Saver()

    # Restore variables from disk.
    saver.restore(sess, chckpt_path)

    print("...Weights restored.")

    if mode == 'recons':
        # evaluation for reconstruction
        voxel_test, surf_test, part_test, num, data_paths = scene_model_id_pair_test(
            dataset_portion=cfg.TRAIN.DATASET_PORTION)

        # Evaluation masks
        if cfg.TYPE_TASK == 'scene':
            # occluded region
            """
            space_effective = np.where(voxel_test > -1, 1, 0) * np.where(
                part_test > -1, 1, 0)
            voxel_test *= space_effective
            part_test *= space_effective
            """
            # occluded region
            part_test[part_test < -1] = 0
            surf_test[surf_test < 0] = 0
            voxel_test[voxel_test < 0] = 0

        num = voxel_test.shape[0]
        print("test voxels loaded")
        from progressbar import ProgressBar
        pbar = ProgressBar()
        for i in pbar(np.arange(int(num / batch_size))):
            bth_tsdf = part_test[i * batch_size:i * batch_size + batch_size]
            bth_surf = surf_test[i * batch_size:i * batch_size + batch_size]
            bth_voxel = voxel_test[i * batch_size:i * batch_size + batch_size]

            bth_pd_surf, bth_pd_full, bth_pd_part, bth_part_enc_Z, bth_comp_gt, bth_pd_comp, bth_ssc = sess.run(
                [
                    surf_dec_tf, full_dec_tf, part_dec_tf, z_enc_tf,
                    comp_gt_tf, comp_dec_tf, ssc_tf
                ],
                feed_dict={
                    part_tf: bth_tsdf,
                    surf_tf: bth_surf,
                    full_tf: bth_voxel
                })

            if i == 0:
                pd_part = bth_pd_part
                pd_surf = bth_pd_surf
                pd_full = bth_pd_full
                pd_ssc = bth_ssc
                part_enc_Z = bth_part_enc_Z
                comp_gt = bth_comp_gt
                pd_comp = bth_pd_comp
            else:
                pd_part = np.concatenate((pd_part, bth_pd_part), axis=0)
                pd_surf = np.concatenate((pd_surf, bth_pd_surf), axis=0)
                pd_full = np.concatenate((pd_full, bth_pd_full), axis=0)
                pd_ssc = np.concatenate((pd_ssc, bth_ssc), axis=0)
                part_enc_Z = np.concatenate((part_enc_Z, bth_part_enc_Z),
                                            axis=0)
                comp_gt = np.concatenate((comp_gt, bth_comp_gt), axis=0)
                pd_comp = np.concatenate((pd_comp, bth_pd_comp), axis=0)

        print("forwarded")

        # For visualization
        bin_file = np.uint8(voxel_test)
        bin_file.tofile(save_path + '/scene.bin')

        sdf_volume = np.round(10 * np.abs(np.array(part_test)))
        observed = np.array(part_test)
        if cfg.TYPE_TASK == 'scene':
            observed = np.abs(observed)
            observed *= 10
            observed -= 7
            observed = np.round(observed)
            pd_part = np.abs(pd_part)
            pd_part *= 10
            pd_part -= 7
        elif cfg.TYPE_TASK == 'object':
            observed = np.clip(observed, 0, 1)
            pd_part = np.clip(pd_part, 0, 1)
        sdf_volume.astype('uint8').tofile(save_path + '/surface.bin')
        pd_part.astype('uint8').tofile(save_path + '/dec_part.bin')

        depsem_gt = np.multiply(voxel_test, np.clip(observed, 0, 1))
        if cfg.TYPE_TASK == 'scene':
            depsem_gt[depsem_gt < 0] = 0
        depsem_gt.astype('uint8').tofile(save_path + '/depth_seg_scene.bin')

        # decoded
        do_save_pcd = True
        if do_save_pcd is True:
            results_pcds = np.argmax(pd_ssc, axis=4)
            for i in range(np.shape(results_pcds)[0]):
                pcd_idx = np.where(results_pcds[i] > 0)
                pts_coord = np.float32(np.transpose(pcd_idx)) / 80 - 0.5
                pts_color = matplotlib.cm.Paired(
                    np.float32(results_pcds[i][pcd_idx]) / 11 - 0.5 / 11)
                output_name = os.path.join('results_pcds',
                                           '%s.pcd' % data_paths[i][1][:-4])
                output_pcds = np.concatenate((pts_coord, pts_color[:, 0:3]),
                                             -1)
                save_pcd(output_name, output_pcds)

        np.argmax(pd_ssc,
                  axis=4).astype('uint8').tofile(save_path + '/dec_ssc.bin')
        error = np.array(
            np.clip(np.argmax(pd_ssc, axis=4), 0, 1) +
            np.argmax(comp_gt, axis=4) * 2)
        error.astype('uint8').tofile(save_path + '/dec_ssc_error.bin')
        np.argmax(pd_surf,
                  axis=4).astype('uint8').tofile(save_path + '/dec_surf.bin')
        error = np.array(
            np.clip(np.argmax(pd_surf, axis=4), 0, 1) +
            np.argmax(comp_gt, axis=4) * 2)
        error.astype('uint8').tofile(save_path + '/dec_surf_error.bin')
        np.argmax(pd_full,
                  axis=4).astype('uint8').tofile(save_path + '/dec_full.bin')
        error = np.array(
            np.clip(np.argmax(pd_full, axis=4), 0, 1) +
            np.argmax(comp_gt, axis=4) * 2)
        error.astype('uint8').tofile(save_path + '/dec_full_error.bin')
        np.argmax(pd_comp, axis=4).astype('uint8').tofile(save_path +
                                                          '/dec_complete.bin')
        np.argmax(comp_gt, axis=4).astype('uint8').tofile(save_path +
                                                          '/complete_gt.bin')

        # reconstruction and generation from normal distribution evaluation
        # generator from random distribution
        if discriminative is True:
            np.save(save_path + '/decode_z.npy', part_enc_Z)
            sample_times = 10
            for j in np.arange(sample_times):
                gaussian_samp = np.random.normal(
                    size=(batch_size, start_vox_size[0], start_vox_size[1],
                          start_vox_size[2], dim_z)).astype(np.float32)

                z_comp_rnd, z_surf_rnd, z_full_rnd, z_part_rnd, scores_samp = sess.run(
                    [
                        comp_tf_samp, surf_tf_samp, full_tf_samp, part_tf_samp,
                        scores_tf_samp
                    ],
                    feed_dict={Z_tf_samp: gaussian_samp})
                if j == 0:
                    z_comp_rnd_all = z_comp_rnd
                    z_part_rnd_all = z_part_rnd
                    z_surf_rnd_all = z_surf_rnd
                    z_full_rnd_all = z_full_rnd
                else:
                    z_comp_rnd_all = np.concatenate(
                        [z_comp_rnd_all, z_comp_rnd], axis=0)
                    z_part_rnd_all = np.concatenate(
                        [z_part_rnd_all, z_part_rnd], axis=0)
                    z_surf_rnd_all = np.concatenate(
                        [z_surf_rnd_all, z_surf_rnd], axis=0)
                    z_full_rnd_all = np.concatenate(
                        [z_full_rnd_all, z_full_rnd], axis=0)
                    print('Discrim score: ' +
                          colored(np.mean(scores_samp), 'blue'))
            gaussian_samp.astype('float32').tofile(save_path + '/sample_z.bin')
            np.argmax(z_comp_rnd_all,
                      axis=4).astype('uint8').tofile(save_path +
                                                     '/gen_comp.bin')
            np.argmax(z_surf_rnd_all,
                      axis=4).astype('uint8').tofile(save_path +
                                                     '/gen_surf.bin')
            np.argmax(z_full_rnd_all,
                      axis=4).astype('uint8').tofile(save_path +
                                                     '/gen_full.bin')
            if cfg.TYPE_TASK == 'scene':
                z_part_rnd_all = np.abs(z_part_rnd_all)
                z_part_rnd_all *= 10
                z_part_rnd_all -= 7
            elif cfg.TYPE_TASK == 'object':
                z_part_rnd_all[z_part_rnd_all <= 0.4] = 0
                z_part_rnd_all[z_part_rnd_all > 0.4] = 1
                z_part_rnd = np.squeeze(z_part_rnd)
            z_part_rnd_all.astype('uint8').tofile(save_path + '/gen_part.bin')

        print("voxels saved")

        # numerical evalutation
        iou_eval = True
        if iou_eval is True:
            # completion
            print(colored("Completion:", 'red'))
            on_gt = comp_gt
            pd_max = np.argmax(pd_comp, axis=4)
            on_pd = onehot(pd_max, 2)
            IoU_comp = np.zeros([2 + 1])
            AP_comp = np.zeros([2 + 1])
            IoU_comp = IoU(on_gt, on_pd,
                           [vox_shape[0], vox_shape[1], vox_shape[2], 2])

            # depth segmentation
            print(colored("Segmentation:", 'red'))
            print(colored("encoded", 'cyan'))
            on_gt = onehot(depsem_gt, vox_shape[3])
            on_pd = np.multiply(
                onehot(np.argmax(pd_ssc, axis=4), vox_shape[3]),
                np.expand_dims(np.clip(observed, 0, 1), -1))
            # IoUs = np.zeros([vox_shape[3] + 1])
            IoU_temp = IoU(on_gt, on_pd, vox_shape)
            IoU_all = np.expand_dims(IoU_temp, axis=1)

            print(colored("decoded", 'cyan'))
            on_pd = np.multiply(
                onehot(np.argmax(pd_surf, axis=4), vox_shape[3]),
                np.expand_dims(np.clip(observed, 0, 1), -1))
            IoU_temp = IoU(on_gt,
                           on_pd,
                           vox_shape,
                           IoU_compared=IoU_all[:, -1])
            IoU_all = np.concatenate(
                (IoU_all, np.expand_dims(IoU_temp, axis=1)), axis=1)

            print(colored("solidly decoded", 'cyan'))
            on_pd = np.multiply(
                onehot(np.argmax(pd_full, axis=4), vox_shape[3]),
                np.expand_dims(np.clip(observed, 0, 1), -1))
            IoU_temp = IoU(on_gt,
                           on_pd,
                           vox_shape,
                           IoU_compared=IoU_all[:, -1])
            IoU_all = np.concatenate(
                (IoU_all, np.expand_dims(IoU_temp, axis=1)), axis=1)

            # volume segmentation
            print(colored("Semantic Completion:", 'red'))
            on_surf_gt = onehot(surf_test, vox_shape[3])
            on_gt = onehot(voxel_test, vox_shape[3])
            print(colored("encoded", 'cyan'))
            on_pd = onehot(np.argmax(pd_ssc, axis=4), vox_shape[3])
            IoU_temp = IoU(on_gt, on_pd, vox_shape)
            IoU_all = np.concatenate(
                (IoU_all, np.expand_dims(IoU_temp, axis=1)), axis=1)

            print(colored("decoded", 'cyan'))
            on_pd = onehot(np.argmax(pd_surf, axis=4), vox_shape[3])
            IoU_temp = IoU(on_gt,
                           on_pd,
                           vox_shape,
                           IoU_compared=IoU_all[:, -1])
            IoU_all = np.concatenate(
                (IoU_all, np.expand_dims(IoU_temp, axis=1)), axis=1)

            print(colored("solidly decoded", 'cyan'))
            on_pd = onehot(np.argmax(pd_full, axis=4), vox_shape[3])
            IoU_temp = IoU(on_gt,
                           on_pd,
                           vox_shape,
                           IoU_compared=IoU_all[:, -1])
            IoU_all = np.concatenate(
                (IoU_all, np.expand_dims(IoU_temp, axis=1)), axis=1)

            np.savetxt(save_path + '/IoU.csv',
                       np.transpose(IoU_all[1:] * 100),
                       delimiter=" & ",
                       fmt='%2.1f')

    # interpolation evaluation
    if mode == 'interpolate':
        interpolate_num = 8
        #interpolatioin latent vectores
        decode_z = np.load(save_path + '/decode_z.npy')
        print(save_path)
        decode_z = decode_z[20:20 + batch_size]
        for l in np.arange(batch_size):
            for r in np.arange(batch_size):
                if l != r:
                    print l, r
                    base_num_left = l
                    base_num_right = r
                    left = np.reshape(decode_z[base_num_left], [
                        1, start_vox_size[0], start_vox_size[1],
                        start_vox_size[2], dim_z
                    ])
                    right = np.reshape(decode_z[base_num_right], [
                        1, start_vox_size[0], start_vox_size[1],
                        start_vox_size[2], dim_z
                    ])

                    duration = (right - left) / (interpolate_num - 1)
                    # left is the reference sample and Z_np_samp is the remaining samples
                    if base_num_left == 0:
                        Z_np_samp = decode_z[1:]
                    elif base_num_left == batch_size - 1:
                        Z_np_samp = decode_z[:batch_size - 1]
                    else:
                        Z_np_samp_before = np.reshape(
                            decode_z[:base_num_left], [
                                base_num_left, start_vox_size[0],
                                start_vox_size[1], start_vox_size[2], dim_z
                            ])
                        Z_np_samp_after = np.reshape(
                            decode_z[base_num_left + 1:], [
                                batch_size - base_num_left - 1,
                                start_vox_size[0], start_vox_size[1],
                                start_vox_size[2], dim_z
                            ])
                        Z_np_samp = np.concatenate(
                            [Z_np_samp_before, Z_np_samp_after], axis=0)
                    for i in np.arange(interpolate_num):
                        if i == 0:
                            Z = copy.copy(left)
                            interpolate_z = copy.copy(Z)
                        else:
                            Z = Z + duration
                            interpolate_z = np.concatenate([interpolate_z, Z],
                                                           axis=0)

                        # Z_np_samp is used to fill up the batch
                        gaussian_samp = np.concatenate([Z, Z_np_samp], axis=0)
                        pd_full_rnd, pd_part_rnd = sess.run(
                            [full_tf_samp, part_tf_samp],
                            feed_dict={Z_tf_samp: gaussian_samp})
                        interpolate_vox = np.reshape(pd_full_rnd[0], [
                            1, vox_shape[0], vox_shape[1], vox_shape[2],
                            vox_shape[3]
                        ])
                        interpolate_part = np.reshape(pd_part_rnd[0], [
                            1, vox_shape[0], vox_shape[1], vox_shape[2],
                            com_shape[3]
                        ])

                        if i == 0:
                            pd_full = interpolate_vox
                            pd_part = interpolate_part
                        else:
                            pd_full = np.concatenate(
                                [pd_full, interpolate_vox], axis=0)
                            pd_part = np.concatenate(
                                [pd_part, interpolate_part], axis=0)
                    interpolate_z.astype('uint8').tofile(
                        save_path + '/interpolate/interpolation_z' + str(l) +
                        '-' + str(r) + '.bin')

                    full_models_cat = np.argmax(pd_full, axis=4)
                    full_models_cat.astype('uint8').tofile(
                        save_path + '/interpolate/interpolation_f' + str(l) +
                        '-' + str(r) + '.bin')
                    if cfg.TYPE_TASK == 'scene':
                        pd_part = np.abs(pd_part)
                        pd_part *= 10
                        pd_part -= 7
                    elif cfg.TYPE_TASK == 'object':
                        pd_part = np.argmax(pd_part, axis=4)
                    pd_part.astype('uint8').tofile(
                        save_path + '/interpolate/interpolation_p' + str(l) +
                        '-' + str(r) + '.bin')
        print("voxels saved")

    # add noise evaluation
    if mode == 'noise':
        decode_z = np.load(save_path + '/decode_z.npy')
        decode_z = decode_z[:batch_size]
        noise_num = 10
        for base_num in np.arange(batch_size):
            print base_num
            base = np.reshape(decode_z[base_num], [
                1, start_vox_size[0], start_vox_size[1], start_vox_size[2],
                dim_z
            ])
            eps = np.random.normal(size=(noise_num - 1,
                                         dim_z)).astype(np.float32)

            if base_num == 0:
                Z_np_samp = decode_z[1:]
            elif base_num == batch_size - 1:
                Z_np_samp = decode_z[:batch_size - 1]
            else:
                Z_np_samp_before = np.reshape(decode_z[:base_num], [
                    base_num, start_vox_size[0], start_vox_size[1],
                    start_vox_size[2], dim_z
                ])
                Z_np_samp_after = np.reshape(decode_z[base_num + 1:], [
                    batch_size - base_num - 1, start_vox_size[0],
                    start_vox_size[1], start_vox_size[2], dim_z
                ])
                Z_np_samp = np.concatenate([Z_np_samp_before, Z_np_samp_after],
                                           axis=0)

            for c in np.arange(start_vox_size[0]):
                for l in np.arange(start_vox_size[1]):
                    for d in np.arange(start_vox_size[2]):

                        for i in np.arange(noise_num):
                            if i == 0:
                                Z = copy.copy(base)
                                noise_z = copy.copy(Z)
                            else:
                                Z = copy.copy(base)
                                Z[0, c, l, d, :] += eps[i - 1]
                                noise_z = np.concatenate([noise_z, Z], axis=0)
                            gaussian_samp = np.concatenate([Z, Z_np_samp],
                                                           axis=0)
                            pd_full_rnd = sess.run(
                                full_tf_samp,
                                feed_dict={Z_tf_samp: gaussian_samp})
                            """
                            refined_voxs_rnd = sess.run(
                                sample_refine_full_tf,
                                feed_dict={
                                    sample_full_tf: pd_full_rnd
                                })
                            """
                            noise_vox = np.reshape(pd_full_rnd[0], [
                                1, vox_shape[0], vox_shape[1], vox_shape[2],
                                vox_shape[3]
                            ])
                            if i == 0:
                                pd_full = noise_vox
                            else:
                                pd_full = np.concatenate([pd_full, noise_vox],
                                                         axis=0)

                        np.save(
                            save_path + '/noise_z' + str(base_num) + '_' +
                            str(c) + str(l) + str(d) + '.npy', noise_z)

                        full_models_cat = np.argmax(pd_full, axis=4)
                        np.save(
                            save_path + '/noise' + str(base_num) + '_' +
                            str(c) + str(l) + str(d) + '.npy', full_models_cat)

        print("voxels saved")