Beispiel #1
0
    def create_loss(self, gt, alpha):

        loss_coarse = chamfer(self.coarse, gt)
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = chamfer(self.fine, gt)
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
Beispiel #2
0
    def create_loss(self, coarse, fine, gt, alpha):
        
        gt_ds = gt[:, :coarse.shape[1], :]
        
        loss_coarse = earth_mover(coarse, gt_ds)
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = chamfer(fine, gt)
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
Beispiel #3
0
    def create_loss(self, gt, alpha):

        gt_ds = gt[:, :self.coarse.shape[1], :]
        loss_coarse = tf_util.earth_mover(self.coarse, gt_ds)
        tf_util.add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = tf_util.add_valid_summary('valid/coarse_loss',
                                                  loss_coarse)

        loss_fine = tf_util.chamfer(self.fine, gt)
        tf_util.add_train_summary('train/fine_loss', loss_fine)
        update_fine = tf_util.add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        tf_util.add_train_summary('train/loss', loss)
        update_loss = tf_util.add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
Beispiel #4
0
    def create_loss(self, coarse, fine, gt, alpha):

        # print('coarse shape:', coarse.shape)
        # print('fine shape:', fine.shape)
        # print('gt shape:', gt.shape)

        loss_coarse = chamfer(coarse, gt)
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = chamfer(fine, gt)
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
Beispiel #5
0
def train(args):
	# with tf.Graph().as_default() as graph:
	# with tf.device('/gpu:'+str(args.gpu)):

	is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
	global_step = tf.Variable(0, trainable=False, name='global_step')
	alpha = tf.train.piecewise_constant(global_step, [10000, 20000, 50000],
										[0.01, 0.1, 0.5, 1.0], 'alpha_op')

	# for ModelNet, it is with Fixed Number of Input Points
	# for ShapeNet, it is with Varying Number of Input Points
	inputs_pl = tf.placeholder(tf.float32, (1, BATCH_SIZE * NUM_POINT, 3), 'inputs')
	npts_pl = tf.placeholder(tf.int32, (BATCH_SIZE,), 'num_points')
	gt_pl = tf.placeholder(tf.float32, (BATCH_SIZE, args.num_gt_points, 3), 'ground_truths')
	add_train_summary('alpha', alpha)
	bn_decay = get_bn_decay(global_step)
	add_train_summary('bn_decay', bn_decay)

	model_module = importlib.import_module('.%s' % args.model_type, 'completion_models')
	model = model_module.Model(inputs_pl, npts_pl, gt_pl, alpha, bn_decay=bn_decay, is_training=is_training_pl)
	
	# Another Solution instead of importlib:
	# ldic = locals()
	# exec('from completion_models.%s import Model' % args.model_type, globals(), ldic)
	# model = ldic['Model'](inputs_pl, npts_pl, gt_pl, alpha, bn_decay=bn_decay, is_training=is_training_pl)

	if args.lr_decay:
		learning_rate = tf.train.exponential_decay(args.base_lr, global_step,
												   args.lr_decay_steps, args.lr_decay_rate,
												   staircase=True, name='lr')
		learning_rate = tf.maximum(learning_rate, args.lr_clip)
		add_train_summary('learning_rate', learning_rate)
	else:
		learning_rate = tf.constant(args.base_lr, name='lr')

	trainer = tf.train.AdamOptimizer(learning_rate)
	train_op = trainer.minimize(model.loss, global_step)
	# seems like different from the what the paper has claimed:
	saver = tf.train.Saver(max_to_keep=10)
	''' from PCN paper:
	All our completion_models are trained using the Adam optimizer with an initial learning rate of 0.0001 for 50 epochs
	and a batch size of 32. The learning rate is decayed by 0.7 every 50K iterations.
	'''

	if args.store_grad:
		grads_and_vars = trainer.compute_gradients(model.loss)
		for g, v in grads_and_vars:
			tf.summary.histogram(v.name, v, collections=['train_summary'])
			tf.summary.histogram(v.name + '_grad', g, collections=['train_summary'])

	train_summary = tf.summary.merge_all('train_summary')
	valid_summary = tf.summary.merge_all('valid_summary')

	# the input number of points for the partial observed data is not a fixed number
	df_train, num_train = lmdb_dataflow(
		args.lmdb_train, args.batch_size, args.num_input_points, args.num_gt_points, is_training=True)
	train_gen = df_train.get_data()
	df_valid, num_valid = lmdb_dataflow(
		args.lmdb_valid, args.batch_size, args.num_input_points, args.num_gt_points, is_training=False)
	valid_gen = df_valid.get_data()

	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True
	config.allow_soft_placement = True
	sess = tf.Session(config=config)
	# saver = tf.train.Saver()

	if args.restore:
		saver.restore(sess, tf.train.latest_checkpoint(args.log_dir))
		writer = tf.summary.FileWriter(args.log_dir)
	else:
		sess.run(tf.global_variables_initializer())
		if os.path.exists(args.log_dir):
			delete_key = input(colored('%s exists. Delete? [y/n]' % args.log_dir, 'white', 'on_red'))
			if delete_key == 'y' or delete_key == "yes":
				os.system('rm -rf %s/*' % args.log_dir)
				os.makedirs(os.path.join(args.log_dir, 'plots'))
		else:
			os.makedirs(os.path.join(args.log_dir, 'plots'))
		with open(os.path.join(args.log_dir, 'args.txt'), 'w') as log:
			for arg in sorted(vars(args)):
				log.write(arg + ': ' + str(getattr(args, arg)) + '\n')
		log.close()
		os.system('cp completion_models/%s.py %s' % (args.model_type, args.log_dir))  # bkp of model scripts
		os.system('cp train_completion.py %s' % args.log_dir)  # bkp of train procedure
		writer = tf.summary.FileWriter(args.log_dir, sess.graph)  # GOOD habit

	log_fout = open(os.path.join(args.log_dir, 'log_train.txt'), 'a+')
	for arg in sorted(vars(args)):
		log_fout.write(arg + ': ' + str(getattr(args, arg)) + '\n')
		log_fout.flush()

	total_time = 0
	train_start = time.time()
	init_step = sess.run(global_step)

	for step in range(init_step + 1, args.max_step + 1):
		# Epoch: how many times the model have seen each sample
		# Step: how many times the generated has been nexted
		epoch = step * args.batch_size // num_train + 1
		ids, inputs, npts, gt = next(train_gen)
		if epoch > args.epoch:
			break
		if DATASET == 'shapenet8':
			inputs, npts = vary2fix(inputs, npts)

		start = time.time()
		feed_dict = {inputs_pl: inputs, npts_pl: npts, gt_pl: gt, is_training_pl: True}
		_, loss, summary = sess.run([train_op, model.loss, train_summary], feed_dict=feed_dict)
		total_time += time.time() - start
		writer.add_summary(summary, step)

		if step % args.steps_per_print == 0:
			print('epoch %d  step %d  loss %.8f - time per batch %.4f' %
				  (epoch, step, loss, total_time / args.steps_per_print))
			total_time = 0

		if step % args.steps_per_eval == 0:
			print(colored('Testing...', 'grey', 'on_green'))
			num_eval_steps = num_valid // args.batch_size
			total_loss = 0
			total_time = 0
			sess.run(tf.local_variables_initializer())
			for i in range(num_eval_steps):
				start = time.time()
				ids, inputs, npts, gt = next(valid_gen)
				if DATASET == 'shapenet8':
					inputs, npts = vary2fix(inputs, npts)
				feed_dict = {inputs_pl: inputs, npts_pl: npts, gt_pl: gt, is_training_pl: False}
				loss, _ = sess.run([model.loss, model.update], feed_dict=feed_dict)
				total_loss += loss
				total_time += time.time() - start
			summary = sess.run(valid_summary, feed_dict={is_training_pl: False})
			writer.add_summary(summary, step)
			print(colored('epoch %d  step %d  loss %.8f - time per batch %.4f' %
						  (epoch, step, total_loss / num_eval_steps, total_time / num_eval_steps),
						  'grey', 'on_green'))
			total_time = 0

		if step % args.steps_per_visu == 0:
			all_pcds = sess.run(model.visualize_ops, feed_dict=feed_dict)
			for i in range(0, args.batch_size, args.visu_freq):
				plot_path = os.path.join(args.log_dir, 'plots',
										 'epoch_%d_step_%d_%s.png' % (epoch, step, ids[i]))
				pcds = [x[i] for x in all_pcds]
				plot_pcd_three_views(plot_path, pcds, model.visualize_titles)
		# if step % args.steps_per_save == 0:
		if (epoch % args.epochs_per_save == 0) and \
				not os.path.exists(os.path.join(args.log_dir, 'model-%d.meta' % epoch)):
			saver.save(sess, os.path.join(args.log_dir, 'model'), epoch)
			print(colored('Epoch:%d, Model saved at %s' % (epoch, args.log_dir), 'white', 'on_blue'))

	print('Total time', datetime.timedelta(seconds=time.time() - train_start))
	sess.close()
Beispiel #6
0
def train(args):

    log_string('\n\n' + '=' * 50)
    log_string('Start Training, Time: %s' % datetime.datetime.now())
    log_string('=' * 50 + '\n\n')

    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    global_step = tf.Variable(
        0, trainable=False,
        name='global_step')  # will be used in defining train_op
    inputs_pl = tf.placeholder(tf.float32, (1, None, 3), 'inputs')
    labels_pl = tf.placeholder(tf.int32, (BATCH_SIZE, ), 'labels')
    npts_pl = tf.placeholder(tf.int32, (BATCH_SIZE, ), 'num_points')

    bn_decay = get_bn_decay(global_step, BN_INIT_DECAY, BATCH_SIZE,
                            BN_DECAY_STEP, BN_DECAY_RATE, BN_DECAY_CLIP)

    # model_module = importlib.import_module('.%s' % args.model, 'cls_models')
    # MODEL = model_module.Model(inputs_pl, npts_pl, labels_pl, is_training_pl, bn_decay=bn_decay)
    ''' === To fix issues when running on woma === '''
    ldic = locals()
    exec('from cls_models.%s import Model' % args.model, globals(), ldic)
    MODEL = ldic['Model'](inputs_pl,
                          npts_pl,
                          labels_pl,
                          is_training_pl,
                          bn_decay=bn_decay)
    pred, loss = MODEL.pred, MODEL.loss
    tf.summary.scalar('loss', loss)
    # pdb.set_trace()

    # useful information in displaying during training
    correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl))
    accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE)
    tf.summary.scalar('accuracy', accuracy)

    learning_rate = get_learning_rate(global_step, BASE_LR, BATCH_SIZE,
                                      DECAY_STEP, DECAY_RATE, LR_CLIP)
    add_train_summary('learning_rate', learning_rate)
    trainer = tf.train.AdamOptimizer(learning_rate)
    train_op = trainer.minimize(MODEL.loss, global_step)
    saver = tf.train.Saver()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    # config.log_device_placement = True
    sess = tf.Session(config=config)

    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'),
                                         sess.graph)
    val_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'val'))

    # Init variables
    init = tf.global_variables_initializer()
    log_string('\nModel Parameters has been Initialized\n')
    sess.run(init, {is_training_pl: True
                    })  # restore will cover the random initialized parameters

    # to save the randomized variables
    if not args.restore and args.just_save:
        save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
        print(
            colored('random initialised model saved at %s' % save_path,
                    'white', 'on_blue'))
        print(colored('just save the model, now exit', 'white', 'on_red'))
        sys.exit()
    '''current solution: first load pretrained head, assemble with output layers then save as a checkpoint'''
    # to partially load the saved head from:
    # if args.load_pretrained_head:
    #   sess.close()
    #   load_pretrained_head(args.pretrained_head_path, os.path.join(LOG_DIR, 'model.ckpt'), None, args.verbose)
    #   print('shared varibles have been restored from ', args.pretrained_head_path)
    #
    #   sess = tf.Session(config=config)
    #   log_string('\nModel Parameters has been Initialized\n')
    #   sess.run(init, {is_training_pl: True})
    #   saver.restore(sess, tf.train.latest_checkpoint(LOG_DIR))
    #   log_string('\nModel Parameters have been restored with pretrained weights from %s' % args.pretrained_head_path)

    if args.restore:
        # load_pretrained_var(args.restore_path, os.path.join(LOG_DIR, "model.ckpt"), args.verbose)
        saver.restore(sess, tf.train.latest_checkpoint(args.restore_path))
        log_string('\n')
        log_string(
            colored(
                'Model Parameters have been restored from %s' %
                args.restore_path, 'white', 'on_red'))

    for arg in sorted(vars(args)):
        print(arg + ': ' + str(getattr(args, arg)) + '\n')  # log of arguments
    os.system('cp cls_models/%s.py %s' %
              (args.model, LOG_DIR))  # bkp of model def
    os.system('cp train_cls.py %s' % LOG_DIR)  # bkp of train procedure

    train_start = time.time()

    ops = {
        'pointclouds_pl': inputs_pl,
        'labels_pl': labels_pl,
        'is_training_pl': is_training_pl,
        'npts_pl': npts_pl,
        'pred': pred,
        'loss': loss,
        'train_op': train_op,
        'merged': merged,
        'step': global_step
    }

    ESC = EarlyStoppingCriterion(patience=args.patience)

    for epoch in range(args.epoch):
        log_string('\n\n')
        log_string(colored('**** EPOCH %03d ****' % epoch, 'grey', 'on_green'))
        sys.stdout.flush()
        '''=== training the model ==='''
        train_one_epoch(sess, ops, train_writer)
        '''=== evaluating the model ==='''
        eval_mean_loss, eval_acc, eval_cls_acc = eval_one_epoch(
            sess, ops, val_writer)
        '''=== check whether to early stop ==='''
        early_stop, save_checkpoint = ESC.step(eval_acc, epoch=epoch)
        if save_checkpoint:
            save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
            log_string(
                colored('model saved at %s' % save_path, 'white', 'on_blue'))
        if early_stop:
            break

    log_string('total time: %s' %
               datetime.timedelta(seconds=time.time() - train_start))
    log_string('stop epoch: %d, best eval acc: %f' %
               (ESC.best_epoch + 1, ESC.best_dev_score))
    sess.close()