예제 #1
0
파일: itn.py 프로젝트: hmgoforth/pcn
    def compute_loss(self, inputs, gt_pose, est_pose):
        # see equation (1) from IT-net
        # est_pose: world -> body coord
        # gt_pose: body -> world coord (to to invert when applying to inputs)
        est_inputs = transform_tf(inputs, est_pose)
        gt_inputs = transform_tf(inputs, tf.linalg.inv(gt_pose))
        sq_dist = tf.reduce_sum(tf.square(est_inputs - gt_inputs), axis=2)
        loss = tf.reduce_mean(tf.reduce_mean(sq_dist, axis=1), axis=0)

        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, update_loss
예제 #2
0
    def create_loss(self, coarse, fine, gt, alpha):
        loss_coarse = chamfer(coarse, gt)
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = chamfer(fine, gt)
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
예제 #3
0
    def create_loss(self, coarse, fine, gt, alpha):
        gt_ds = gt[:, :coarse.shape[1], :]
        loss_coarse = earth_mover(coarse, gt_ds)
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = chamfer(fine, gt)
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
예제 #4
0
    def create_loss(self, coarse_highres, coarse, fine, gt, theta):
        loss_coarse_highres = chamfer(coarse_highres, gt)

        loss_coarse = chamfer(coarse, gt)
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = chamfer(fine, gt)
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        repulsion_loss = get_repulsion_loss4(coarse)

        loss = 0.5 * loss_coarse_highres + loss_coarse + theta * loss_fine + 0.2 * repulsion_loss
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, loss_fine, [update_coarse, update_fine, update_loss]
예제 #5
0
    def create_loss(self, coarse, fine, gt, alpha):
        gt_ds = gt[:, :coarse.shape[1], :]
        loss_coarse = 10 * earth_mover(coarse[:, :, 0:3], gt_ds[:, :, 0:3])
        _, retb, _, retd = tf_nndistance.nn_distance(coarse[:, :, 0:3],
                                                     gt_ds[:, :, 0:3])
        for i in range(np.shape(gt_ds)[0]):
            index = tf.expand_dims(retb[i], -1)
            sem_feat = tf.nn.softmax(coarse[i, :, 3:], -1)
            sem_gt = tf.cast(
                tf.one_hot(
                    tf.gather_nd(tf.cast(gt_ds[i, :, 3] * 80 * 12, tf.int32),
                                 index), 12), tf.float32)
            loss_sem_coarse = tf.reduce_mean(-tf.reduce_sum(
                0.9 * sem_gt * tf.log(1e-6 + sem_feat) + (1 - 0.9) *
                (1 - sem_gt) * tf.log(1e-6 + 1 - sem_feat), [1]))
            loss_coarse += loss_sem_coarse
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = 10 * chamfer(fine[:, :, 0:3], gt[:, :, 0:3])
        _, retb, _, retd = tf_nndistance.nn_distance(fine[:, :, 0:3], gt[:, :,
                                                                         0:3])
        for i in range(np.shape(gt)[0]):
            index = tf.expand_dims(retb[i], -1)
            sem_feat = tf.nn.softmax(fine[i, :, 3:], -1)
            sem_gt = tf.cast(
                tf.one_hot(
                    tf.gather_nd(tf.cast(gt[i, :, 3] * 80 * 12, tf.int32),
                                 index), 12), tf.float32)
            loss_sem_fine = tf.reduce_mean(-tf.reduce_sum(
                0.9 * sem_gt * tf.log(1e-6 + sem_feat) + (1 - 0.9) *
                (1 - sem_gt) * tf.log(1e-6 + 1 - sem_feat), [1]))
            loss_fine += loss_sem_fine
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
예제 #6
0
파일: fc.py 프로젝트: mihaibujanca/pcn
 def create_loss(self, outputs, gt):
     loss = chamfer(outputs, gt)
     add_train_summary('train/loss', loss)
     update_loss = add_valid_summary('valid/loss', loss)
     return loss, update_loss
예제 #7
0
def train(args):
    min_loss_fine = 1.0
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    global_step = tf.Variable(0, trainable=False, name='global_step')
    alpha = tf.train.piecewise_constant(global_step, [10000, 20000, 50000],
                                        [0.01, 0.1, 0.5, 1.0], 'alpha_op')

    provider = TrainProvider(args, is_training_pl)
    ids, inputs, gt = provider.batch_data
    num_eval_steps = provider.num_valid // args.batch_size

    model_module = importlib.import_module('.%s' % args.model_type, 'models')
    model = model_module.Model(inputs, gt, alpha, is_training_pl)
    add_train_summary('alpha', alpha)

    if args.lr_decay:
        learning_rate = tf.train.exponential_decay(args.base_lr,
                                                   global_step,
                                                   args.lr_decay_steps,
                                                   args.lr_decay_rate,
                                                   staircase=True,
                                                   name='lr')
        learning_rate = tf.maximum(learning_rate, args.lr_clip)
        add_train_summary('learning_rate', learning_rate)
    else:
        learning_rate = tf.constant(args.base_lr, name='lr')

    trainer = tf.train.AdamOptimizer(learning_rate)
    train_op = trainer.minimize(model.loss, global_step)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver(max_to_keep=15)
    if args.restore:
        saver.restore(sess, tf.train.latest_checkpoint(args.log_dir))
        #saver.restore(sess, 'data/trained_models/pcn_cd')
    else:
        if os.path.exists(args.log_dir):
            delete_key = input(
                colored('%s exists. Delete? [y (or enter)/N]' % args.log_dir,
                        'white', 'on_red'))
            if delete_key == 'y' or delete_key == "":
                os.system('rm -rf %s/*' % args.log_dir)
                os.makedirs(os.path.join(args.log_dir, 'plots'))
        else:
            os.makedirs(os.path.join(args.log_dir, 'plots'))
        with open(os.path.join(args.log_dir, 'args.txt'), 'w') as log:
            for arg in sorted(vars(args)):
                log.write(arg + ': ' + str(getattr(args, arg)) +
                          '\n')  # log of arguments
        os.system('cp models/%s.py %s' %
                  (args.model_type, args.log_dir))  # bkp of model def
        os.system('cp train.py %s' % args.log_dir)  # bkp of train procedure

    train_summary = tf.summary.merge_all('train_summary')
    valid_summary = tf.summary.merge_all('valid_summary')
    writer = tf.summary.FileWriter(args.log_dir, sess.graph)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    total_time = 0
    train_start = time.time()
    step = sess.run(global_step)
    while not coord.should_stop():
        step += 1
        epoch = step * args.batch_size // provider.num_train + 1
        start = time.time()
        _, loss, loss_fine, summary = sess.run(
            [train_op, model.loss, model.loss_fine, train_summary],
            feed_dict={is_training_pl: True})
        total_time += time.time() - start
        writer.add_summary(summary, step)
        if step % args.steps_per_print == 0:
            print(
                'epoch %d  step %d  loss %.8f loss_fine %.8f - time per batch %.4f'
                % (epoch, step, loss, loss_fine,
                   total_time / args.steps_per_print))
            total_time = 0
        if step < 100000:
            steps_per_eval = args.steps_per_eval * 10
        else:
            steps_per_eval = args.steps_per_eval
        if step % steps_per_eval == 0:
            print(colored('Testing...', 'grey', 'on_green'))
            total_loss = 0
            total_time = 0
            total_loss_fine = 0
            sess.run(tf.local_variables_initializer())
            for i in range(num_eval_steps):
                start = time.time()
                loss, loss_fine, _ = sess.run(
                    [model.loss, model.loss_fine, model.update],
                    feed_dict={is_training_pl: False})
                total_loss += loss
                total_loss_fine += loss_fine
                total_time += time.time() - start
            summary = sess.run(valid_summary,
                               feed_dict={is_training_pl: False})
            writer.add_summary(summary, step)
            print(
                colored(
                    'epoch %d  step %d  loss %.8f loss_fine %.8f - time per batch %.4f'
                    % (epoch, step, total_loss / num_eval_steps,
                       total_loss_fine / num_eval_steps,
                       total_time / num_eval_steps), 'grey', 'on_green'))
            total_time = 0
            if (total_loss_fine / num_eval_steps) < min_loss_fine:
                min_loss_fine = total_loss_fine / num_eval_steps
                saver.save(sess, os.path.join(args.log_dir, 'model'), step)
                print(
                    colored('Model saved at %s' % args.log_dir, 'white',
                            'on_blue'))
        if step % args.steps_per_visu == 0:
            model_id, pcds = sess.run([ids[0], model.visualize_ops],
                                      feed_dict={is_training_pl: True})
            model_id = model_id.decode('utf-8')
            plot_path = os.path.join(
                args.log_dir, 'plots',
                'epoch_%d_step_%d_%s.png' % (epoch, step, model_id))
            plot_pcd_three_views(plot_path, pcds, model.visualize_titles)
        #if step % args.steps_per_save == 0:
        #    saver.save(sess, os.path.join(args.log_dir, 'model'), step)
        #    print(colored('Model saved at %s' % args.log_dir, 'white', 'on_blue'))
        if step >= args.max_step:
            break
    print('Total time', datetime.timedelta(seconds=time.time() - train_start))
    coord.request_stop()
    coord.join(threads)
    sess.close()
예제 #8
0
def train(args):
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    global_step = tf.Variable(0, trainable=False, name='global_step')
    alpha = tf.train.piecewise_constant(global_step, [10000, 20000, 50000],
                                        [0.01, 0.1, 0.5, 1.0], 'alpha_op')
    inputs_pl = tf.placeholder(tf.float32, (1, None, 3), 'inputs')
    npts_pl = tf.placeholder(tf.int32, (args.batch_size, ), 'num_points')
    gt_pl = tf.placeholder(tf.float32,
                           (args.batch_size, args.num_gt_points, 3),
                           'ground_truths')

    model_module = importlib.import_module('.%s' % args.model_type, 'models')

    model = model_module.Model(inputs_pl, npts_pl, gt_pl, alpha)
    add_train_summary('alpha', alpha)

    if args.lr_decay:
        learning_rate = tf.train.exponential_decay(args.base_lr,
                                                   global_step,
                                                   args.lr_decay_steps,
                                                   args.lr_decay_rate,
                                                   staircase=True,
                                                   name='lr')
        learning_rate = tf.maximum(learning_rate, args.lr_clip)
        add_train_summary('learning_rate', learning_rate)
    else:
        learning_rate = tf.constant(args.base_lr, name='lr')
    train_summary = tf.summary.merge_all('train_summary')
    valid_summary = tf.summary.merge_all('valid_summary')

    trainer = tf.train.AdamOptimizer(learning_rate)
    train_op = trainer.minimize(model.loss, global_step)

    df_train, num_train = lmdb_dataflow(args.lmdb_train,
                                        args.batch_size,
                                        args.num_input_points,
                                        args.num_gt_points,
                                        is_training=True)
    train_gen = df_train.get_data()
    df_valid, num_valid = lmdb_dataflow(args.lmdb_valid,
                                        args.batch_size,
                                        args.num_input_points,
                                        args.num_gt_points,
                                        is_training=False)
    valid_gen = df_valid.get_data()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    saver = tf.train.Saver()

    if args.restore:
        saver.restore(sess, tf.train.latest_checkpoint(args.log_dir))
        writer = tf.summary.FileWriter(args.log_dir)
    else:
        sess.run(tf.global_variables_initializer())
        if os.path.exists(args.log_dir):
            delete_key = input(
                colored('%s exists. Delete? [y (or enter)/N]' % args.log_dir,
                        'white', 'on_red'))
            if delete_key == 'y' or delete_key == "":
                os.system('rm -rf %s/*' % args.log_dir)
                os.makedirs(os.path.join(args.log_dir, 'plots'))
        else:
            os.makedirs(os.path.join(args.log_dir, 'plots'))
        with open(os.path.join(args.log_dir, 'args.txt'), 'w') as log:
            for arg in sorted(vars(args)):
                log.write(arg + ': ' + str(getattr(args, arg)) +
                          '\n')  # log of arguments
        os.system('cp models/%s.py %s' %
                  (args.model_type, args.log_dir))  # bkp of model def
        os.system('cp train.py %s' % args.log_dir)  # bkp of train procedure
        writer = tf.summary.FileWriter(args.log_dir, sess.graph)

    total_time = 0
    train_start = time.time()
    init_step = sess.run(global_step)
    for step in range(init_step + 1, args.max_step + 1):
        epoch = step * args.batch_size // num_train + 1
        ids, inputs, npts, gt = next(train_gen)
        start = time.time()
        feed_dict = {
            inputs_pl: inputs,
            npts_pl: npts,
            gt_pl: gt,
            is_training_pl: True
        }
        _, loss, summary = sess.run([train_op, model.loss, train_summary],
                                    feed_dict=feed_dict)
        total_time += time.time() - start
        writer.add_summary(summary, step)
        if step % args.steps_per_print == 0:
            print('epoch %d  step %d  loss %.8f - time per batch %.4f' %
                  (epoch, step, loss, total_time / args.steps_per_print))
            total_time = 0
        if step % args.steps_per_eval == 0:
            print(colored('Testing...', 'grey', 'on_green'))
            num_eval_steps = num_valid // args.batch_size
            total_loss = 0
            total_time = 0
            sess.run(tf.local_variables_initializer())
            for i in range(num_eval_steps):
                start = time.time()
                ids, inputs, npts, gt = next(valid_gen)
                feed_dict = {
                    inputs_pl: inputs,
                    npts_pl: npts,
                    gt_pl: gt,
                    is_training_pl: False
                }
                loss, _ = sess.run([model.loss, model.update],
                                   feed_dict=feed_dict)
                total_loss += loss
                total_time += time.time() - start
            summary = sess.run(valid_summary,
                               feed_dict={is_training_pl: False})
            writer.add_summary(summary, step)
            print(
                colored(
                    'epoch %d  step %d  loss %.8f - time per batch %.4f' %
                    (epoch, step, total_loss / num_eval_steps,
                     total_time / num_eval_steps), 'grey', 'on_green'))
            total_time = 0
            if step % args.steps_per_visu == 0:
                all_pcds = sess.run(model.visualize_ops, feed_dict=feed_dict)
                for i in range(0, args.batch_size, args.visu_freq):
                    plot_path = os.path.join(
                        args.log_dir, 'plots',
                        'epoch_%d_step_%d_%s.png' % (epoch, step, ids[i]))
                    pcds = [x[i] for x in all_pcds]
                    plot_pcd_three_views(plot_path, pcds,
                                         model.visualize_titles)
        if step % args.steps_per_save == 0:
            saver.save(sess, os.path.join(args.log_dir, 'model'), step)
            print(
                colored('Model saved at %s' % args.log_dir, 'white',
                        'on_blue'))

    print('Total time', datetime.timedelta(seconds=time.time() - train_start))
    sess.close()
예제 #9
0
파일: train_itn.py 프로젝트: hmgoforth/pcn
def train(args):
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    global_step = tf.Variable(0, trainable=False, name='global_step')
    provider = TrainProvider(args, is_training_pl)

    ids, inputs, gt_pose = provider.batch_data
    num_eval_steps = provider.num_valid // args.batch_size

    model = itn.ITN(inputs, gt_pose, args.iterations,
                    args.validation_iterations, args.no_batchnorm,
                    args.rot_representation, is_training_pl)

    if not args.no_lr_decay:
        learning_rate = tf.train.exponential_decay(args.base_lr,
                                                   global_step,
                                                   args.lr_decay_steps,
                                                   args.lr_decay_rate,
                                                   staircase=True,
                                                   name='lr')
        learning_rate = tf.maximum(learning_rate, args.lr_clip)
        add_train_summary('learning_rate', learning_rate)
    else:
        learning_rate = tf.constant(args.base_lr, name='lr')

    trainer = tf.train.AdamOptimizer(learning_rate)
    train_op = trainer.minimize(model.loss, global_step)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    saver = tf.train.Saver()

    now = datetime.datetime.now(pytz.timezone('US/Pacific'))
    hr = '{:02d}'.format(now.hour)
    mn = '{:02d}'.format(now.minute)
    dy = '{:02d}'.format(now.day)
    mt = '{:02d}'.format(now.month)
    yr = '{:04d}'.format(now.year)
    log_dir = args.log_dir + '_'.join(['', hr, mn, dy, mt, yr])

    if args.restore:
        saver.restore(sess, tf.train.latest_checkpoint(log_dir))
    else:
        if os.path.exists(log_dir):
            delete_key = input(
                colored('%s exists. Delete? [y (or enter)/N]' % log_dir,
                        'white', 'on_red'))
            if delete_key == 'y' or delete_key == "":
                os.system('rm -rf %s' % log_dir)
                os.makedirs(log_dir)
                os.makedirs(os.path.join(log_dir, 'plots'))
        else:
            os.makedirs(os.path.join(log_dir, 'plots'))
        with open(os.path.join(log_dir, 'args.txt'), 'w') as log:
            for arg in sorted(vars(args)):
                log.write(arg + ': ' + str(getattr(args, arg)) +
                          '\n')  # log of arguments
        os.system('cp models/itn.py %s' % (log_dir))  # bkp of model def
        os.system('cp train.py %s' % log_dir)  # bkp of train procedure

    writer = tf.summary.FileWriter(log_dir, sess.graph)

    train_summary = tf.summary.merge_all('train_summary')
    valid_summary = tf.summary.merge_all('valid_summary')
    writer = tf.summary.FileWriter(log_dir, sess.graph)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    total_time = 0
    train_start = time.time()
    step = sess.run(global_step)

    while not coord.should_stop():
        step += 1
        epoch = step * args.batch_size // provider.num_train + 1
        start = time.time()
        __, loss, summary = sess.run([train_op, model.loss, train_summary],
                                     feed_dict={is_training_pl: True})
        total_time += time.time() - start
        writer.add_summary(summary, step)

        if step % args.steps_per_print == 0:
            print('epoch %d  step %d  loss %.8f - time per batch %.4f' %
                  (epoch, step, loss, total_time / args.steps_per_print))
            total_time = 0

        if step % args.steps_per_eval == 0:
            print(colored('Testing...', 'grey', 'on_green'))
            total_loss = 0
            total_time = 0
            sess.run(tf.local_variables_initializer())
            for i in range(num_eval_steps):
                start = time.time()
                loss, _ = sess.run([model.loss, model.update],
                                   feed_dict={is_training_pl: False})
                total_loss += loss
                total_time += time.time() - start
            summary = sess.run(valid_summary,
                               feed_dict={is_training_pl: False})
            writer.add_summary(summary, step)
            print(
                colored(
                    'epoch %d  step %d  loss %.8f - time per batch %.4f' %
                    (epoch, step, total_loss / num_eval_steps,
                     total_time / num_eval_steps), 'grey', 'on_green'))
            total_time = 0

        if step % args.steps_per_visu == 0:
            model_id, pcds = sess.run([ids[0], model.visualize_ops],
                                      feed_dict={is_training_pl: True})
            model_id = model_id.decode('utf-8')
            plot_path = os.path.join(
                log_dir, 'plots',
                'epoch_%d_step_%d_%s.png' % (epoch, step, model_id))
            plot_pcd_three_views(plot_path, pcds, model.visualize_titles)

        if step % args.steps_per_save == 0:
            saver.save(sess, os.path.join(log_dir, 'model'), step)
            print(colored('Model saved at %s' % log_dir, 'white', 'on_blue'))

        if step >= args.max_step:
            break

    print('Total time', datetime.timedelta(seconds=time.time() - train_start))
    coord.request_stop()
    coord.join(threads)
    sess.close()
def train(args):
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    global_step = tf.Variable(0, trainable=False, name='global_step')

    #Note that theta is a parameter used for progressive training
    theta = tf.train.piecewise_constant(global_step, [10000, 20000, 50000],
                                        [0.01, 0.1, 0.5, 1.0], 'theta_op')

    provider = TrainProvider(args, is_training_pl)
    ids, inputs, gt = provider.batch_data
    num_eval_steps = provider.num_test // args.batch_size

    print('provider.num_valid', provider.num_test)
    print('num_eval_steps', num_eval_steps)

    model_module = importlib.import_module('.%s' % args.model_type, 'models')
    model = model_module.Model(inputs, gt, theta, False)
    add_train_summary('alpha', theta)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver(max_to_keep=10)
    saver.restore(sess, args.model_path)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    start_time = time.time()
    while not coord.should_stop():
        print(colored('Testing...', 'grey', 'on_green'))
        total_time = 0
        total_loss_fine = 0
        cd_per_cat = {}
        sess.run(tf.local_variables_initializer())
        for j in range(num_eval_steps):
            start = time.time()
            ids_eval, inputs_eval, gt_eval, loss_fine, fine = sess.run(
                [ids, inputs, gt, model.loss_fine, model.fine],
                feed_dict={is_training_pl: False})
            synset_id = str(ids_eval[0]).split('_')[0].split('\'')[1]
            total_loss_fine += loss_fine
            total_time += time.time() - start

            if not cd_per_cat.get(synset_id):
                cd_per_cat[synset_id] = []
            cd_per_cat[synset_id].append(loss_fine)

            if args.plot:
                for i in range(args.batch_size):
                    model_id = str(ids_eval[i]).split('_')[1]
                    os.makedirs(os.path.join(args.save_path, 'plots',
                                             synset_id),
                                exist_ok=True)
                    plot_path = os.path.join(args.save_path, 'plots',
                                             synset_id, '%s.png' % model_id)
                    plot_pcd_three_views(plot_path,
                                         [inputs_eval[i], fine[i], gt_eval[i]],
                                         ['input', 'output', 'ground truth'],
                                         'CD %.4f' % (loss_fine),
                                         [0.5, 0.5, 0.5])
        print('Average Chamfer distance: %f' %
              (total_loss_fine / num_eval_steps))
        print('Chamfer distance per category')
        dict_novel = {
            '02924116': 'Bus',
            '02818832': 'Bed',
            '02871439': 'bookshelf',
            '02828884': 'bench',
            '03467517': 'guitar',
            '03790512': 'motorbike',
            '04225987': 'skateboard',
            '03948459': 'pistol'
        }
        temp_loss = 0
        for synset_id in dict_novel.keys():
            temp_loss += np.mean(cd_per_cat[synset_id])
            print(dict_novel[synset_id],
                  ' %f' % np.mean(cd_per_cat[synset_id]))
        break
    print('Total time', datetime.timedelta(seconds=time.time() - start_time))
    coord.request_stop()
    coord.join(threads)
    sess.close()
def train(args):
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    global_step = tf.Variable(0, trainable=False, name='global_step')

    #Note that theta is a parameter used for progressive training
    theta = tf.train.piecewise_constant(global_step, [10000, 20000, 50000],
                                        [0.01, 0.1, 0.5, 1.0], 'theta_op')

    provider = TrainProvider(args, is_training_pl)
    ids, inputs, gt = provider.batch_data
    num_eval_steps = provider.num_test // args.batch_size

    print('provider.num_valid', provider.num_test)
    print('num_eval_steps', num_eval_steps)

    model_module = importlib.import_module('.%s' % args.model_type, 'models')
    model = model_module.Model(inputs, gt, theta, False)
    add_train_summary('alpha', theta)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver(max_to_keep=10)
    saver.restore(sess, args.model_path)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    start_time = time.time()
    while not coord.should_stop():
        print(colored('Testing...', 'grey', 'on_green'))
        total_time = 0
        total_loss_fine = 0
        cd_per_cat = {}
        sess.run(tf.local_variables_initializer())
        for j in range(num_eval_steps):
            start = time.time()
            ids_eval, inputs_eval, gt_eval, loss_fine, fine = sess.run(
                [ids, inputs, gt, model.loss_fine, model.fine],
                feed_dict={is_training_pl: False})
            pc_gt = open3d.geometry.PointCloud()
            pc_pr = open3d.geometry.PointCloud()
            print(np.squeeze(gt_eval).shape)
            pc_gt.points = open3d.Vector3dVector(np.squeeze(gt_eval))
            pc_pr.points = open3d.Vector3dVector(np.squeeze(fine))

            f_score = calculate_fscore(pc_gt, pc_pr)
            # print('f_score:', f_score)

            synset_id = str(ids_eval[0]).split('_')[0].split('\'')[1]
            total_loss_fine += f_score
            total_time += time.time() - start

            if not cd_per_cat.get(synset_id):
                cd_per_cat[synset_id] = []
            cd_per_cat[synset_id].append(f_score)

            # if args.plot:
            #     for i in range(args.batch_size):
            #         model_id = str(ids_eval[i]).split('_')[1]
            #         os.makedirs(os.path.join(args.save_path, 'plots', synset_id), exist_ok=True)
            #         plot_path = os.path.join(args.save_path, 'plots', synset_id, '%s.png' % model_id)
            #         plot_pcd_three_views(plot_path, [inputs_eval[i], fine[i], gt_eval[i]],
            #                              ['input', 'output', 'ground truth'],
            #                              'CD %.4f' % (loss_fine),
            #                              [0.5, 0.5, 0.5])
        print('Average F_score: %f' % (total_loss_fine / num_eval_steps))
        print('F_score per category')
        dict_known = {
            '02691156': 'airplane',
            '02933112': 'cabinet',
            '02958343': 'car',
            '03001627': 'chair',
            '03636649': 'lamp',
            '04256520': 'sofa',
            '04379243': 'table',
            '04530566': 'vessel'
        }
        temp_loss = 0
        for synset_id in dict_known.keys():
            temp_loss += np.mean(cd_per_cat[synset_id])
            print(dict_known[synset_id],
                  ' %f' % np.mean(cd_per_cat[synset_id]))
        break
    print('Total time', datetime.timedelta(seconds=time.time() - start_time))
    coord.request_stop()
    coord.join(threads)
    sess.close()
def train(args):
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    global_step = tf.Variable(0, trainable=False, name='global_step')

    # Note that theta is a parameter used for progressive training
    theta = tf.train.piecewise_constant(global_step, [10000, 20000, 50000],
                                        [0.01, 0.1, 0.5, 1.0], 'theta_op')

    provider = TrainProvider(args, is_training_pl)
    ids, inputs, gt = provider.batch_data
    num_eval_steps = provider.num_valid // args.batch_size

    print('provider.num_valid', provider.num_valid)
    print('num_eval_steps', num_eval_steps)

    model_module = importlib.import_module('.%s' % args.model_type, 'models')
    model = model_module.Model(inputs, gt, theta, False)
    add_train_summary('alpha', theta)

    # [new] to output pcds
    # out_path = '/mnt/data3/zwx/results_pcds'
    # f_out_pcd = h5py.File( os.path.join(out_path, 'SFA_point_comp_pcds.h5'), 'w')
    # g_output_pcd = f_out_pcd.create_group("output")

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver(max_to_keep=10)
    saver.restore(sess, args.model_path)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    start_time = time.time()
    while not coord.should_stop():
        print(colored('Testing...', 'grey', 'on_green'))
        total_time = 0
        total_loss_fine = 0
        cd_per_cat = {}
        sess.run(tf.local_variables_initializer())
        for j in range(num_eval_steps):
            start = time.time()
            ids_eval, inputs_eval, gt_eval, loss_fine, fine = sess.run(
                [ids, inputs, gt, model.loss_fine, model.fine],
                feed_dict={is_training_pl: False})
            synset_id = str(ids_eval[0]).split('_')[0].split('\'')[1]
            total_loss_fine += loss_fine
            total_time += time.time() - start

            if not cd_per_cat.get(synset_id):
                cd_per_cat[synset_id] = []
            cd_per_cat[synset_id].append(loss_fine)

            # [new] to output pcds
            # for i in range(args.batch_size):
            #     ind = args.batch_size * j + i
            #     g_output_pcd[f"{ind}"] = fine[i]

            #            print('ids : ', str(ids_eval[0]),synset_id)

            dir = str(ids_eval[0]).split('\'')[1].replace('_', '/')
            ofile = os.path.join('benchmark', dir + '.h5')
            if not os.path.exists('benchmark/all'):
                os.mkdir('benchmark')
                os.mkdir('benchmark/all')
            # fname = clouds_data[0][idx][:clouds_data[0][idx].rfind('.')]
            # synset = fname.split('/')[-2]
            # outp = outputs[idx:idx + 1, ...].squeeze()
            # odir = args.odir + '/benchmark/%s' % (synset)
            # if not os.path.isdir(odir):
            #     print("Creating %s ..." % (odir))
            #     os.makedirs(odir)
            # ofile = os.path.join(odir, fname.split('/')[-1])
            print("Saving to %s ..." % (ofile))
            print(fine.shape)
            with h5py.File(ofile, "w") as f:
                f.create_dataset("data", data=np.squeeze(fine))

            # if args.plot:
            #     for i in range(args.batch_size):
            #         model_id = str(ids_eval[i]).split('_')[1]
            #         os.makedirs(os.path.join(args.save_path, 'plots', synset_id), exist_ok=True)
            #         plot_path = os.path.join(args.save_path, 'plots', synset_id, '%s.png' % model_id)
            #         plot_pcd_three_views(plot_path, [inputs_eval[i], fine[i], gt_eval[i]],
            #                              ['input', 'output', 'ground truth'],
            #                              'CD %.4f' % (loss_fine),
            #                              [0.5, 0.5, 0.5])
        cur_dir = os.getcwd()
        subprocess.call("cd %s; zip -r submission.zip *; cd %s" %
                        ('benchmark', cur_dir),
                        shell=True)
        print('zip file generated.')
        print('Average Chamfer distance: %f' %
              (total_loss_fine / num_eval_steps))
        # print(colored('epoch %d  step %d  loss %.8f loss_fine %.8f - time per batch %.4f' %
        #               (epoch, step, total_loss / num_eval_steps, total_loss_fine / num_eval_steps, total_time / num_eval_steps),
        #               'grey', 'on_green'))
        print('Chamfer distance per category')
        dict_known = {
            '02691156': 'airplane',
            '02933112': 'cabinet',
            '02958343': 'car',
            '03001627': 'chair',
            '03636649': 'lamp',
            '04256520': 'sofa',
            '04379243': 'table',
            '04530566': 'vessel'
        }
        dict_novel = {
            '02924116': 'Bus',
            '02818832': 'Bed',
            '02871439': 'bookshelf',
            '02828884': 'bench',
            '03467517': 'guitar',
            '03790512': 'motorbike',
            '04225987': 'skateboard',
            '03948459': 'pistol'
        }
        dict_known_list = [
            '02691156', '02933112', '02958343', '03001627', '03636649',
            '04256520', '04379243', '04530566'
        ]
        dict_novel_list = [
            '02924116', '02818832', '02871439', '02828884', '03467517',
            '03790512', '04225987', '03948459'
        ]
        temp_loss = 0
        for synset_id in dict_known_list:
            # print(len(dict_novel_list[:4]))
            temp_loss += np.mean(cd_per_cat[synset_id])
            print('%f' % np.mean(cd_per_cat[synset_id]), '&', end='')
        print(temp_loss / 8)
        # temp_loss=0
        # for synset_id in cd_per_cat.keys():
        #     temp_loss += np.mean(cd_per_cat[synset_id])
        #     print(dict[synset_id], '%f' % np.mean(cd_per_cat[synset_id]))
        break
    print('Total time', datetime.timedelta(seconds=time.time() - start_time))
    coord.request_stop()
    coord.join(threads)
    sess.close()

    f_out_pcd.close()