示例#1
0
    def create_loss(self, coarse, fine, gt, alpha):
        loss_coarse = chamfer(coarse, gt)
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = chamfer(fine, gt)
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
示例#2
0
    def create_loss(self, coarse_highres, coarse, fine, gt, theta):
        loss_coarse_highres = chamfer(coarse_highres, gt)

        loss_coarse = chamfer(coarse, gt)
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = chamfer(fine, gt)
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        repulsion_loss = get_repulsion_loss4(coarse)

        loss = 0.5 * loss_coarse_highres + loss_coarse + theta * loss_fine + 0.2 * repulsion_loss
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, loss_fine, [update_coarse, update_fine, update_loss]
示例#3
0
    def create_loss(self, coarse, fine, gt, alpha):
        loss_coarse = chamfer(coarse[:, :, 0:3], gt[:, :, 0:3])
        _, retb, _, retd = tf_nndistance.nn_distance(coarse[:, :, 0:3],
                                                     gt[:, :, 0:3])
        for i in range(np.shape(gt)[0]):
            index = tf.expand_dims(retb[i], -1)
            sem_feat = tf.nn.softmax(coarse[i, :, 3:], -1)
            sem_gt = tf.cast(
                tf.one_hot(
                    tf.gather_nd(tf.cast(gt[i, :, 3] * 80 * 12, tf.int32),
                                 index), 12), tf.float32)
            loss_sem_coarse = tf.reduce_mean(-tf.reduce_sum(
                0.9 * sem_gt * tf.log(1e-6 + sem_feat) + (1 - 0.9) *
                (1 - sem_gt) * tf.log(1e-6 + 1 - sem_feat), [1]))
            loss_coarse += loss_sem_coarse
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = chamfer(fine[:, :, 0:3], gt[:, :, 0:3])
        _, retb, _, retd = tf_nndistance.nn_distance(fine[:, :, 0:3], gt[:, :,
                                                                         0:3])
        for i in range(np.shape(gt)[0]):
            index = tf.expand_dims(retb[i], -1)
            sem_feat = tf.nn.softmax(fine[i, :, 3:], -1)
            sem_gt = tf.cast(
                tf.one_hot(
                    tf.gather_nd(tf.cast(gt[i, :, 3] * 80 * 12, tf.int32),
                                 index), 12), tf.float32)
            loss_sem_fine = tf.reduce_mean(-tf.reduce_sum(
                0.9 * sem_gt * tf.log(1e-6 + sem_feat) + (1 - 0.9) *
                (1 - sem_gt) * tf.log(1e-6 + 1 - sem_feat), [1]))
            loss_fine += loss_sem_fine
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
示例#4
0
    def create_loss(self, coarse, fine, gt, alpha):
        gt_ds = gt[:, :coarse.shape[1], :]
        loss_coarse = earth_mover(coarse, gt_ds)
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = chamfer(fine, gt)
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
示例#5
0
文件: fc.py 项目: mihaibujanca/pcn
 def create_loss(self, outputs, gt):
     loss = chamfer(outputs, gt)
     add_train_summary('train/loss', loss)
     update_loss = add_valid_summary('valid/loss', loss)
     return loss, update_loss
示例#6
0
def test(args):
    inputs = tf.placeholder(tf.float32, (1, None, 3))
    gt = tf.placeholder(tf.float32, (1, args.num_gt_points, 3))
    model_module = importlib.import_module('.%s' % args.model_type, 'models')
    model = model_module.Model(inputs, gt, tf.constant(1.0))

    output = tf.placeholder(tf.float32, (1, args.num_gt_points, 3))
    cd_op = chamfer(output, gt)
    emd_op = earth_mover(output, gt)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)

    saver = tf.train.Saver()
    saver.restore(sess, args.checkpoint)

    os.makedirs(args.results_dir, exist_ok=True)
    csv_file = open(os.path.join(args.results_dir, 'results.csv'), 'w')
    writer = csv.writer(csv_file)
    writer.writerow(['id', 'cd', 'emd'])

    with open(args.list_path) as file:
        model_list = file.read().splitlines()
    total_time = 0
    total_cd = 0
    total_emd = 0
    cd_per_cat = {}
    emd_per_cat = {}
    for i, model_id in enumerate(model_list):
        partial = read_pcd(
            os.path.join(args.data_dir, 'partial', '%s.pcd' % model_id))
        complete = read_pcd(
            os.path.join(args.data_dir, 'complete', '%s.pcd' % model_id))
        start = time.time()
        completion = sess.run(model.outputs, feed_dict={inputs: [partial]})
        total_time += time.time() - start
        cd, emd = sess.run([cd_op, emd_op],
                           feed_dict={
                               output: completion,
                               gt: [complete]
                           })
        total_cd += cd
        total_emd += emd
        writer.writerow([model_id, cd, emd])

        synset_id, model_id = model_id.split('/')
        if not cd_per_cat.get(synset_id):
            cd_per_cat[synset_id] = []
        if not emd_per_cat.get(synset_id):
            emd_per_cat[synset_id] = []
        cd_per_cat[synset_id].append(cd)
        emd_per_cat[synset_id].append(emd)

        if i % args.plot_freq == 0:
            os.makedirs(os.path.join(args.results_dir, 'plots', synset_id),
                        exist_ok=True)
            plot_path = os.path.join(args.results_dir, 'plots', synset_id,
                                     '%s.png' % model_id)
            plot_pcd_three_views(plot_path, [partial, completion[0], complete],
                                 ['input', 'output', 'ground truth'],
                                 'CD %.4f  EMD %.4f' % (cd, emd),
                                 [5, 0.5, 0.5])
        if args.save_pcd:
            os.makedirs(os.path.join(args.results_dir, 'pcds', synset_id),
                        exist_ok=True)
            save_pcd(
                os.path.join(args.results_dir, 'pcds', '%s.pcd' % model_id),
                completion[0])
    csv_file.close()
    sess.close()

    print('Average time: %f' % (total_time / len(model_list)))
    print('Average Chamfer distance: %f' % (total_cd / len(model_list)))
    print('Average Earth mover distance: %f' % (total_emd / len(model_list)))
    print('Chamfer distance per category')
    for synset_id in cd_per_cat.keys():
        print(synset_id, '%f' % np.mean(cd_per_cat[synset_id]))
    print('Earth mover distance per category')
    for synset_id in emd_per_cat.keys():
        print(synset_id, '%f' % np.mean(emd_per_cat[synset_id]))
示例#7
0
def test(args):
    inputs = tf.placeholder(tf.float32, (1, None, 3))
    npts = tf.placeholder(tf.int32, (1, ))
    gt = tf.placeholder(tf.float32, (1, args.num_gt_points, 6))
    model_module = importlib.import_module('.%s' % args.model_type, 'models')
    model = model_module.Model(inputs, npts, gt, tf.constant(1.0),
                               args.num_channel)

    output = tf.placeholder(tf.float32,
                            (1, args.num_gt_points, 3 + args.num_channel))
    cd_op = chamfer(output[:, :, 0:3], gt[:, :, 0:3])
    emd_op = earth_mover(output[:, :, 0:3], gt[:, :, 0:3])

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)

    saver = tf.train.Saver()
    saver.restore(sess, args.checkpoint)

    os.makedirs(args.results_dir, exist_ok=True)
    csv_file = open(os.path.join(args.results_dir, 'results.csv'), 'w')
    writer = csv.writer(csv_file)
    writer.writerow(['id', 'cd', 'emd'])

    with open(args.list_path) as file:
        model_list = file.read().splitlines()
    total_time = 0
    total_cd = 0
    total_emd = 0
    cd_per_cat = {}
    emd_per_cat = {}
    np.random.seed(1)
    for i, model_id in enumerate(model_list):
        if args.experiment == 'shapenet':
            synset_id, model_id = model_id.split('/')
            partial = read_pcd(
                os.path.join(args.data_dir, 'partial', synset_id,
                             '%s.pcd' % model_id))
            complete = read_pcd(
                os.path.join(args.data_dir, 'complete', synset_id,
                             '%s.pcd' % model_id))
        elif args.experiment == 'suncg':
            synset_id = 'all_rooms'
            partial = read_pcd(
                os.path.join(args.data_dir, 'pcd_partial',
                             '%s.pcd' % model_id))
            complete = read_pcd(
                os.path.join(args.data_dir, 'pcd_complete',
                             '%s.pcd' % model_id))
        if args.rotate:
            angle = np.random.rand(1) * 2 * np.pi
            partial = np.stack([
                np.cos(angle) * partial[:, 0] - np.sin(angle) * partial[:, 2],
                partial[:, 1],
                np.sin(angle) * partial[:, 0] + np.cos(angle) * partial[:, 2]
            ],
                               axis=-1)
            complete = np.stack([
                np.cos(angle) * complete[:, 0] -
                np.sin(angle) * complete[:, 2], complete[:, 1],
                np.sin(angle) * complete[:, 0] +
                np.cos(angle) * complete[:, 2], complete[:, 3], complete[:, 4],
                complete[:, 5]
            ],
                                axis=-1)
        partial = partial[:, :3]
        complete = resample_pcd(complete, 16384)
        start = time.time()
        completion1, completion2, mesh_out = sess.run(
            [model.outputs1, model.outputs2, model.gt_can],
            feed_dict={
                inputs: [partial],
                npts: [partial.shape[0]],
                gt: [complete]
            })
        completion1[0][:, (3 + args.num_channel):] *= 0
        completion2[0][:, (3 + args.num_channel):] *= 0
        mesh_out[0][:, (3 + args.num_channel):] *= 0
        total_time += time.time() - start
        # cd, emd = sess.run([cd_op, emd_op],
        cd, emd = sess.run([cd_op, cd_op],
                           feed_dict={
                               output: completion2,
                               gt: [complete]
                           })
        total_cd += cd
        total_emd += emd
        if not cd_per_cat.get(synset_id):
            cd_per_cat[synset_id] = []
        if not emd_per_cat.get(synset_id):
            emd_per_cat[synset_id] = []
        cd_per_cat[synset_id].append(cd)
        emd_per_cat[synset_id].append(emd)
        writer.writerow([model_id, cd, emd])

        if i % args.plot_freq == 0:
            os.makedirs(
                os.path.join(args.results_dir, 'plots', synset_id),
                exist_ok=True)
            plot_path = os.path.join(args.results_dir, 'plots', synset_id,
                                     '%s.png' % model_id)
            plot_pcd_three_views(
                plot_path, [
                    partial, completion1[0], completion2[0], mesh_out[0],
                    complete
                ], ['input', 'coarse', 'fine', 'mesh', 'ground truth'],
                'CD %.4f  EMD %.4f' % (cd, emd), [5, 0.5, 0.5, 0.5, 0.5],
                num_channel=args.num_channel)
        if args.save_pcd:
            os.makedirs(
                os.path.join(args.results_dir, 'input', synset_id),
                exist_ok=True)
            pts_coord = partial[:, 0:3]
            pts_color = matplotlib.cm.cool((partial[:, 1]))[:, 0:3]
            # save_pcd(os.path.join(args.results_dir, 'input', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1))
            pcd = PointCloud()
            pcd.points = Vector3dVector(pts_coord)
            pcd.colors = Vector3dVector(pts_color)
            write_point_cloud(
                os.path.join(args.results_dir, 'input', synset_id,
                             '%s.ply' % model_id),
                pcd,
                write_ascii=True)
            os.makedirs(
                os.path.join(args.results_dir, 'output1', synset_id),
                exist_ok=True)
            pts_coord = completion1[0][:, 0:3]
            pts_color = matplotlib.cm.Set1(
                (np.argmax(completion1[0][:, 3:3 + args.num_channel], -1) +
                 1) / args.num_channel - 0.5 / args.num_channel)[:, 0:3]
            # pts_color = matplotlib.cm.tab20((np.argmax(completion1[0][:, 3:3+args.num_channel], -1) + 1)/args.num_channel - 0.5/args.num_channel)[:,0:3]
            # save_pcd(os.path.join(args.results_dir, 'output1', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1))
            pcd.points = Vector3dVector(pts_coord)
            pcd.colors = Vector3dVector(pts_color)
            write_point_cloud(
                os.path.join(args.results_dir, 'output1', synset_id,
                             '%s.ply' % model_id),
                pcd,
                write_ascii=True)
            os.makedirs(
                os.path.join(args.results_dir, 'output2', synset_id),
                exist_ok=True)
            pts_coord = completion2[0][:, 0:3]
            pts_color = matplotlib.cm.Set1(
                (np.argmax(completion2[0][:, 3:3 + args.num_channel], -1) +
                 1) / args.num_channel - 0.5 / args.num_channel)[:, 0:3]
            # pts_color = matplotlib.cm.tab20((np.argmax(completion2[0][:, 3:3+args.num_channel], -1) + 1)/args.num_channel - 0.5/args.num_channel)[:,0:3]
            # save_pcd(os.path.join(args.results_dir, 'output2', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1))
            pcd.points = Vector3dVector(pts_coord)
            pcd.colors = Vector3dVector(pts_color)
            write_point_cloud(
                os.path.join(args.results_dir, 'output2', synset_id,
                             '%s.ply' % model_id),
                pcd,
                write_ascii=True)
            #######
            os.makedirs(
                os.path.join(args.results_dir, 'regions', synset_id),
                exist_ok=True)
            for idx in range(3, 3 + args.num_channel):
                val_min = np.min(completion2[0][:, idx])
                val_max = np.max(completion2[0][:, idx])
                pts_color = 0.8 * matplotlib.cm.Reds(
                    (completion2[0][:, idx] - val_min) /
                    (val_max - val_min))[:, 0:3]
                pts_color += 0.2 * matplotlib.cm.gist_gray(
                    (completion2[0][:, idx] - val_min) /
                    (val_max - val_min))[:, 0:3]
                pcd.colors = Vector3dVector(pts_color)
                write_point_cloud(
                    os.path.join(args.results_dir, 'regions', synset_id,
                                 '%s_%s.ply' % (model_id, idx - 3)),
                    pcd,
                    write_ascii=True)
            os.makedirs(
                os.path.join(args.results_dir, 'gt', synset_id), exist_ok=True)
            pts_coord = complete[:, 0:3]
            if args.experiment == 'shapenet':
                pts_color = matplotlib.cm.cool(complete[:, 1])[:, 0:3]
            elif args.experiment == 'suncg':
                pts_color = matplotlib.cm.Set1(complete[:, 3] -
                                               0.5 / args.num_channel)[:, 0:3]
            # save_pcd(os.path.join(args.results_dir, 'gt', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1))
            pcd.points = Vector3dVector(pts_coord)
            pcd.colors = Vector3dVector(pts_color)
            write_point_cloud(
                os.path.join(args.results_dir, 'gt', synset_id,
                             '%s.ply' % model_id),
                pcd,
                write_ascii=True)
    sess.close()

    print('Average time: %f' % (total_time / len(model_list)))
    print('Average Chamfer distance: %f' % (total_cd / len(model_list)))
    print('Average Earth mover distance: %f' % (total_emd / len(model_list)))
    writer.writerow([
        total_time / len(model_list), total_cd / len(model_list),
        total_emd / len(model_list)
    ])
    print('Chamfer distance per category')
    for synset_id in cd_per_cat.keys():
        print(synset_id, '%f' % np.mean(cd_per_cat[synset_id]))
        writer.writerow([synset_id, np.mean(cd_per_cat[synset_id])])
    print('Earth mover distance per category')
    for synset_id in emd_per_cat.keys():
        print(synset_id, '%f' % np.mean(emd_per_cat[synset_id]))
        writer.writerow([synset_id, np.mean(emd_per_cat[synset_id])])
    csv_file.close()
def train(args):
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    global_step = tf.Variable(0, trainable=False, name='global_step')
    alpha = tf.train.piecewise_constant(
        global_step // (args.gen_iter + args.dis_iter), args.fine_step,
        [0.01, 0.1, 0.5, 1.0], 'alpha_op')
    inputs_pl = tf.placeholder(tf.float32,
                               (args.batch_size, args.num_input_points, 3),
                               'inputs')
    gt_pl = tf.placeholder(tf.float32,
                           (args.batch_size, args.num_gt_points, 3), 'gt')
    complete_feature = tf.placeholder(tf.float32, (args.batch_size, 1024),
                                      'complete_feature')
    complete_feature0 = tf.placeholder(tf.float32, (args.batch_size, 256),
                                       'complete_feature0')
    label_pl = tf.placeholder(tf.int32, shape=(args.batch_size))

    model_module = importlib.import_module('.%s' % args.model_type, 'models')

    file_train = h5py.File(args.h5_train, 'r')
    incomplete_pcds_train = file_train['incomplete_pcds'][()]
    complete_pcds_train = file_train['complete_pcds'][()]
    labels_train = file_train['labels'][()]
    if args.num_gt_points == 2048:
        if args.step_ratio == 2:
            complete_features_train = file_train['complete_feature'][()]
            complete_features_train0 = file_train['complete_feature0'][()]
        elif args.step_ratio == 4:
            complete_features_train = file_train['complete_feature1_4'][()]
            complete_features_train0 = file_train['complete_feature0_4'][()]
        elif args.step_ratio == 8:
            complete_features_train = file_train['complete_feature1_8'][()]
            complete_features_train0 = file_train['complete_feature0_8'][()]
        elif args.step_ratio == 16:
            complete_features_train = file_train['complete_feature1_16'][()]
            complete_features_train0 = file_train['complete_feature0_16'][()]
    elif args.num_gt_points == 16384:
        file_train_feature = h5py.File(
            args.pretrain_complete_decoder + '/train_complete_feature.h5', 'r')
        complete_features_train = file_train_feature['complete_feature1'][()]
        complete_features_train0 = file_train_feature['complete_feature0'][()]
        file_train_feature.close()
    file_train.close()

    assert complete_features_train.shape[0] == complete_features_train0.shape[
        0] == incomplete_pcds_train.shape[0]
    assert complete_features_train.shape[1] == 1024
    assert complete_features_train0.shape[1] == 256

    train_num = incomplete_pcds_train.shape[0]
    BN_DECAY_DECAY_STEP = int(train_num / args.batch_size *
                              args.lr_decay_epochs)

    learning_rate_d = tf.train.exponential_decay(
        args.base_lr_d,
        global_step // (args.gen_iter + args.dis_iter),
        BN_DECAY_DECAY_STEP,
        args.lr_decay_rate,
        staircase=True,
        name='lr_d')
    learning_rate_d = tf.maximum(learning_rate_d, args.lr_clip)

    learning_rate_g = tf.train.exponential_decay(
        args.base_lr_g,
        global_step // (args.gen_iter + args.dis_iter),
        BN_DECAY_DECAY_STEP,
        args.lr_decay_rate,
        staircase=True,
        name='lr_g')
    learning_rate_g = tf.maximum(learning_rate_g, args.lr_clip)

    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        G_optimizers = tf.train.AdamOptimizer(learning_rate_g, beta1=0.9)
        D_optimizers = tf.train.AdamOptimizer(learning_rate_d, beta1=0.5)

    coarse_gpu = []
    fine_gpu = []
    tower_grads_g = []
    tower_grads_d = []
    total_dis_loss_gpu = []
    total_gen_loss_gpu = []
    total_lossReconstruction_gpu = []
    total_lossFeature_gpu = []

    for i in range(NUM_GPUS):
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            with tf.device('/gpu:%d' % (i)), tf.name_scope('gpu_%d' %
                                                           (i)) as scope:
                inputs_pl_batch = tf.slice(inputs_pl,
                                           [int(i * DEVICE_BATCH_SIZE), 0, 0],
                                           [int(DEVICE_BATCH_SIZE), -1, -1])
                gt_pl_batch = tf.slice(gt_pl,
                                       [int(i * DEVICE_BATCH_SIZE), 0, 0],
                                       [int(DEVICE_BATCH_SIZE), -1, -1])
                complete_feature_batch = tf.slice(
                    complete_feature, [int(i * DEVICE_BATCH_SIZE), 0],
                    [int(DEVICE_BATCH_SIZE), -1])
                complete_feature_batch0 = tf.slice(
                    complete_feature0, [int(i * DEVICE_BATCH_SIZE), 0],
                    [int(DEVICE_BATCH_SIZE), -1])

                with tf.variable_scope('generator'):
                    features_partial_0, partial_reconstruct = model_module.encoder(
                        inputs_pl_batch, embed_size=1024)
                    coarse_batch, fine_batch = model_module.decoder(
                        inputs_pl_batch,
                        partial_reconstruct,
                        step_ratio=args.step_ratio,
                        num_fine=args.step_ratio * 1024)

                with tf.variable_scope('discriminator') as dis_scope:
                    errD_fake = mlp(tf.expand_dims(partial_reconstruct,
                                                   axis=1), [16],
                                    bn=None,
                                    bn_params=None)
                    dis_scope.reuse_variables()
                    errD_real = mlp(tf.expand_dims(complete_feature_batch,
                                                   axis=1), [16],
                                    bn=None,
                                    bn_params=None)

                    kernel = getattr(mmd, '_%s_kernel' % args.mmd_kernel)
                    kerGI = kernel(errD_fake[:, 0, :], errD_real[:, 0, :])
                    errG = mmd.mmd2(kerGI)
                    errD = -errG
                    epsilon = tf.random_uniform([], 0.0, 1.0)
                    x_hat = complete_feature_batch * (
                        1 - epsilon) + epsilon * partial_reconstruct
                    d_hat = mlp(tf.expand_dims(x_hat, axis=1), [16],
                                bn=None,
                                bn_params=None)
                    Ekx = lambda yy: tf.reduce_mean(
                        kernel(d_hat[:, 0, :], yy, K_XY_only=True), axis=1)
                    Ekxr, Ekxf = Ekx(errD_real[:, 0, :]), Ekx(errD_fake[:,
                                                                        0, :])
                    witness = Ekxr - Ekxf
                    gradients = tf.gradients(witness, [x_hat])[0]
                    penalty = tf.reduce_mean(
                        tf.square(mmd.safer_norm(gradients, axis=1) - 1.0))
                    errD_loss_batch = penalty * args.gp_weight + errD
                    errG_loss_batch = errG

                feature_loss = tf.reduce_mean(tf.squared_difference(partial_reconstruct,complete_feature_batch))+ \
                               tf.reduce_mean(tf.squared_difference(features_partial_0[:,0,:], complete_feature_batch0))

                dist1_fine, dist2_fine = chamfer(fine_batch, gt_pl_batch)
                dist1_coarse, dist2_coarse = chamfer(coarse_batch, gt_pl_batch)
                total_loss_fine = (tf.reduce_mean(tf.sqrt(dist1_fine)) +
                                   tf.reduce_mean(tf.sqrt(dist2_fine))) / 2
                total_loss_coarse = (tf.reduce_mean(tf.sqrt(dist1_coarse)) +
                                     tf.reduce_mean(tf.sqrt(dist2_coarse))) / 2
                total_loss_rec_batch = alpha * total_loss_fine + total_loss_coarse

                t_vars = tf.global_variables()
                gen_tvars = [
                    var for var in t_vars if var.name.startswith("generator")
                ]
                dis_tvars = [
                    var for var in t_vars
                    if var.name.startswith("discriminator")
                ]
                total_gen_loss_batch = args.feat_weight * feature_loss + errG_loss_batch + args.rec_weight * total_loss_rec_batch
                total_dis_loss_batch = errD_loss_batch

                # Calculate the gradients for the batch of data on this tower.
                grads_g = G_optimizers.compute_gradients(total_gen_loss_batch,
                                                         var_list=gen_tvars)
                grads_d = D_optimizers.compute_gradients(total_dis_loss_batch,
                                                         var_list=dis_tvars)

                # Keep track of the gradients across all towers.
                tower_grads_g.append(grads_g)
                tower_grads_d.append(grads_d)

                coarse_gpu.append(coarse_batch)
                fine_gpu.append(fine_batch)

                total_dis_loss_gpu.append(total_dis_loss_batch)
                total_gen_loss_gpu.append(errG_loss_batch)
                total_lossReconstruction_gpu.append(args.rec_weight *
                                                    total_loss_rec_batch)
                total_lossFeature_gpu.append(args.feat_weight * feature_loss)

    fine = tf.concat(fine_gpu, 0)

    grads_g = average_gradients(tower_grads_g)
    grads_d = average_gradients(tower_grads_d)

    # apply the gradients with our optimizers
    train_G = G_optimizers.apply_gradients(grads_g, global_step=global_step)
    train_D = D_optimizers.apply_gradients(grads_d, global_step=global_step)

    total_dis_loss = tf.reduce_mean(tf.stack(total_dis_loss_gpu, 0))
    total_gen_loss = tf.reduce_mean(tf.stack(total_gen_loss_gpu, 0))
    total_loss_rec = tf.reduce_mean(tf.stack(total_lossReconstruction_gpu, 0))
    total_loss_fea = tf.reduce_mean(tf.stack(total_lossFeature_gpu, 0))

    dist1_eval, dist2_eval = chamfer(fine, gt_pl)

    file_validate = h5py.File(args.h5_validate, 'r')
    incomplete_pcds_validate = file_validate['incomplete_pcds'][()]
    complete_pcds_validate = file_validate['complete_pcds'][()]
    labels_validate = file_validate['labels'][()]
    file_validate.close()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    sess = tf.Session(config=config)
    saver = tf.train.Saver(max_to_keep=3)
    sess.run(tf.global_variables_initializer())

    saver_decoder = tf.train.Saver(var_list=[var for var in tf.global_variables() if (var.name.startswith("generator/decoder") \
                                                                                      or var.name.startswith("generator/folding"))])
    saver_decoder.restore(sess, args.pretrain_complete_decoder)

    if args.restore:
        saver.restore(sess, tf.train.latest_checkpoint(args.log_dir))

    init_step = sess.run(global_step) // (args.gen_iter + args.dis_iter)
    epoch = init_step * args.batch_size // train_num + 1
    print('init_step:%d,' % init_step, 'epoch:%d' % epoch,
          'training data number:%d' % train_num)
    train_idxs = np.arange(0, train_num)

    for ep_cnt in range(epoch, args.max_epoch + 1):
        num_batches = train_num // args.batch_size
        np.random.shuffle(train_idxs)

        for batch_idx in range(num_batches):
            init_step += 1
            start_idx = batch_idx * args.batch_size
            end_idx = min(start_idx + args.batch_size, train_num)
            ids_train = list(np.sort(train_idxs[start_idx:end_idx]))
            batch_data = incomplete_pcds_train[ids_train]
            batch_gt = complete_pcds_train[ids_train]
            labels = labels_train[ids_train]
            # partial_feature_input=incomplete_features_train[ids_train]
            complete_feature_input = complete_features_train[ids_train]
            complete_feature_input0 = complete_features_train0[ids_train]

            feed_dict = {
                inputs_pl: batch_data,
                gt_pl: batch_gt,
                is_training_pl: True,
                label_pl: labels,
                complete_feature: complete_feature_input,
                complete_feature0: complete_feature_input0
            }

            for i in range(args.dis_iter):
                _, loss_dis = sess.run([train_D, total_dis_loss],
                                       feed_dict=feed_dict)
            for i in range(args.gen_iter):
                _, loss_gen, rec_loss, fea_loss = sess.run(
                    [train_G, total_gen_loss, total_loss_rec, total_loss_fea],
                    feed_dict=feed_dict)

            if init_step % args.steps_per_print == 0:
                print('epoch %d step %d gen_loss %.8f rec_loss %.8f fea_loss %.8f dis_loss %.8f' % \
                    (ep_cnt, init_step, loss_gen,rec_loss,fea_loss,loss_dis))

            if init_step % args.steps_per_eval == 0:
                total_loss = 0
                sess.run(tf.local_variables_initializer())
                batch_data = np.zeros(
                    (args.batch_size, incomplete_pcds_validate[0].shape[0], 3),
                    'f')
                batch_gt = np.zeros((args.batch_size, args.num_gt_points, 3),
                                    'f')
                # partial_feature_input = np.zeros((args.batch_size, 1024), 'f')
                labels = np.zeros((args.batch_size, ), dtype=np.int32)
                feature_complete_input = np.zeros(
                    (args.batch_size, 1024)).astype(np.float32)
                feature_complete_input0 = np.zeros(
                    (args.batch_size, 256)).astype(np.float32)
                for batch_idx_eval in range(0,
                                            incomplete_pcds_validate.shape[0],
                                            args.batch_size):
                    # start = time.time()
                    start_idx = batch_idx_eval
                    end_idx = min(start_idx + args.batch_size,
                                  incomplete_pcds_validate.shape[0])

                    batch_data[0:end_idx -
                               start_idx] = incomplete_pcds_validate[
                                   start_idx:end_idx]
                    batch_gt[0:end_idx - start_idx] = complete_pcds_validate[
                        start_idx:end_idx]
                    labels[0:end_idx -
                           start_idx] = labels_validate[start_idx:end_idx]
                    feature_complete_input[
                        0:end_idx -
                        start_idx] = complete_features_train[start_idx:end_idx]
                    feature_complete_input0[
                        0:end_idx - start_idx] = complete_features_train0[
                            start_idx:end_idx]

                    feed_dict = {
                        inputs_pl: batch_data,
                        gt_pl: batch_gt,
                        is_training_pl: False,
                        label_pl: labels,
                        complete_feature: feature_complete_input,
                        complete_feature0: feature_complete_input0
                    }
                    dist1_out, dist2_out = sess.run([dist1_eval, dist2_eval],
                                                    feed_dict=feed_dict)
                    if args.loss_type == 'cd_1':
                        total_loss += np.mean(dist1_out[0:end_idx - start_idx]) * (end_idx - start_idx) \
                                      + np.mean(dist2_out[0:end_idx - start_idx]) * (end_idx - start_idx)
                    elif args.loss_type == 'cd_2':
                        total_loss += (np.mean(np.sqrt(dist1_out[0:end_idx - start_idx])) * (end_idx - start_idx) \
                                       + np.mean(np.sqrt(dist2_out[0:end_idx - start_idx])) * (end_idx - start_idx)) / 2

                if total_loss / incomplete_pcds_validate.shape[
                        0] < args.best_loss:
                    args.best_loss = total_loss / incomplete_pcds_validate.shape[
                        0]
                    saver.save(sess, os.path.join(args.log_dir, 'model'),
                               init_step)

                print('epoch %d  step %d  loss %.8f best_loss %.8f' %
                      (ep_cnt, init_step, total_loss /
                       incomplete_pcds_validate.shape[0], args.best_loss))
    sess.close()
def test(args):
    inputs = tf.placeholder(tf.float32, (1, 2048, 3))
    gt = tf.placeholder(tf.float32, (1, args.num_gt_points, 3))
    reconstruction = tf.placeholder(tf.float32, (1, 1024 * args.step_ratio, 3))
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    model_module = importlib.import_module('.%s' % args.model_type, 'models')

    with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
        _, features_partial = model_module.encoder(inputs)
        coarse, fine = model_module.decoder(inputs, features_partial,
                                            args.step_ratio,
                                            args.step_ratio * 1024)

    dist1_fine, dist2_fine = chamfer(reconstruction, gt)
    if args.loss_type == 'cd_1':
        loss = tf.reduce_mean(dist1_fine) + tf.reduce_mean(dist2_fine)
    elif args.loss_type == 'cd_2':
        loss = (tf.reduce_mean(tf.sqrt(dist1_fine)) +
                tf.reduce_mean(tf.sqrt(dist2_fine))) / 2

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    saver = tf.train.Saver()
    saver.restore(sess, os.path.join(args.checkpoint))

    data_all = h5py.File(args.data_dir, 'r')
    partial_all = data_all['incomplete_pcds'][()]
    complete_all = data_all['complete_pcds'][()]
    model_list = data_all['labels'][()].astype(int)
    data_all.close()

    cd_per_cat = {}

    total_cd = 0
    for i, model_id in enumerate(model_list):
        partial = partial_all[i]
        complete = complete_all[i]

        label = model_list[i]

        completion = sess.run(fine,
                              feed_dict={
                                  inputs: [partial],
                                  is_training_pl: False
                              })
        cd = sess.run(loss,
                      feed_dict={
                          reconstruction: completion,
                          gt: [complete],
                          is_training_pl: False
                      })
        total_cd += cd

        category = objects[label]
        key_list = list(snc_synth_id_to_category.keys())
        val_list = list(snc_synth_id_to_category.values())
        synset_id = key_list[val_list.index(category)]
        if not cd_per_cat.get(synset_id):
            cd_per_cat[synset_id] = []
        cd_per_cat[synset_id].append(cd)

    print('Average Chamfer distance: %f' % (total_cd / len(model_list)))
    print('Chamfer distance per category')
    for synset_id in sorted(cd_per_cat.keys()):
        print(synset_id, '%f' % np.mean(cd_per_cat[synset_id]))

    sess.close()