예제 #1
0
    def create_decoder(self, features):
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            coarse = mlp(features, [2048, 2048, self.num_coarse * (3 + 12)])
            coarse = tf.reshape(coarse, [-1, self.num_coarse, 3 + 12])

        with tf.variable_scope('folding', reuse=tf.AUTO_REUSE):
            grid = tf.meshgrid(tf.linspace(-0.05, 0.05, self.grid_size),
                               tf.linspace(-0.05, 0.05, self.grid_size))
            grid = tf.expand_dims(tf.reshape(tf.stack(grid, axis=2), [-1, 2]),
                                  0)
            grid_feat = tf.tile(grid, [features.shape[0], self.num_coarse, 1])

            point_feat = tf.tile(tf.expand_dims(coarse, 2),
                                 [1, 1, self.grid_size**2, 1])
            point_feat = tf.reshape(point_feat, [-1, self.num_fine, 3 + 12])

            global_feat = tf.tile(tf.expand_dims(features, 1),
                                  [1, self.num_fine, 1])

            feat = tf.concat([grid_feat, point_feat, global_feat], axis=2)
            # feat_1 = mlp(feat, [512, 512, 3])
            # feat_2 = tf.concat([feat_1, point_feat, global_feat], axis=2)

            center = tf.tile(tf.expand_dims(coarse, 2),
                             [1, 1, self.grid_size**2, 1])
            center = tf.reshape(center, [-1, self.num_fine, 3 + 12])

            # with tf.variable_scope('folding_1', reuse=tf.AUTO_REUSE):
            fine = mlp(feat, [512, 512, 3 + 12]) + center
        return coarse, fine
예제 #2
0
파일: itn.py 프로젝트: hmgoforth/pcn
        def while_body(i, est_pose, est_inputs):
            with tf.variable_scope('encoder_0', reuse=tf.AUTO_REUSE):
                features = mlp_conv(est_inputs, [128, 256], bn=self.bn)
                features_global = tf.reduce_max(features, axis=1, keep_dims=True, name='maxpool_0')
                features = tf.concat([features, tf.tile(features_global, [1, tf.shape(inputs)[1], 1])], axis=2)
            with tf.variable_scope('encoder_1', reuse=tf.AUTO_REUSE):
                features = mlp_conv(features, [512, 1024], bn=self.bn)
                features = tf.reduce_max(features, axis=1, name='maxpool_1')
            with tf.variable_scope('fc', reuse=tf.AUTO_REUSE):
                if self.rot_representation == 'quat':
                    est_pose_rep_i = mlp(features, [1024, 1024, 512, 512, 256, 7], bn=self.bn)
                elif self.rot_representation == '6dof':
                    est_pose_rep_i = mlp(features, [1024, 1024, 512, 512, 256, 9], bn=self.bn)

            with tf.variable_scope('est', reuse=tf.AUTO_REUSE):
                if self.rot_representation == 'quat':
                    t = tf.expand_dims(est_pose_rep_i[:, 4:], axis=2)
                    q = est_pose_rep_i[:, 0:4]
                    R = quat2rotm_tf(q)
                elif self.rot_representation == '6dof':
                    t = tf.expand_dims(est_pose_rep_i[:, 6:], axis=2)
                    mat6d = est_pose_rep_i[:, 0:6]
                    R = mat6d2rotm_tf(mat6d)

                est_pose_T_i = tf.concat([
                    tf.concat([R, t], axis=2),
                    tf.concat([tf.zeros([B, 1, 3]), tf.ones([B, 1, 1])], axis=2)],
                    axis=1)
                est_inputs = transform_tf(est_inputs, est_pose_T_i)
                est_pose = tf.linalg.matmul(est_pose_T_i, est_pose)

            return [tf.add(i, 1), est_pose, est_inputs]
예제 #3
0
def decoder(inputs, features, step_ratio=16, num_fine=16 * 1024):
    num_coarse=1024
    assert num_fine == num_coarse * step_ratio
    with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
        coarse = mlp(features, [1024, 1024, num_coarse * 3], bn=None, bn_params=None)
        coarse = tf.reshape(coarse, [-1, num_coarse, 3])

    p1_idx = farthest_point_sample(512, coarse)
    coarse_1 = gather_point(coarse, p1_idx)
    input_fps = symmetric_sample(inputs, int(512 / 2))
    coarse = tf.concat([input_fps, coarse_1], 1)

    with tf.variable_scope('folding', reuse=tf.AUTO_REUSE):
        if not step_ratio ** .5 % 1 == 0:
            grid = gen_1d_grid(step_ratio)
        else:
            grid = gen_grid(np.round(np.sqrt(step_ratio)).astype(np.int32))
        grid = tf.expand_dims(grid, 0)
        grid_feat = tf.tile(grid, [features.shape[0], num_coarse, 1])
        point_feat = tf.tile(tf.expand_dims(coarse, 2), [1, 1, step_ratio, 1])
        point_feat = tf.reshape(point_feat, [-1, num_fine, 3])
        global_feat = tf.tile(tf.expand_dims(features, 1), [1, num_fine, 1])
        feat = tf.concat([grid_feat, point_feat, global_feat], axis=2)
        fine = mlp_conv(feat, [512, 512, 3], bn=None, bn_params=None) + point_feat
    return coarse, fine
예제 #4
0
    def create_decoder(self, features):
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            coarse = mlp(features, [1024, 1024, self.num_coarse * (3 + 12)])
            coarse = tf.reshape(coarse, [-1, self.num_coarse, 3 + 12])

        with tf.variable_scope('folding', reuse=tf.AUTO_REUSE):
            x = tf.linspace(-self.grid_scale, self.grid_scale, self.grid_size)
            y = tf.linspace(-self.grid_scale, self.grid_scale, self.grid_size)
            grid = tf.meshgrid(x, y)
            grid = tf.expand_dims(tf.reshape(tf.stack(grid, axis=2), [-1, 2]),
                                  0)
            grid_feat = tf.tile(grid, [features.shape[0], self.num_coarse, 1])

            point_feat = tf.tile(tf.expand_dims(coarse, 2),
                                 [1, 1, self.grid_size**2, 1])
            point_feat = tf.reshape(point_feat, [-1, self.num_fine, 3 + 12])

            global_feat = tf.tile(tf.expand_dims(features, 1),
                                  [1, self.num_fine, 1])

            feat = tf.concat([grid_feat, point_feat, global_feat], axis=2)

            center = tf.tile(tf.expand_dims(coarse, 2),
                             [1, 1, self.grid_size**2, 1])
            center = tf.reshape(center, [-1, self.num_fine, 3 + 12])

            fine = tf.squeeze(
                mlp_conv(tf.expand_dims(feat, -2),
                         [512, 512, 3 + 12])) + center
        return coarse, fine
    def create_decoder(self, features):
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            coarse = mlp(features, [1024, 1024, self.num_coarse * 3])
            coarse = tf.reshape(coarse, [-1, self.num_coarse, 3])
            print('coarse', coarse)
        with tf.variable_scope('folding', reuse=tf.AUTO_REUSE):
            grid = tf.meshgrid(tf.linspace(-0.05, 0.05, self.grid_size), tf.linspace(-0.05, 0.05, self.grid_size))
            print('grid:', grid)
            grid = tf.expand_dims(tf.reshape(tf.stack(grid, axis=2), [-1, 2]), 0)
            print('grid:', grid)
            grid_feat = tf.tile(grid, [features.shape[0], self.num_coarse, 1])
            print('grid_feat', grid_feat)

            point_feat = tf.tile(tf.expand_dims(coarse, 2), [1, 1, self.grid_size ** 2, 1])
            point_feat = tf.reshape(point_feat, [-1, self.num_fine, 3])
            print('point_feat', point_feat)

            global_feat = tf.tile(tf.expand_dims(features, 1), [1, self.num_fine, 1])

            print('global_feat', global_feat)

            feat = tf.concat([grid_feat, point_feat, global_feat], axis=2)
            print('feat:', feat)

            center = tf.tile(tf.expand_dims(coarse, 2), [1, 1, self.grid_size ** 2, 1])
            center = tf.reshape(center, [-1, self.num_fine, 3])

            print('center', center)

            fine = mlp_conv(feat, [512, 512, 3]) + center
            print('fine:', fine)
        return coarse, fine
def create_decoder(code,
                   inputs,
                   step_ratio,
                   num_extract=512,
                   mean_feature=None):
    with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):

        level0 = tf_util.mlp(code, [1024, 1024, 512 * 3],
                             bn=None,
                             bn_params=None,
                             name='coarse')  # ,name='coarse'
        level0 = tf.tanh(level0)
        level0 = tf.reshape(level0, [-1, 512, 3])
        coarse = level0

        input_fps = symmetric_sample(inputs, int(num_extract / 2))
        level0 = tf.concat([input_fps, level0], 1)
        if num_extract > 512:
            level0 = gather_point(level0, farthest_point_sample(1024, level0))

        for i in range(int(math.log2(step_ratio))):
            num_fine = 2**(i + 1) * 1024
            grid = tf_util.gen_grid_up(2**(i + 1))
            grid = tf.expand_dims(grid, 0)
            grid_feat = tf.tile(grid, [level0.shape[0], 1024, 1])
            point_feat = tf.tile(tf.expand_dims(level0, 2), [1, 1, 2, 1])
            point_feat = tf.reshape(point_feat, [-1, num_fine, 3])
            global_feat = tf.tile(tf.expand_dims(code, 1), [1, num_fine, 1])

            mean_feature_use = tf.contrib.layers.fully_connected(
                mean_feature, 128, activation_fn=tf.nn.relu, scope='mean_fc')
            mean_feature_use = tf.expand_dims(mean_feature_use, 1)
            mean_feature_use = tf.tile(mean_feature_use, [1, num_fine, 1])
            feat = tf.concat(
                [grid_feat, point_feat, global_feat, mean_feature_use], axis=2)

            feat1 = tf_util.mlp_conv(feat, [128, 64],
                                     bn=None,
                                     bn_params=None,
                                     name='up_branch')
            feat1 = tf.nn.relu(feat1)
            feat2 = contract_expand_operation(feat1, 2)
            feat = feat1 + feat2

            fine = tf_util.mlp_conv(
                feat, [512, 512, 3], bn=None, bn_params=None,
                name='fine') + point_feat

            level0 = fine
        return coarse, fine
예제 #7
0
파일: fc.py 프로젝트: mihaibujanca/pcn
 def create_decoder(self, features):
     with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
         outputs = mlp(features, [1024, 1024, self.num_output_points * 3])
         outputs = tf.reshape(outputs, [-1, self.num_output_points, 3])
     return outputs
def train(args):
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    global_step = tf.Variable(0, trainable=False, name='global_step')
    alpha = tf.train.piecewise_constant(
        global_step // (args.gen_iter + args.dis_iter), args.fine_step,
        [0.01, 0.1, 0.5, 1.0], 'alpha_op')
    inputs_pl = tf.placeholder(tf.float32,
                               (args.batch_size, args.num_input_points, 3),
                               'inputs')
    gt_pl = tf.placeholder(tf.float32,
                           (args.batch_size, args.num_gt_points, 3), 'gt')
    complete_feature = tf.placeholder(tf.float32, (args.batch_size, 1024),
                                      'complete_feature')
    complete_feature0 = tf.placeholder(tf.float32, (args.batch_size, 256),
                                       'complete_feature0')
    label_pl = tf.placeholder(tf.int32, shape=(args.batch_size))

    model_module = importlib.import_module('.%s' % args.model_type, 'models')

    file_train = h5py.File(args.h5_train, 'r')
    incomplete_pcds_train = file_train['incomplete_pcds'][()]
    complete_pcds_train = file_train['complete_pcds'][()]
    labels_train = file_train['labels'][()]
    if args.num_gt_points == 2048:
        if args.step_ratio == 2:
            complete_features_train = file_train['complete_feature'][()]
            complete_features_train0 = file_train['complete_feature0'][()]
        elif args.step_ratio == 4:
            complete_features_train = file_train['complete_feature1_4'][()]
            complete_features_train0 = file_train['complete_feature0_4'][()]
        elif args.step_ratio == 8:
            complete_features_train = file_train['complete_feature1_8'][()]
            complete_features_train0 = file_train['complete_feature0_8'][()]
        elif args.step_ratio == 16:
            complete_features_train = file_train['complete_feature1_16'][()]
            complete_features_train0 = file_train['complete_feature0_16'][()]
    elif args.num_gt_points == 16384:
        file_train_feature = h5py.File(
            args.pretrain_complete_decoder + '/train_complete_feature.h5', 'r')
        complete_features_train = file_train_feature['complete_feature1'][()]
        complete_features_train0 = file_train_feature['complete_feature0'][()]
        file_train_feature.close()
    file_train.close()

    assert complete_features_train.shape[0] == complete_features_train0.shape[
        0] == incomplete_pcds_train.shape[0]
    assert complete_features_train.shape[1] == 1024
    assert complete_features_train0.shape[1] == 256

    train_num = incomplete_pcds_train.shape[0]
    BN_DECAY_DECAY_STEP = int(train_num / args.batch_size *
                              args.lr_decay_epochs)

    learning_rate_d = tf.train.exponential_decay(
        args.base_lr_d,
        global_step // (args.gen_iter + args.dis_iter),
        BN_DECAY_DECAY_STEP,
        args.lr_decay_rate,
        staircase=True,
        name='lr_d')
    learning_rate_d = tf.maximum(learning_rate_d, args.lr_clip)

    learning_rate_g = tf.train.exponential_decay(
        args.base_lr_g,
        global_step // (args.gen_iter + args.dis_iter),
        BN_DECAY_DECAY_STEP,
        args.lr_decay_rate,
        staircase=True,
        name='lr_g')
    learning_rate_g = tf.maximum(learning_rate_g, args.lr_clip)

    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        G_optimizers = tf.train.AdamOptimizer(learning_rate_g, beta1=0.9)
        D_optimizers = tf.train.AdamOptimizer(learning_rate_d, beta1=0.5)

    coarse_gpu = []
    fine_gpu = []
    tower_grads_g = []
    tower_grads_d = []
    total_dis_loss_gpu = []
    total_gen_loss_gpu = []
    total_lossReconstruction_gpu = []
    total_lossFeature_gpu = []

    for i in range(NUM_GPUS):
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            with tf.device('/gpu:%d' % (i)), tf.name_scope('gpu_%d' %
                                                           (i)) as scope:
                inputs_pl_batch = tf.slice(inputs_pl,
                                           [int(i * DEVICE_BATCH_SIZE), 0, 0],
                                           [int(DEVICE_BATCH_SIZE), -1, -1])
                gt_pl_batch = tf.slice(gt_pl,
                                       [int(i * DEVICE_BATCH_SIZE), 0, 0],
                                       [int(DEVICE_BATCH_SIZE), -1, -1])
                complete_feature_batch = tf.slice(
                    complete_feature, [int(i * DEVICE_BATCH_SIZE), 0],
                    [int(DEVICE_BATCH_SIZE), -1])
                complete_feature_batch0 = tf.slice(
                    complete_feature0, [int(i * DEVICE_BATCH_SIZE), 0],
                    [int(DEVICE_BATCH_SIZE), -1])

                with tf.variable_scope('generator'):
                    features_partial_0, partial_reconstruct = model_module.encoder(
                        inputs_pl_batch, embed_size=1024)
                    coarse_batch, fine_batch = model_module.decoder(
                        inputs_pl_batch,
                        partial_reconstruct,
                        step_ratio=args.step_ratio,
                        num_fine=args.step_ratio * 1024)

                with tf.variable_scope('discriminator') as dis_scope:
                    errD_fake = mlp(tf.expand_dims(partial_reconstruct,
                                                   axis=1), [16],
                                    bn=None,
                                    bn_params=None)
                    dis_scope.reuse_variables()
                    errD_real = mlp(tf.expand_dims(complete_feature_batch,
                                                   axis=1), [16],
                                    bn=None,
                                    bn_params=None)

                    kernel = getattr(mmd, '_%s_kernel' % args.mmd_kernel)
                    kerGI = kernel(errD_fake[:, 0, :], errD_real[:, 0, :])
                    errG = mmd.mmd2(kerGI)
                    errD = -errG
                    epsilon = tf.random_uniform([], 0.0, 1.0)
                    x_hat = complete_feature_batch * (
                        1 - epsilon) + epsilon * partial_reconstruct
                    d_hat = mlp(tf.expand_dims(x_hat, axis=1), [16],
                                bn=None,
                                bn_params=None)
                    Ekx = lambda yy: tf.reduce_mean(
                        kernel(d_hat[:, 0, :], yy, K_XY_only=True), axis=1)
                    Ekxr, Ekxf = Ekx(errD_real[:, 0, :]), Ekx(errD_fake[:,
                                                                        0, :])
                    witness = Ekxr - Ekxf
                    gradients = tf.gradients(witness, [x_hat])[0]
                    penalty = tf.reduce_mean(
                        tf.square(mmd.safer_norm(gradients, axis=1) - 1.0))
                    errD_loss_batch = penalty * args.gp_weight + errD
                    errG_loss_batch = errG

                feature_loss = tf.reduce_mean(tf.squared_difference(partial_reconstruct,complete_feature_batch))+ \
                               tf.reduce_mean(tf.squared_difference(features_partial_0[:,0,:], complete_feature_batch0))

                dist1_fine, dist2_fine = chamfer(fine_batch, gt_pl_batch)
                dist1_coarse, dist2_coarse = chamfer(coarse_batch, gt_pl_batch)
                total_loss_fine = (tf.reduce_mean(tf.sqrt(dist1_fine)) +
                                   tf.reduce_mean(tf.sqrt(dist2_fine))) / 2
                total_loss_coarse = (tf.reduce_mean(tf.sqrt(dist1_coarse)) +
                                     tf.reduce_mean(tf.sqrt(dist2_coarse))) / 2
                total_loss_rec_batch = alpha * total_loss_fine + total_loss_coarse

                t_vars = tf.global_variables()
                gen_tvars = [
                    var for var in t_vars if var.name.startswith("generator")
                ]
                dis_tvars = [
                    var for var in t_vars
                    if var.name.startswith("discriminator")
                ]
                total_gen_loss_batch = args.feat_weight * feature_loss + errG_loss_batch + args.rec_weight * total_loss_rec_batch
                total_dis_loss_batch = errD_loss_batch

                # Calculate the gradients for the batch of data on this tower.
                grads_g = G_optimizers.compute_gradients(total_gen_loss_batch,
                                                         var_list=gen_tvars)
                grads_d = D_optimizers.compute_gradients(total_dis_loss_batch,
                                                         var_list=dis_tvars)

                # Keep track of the gradients across all towers.
                tower_grads_g.append(grads_g)
                tower_grads_d.append(grads_d)

                coarse_gpu.append(coarse_batch)
                fine_gpu.append(fine_batch)

                total_dis_loss_gpu.append(total_dis_loss_batch)
                total_gen_loss_gpu.append(errG_loss_batch)
                total_lossReconstruction_gpu.append(args.rec_weight *
                                                    total_loss_rec_batch)
                total_lossFeature_gpu.append(args.feat_weight * feature_loss)

    fine = tf.concat(fine_gpu, 0)

    grads_g = average_gradients(tower_grads_g)
    grads_d = average_gradients(tower_grads_d)

    # apply the gradients with our optimizers
    train_G = G_optimizers.apply_gradients(grads_g, global_step=global_step)
    train_D = D_optimizers.apply_gradients(grads_d, global_step=global_step)

    total_dis_loss = tf.reduce_mean(tf.stack(total_dis_loss_gpu, 0))
    total_gen_loss = tf.reduce_mean(tf.stack(total_gen_loss_gpu, 0))
    total_loss_rec = tf.reduce_mean(tf.stack(total_lossReconstruction_gpu, 0))
    total_loss_fea = tf.reduce_mean(tf.stack(total_lossFeature_gpu, 0))

    dist1_eval, dist2_eval = chamfer(fine, gt_pl)

    file_validate = h5py.File(args.h5_validate, 'r')
    incomplete_pcds_validate = file_validate['incomplete_pcds'][()]
    complete_pcds_validate = file_validate['complete_pcds'][()]
    labels_validate = file_validate['labels'][()]
    file_validate.close()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    sess = tf.Session(config=config)
    saver = tf.train.Saver(max_to_keep=3)
    sess.run(tf.global_variables_initializer())

    saver_decoder = tf.train.Saver(var_list=[var for var in tf.global_variables() if (var.name.startswith("generator/decoder") \
                                                                                      or var.name.startswith("generator/folding"))])
    saver_decoder.restore(sess, args.pretrain_complete_decoder)

    if args.restore:
        saver.restore(sess, tf.train.latest_checkpoint(args.log_dir))

    init_step = sess.run(global_step) // (args.gen_iter + args.dis_iter)
    epoch = init_step * args.batch_size // train_num + 1
    print('init_step:%d,' % init_step, 'epoch:%d' % epoch,
          'training data number:%d' % train_num)
    train_idxs = np.arange(0, train_num)

    for ep_cnt in range(epoch, args.max_epoch + 1):
        num_batches = train_num // args.batch_size
        np.random.shuffle(train_idxs)

        for batch_idx in range(num_batches):
            init_step += 1
            start_idx = batch_idx * args.batch_size
            end_idx = min(start_idx + args.batch_size, train_num)
            ids_train = list(np.sort(train_idxs[start_idx:end_idx]))
            batch_data = incomplete_pcds_train[ids_train]
            batch_gt = complete_pcds_train[ids_train]
            labels = labels_train[ids_train]
            # partial_feature_input=incomplete_features_train[ids_train]
            complete_feature_input = complete_features_train[ids_train]
            complete_feature_input0 = complete_features_train0[ids_train]

            feed_dict = {
                inputs_pl: batch_data,
                gt_pl: batch_gt,
                is_training_pl: True,
                label_pl: labels,
                complete_feature: complete_feature_input,
                complete_feature0: complete_feature_input0
            }

            for i in range(args.dis_iter):
                _, loss_dis = sess.run([train_D, total_dis_loss],
                                       feed_dict=feed_dict)
            for i in range(args.gen_iter):
                _, loss_gen, rec_loss, fea_loss = sess.run(
                    [train_G, total_gen_loss, total_loss_rec, total_loss_fea],
                    feed_dict=feed_dict)

            if init_step % args.steps_per_print == 0:
                print('epoch %d step %d gen_loss %.8f rec_loss %.8f fea_loss %.8f dis_loss %.8f' % \
                    (ep_cnt, init_step, loss_gen,rec_loss,fea_loss,loss_dis))

            if init_step % args.steps_per_eval == 0:
                total_loss = 0
                sess.run(tf.local_variables_initializer())
                batch_data = np.zeros(
                    (args.batch_size, incomplete_pcds_validate[0].shape[0], 3),
                    'f')
                batch_gt = np.zeros((args.batch_size, args.num_gt_points, 3),
                                    'f')
                # partial_feature_input = np.zeros((args.batch_size, 1024), 'f')
                labels = np.zeros((args.batch_size, ), dtype=np.int32)
                feature_complete_input = np.zeros(
                    (args.batch_size, 1024)).astype(np.float32)
                feature_complete_input0 = np.zeros(
                    (args.batch_size, 256)).astype(np.float32)
                for batch_idx_eval in range(0,
                                            incomplete_pcds_validate.shape[0],
                                            args.batch_size):
                    # start = time.time()
                    start_idx = batch_idx_eval
                    end_idx = min(start_idx + args.batch_size,
                                  incomplete_pcds_validate.shape[0])

                    batch_data[0:end_idx -
                               start_idx] = incomplete_pcds_validate[
                                   start_idx:end_idx]
                    batch_gt[0:end_idx - start_idx] = complete_pcds_validate[
                        start_idx:end_idx]
                    labels[0:end_idx -
                           start_idx] = labels_validate[start_idx:end_idx]
                    feature_complete_input[
                        0:end_idx -
                        start_idx] = complete_features_train[start_idx:end_idx]
                    feature_complete_input0[
                        0:end_idx - start_idx] = complete_features_train0[
                            start_idx:end_idx]

                    feed_dict = {
                        inputs_pl: batch_data,
                        gt_pl: batch_gt,
                        is_training_pl: False,
                        label_pl: labels,
                        complete_feature: feature_complete_input,
                        complete_feature0: feature_complete_input0
                    }
                    dist1_out, dist2_out = sess.run([dist1_eval, dist2_eval],
                                                    feed_dict=feed_dict)
                    if args.loss_type == 'cd_1':
                        total_loss += np.mean(dist1_out[0:end_idx - start_idx]) * (end_idx - start_idx) \
                                      + np.mean(dist2_out[0:end_idx - start_idx]) * (end_idx - start_idx)
                    elif args.loss_type == 'cd_2':
                        total_loss += (np.mean(np.sqrt(dist1_out[0:end_idx - start_idx])) * (end_idx - start_idx) \
                                       + np.mean(np.sqrt(dist2_out[0:end_idx - start_idx])) * (end_idx - start_idx)) / 2

                if total_loss / incomplete_pcds_validate.shape[
                        0] < args.best_loss:
                    args.best_loss = total_loss / incomplete_pcds_validate.shape[
                        0]
                    saver.save(sess, os.path.join(args.log_dir, 'model'),
                               init_step)

                print('epoch %d  step %d  loss %.8f best_loss %.8f' %
                      (ep_cnt, init_step, total_loss /
                       incomplete_pcds_validate.shape[0], args.best_loss))
    sess.close()