def create_loss(self, coarse, fine, gt, alpha): loss_coarse = chamfer(coarse, gt) add_train_summary('train/coarse_loss', loss_coarse) update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse) loss_fine = chamfer(fine, gt) add_train_summary('train/fine_loss', loss_fine) update_fine = add_valid_summary('valid/fine_loss', loss_fine) loss = loss_coarse + alpha * loss_fine add_train_summary('train/loss', loss) update_loss = add_valid_summary('valid/loss', loss) return loss, [update_coarse, update_fine, update_loss]
def create_loss(self, coarse_highres, coarse, fine, gt, theta): loss_coarse_highres = chamfer(coarse_highres, gt) loss_coarse = chamfer(coarse, gt) add_train_summary('train/coarse_loss', loss_coarse) update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse) loss_fine = chamfer(fine, gt) add_train_summary('train/fine_loss', loss_fine) update_fine = add_valid_summary('valid/fine_loss', loss_fine) repulsion_loss = get_repulsion_loss4(coarse) loss = 0.5 * loss_coarse_highres + loss_coarse + theta * loss_fine + 0.2 * repulsion_loss add_train_summary('train/loss', loss) update_loss = add_valid_summary('valid/loss', loss) return loss, loss_fine, [update_coarse, update_fine, update_loss]
def create_loss(self, coarse, fine, gt, alpha): loss_coarse = chamfer(coarse[:, :, 0:3], gt[:, :, 0:3]) _, retb, _, retd = tf_nndistance.nn_distance(coarse[:, :, 0:3], gt[:, :, 0:3]) for i in range(np.shape(gt)[0]): index = tf.expand_dims(retb[i], -1) sem_feat = tf.nn.softmax(coarse[i, :, 3:], -1) sem_gt = tf.cast( tf.one_hot( tf.gather_nd(tf.cast(gt[i, :, 3] * 80 * 12, tf.int32), index), 12), tf.float32) loss_sem_coarse = tf.reduce_mean(-tf.reduce_sum( 0.9 * sem_gt * tf.log(1e-6 + sem_feat) + (1 - 0.9) * (1 - sem_gt) * tf.log(1e-6 + 1 - sem_feat), [1])) loss_coarse += loss_sem_coarse add_train_summary('train/coarse_loss', loss_coarse) update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse) loss_fine = chamfer(fine[:, :, 0:3], gt[:, :, 0:3]) _, retb, _, retd = tf_nndistance.nn_distance(fine[:, :, 0:3], gt[:, :, 0:3]) for i in range(np.shape(gt)[0]): index = tf.expand_dims(retb[i], -1) sem_feat = tf.nn.softmax(fine[i, :, 3:], -1) sem_gt = tf.cast( tf.one_hot( tf.gather_nd(tf.cast(gt[i, :, 3] * 80 * 12, tf.int32), index), 12), tf.float32) loss_sem_fine = tf.reduce_mean(-tf.reduce_sum( 0.9 * sem_gt * tf.log(1e-6 + sem_feat) + (1 - 0.9) * (1 - sem_gt) * tf.log(1e-6 + 1 - sem_feat), [1])) loss_fine += loss_sem_fine add_train_summary('train/fine_loss', loss_fine) update_fine = add_valid_summary('valid/fine_loss', loss_fine) loss = loss_coarse + alpha * loss_fine add_train_summary('train/loss', loss) update_loss = add_valid_summary('valid/loss', loss) return loss, [update_coarse, update_fine, update_loss]
def create_loss(self, coarse, fine, gt, alpha): gt_ds = gt[:, :coarse.shape[1], :] loss_coarse = earth_mover(coarse, gt_ds) add_train_summary('train/coarse_loss', loss_coarse) update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse) loss_fine = chamfer(fine, gt) add_train_summary('train/fine_loss', loss_fine) update_fine = add_valid_summary('valid/fine_loss', loss_fine) loss = loss_coarse + alpha * loss_fine add_train_summary('train/loss', loss) update_loss = add_valid_summary('valid/loss', loss) return loss, [update_coarse, update_fine, update_loss]
def create_loss(self, outputs, gt): loss = chamfer(outputs, gt) add_train_summary('train/loss', loss) update_loss = add_valid_summary('valid/loss', loss) return loss, update_loss
def test(args): inputs = tf.placeholder(tf.float32, (1, None, 3)) gt = tf.placeholder(tf.float32, (1, args.num_gt_points, 3)) model_module = importlib.import_module('.%s' % args.model_type, 'models') model = model_module.Model(inputs, gt, tf.constant(1.0)) output = tf.placeholder(tf.float32, (1, args.num_gt_points, 3)) cd_op = chamfer(output, gt) emd_op = earth_mover(output, gt) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True sess = tf.Session(config=config) saver = tf.train.Saver() saver.restore(sess, args.checkpoint) os.makedirs(args.results_dir, exist_ok=True) csv_file = open(os.path.join(args.results_dir, 'results.csv'), 'w') writer = csv.writer(csv_file) writer.writerow(['id', 'cd', 'emd']) with open(args.list_path) as file: model_list = file.read().splitlines() total_time = 0 total_cd = 0 total_emd = 0 cd_per_cat = {} emd_per_cat = {} for i, model_id in enumerate(model_list): partial = read_pcd( os.path.join(args.data_dir, 'partial', '%s.pcd' % model_id)) complete = read_pcd( os.path.join(args.data_dir, 'complete', '%s.pcd' % model_id)) start = time.time() completion = sess.run(model.outputs, feed_dict={inputs: [partial]}) total_time += time.time() - start cd, emd = sess.run([cd_op, emd_op], feed_dict={ output: completion, gt: [complete] }) total_cd += cd total_emd += emd writer.writerow([model_id, cd, emd]) synset_id, model_id = model_id.split('/') if not cd_per_cat.get(synset_id): cd_per_cat[synset_id] = [] if not emd_per_cat.get(synset_id): emd_per_cat[synset_id] = [] cd_per_cat[synset_id].append(cd) emd_per_cat[synset_id].append(emd) if i % args.plot_freq == 0: os.makedirs(os.path.join(args.results_dir, 'plots', synset_id), exist_ok=True) plot_path = os.path.join(args.results_dir, 'plots', synset_id, '%s.png' % model_id) plot_pcd_three_views(plot_path, [partial, completion[0], complete], ['input', 'output', 'ground truth'], 'CD %.4f EMD %.4f' % (cd, emd), [5, 0.5, 0.5]) if args.save_pcd: os.makedirs(os.path.join(args.results_dir, 'pcds', synset_id), exist_ok=True) save_pcd( os.path.join(args.results_dir, 'pcds', '%s.pcd' % model_id), completion[0]) csv_file.close() sess.close() print('Average time: %f' % (total_time / len(model_list))) print('Average Chamfer distance: %f' % (total_cd / len(model_list))) print('Average Earth mover distance: %f' % (total_emd / len(model_list))) print('Chamfer distance per category') for synset_id in cd_per_cat.keys(): print(synset_id, '%f' % np.mean(cd_per_cat[synset_id])) print('Earth mover distance per category') for synset_id in emd_per_cat.keys(): print(synset_id, '%f' % np.mean(emd_per_cat[synset_id]))
def test(args): inputs = tf.placeholder(tf.float32, (1, None, 3)) npts = tf.placeholder(tf.int32, (1, )) gt = tf.placeholder(tf.float32, (1, args.num_gt_points, 6)) model_module = importlib.import_module('.%s' % args.model_type, 'models') model = model_module.Model(inputs, npts, gt, tf.constant(1.0), args.num_channel) output = tf.placeholder(tf.float32, (1, args.num_gt_points, 3 + args.num_channel)) cd_op = chamfer(output[:, :, 0:3], gt[:, :, 0:3]) emd_op = earth_mover(output[:, :, 0:3], gt[:, :, 0:3]) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True sess = tf.Session(config=config) saver = tf.train.Saver() saver.restore(sess, args.checkpoint) os.makedirs(args.results_dir, exist_ok=True) csv_file = open(os.path.join(args.results_dir, 'results.csv'), 'w') writer = csv.writer(csv_file) writer.writerow(['id', 'cd', 'emd']) with open(args.list_path) as file: model_list = file.read().splitlines() total_time = 0 total_cd = 0 total_emd = 0 cd_per_cat = {} emd_per_cat = {} np.random.seed(1) for i, model_id in enumerate(model_list): if args.experiment == 'shapenet': synset_id, model_id = model_id.split('/') partial = read_pcd( os.path.join(args.data_dir, 'partial', synset_id, '%s.pcd' % model_id)) complete = read_pcd( os.path.join(args.data_dir, 'complete', synset_id, '%s.pcd' % model_id)) elif args.experiment == 'suncg': synset_id = 'all_rooms' partial = read_pcd( os.path.join(args.data_dir, 'pcd_partial', '%s.pcd' % model_id)) complete = read_pcd( os.path.join(args.data_dir, 'pcd_complete', '%s.pcd' % model_id)) if args.rotate: angle = np.random.rand(1) * 2 * np.pi partial = np.stack([ np.cos(angle) * partial[:, 0] - np.sin(angle) * partial[:, 2], partial[:, 1], np.sin(angle) * partial[:, 0] + np.cos(angle) * partial[:, 2] ], axis=-1) complete = np.stack([ np.cos(angle) * complete[:, 0] - np.sin(angle) * complete[:, 2], complete[:, 1], np.sin(angle) * complete[:, 0] + np.cos(angle) * complete[:, 2], complete[:, 3], complete[:, 4], complete[:, 5] ], axis=-1) partial = partial[:, :3] complete = resample_pcd(complete, 16384) start = time.time() completion1, completion2, mesh_out = sess.run( [model.outputs1, model.outputs2, model.gt_can], feed_dict={ inputs: [partial], npts: [partial.shape[0]], gt: [complete] }) completion1[0][:, (3 + args.num_channel):] *= 0 completion2[0][:, (3 + args.num_channel):] *= 0 mesh_out[0][:, (3 + args.num_channel):] *= 0 total_time += time.time() - start # cd, emd = sess.run([cd_op, emd_op], cd, emd = sess.run([cd_op, cd_op], feed_dict={ output: completion2, gt: [complete] }) total_cd += cd total_emd += emd if not cd_per_cat.get(synset_id): cd_per_cat[synset_id] = [] if not emd_per_cat.get(synset_id): emd_per_cat[synset_id] = [] cd_per_cat[synset_id].append(cd) emd_per_cat[synset_id].append(emd) writer.writerow([model_id, cd, emd]) if i % args.plot_freq == 0: os.makedirs( os.path.join(args.results_dir, 'plots', synset_id), exist_ok=True) plot_path = os.path.join(args.results_dir, 'plots', synset_id, '%s.png' % model_id) plot_pcd_three_views( plot_path, [ partial, completion1[0], completion2[0], mesh_out[0], complete ], ['input', 'coarse', 'fine', 'mesh', 'ground truth'], 'CD %.4f EMD %.4f' % (cd, emd), [5, 0.5, 0.5, 0.5, 0.5], num_channel=args.num_channel) if args.save_pcd: os.makedirs( os.path.join(args.results_dir, 'input', synset_id), exist_ok=True) pts_coord = partial[:, 0:3] pts_color = matplotlib.cm.cool((partial[:, 1]))[:, 0:3] # save_pcd(os.path.join(args.results_dir, 'input', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1)) pcd = PointCloud() pcd.points = Vector3dVector(pts_coord) pcd.colors = Vector3dVector(pts_color) write_point_cloud( os.path.join(args.results_dir, 'input', synset_id, '%s.ply' % model_id), pcd, write_ascii=True) os.makedirs( os.path.join(args.results_dir, 'output1', synset_id), exist_ok=True) pts_coord = completion1[0][:, 0:3] pts_color = matplotlib.cm.Set1( (np.argmax(completion1[0][:, 3:3 + args.num_channel], -1) + 1) / args.num_channel - 0.5 / args.num_channel)[:, 0:3] # pts_color = matplotlib.cm.tab20((np.argmax(completion1[0][:, 3:3+args.num_channel], -1) + 1)/args.num_channel - 0.5/args.num_channel)[:,0:3] # save_pcd(os.path.join(args.results_dir, 'output1', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1)) pcd.points = Vector3dVector(pts_coord) pcd.colors = Vector3dVector(pts_color) write_point_cloud( os.path.join(args.results_dir, 'output1', synset_id, '%s.ply' % model_id), pcd, write_ascii=True) os.makedirs( os.path.join(args.results_dir, 'output2', synset_id), exist_ok=True) pts_coord = completion2[0][:, 0:3] pts_color = matplotlib.cm.Set1( (np.argmax(completion2[0][:, 3:3 + args.num_channel], -1) + 1) / args.num_channel - 0.5 / args.num_channel)[:, 0:3] # pts_color = matplotlib.cm.tab20((np.argmax(completion2[0][:, 3:3+args.num_channel], -1) + 1)/args.num_channel - 0.5/args.num_channel)[:,0:3] # save_pcd(os.path.join(args.results_dir, 'output2', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1)) pcd.points = Vector3dVector(pts_coord) pcd.colors = Vector3dVector(pts_color) write_point_cloud( os.path.join(args.results_dir, 'output2', synset_id, '%s.ply' % model_id), pcd, write_ascii=True) ####### os.makedirs( os.path.join(args.results_dir, 'regions', synset_id), exist_ok=True) for idx in range(3, 3 + args.num_channel): val_min = np.min(completion2[0][:, idx]) val_max = np.max(completion2[0][:, idx]) pts_color = 0.8 * matplotlib.cm.Reds( (completion2[0][:, idx] - val_min) / (val_max - val_min))[:, 0:3] pts_color += 0.2 * matplotlib.cm.gist_gray( (completion2[0][:, idx] - val_min) / (val_max - val_min))[:, 0:3] pcd.colors = Vector3dVector(pts_color) write_point_cloud( os.path.join(args.results_dir, 'regions', synset_id, '%s_%s.ply' % (model_id, idx - 3)), pcd, write_ascii=True) os.makedirs( os.path.join(args.results_dir, 'gt', synset_id), exist_ok=True) pts_coord = complete[:, 0:3] if args.experiment == 'shapenet': pts_color = matplotlib.cm.cool(complete[:, 1])[:, 0:3] elif args.experiment == 'suncg': pts_color = matplotlib.cm.Set1(complete[:, 3] - 0.5 / args.num_channel)[:, 0:3] # save_pcd(os.path.join(args.results_dir, 'gt', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1)) pcd.points = Vector3dVector(pts_coord) pcd.colors = Vector3dVector(pts_color) write_point_cloud( os.path.join(args.results_dir, 'gt', synset_id, '%s.ply' % model_id), pcd, write_ascii=True) sess.close() print('Average time: %f' % (total_time / len(model_list))) print('Average Chamfer distance: %f' % (total_cd / len(model_list))) print('Average Earth mover distance: %f' % (total_emd / len(model_list))) writer.writerow([ total_time / len(model_list), total_cd / len(model_list), total_emd / len(model_list) ]) print('Chamfer distance per category') for synset_id in cd_per_cat.keys(): print(synset_id, '%f' % np.mean(cd_per_cat[synset_id])) writer.writerow([synset_id, np.mean(cd_per_cat[synset_id])]) print('Earth mover distance per category') for synset_id in emd_per_cat.keys(): print(synset_id, '%f' % np.mean(emd_per_cat[synset_id])) writer.writerow([synset_id, np.mean(emd_per_cat[synset_id])]) csv_file.close()
def train(args): is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training') global_step = tf.Variable(0, trainable=False, name='global_step') alpha = tf.train.piecewise_constant( global_step // (args.gen_iter + args.dis_iter), args.fine_step, [0.01, 0.1, 0.5, 1.0], 'alpha_op') inputs_pl = tf.placeholder(tf.float32, (args.batch_size, args.num_input_points, 3), 'inputs') gt_pl = tf.placeholder(tf.float32, (args.batch_size, args.num_gt_points, 3), 'gt') complete_feature = tf.placeholder(tf.float32, (args.batch_size, 1024), 'complete_feature') complete_feature0 = tf.placeholder(tf.float32, (args.batch_size, 256), 'complete_feature0') label_pl = tf.placeholder(tf.int32, shape=(args.batch_size)) model_module = importlib.import_module('.%s' % args.model_type, 'models') file_train = h5py.File(args.h5_train, 'r') incomplete_pcds_train = file_train['incomplete_pcds'][()] complete_pcds_train = file_train['complete_pcds'][()] labels_train = file_train['labels'][()] if args.num_gt_points == 2048: if args.step_ratio == 2: complete_features_train = file_train['complete_feature'][()] complete_features_train0 = file_train['complete_feature0'][()] elif args.step_ratio == 4: complete_features_train = file_train['complete_feature1_4'][()] complete_features_train0 = file_train['complete_feature0_4'][()] elif args.step_ratio == 8: complete_features_train = file_train['complete_feature1_8'][()] complete_features_train0 = file_train['complete_feature0_8'][()] elif args.step_ratio == 16: complete_features_train = file_train['complete_feature1_16'][()] complete_features_train0 = file_train['complete_feature0_16'][()] elif args.num_gt_points == 16384: file_train_feature = h5py.File( args.pretrain_complete_decoder + '/train_complete_feature.h5', 'r') complete_features_train = file_train_feature['complete_feature1'][()] complete_features_train0 = file_train_feature['complete_feature0'][()] file_train_feature.close() file_train.close() assert complete_features_train.shape[0] == complete_features_train0.shape[ 0] == incomplete_pcds_train.shape[0] assert complete_features_train.shape[1] == 1024 assert complete_features_train0.shape[1] == 256 train_num = incomplete_pcds_train.shape[0] BN_DECAY_DECAY_STEP = int(train_num / args.batch_size * args.lr_decay_epochs) learning_rate_d = tf.train.exponential_decay( args.base_lr_d, global_step // (args.gen_iter + args.dis_iter), BN_DECAY_DECAY_STEP, args.lr_decay_rate, staircase=True, name='lr_d') learning_rate_d = tf.maximum(learning_rate_d, args.lr_clip) learning_rate_g = tf.train.exponential_decay( args.base_lr_g, global_step // (args.gen_iter + args.dis_iter), BN_DECAY_DECAY_STEP, args.lr_decay_rate, staircase=True, name='lr_g') learning_rate_g = tf.maximum(learning_rate_g, args.lr_clip) with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): G_optimizers = tf.train.AdamOptimizer(learning_rate_g, beta1=0.9) D_optimizers = tf.train.AdamOptimizer(learning_rate_d, beta1=0.5) coarse_gpu = [] fine_gpu = [] tower_grads_g = [] tower_grads_d = [] total_dis_loss_gpu = [] total_gen_loss_gpu = [] total_lossReconstruction_gpu = [] total_lossFeature_gpu = [] for i in range(NUM_GPUS): with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): with tf.device('/gpu:%d' % (i)), tf.name_scope('gpu_%d' % (i)) as scope: inputs_pl_batch = tf.slice(inputs_pl, [int(i * DEVICE_BATCH_SIZE), 0, 0], [int(DEVICE_BATCH_SIZE), -1, -1]) gt_pl_batch = tf.slice(gt_pl, [int(i * DEVICE_BATCH_SIZE), 0, 0], [int(DEVICE_BATCH_SIZE), -1, -1]) complete_feature_batch = tf.slice( complete_feature, [int(i * DEVICE_BATCH_SIZE), 0], [int(DEVICE_BATCH_SIZE), -1]) complete_feature_batch0 = tf.slice( complete_feature0, [int(i * DEVICE_BATCH_SIZE), 0], [int(DEVICE_BATCH_SIZE), -1]) with tf.variable_scope('generator'): features_partial_0, partial_reconstruct = model_module.encoder( inputs_pl_batch, embed_size=1024) coarse_batch, fine_batch = model_module.decoder( inputs_pl_batch, partial_reconstruct, step_ratio=args.step_ratio, num_fine=args.step_ratio * 1024) with tf.variable_scope('discriminator') as dis_scope: errD_fake = mlp(tf.expand_dims(partial_reconstruct, axis=1), [16], bn=None, bn_params=None) dis_scope.reuse_variables() errD_real = mlp(tf.expand_dims(complete_feature_batch, axis=1), [16], bn=None, bn_params=None) kernel = getattr(mmd, '_%s_kernel' % args.mmd_kernel) kerGI = kernel(errD_fake[:, 0, :], errD_real[:, 0, :]) errG = mmd.mmd2(kerGI) errD = -errG epsilon = tf.random_uniform([], 0.0, 1.0) x_hat = complete_feature_batch * ( 1 - epsilon) + epsilon * partial_reconstruct d_hat = mlp(tf.expand_dims(x_hat, axis=1), [16], bn=None, bn_params=None) Ekx = lambda yy: tf.reduce_mean( kernel(d_hat[:, 0, :], yy, K_XY_only=True), axis=1) Ekxr, Ekxf = Ekx(errD_real[:, 0, :]), Ekx(errD_fake[:, 0, :]) witness = Ekxr - Ekxf gradients = tf.gradients(witness, [x_hat])[0] penalty = tf.reduce_mean( tf.square(mmd.safer_norm(gradients, axis=1) - 1.0)) errD_loss_batch = penalty * args.gp_weight + errD errG_loss_batch = errG feature_loss = tf.reduce_mean(tf.squared_difference(partial_reconstruct,complete_feature_batch))+ \ tf.reduce_mean(tf.squared_difference(features_partial_0[:,0,:], complete_feature_batch0)) dist1_fine, dist2_fine = chamfer(fine_batch, gt_pl_batch) dist1_coarse, dist2_coarse = chamfer(coarse_batch, gt_pl_batch) total_loss_fine = (tf.reduce_mean(tf.sqrt(dist1_fine)) + tf.reduce_mean(tf.sqrt(dist2_fine))) / 2 total_loss_coarse = (tf.reduce_mean(tf.sqrt(dist1_coarse)) + tf.reduce_mean(tf.sqrt(dist2_coarse))) / 2 total_loss_rec_batch = alpha * total_loss_fine + total_loss_coarse t_vars = tf.global_variables() gen_tvars = [ var for var in t_vars if var.name.startswith("generator") ] dis_tvars = [ var for var in t_vars if var.name.startswith("discriminator") ] total_gen_loss_batch = args.feat_weight * feature_loss + errG_loss_batch + args.rec_weight * total_loss_rec_batch total_dis_loss_batch = errD_loss_batch # Calculate the gradients for the batch of data on this tower. grads_g = G_optimizers.compute_gradients(total_gen_loss_batch, var_list=gen_tvars) grads_d = D_optimizers.compute_gradients(total_dis_loss_batch, var_list=dis_tvars) # Keep track of the gradients across all towers. tower_grads_g.append(grads_g) tower_grads_d.append(grads_d) coarse_gpu.append(coarse_batch) fine_gpu.append(fine_batch) total_dis_loss_gpu.append(total_dis_loss_batch) total_gen_loss_gpu.append(errG_loss_batch) total_lossReconstruction_gpu.append(args.rec_weight * total_loss_rec_batch) total_lossFeature_gpu.append(args.feat_weight * feature_loss) fine = tf.concat(fine_gpu, 0) grads_g = average_gradients(tower_grads_g) grads_d = average_gradients(tower_grads_d) # apply the gradients with our optimizers train_G = G_optimizers.apply_gradients(grads_g, global_step=global_step) train_D = D_optimizers.apply_gradients(grads_d, global_step=global_step) total_dis_loss = tf.reduce_mean(tf.stack(total_dis_loss_gpu, 0)) total_gen_loss = tf.reduce_mean(tf.stack(total_gen_loss_gpu, 0)) total_loss_rec = tf.reduce_mean(tf.stack(total_lossReconstruction_gpu, 0)) total_loss_fea = tf.reduce_mean(tf.stack(total_lossFeature_gpu, 0)) dist1_eval, dist2_eval = chamfer(fine, gt_pl) file_validate = h5py.File(args.h5_validate, 'r') incomplete_pcds_validate = file_validate['incomplete_pcds'][()] complete_pcds_validate = file_validate['complete_pcds'][()] labels_validate = file_validate['labels'][()] file_validate.close() config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True config.log_device_placement = False sess = tf.Session(config=config) saver = tf.train.Saver(max_to_keep=3) sess.run(tf.global_variables_initializer()) saver_decoder = tf.train.Saver(var_list=[var for var in tf.global_variables() if (var.name.startswith("generator/decoder") \ or var.name.startswith("generator/folding"))]) saver_decoder.restore(sess, args.pretrain_complete_decoder) if args.restore: saver.restore(sess, tf.train.latest_checkpoint(args.log_dir)) init_step = sess.run(global_step) // (args.gen_iter + args.dis_iter) epoch = init_step * args.batch_size // train_num + 1 print('init_step:%d,' % init_step, 'epoch:%d' % epoch, 'training data number:%d' % train_num) train_idxs = np.arange(0, train_num) for ep_cnt in range(epoch, args.max_epoch + 1): num_batches = train_num // args.batch_size np.random.shuffle(train_idxs) for batch_idx in range(num_batches): init_step += 1 start_idx = batch_idx * args.batch_size end_idx = min(start_idx + args.batch_size, train_num) ids_train = list(np.sort(train_idxs[start_idx:end_idx])) batch_data = incomplete_pcds_train[ids_train] batch_gt = complete_pcds_train[ids_train] labels = labels_train[ids_train] # partial_feature_input=incomplete_features_train[ids_train] complete_feature_input = complete_features_train[ids_train] complete_feature_input0 = complete_features_train0[ids_train] feed_dict = { inputs_pl: batch_data, gt_pl: batch_gt, is_training_pl: True, label_pl: labels, complete_feature: complete_feature_input, complete_feature0: complete_feature_input0 } for i in range(args.dis_iter): _, loss_dis = sess.run([train_D, total_dis_loss], feed_dict=feed_dict) for i in range(args.gen_iter): _, loss_gen, rec_loss, fea_loss = sess.run( [train_G, total_gen_loss, total_loss_rec, total_loss_fea], feed_dict=feed_dict) if init_step % args.steps_per_print == 0: print('epoch %d step %d gen_loss %.8f rec_loss %.8f fea_loss %.8f dis_loss %.8f' % \ (ep_cnt, init_step, loss_gen,rec_loss,fea_loss,loss_dis)) if init_step % args.steps_per_eval == 0: total_loss = 0 sess.run(tf.local_variables_initializer()) batch_data = np.zeros( (args.batch_size, incomplete_pcds_validate[0].shape[0], 3), 'f') batch_gt = np.zeros((args.batch_size, args.num_gt_points, 3), 'f') # partial_feature_input = np.zeros((args.batch_size, 1024), 'f') labels = np.zeros((args.batch_size, ), dtype=np.int32) feature_complete_input = np.zeros( (args.batch_size, 1024)).astype(np.float32) feature_complete_input0 = np.zeros( (args.batch_size, 256)).astype(np.float32) for batch_idx_eval in range(0, incomplete_pcds_validate.shape[0], args.batch_size): # start = time.time() start_idx = batch_idx_eval end_idx = min(start_idx + args.batch_size, incomplete_pcds_validate.shape[0]) batch_data[0:end_idx - start_idx] = incomplete_pcds_validate[ start_idx:end_idx] batch_gt[0:end_idx - start_idx] = complete_pcds_validate[ start_idx:end_idx] labels[0:end_idx - start_idx] = labels_validate[start_idx:end_idx] feature_complete_input[ 0:end_idx - start_idx] = complete_features_train[start_idx:end_idx] feature_complete_input0[ 0:end_idx - start_idx] = complete_features_train0[ start_idx:end_idx] feed_dict = { inputs_pl: batch_data, gt_pl: batch_gt, is_training_pl: False, label_pl: labels, complete_feature: feature_complete_input, complete_feature0: feature_complete_input0 } dist1_out, dist2_out = sess.run([dist1_eval, dist2_eval], feed_dict=feed_dict) if args.loss_type == 'cd_1': total_loss += np.mean(dist1_out[0:end_idx - start_idx]) * (end_idx - start_idx) \ + np.mean(dist2_out[0:end_idx - start_idx]) * (end_idx - start_idx) elif args.loss_type == 'cd_2': total_loss += (np.mean(np.sqrt(dist1_out[0:end_idx - start_idx])) * (end_idx - start_idx) \ + np.mean(np.sqrt(dist2_out[0:end_idx - start_idx])) * (end_idx - start_idx)) / 2 if total_loss / incomplete_pcds_validate.shape[ 0] < args.best_loss: args.best_loss = total_loss / incomplete_pcds_validate.shape[ 0] saver.save(sess, os.path.join(args.log_dir, 'model'), init_step) print('epoch %d step %d loss %.8f best_loss %.8f' % (ep_cnt, init_step, total_loss / incomplete_pcds_validate.shape[0], args.best_loss)) sess.close()
def test(args): inputs = tf.placeholder(tf.float32, (1, 2048, 3)) gt = tf.placeholder(tf.float32, (1, args.num_gt_points, 3)) reconstruction = tf.placeholder(tf.float32, (1, 1024 * args.step_ratio, 3)) is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training') model_module = importlib.import_module('.%s' % args.model_type, 'models') with tf.variable_scope('generator', reuse=tf.AUTO_REUSE): _, features_partial = model_module.encoder(inputs) coarse, fine = model_module.decoder(inputs, features_partial, args.step_ratio, args.step_ratio * 1024) dist1_fine, dist2_fine = chamfer(reconstruction, gt) if args.loss_type == 'cd_1': loss = tf.reduce_mean(dist1_fine) + tf.reduce_mean(dist2_fine) elif args.loss_type == 'cd_2': loss = (tf.reduce_mean(tf.sqrt(dist1_fine)) + tf.reduce_mean(tf.sqrt(dist2_fine))) / 2 config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True sess = tf.Session(config=config) saver = tf.train.Saver() saver.restore(sess, os.path.join(args.checkpoint)) data_all = h5py.File(args.data_dir, 'r') partial_all = data_all['incomplete_pcds'][()] complete_all = data_all['complete_pcds'][()] model_list = data_all['labels'][()].astype(int) data_all.close() cd_per_cat = {} total_cd = 0 for i, model_id in enumerate(model_list): partial = partial_all[i] complete = complete_all[i] label = model_list[i] completion = sess.run(fine, feed_dict={ inputs: [partial], is_training_pl: False }) cd = sess.run(loss, feed_dict={ reconstruction: completion, gt: [complete], is_training_pl: False }) total_cd += cd category = objects[label] key_list = list(snc_synth_id_to_category.keys()) val_list = list(snc_synth_id_to_category.values()) synset_id = key_list[val_list.index(category)] if not cd_per_cat.get(synset_id): cd_per_cat[synset_id] = [] cd_per_cat[synset_id].append(cd) print('Average Chamfer distance: %f' % (total_cd / len(model_list))) print('Chamfer distance per category') for synset_id in sorted(cd_per_cat.keys()): print(synset_id, '%f' % np.mean(cd_per_cat[synset_id])) sess.close()