示例#1
0
    def shape_complete(tracklet, car_id):
        origin_partial = read_pcd(
            os.path.join(config['pcd_dir'], '%s.pcd' % car_id))
        bbox = np.loadtxt(os.path.join(config['bbox_dir'], '%s.txt' % car_id))

        # Calculate center, rotation and scale
        center = (bbox.min(0) + bbox.max(0)) / 2
        bbox -= center
        yaw = np.arctan2(bbox[3, 1] - bbox[0, 1], bbox[3, 0] - bbox[0, 0])
        rotation = np.array([[np.cos(yaw), -np.sin(yaw), 0],
                             [np.sin(yaw), np.cos(yaw), 0], [0, 0, 1]])
        bbox = np.dot(bbox, rotation)
        scale = bbox[3, 0] - bbox[0, 0]
        bbox /= scale

        partial = np.dot(origin_partial - center, rotation) / scale
        partial = np.dot(partial, [[1, 0, 0], [0, 0, 1], [0, 1, 0]])

        completion = sess.run(model.completion,
                              feed_dict={
                                  inputs: [partial],
                                  npts: [partial.shape[0]]
                              })
        if not model.use_decoder_only:
            completion, fps_indices = sess.run(
                model.fps(data_config['num_gt_points'], completion))
            is_from_decoder_fps = fps_indices >= config['model'][
                'upsampling_ratio'] * partial.shape[0]
        completion = completion[0]

        nn_dists1, _ = sess.run(nearest_dist_op,
                                feed_dict={
                                    inputs: [partial],
                                    output: [completion]
                                })

        fidelity = np.mean(nn_dists1)

        # visualize
        os.makedirs(os.path.join(config['results_dir'], 'plots', tracklet),
                    exist_ok=True)
        plot_path = os.path.join(config['results_dir'], 'plots', tracklet,
                                 '%s.png' % car_id)
        if config['model']['use_decoder_only']:
            plot_pcd_three_views(plot_path, [partial, completion],
                                 ['input', 'output'], None,
                                 '%d input points' % partial.shape[0],
                                 [5, 0.5])
        elif config['visualizing']['visu_split']:
            plot_pcd_three_views(plot_path, [partial, completion], [
                'input', 'output', 'upsampling',
                config['model']['decoder']['type']
            ], is_from_decoder_fps[0], '%d input points' % partial.shape[0],
                                 [5, 0.5])

        return completion, fidelity
示例#2
0
文件: test.py 项目: countywest/SAUM
def test_vanilla(config):
    test_config = config['test_setting']
    data_config = config['dataset']

    # Data
    inputs = tf.placeholder(tf.float32, (1, None, 3))
    npts = tf.placeholder(tf.int32, (1,))
    gt = tf.placeholder(tf.float32, (1, data_config['num_gt_points'], 3))
    output = tf.placeholder(tf.float32, (1, data_config['num_gt_points'], 3))

    # Model
    model_module = importlib.import_module(config['model']['decoder']['type'])
    model = model_module.model(config, inputs, npts, gt, False)

    # Metric
    cd_op = chamfer(output, gt)
    emd_op = earth_mover(output, gt)

    # make results directory & save configuration
    if os.path.exists(config['results_dir']):
        delete_key = input(colored('%s exists. Delete? [y (or enter)/N]'
                                   % config['results_dir'], 'white', 'on_red'))
        if delete_key == 'y' or delete_key == "":
            os.system('rm -rf %s/*' % config['results_dir'])
    else:
        os.makedirs(os.path.join(config['results_dir']))

    os.system('cp %s %s' % (config['config_path'], config['results_dir']))
    os.system('cp test.py %s' % config['results_dir'])

    # TF Config
    config_proto = tf.ConfigProto()
    config_proto.gpu_options.allow_growth = True
    config_proto.allow_soft_placement = True
    sess = tf.Session(config=config_proto)
    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(config['checkpoint']))

    # Test
    test_start = time.time()
    print(colored("Testing...", 'grey', 'on_green'))

    with open(config['list_path']) as file:
        model_list = file.read().splitlines()

    total_time = 0
    total_cd = 0
    total_emd = 0
    cd_per_cat = {}
    emd_per_cat = {}
    os.makedirs(config['results_dir'], exist_ok=True)
    csv_file = open(os.path.join(config['results_dir'], 'results.csv'), 'w')
    writer = csv.writer(csv_file, delimiter=',', quotechar='"')
    writer.writerow(['id', 'cd', 'emd'])

    for i, model_id in enumerate(model_list):
        start = time.time()

        # data
        if data_config['type'] == 'pcn' or data_config['type'] == 'car':
            partial = read_pcd(os.path.join(data_config['dir'], 'partial', '%s.pcd' % model_id))
            partial_npts = partial.shape[0]
            gt_complete = read_pcd(os.path.join(data_config['dir'], 'complete', '%s.pcd' % model_id))
        elif data_config['type'] == 'topnet':
            with h5py.File(os.path.join(data_config['dir'], 'partial', '%s.h5' % model_id), 'r') as f:
                partial = f.get('data').value.astype(np.float32)
            partial_npts = partial.shape[0]
            with h5py.File(os.path.join(data_config['dir'], 'gt', '%s.h5' % model_id), 'r') as f:
                gt_complete = f.get('data').value.astype(np.float32)
        else:
            raise NotImplementedError

        # inference
        completion = sess.run(model.completion, feed_dict={inputs: [partial], npts: [partial_npts]})

        cd = sess.run(cd_op, feed_dict={output: completion, gt: [gt_complete]})
        emd = sess.run(emd_op, feed_dict={output: completion, gt:[gt_complete]})
        total_cd += cd
        total_emd += emd

        total_time += time.time() - start

        writer.writerow([model_id, cd, emd])
        csv_file.flush()

        synset_id, model_id = model_id.split('/')
        if not cd_per_cat.get(synset_id):
            cd_per_cat[synset_id] = []
        if not emd_per_cat.get(synset_id):
            emd_per_cat[synset_id] = []
        cd_per_cat[synset_id].append(cd)
        emd_per_cat[synset_id].append(emd)

        # visualize
        if i % test_config['plot_freq'] == 0:
            os.makedirs(os.path.join(config['results_dir'], 'plots', synset_id), exist_ok=True)
            plot_path = os.path.join(config['results_dir'], 'plots', synset_id, '%s.png' % model_id)
            plot_pcd_three_views(plot_path, [partial, completion[0], gt_complete],
                                 model.visualize_titles, None,
                                 'CD %.4f EMD %.4f' %
                                 (cd, emd)
                                 )
        if test_config['save_pcd']:
            os.makedirs(os.path.join(config['results_dir'], 'pcds', synset_id), exist_ok=True)
            save_pcd(os.path.join(config['results_dir'], 'pcds', synset_id, '%s.pcd' % model_id), completion[0])

    writer.writerow(["average", total_cd / len(model_list), total_emd / len(model_list)])

    for synset_id in cd_per_cat.keys():
        writer.writerow([synset_id, np.mean(cd_per_cat[synset_id]), np.mean(emd_per_cat[synset_id])])

    with open(os.path.join(config['results_dir'], 'results_summary.txt'), 'w') as log:
        log.write('Average Chamfer distance: %.8f \n' % (total_cd / len(model_list)))
        log.write('Average Earth mover distance: %.8f \n' % (total_emd / len(model_list)))
        log.write('## Summary for each category ## \n')
        log.write('ID  CD  EMD  \n')
        for synset_id in cd_per_cat.keys():
            log.write('%s %.8f %.8f\n' % (synset_id,
                                          np.mean(cd_per_cat[synset_id]),
                                          np.mean(emd_per_cat[synset_id])
                                          )
                      )

    # print results
    print('Average time: %f' % (total_time / len(model_list)))
    print('Average Chamfer distance: %f' % (total_cd / len(model_list)))
    print('Average Earth mover distance: %f' % (total_emd / len(model_list)))
    print('Chamfer distance per category')
    for synset_id in cd_per_cat.keys():
        print(synset_id, '%f' % np.mean(cd_per_cat[synset_id]))
    print('Earth mover distance per category')
    for synset_id in emd_per_cat.keys():
        print(synset_id, '%f' % np.mean(emd_per_cat[synset_id]))
    csv_file.close()
    sess.close()
    print(colored("Test ended!", 'grey', 'on_green'))
    print('Total testing time', datetime.timedelta(seconds=time.time() - test_start))
示例#3
0
def train(args):
	# with tf.Graph().as_default() as graph:
	# with tf.device('/gpu:'+str(args.gpu)):

	is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
	global_step = tf.Variable(0, trainable=False, name='global_step')
	alpha = tf.train.piecewise_constant(global_step, [10000, 20000, 50000],
										[0.01, 0.1, 0.5, 1.0], 'alpha_op')

	# for ModelNet, it is with Fixed Number of Input Points
	# for ShapeNet, it is with Varying Number of Input Points
	inputs_pl = tf.placeholder(tf.float32, (1, BATCH_SIZE * NUM_POINT, 3), 'inputs')
	npts_pl = tf.placeholder(tf.int32, (BATCH_SIZE,), 'num_points')
	gt_pl = tf.placeholder(tf.float32, (BATCH_SIZE, args.num_gt_points, 3), 'ground_truths')
	add_train_summary('alpha', alpha)
	bn_decay = get_bn_decay(global_step)
	add_train_summary('bn_decay', bn_decay)

	model_module = importlib.import_module('.%s' % args.model_type, 'completion_models')
	model = model_module.Model(inputs_pl, npts_pl, gt_pl, alpha, bn_decay=bn_decay, is_training=is_training_pl)
	
	# Another Solution instead of importlib:
	# ldic = locals()
	# exec('from completion_models.%s import Model' % args.model_type, globals(), ldic)
	# model = ldic['Model'](inputs_pl, npts_pl, gt_pl, alpha, bn_decay=bn_decay, is_training=is_training_pl)

	if args.lr_decay:
		learning_rate = tf.train.exponential_decay(args.base_lr, global_step,
												   args.lr_decay_steps, args.lr_decay_rate,
												   staircase=True, name='lr')
		learning_rate = tf.maximum(learning_rate, args.lr_clip)
		add_train_summary('learning_rate', learning_rate)
	else:
		learning_rate = tf.constant(args.base_lr, name='lr')

	trainer = tf.train.AdamOptimizer(learning_rate)
	train_op = trainer.minimize(model.loss, global_step)
	# seems like different from the what the paper has claimed:
	saver = tf.train.Saver(max_to_keep=10)
	''' from PCN paper:
	All our completion_models are trained using the Adam optimizer with an initial learning rate of 0.0001 for 50 epochs
	and a batch size of 32. The learning rate is decayed by 0.7 every 50K iterations.
	'''

	if args.store_grad:
		grads_and_vars = trainer.compute_gradients(model.loss)
		for g, v in grads_and_vars:
			tf.summary.histogram(v.name, v, collections=['train_summary'])
			tf.summary.histogram(v.name + '_grad', g, collections=['train_summary'])

	train_summary = tf.summary.merge_all('train_summary')
	valid_summary = tf.summary.merge_all('valid_summary')

	# the input number of points for the partial observed data is not a fixed number
	df_train, num_train = lmdb_dataflow(
		args.lmdb_train, args.batch_size, args.num_input_points, args.num_gt_points, is_training=True)
	train_gen = df_train.get_data()
	df_valid, num_valid = lmdb_dataflow(
		args.lmdb_valid, args.batch_size, args.num_input_points, args.num_gt_points, is_training=False)
	valid_gen = df_valid.get_data()

	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True
	config.allow_soft_placement = True
	sess = tf.Session(config=config)
	# saver = tf.train.Saver()

	if args.restore:
		saver.restore(sess, tf.train.latest_checkpoint(args.log_dir))
		writer = tf.summary.FileWriter(args.log_dir)
	else:
		sess.run(tf.global_variables_initializer())
		if os.path.exists(args.log_dir):
			delete_key = input(colored('%s exists. Delete? [y/n]' % args.log_dir, 'white', 'on_red'))
			if delete_key == 'y' or delete_key == "yes":
				os.system('rm -rf %s/*' % args.log_dir)
				os.makedirs(os.path.join(args.log_dir, 'plots'))
		else:
			os.makedirs(os.path.join(args.log_dir, 'plots'))
		with open(os.path.join(args.log_dir, 'args.txt'), 'w') as log:
			for arg in sorted(vars(args)):
				log.write(arg + ': ' + str(getattr(args, arg)) + '\n')
		log.close()
		os.system('cp completion_models/%s.py %s' % (args.model_type, args.log_dir))  # bkp of model scripts
		os.system('cp train_completion.py %s' % args.log_dir)  # bkp of train procedure
		writer = tf.summary.FileWriter(args.log_dir, sess.graph)  # GOOD habit

	log_fout = open(os.path.join(args.log_dir, 'log_train.txt'), 'a+')
	for arg in sorted(vars(args)):
		log_fout.write(arg + ': ' + str(getattr(args, arg)) + '\n')
		log_fout.flush()

	total_time = 0
	train_start = time.time()
	init_step = sess.run(global_step)

	for step in range(init_step + 1, args.max_step + 1):
		# Epoch: how many times the model have seen each sample
		# Step: how many times the generated has been nexted
		epoch = step * args.batch_size // num_train + 1
		ids, inputs, npts, gt = next(train_gen)
		if epoch > args.epoch:
			break
		if DATASET == 'shapenet8':
			inputs, npts = vary2fix(inputs, npts)

		start = time.time()
		feed_dict = {inputs_pl: inputs, npts_pl: npts, gt_pl: gt, is_training_pl: True}
		_, loss, summary = sess.run([train_op, model.loss, train_summary], feed_dict=feed_dict)
		total_time += time.time() - start
		writer.add_summary(summary, step)

		if step % args.steps_per_print == 0:
			print('epoch %d  step %d  loss %.8f - time per batch %.4f' %
				  (epoch, step, loss, total_time / args.steps_per_print))
			total_time = 0

		if step % args.steps_per_eval == 0:
			print(colored('Testing...', 'grey', 'on_green'))
			num_eval_steps = num_valid // args.batch_size
			total_loss = 0
			total_time = 0
			sess.run(tf.local_variables_initializer())
			for i in range(num_eval_steps):
				start = time.time()
				ids, inputs, npts, gt = next(valid_gen)
				if DATASET == 'shapenet8':
					inputs, npts = vary2fix(inputs, npts)
				feed_dict = {inputs_pl: inputs, npts_pl: npts, gt_pl: gt, is_training_pl: False}
				loss, _ = sess.run([model.loss, model.update], feed_dict=feed_dict)
				total_loss += loss
				total_time += time.time() - start
			summary = sess.run(valid_summary, feed_dict={is_training_pl: False})
			writer.add_summary(summary, step)
			print(colored('epoch %d  step %d  loss %.8f - time per batch %.4f' %
						  (epoch, step, total_loss / num_eval_steps, total_time / num_eval_steps),
						  'grey', 'on_green'))
			total_time = 0

		if step % args.steps_per_visu == 0:
			all_pcds = sess.run(model.visualize_ops, feed_dict=feed_dict)
			for i in range(0, args.batch_size, args.visu_freq):
				plot_path = os.path.join(args.log_dir, 'plots',
										 'epoch_%d_step_%d_%s.png' % (epoch, step, ids[i]))
				pcds = [x[i] for x in all_pcds]
				plot_pcd_three_views(plot_path, pcds, model.visualize_titles)
		# if step % args.steps_per_save == 0:
		if (epoch % args.epochs_per_save == 0) and \
				not os.path.exists(os.path.join(args.log_dir, 'model-%d.meta' % epoch)):
			saver.save(sess, os.path.join(args.log_dir, 'model'), epoch)
			print(colored('Epoch:%d, Model saved at %s' % (epoch, args.log_dir), 'white', 'on_blue'))

	print('Total time', datetime.timedelta(seconds=time.time() - train_start))
	sess.close()
示例#4
0
def train(config):
    data_config = config['dataset']
    train_config = config['train_setting']
    lr_config = train_config['learning_rate']

    # Data
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    inputs_pl = tf.placeholder(tf.float32, (1, None, 3), 'inputs')
    npts_pl = tf.placeholder(tf.int32, (train_config['batch_size'], ),
                             'num_points')
    gt_pl = tf.placeholder(
        tf.float32,
        (train_config['batch_size'], data_config['num_gt_points'], 3),
        'ground_truths')

    train_set = Dataset(data_config, train_config, is_training=True)
    valid_set = Dataset(data_config, train_config, is_training=False)
    num_train = train_set.get_num_data()
    num_valid = valid_set.get_num_data()

    # Model
    model_module = importlib.import_module(config['model']['decoder']['type'])
    model = model_module.model(config, inputs_pl, npts_pl, gt_pl,
                               is_training_pl)

    # Optimizer
    optimizer = importlib.import_module('optimizer').optimizer(
        lr_config, model.global_step, model.target_loss)

    # TF Config
    config_proto = tf.ConfigProto()
    config_proto.gpu_options.allow_growth = True
    config_proto.allow_soft_placement = True
    sess = tf.Session(config=config_proto)
    saver = tf.train.Saver()
    train_summary = tf.summary.merge_all('train_summary')
    valid_summary = tf.summary.merge_all('valid_summary')

    # restart training
    if config['restore']:
        saver.restore(sess, tf.train.latest_checkpoint(config['log_dir']))
        writer = tf.summary.FileWriter(config['log_dir'])

        # calc the last best valid loss
        num_eval_steps = num_valid // train_config['batch_size']
        total_eval_loss = 0
        sess.run(tf.local_variables_initializer())

        for i in range(num_eval_steps):
            ids, inputs, npts, gt = valid_set.fetch(sess)
            gt = gt[:, :data_config['num_gt_points'], :]
            feed_dict = {
                inputs_pl: inputs,
                npts_pl: npts,
                gt_pl: gt,
                is_training_pl: False
            }
            evaluation_loss = sess.run(model.evaluation_loss,
                                       feed_dict=feed_dict)
            total_eval_loss += evaluation_loss
        best_valid_loss = total_eval_loss / num_eval_steps

    # train from scratch
    else:
        sess.run(tf.global_variables_initializer())
        if os.path.exists(config['log_dir']):
            delete_key = input(
                colored(
                    '%s exists. Delete? [y (or enter)/N]' % config['log_dir'],
                    'white', 'on_red'))
            if delete_key == 'y' or delete_key == "":
                os.system('rm -rf %s/*' % config['log_dir'])
                os.makedirs(os.path.join(config['log_dir'], 'plots'))
        else:
            os.makedirs(os.path.join(config['log_dir'], 'plots'))

        # save configuration in log directory
        os.system('cp %s %s' % (config['config_path'], config['log_dir']))
        os.system('cp train.py %s' % config['log_dir'])

        writer = tf.summary.FileWriter(config['log_dir'], sess.graph)
        best_valid_loss = 1e5  # initialize with enough big num

    print(colored("Training will begin.. ", 'grey', 'on_green'))
    print(
        colored("Batch_size: " + str(train_config['batch_size']), 'grey',
                'on_green'))
    print(
        colored("Batch norm use?: " + str(config['model']['use_bn']), 'red',
                'on_green'))
    print(
        colored("Decoder arch: " + config['model']['decoder']['type'], 'grey',
                'on_green'))
    print(
        colored("Last best_validation_loss: " + str(best_valid_loss), 'grey',
                'on_green'))

    # Training
    total_time = 0
    train_start = time.time()
    init_step = sess.run(model.global_step)

    for step in range(init_step + 1, train_config['max_step'] + 1):
        epoch = step * train_config['batch_size'] // num_train + 1

        ids, inputs, npts, gt = train_set.fetch(sess)
        gt = gt[:, :data_config['num_gt_points'], :]
        start = time.time()
        feed_dict = {
            inputs_pl: inputs,
            npts_pl: npts,
            gt_pl: gt,
            is_training_pl: True
        }
        _, target_loss, summary = sess.run(
            [optimizer.train, model.target_loss, train_summary],
            feed_dict=feed_dict)

        total_time += time.time() - start
        writer.add_summary(summary, step)

        # logging
        if step % train_config['steps_per_print'] == 0:
            print('epoch %d  step %d  target_loss %.8f - time per batch %.4f' %
                  (epoch, step, target_loss,
                   total_time / train_config['steps_per_print']))
            total_time = 0

        # eval on validation set
        if step % train_config['steps_per_eval'] == 0:
            print(colored('Testing...', 'grey', 'on_green'))
            num_eval_steps = num_valid // train_config['batch_size']
            total_eval_loss = 0
            total_time = 0
            sess.run(tf.local_variables_initializer())
            for i in range(num_eval_steps):
                start = time.time()
                ids, inputs, npts, gt = valid_set.fetch(sess)
                gt = gt[:, :data_config['num_gt_points'], :]
                feed_dict = {
                    inputs_pl: inputs,
                    npts_pl: npts,
                    gt_pl: gt,
                    is_training_pl: False
                }
                evaluation_loss, _ = sess.run(
                    [model.evaluation_loss, model.update], feed_dict=feed_dict)
                total_eval_loss += evaluation_loss
                total_time += time.time() - start
            summary = sess.run(valid_summary,
                               feed_dict={is_training_pl: False})
            writer.add_summary(summary, step)
            temp_valid_loss = total_eval_loss / num_eval_steps
            print(
                colored(
                    'epoch %d  step %d  eval_loss %.8f - time per batch %.4f' %
                    (epoch, step, temp_valid_loss,
                     total_time / num_eval_steps), 'grey', 'on_green'))
            if temp_valid_loss <= best_valid_loss:  # save best model for validation set
                best_valid_loss = temp_valid_loss
                saver.save(sess, os.path.join(config['log_dir'], 'model'),
                           step)
                print(
                    colored('Model saved at %s' % config['log_dir'], 'white',
                            'on_blue'))
            total_time = 0

        # visualize
        if step % config['visualizing']['steps_per_visu'] == 0:
            print('visualizing!')
            vis_ids, vis_inputs, vis_npts, vis_gt = valid_set.fetch(sess)
            if data_config['type'] == 'topnet':
                # for replace the character "/" to "_"
                vis_ids = vis_ids.astype('U')
                vis_ids = np.char.split(vis_ids, sep='/', maxsplit=1)
                vis_ids = np.char.join(['_'] * train_config['batch_size'],
                                       vis_ids)

            vis_feed_dict = {
                inputs_pl: vis_inputs,
                npts_pl: vis_npts,
                gt_pl: vis_gt,
                is_training_pl: False
            }
            all_pcds = sess.run(model.visualize_ops, feed_dict=vis_feed_dict)
            is_from_decoder = \
                np.arange(0, config['model']['decoder']['num_decoder_points'] + config['model']['upsampling_ratio'] * train_config['num_input_points'])\
                >= config['model']['upsampling_ratio'] * train_config['num_input_points']

            for i in range(0, train_config['batch_size'],
                           config['visualizing']['visu_freq']):
                plot_path = os.path.join(
                    config['log_dir'], 'plots',
                    'epoch_%d_step_%d_%s.png' % (epoch, step, vis_ids[i]))
                pcds = [x[i] for x in all_pcds]
                if config['visualizing']['visu_split']:
                    plot_pcd_three_views(plot_path, pcds,
                                         model.visualize_titles,
                                         is_from_decoder)
                else:
                    plot_pcd_three_views(plot_path, pcds,
                                         model.visualize_titles, None)

    print(colored("Training ended!", 'grey', 'on_green'))
    print('Total training time',
          datetime.timedelta(seconds=time.time() - train_start))
    sess.close()
示例#5
0
def test_saum(config):
    test_config = config['test_setting']
    data_config = config['dataset']

    # Data
    inputs = tf.placeholder(tf.float32, (1, None, 3))
    npts = tf.placeholder(tf.int32, (1, ))
    gt = tf.placeholder(tf.float32, (1, data_config['num_gt_points'], 3))
    output = tf.placeholder(tf.float32, (1, None, 3))
    sampled_output = tf.placeholder(tf.float32,
                                    (1, data_config['num_gt_points'], 3))

    # Model
    model_module = importlib.import_module(config['model']['decoder']['type'])
    model = model_module.model(config, inputs, npts, gt, False)

    # Loss
    cd_op = chamfer(output, gt)
    emd_op = earth_mover(sampled_output, gt)
    nearest_dist_op = dist_to_nearest(output, gt)

    # make results directory
    if os.path.exists(config['results_dir']):
        delete_key = input(
            colored(
                '%s exists. Delete? [y (or enter)/N]' % config['results_dir'],
                'white', 'on_red'))
        if delete_key == 'y' or delete_key == "":
            os.system('rm -rf %s/*' % config['results_dir'])
    else:
        os.makedirs(os.path.join(config['results_dir']))

    os.system('cp test_self_consistency.py %s' % config['results_dir'])

    # TF Config
    config_proto = tf.ConfigProto()
    config_proto.gpu_options.allow_growth = True
    config_proto.allow_soft_placement = True
    sess = tf.Session(config=config_proto)
    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(config['checkpoint']))

    # Test
    test_start = time.time()
    print(colored("Testing...", 'grey', 'on_green'))

    with open(config['list_path']) as file:
        model_list = file.read().splitlines()

    total_time = 0

    total_cd = 0
    total_fps_cd = 0
    total_fps_emd = 0
    total_fps_f1_score = 0

    cd_per_cat = {}
    fps_cd_per_cat = {}
    fps_emd_per_cat = {}
    fps_f1_score_per_cat = {}

    os.makedirs(config['results_dir'], exist_ok=True)
    csv_file = open(os.path.join(config['results_dir'], 'results.csv'), 'w')
    writer = csv.writer(csv_file, delimiter=',', quotechar='"')
    writer.writerow(['id', 'cd', 'fps_cd', 'fps_emd', 'fps_f1_score'])

    for i, model_id in enumerate(model_list):
        start = time.time()

        # data
        if data_config['type'] == 'pcn' or data_config['type'] == 'car':
            gt_complete = read_pcd(
                os.path.join(data_config['dir'], 'complete',
                             '%s.pcd' % model_id))
            gt_complete_npts = gt_complete.shape[0]
        elif data_config['type'] == 'topnet':
            with h5py.File(
                    os.path.join(data_config['dir'], 'gt', '%s.h5' % model_id),
                    'r') as f:
                gt_complete = f.get('data').value.astype(np.float32)
            gt_complete_npts = gt_complete.shape[0]
        else:
            raise NotImplementedError

        # inference
        completion = sess.run(model.completion,
                              feed_dict={
                                  inputs: [gt_complete],
                                  npts: [gt_complete_npts]
                              })

        fps_completion, fps_indices = sess.run(
            model.fps(data_config['num_gt_points'], completion))

        is_from_decoder_raw = \
            np.arange(0, config['model']['decoder']['num_decoder_points'] + config['model']['upsampling_ratio'] * gt_complete_npts) \
            >= config['model']['upsampling_ratio'] * gt_complete_npts
        is_from_decoder_fps = fps_indices >= config['model'][
            'upsampling_ratio'] * gt_complete_npts

        total_time += time.time() - start

        # raw
        cd = sess.run(cd_op, feed_dict={output: completion, gt: [gt_complete]})
        total_cd += cd

        # farthest point sampling
        # cd, emd
        fps_cd = sess.run(cd_op,
                          feed_dict={
                              output: fps_completion,
                              gt: [gt_complete]
                          })
        fps_emd = sess.run(emd_op,
                           feed_dict={
                               sampled_output: fps_completion,
                               gt: [gt_complete]
                           })
        total_fps_cd += fps_cd
        total_fps_emd += fps_emd
        # f1_score
        fps_nn_dists1, fps_nn_dists2 = sess.run(nearest_dist_op,
                                                feed_dict={
                                                    output: fps_completion,
                                                    gt: [gt_complete]
                                                })
        fps_P = len(fps_nn_dists1[fps_nn_dists1 < test_config['threshold']]
                    ) / data_config['num_gt_points']
        fps_R = len(fps_nn_dists2[fps_nn_dists2 < test_config['threshold']]
                    ) / data_config['num_gt_points']
        fps_f1_score = 2 * fps_P * fps_R / (fps_P + fps_R)
        total_fps_f1_score += fps_f1_score

        writer.writerow([model_id, cd, fps_cd, fps_emd, fps_f1_score])
        csv_file.flush()

        synset_id, model_id = model_id.split('/')
        if not cd_per_cat.get(synset_id):
            cd_per_cat[synset_id] = []
        if not fps_cd_per_cat.get(synset_id):
            fps_cd_per_cat[synset_id] = []
        if not fps_emd_per_cat.get(synset_id):
            fps_emd_per_cat[synset_id] = []
        if not fps_f1_score_per_cat.get(synset_id):
            fps_f1_score_per_cat[synset_id] = []

        cd_per_cat[synset_id].append(cd)
        fps_cd_per_cat[synset_id].append(fps_cd)
        fps_emd_per_cat[synset_id].append(fps_emd)
        fps_f1_score_per_cat[synset_id].append(fps_f1_score)

        # visualize
        if i % test_config['plot_freq'] == 0:
            if config['visualizing']['visu_split']:
                raw_dir = os.path.join(config['results_dir'], 'plots', 'raw',
                                       synset_id)
                fps_dir = os.path.join(config['results_dir'], 'plots', 'fps',
                                       synset_id)

                os.makedirs(raw_dir, exist_ok=True)
                os.makedirs(fps_dir, exist_ok=True)

                raw_plot_path = os.path.join(raw_dir, '%s.png' % model_id)
                fps_plot_path = os.path.join(fps_dir, '%s.png' % model_id)

                plot_pcd_three_views(raw_plot_path,
                                     [gt_complete, completion[0], gt_complete],
                                     model.visualize_titles,
                                     is_from_decoder_raw, 'CD %.4f' % (cd))

                plot_pcd_three_views(
                    fps_plot_path,
                    [gt_complete, fps_completion[0], gt_complete],
                    model.visualize_titles, is_from_decoder_fps[0],
                    'FPS_CD %.4f FPS_EMD %.4f FPS_f1_score %.4f' %
                    (fps_cd, fps_emd, fps_f1_score))
            else:
                os.makedirs(os.path.join(config['results_dir'], 'plots',
                                         synset_id),
                            exist_ok=True)

                plot_path = os.path.join(config['results_dir'], 'plots',
                                         synset_id, '%s.png' % model_id)
                plot_pcd_three_views(
                    plot_path, [gt_complete, completion[0], gt_complete],
                    model.visualize_titles, None,
                    'CD %.4f FPS_CD %.4f FPS_EMD %.4f FPS_f1_score %.4f' %
                    (cd, fps_cd, fps_emd, fps_f1_score))

        if test_config['save_pcd']:
            os.makedirs(os.path.join(config['results_dir'], 'pcds', synset_id),
                        exist_ok=True)
            save_pcd(
                os.path.join(config['results_dir'], 'pcds', synset_id,
                             '%s.pcd' % model_id), completion[0])
            save_pcd(
                os.path.join(config['results_dir'], 'pcds', synset_id,
                             '%s_fps.pcd' % model_id), fps_completion[0])

    # write average info in csv file
    writer.writerow([
        "average", total_cd / len(model_list), total_fps_cd / len(model_list),
        total_fps_emd / len(model_list), total_fps_f1_score / len(model_list)
    ])
    for synset_id in cd_per_cat.keys():
        writer.writerow([
            synset_id,
            np.mean(cd_per_cat[synset_id]),
            np.mean(fps_cd_per_cat[synset_id]),
            np.mean(fps_emd_per_cat[synset_id]),
            np.mean(fps_f1_score_per_cat[synset_id])
        ])

    # write average distances(cd, emd) in txt file
    with open(os.path.join(config['results_dir'], 'results_summary.txt'),
              'w') as log:
        log.write('Average Chamfer distance: %.8f \n' %
                  (total_cd / len(model_list)))
        log.write('Average FPS Chamfer distance: %.8f \n' %
                  (total_fps_cd / len(model_list)))
        log.write('Average FPS Earth mover distance: %.8f \n' %
                  (total_fps_emd / len(model_list)))
        log.write(
            'Average FPS f1_score(threshold: %.4f): %.8f \n' %
            (test_config['threshold'], total_fps_f1_score / len(model_list)))

        log.write('## Summary for each category ## \n')
        log.write('ID  CD  FPS_CD  FPS_EMD  FPS_f1_score\n')
        for synset_id in cd_per_cat.keys():
            log.write('%s %.8f %.8f %.8f %.8f\n' %
                      (synset_id, np.mean(cd_per_cat[synset_id]),
                       np.mean(fps_cd_per_cat[synset_id]),
                       np.mean(fps_emd_per_cat[synset_id]),
                       np.mean(fps_f1_score_per_cat[synset_id])))

    # print results
    print('Average time: %f' % (total_time / len(model_list)))
    print('Average Chamfer distance: %f' % (total_cd / len(model_list)))
    print('Average FPS Chamfer distance: %f' %
          (total_fps_cd / len(model_list)))
    print('Average FPS Earth mover distance: %f' %
          (total_fps_emd / len(model_list)))
    print('Average FPS f1_score(threshold: %.4f): %f' %
          (test_config['threshold'], total_fps_f1_score / len(model_list)))

    print('Chamfer distance per category')
    for synset_id in cd_per_cat.keys():
        print(synset_id, '%f' % np.mean(cd_per_cat[synset_id]))
    print('Average FPS Chamfer distance per catergory')
    for synset_id in fps_cd_per_cat.keys():
        print(synset_id, '%f' % np.mean(fps_cd_per_cat[synset_id]))
    print('Average FPS Earth mover distance per category')
    for synset_id in fps_emd_per_cat.keys():
        print(synset_id, '%f' % np.mean(fps_emd_per_cat[synset_id]))
    print('Average FPS f1_score per category')
    for synset_id in fps_f1_score_per_cat.keys():
        print(synset_id, '%f' % np.mean(fps_f1_score_per_cat[synset_id]))

    csv_file.close()
    sess.close()

    print(colored("Test ended!", 'grey', 'on_green'))
    print('Total testing time',
          datetime.timedelta(seconds=time.time() - test_start))