Exemple #1
0
 def __init__(self, args, is_training):
     df_train, self.num_train = lmdb_dataflow(args.lmdb_train, args.batch_size,
                                              args.num_input_points, args.num_gt_points, is_training=True)
     batch_train = get_queued_data(df_train.get_data(), [tf.string, tf.float32, tf.float32],
                                   [[args.batch_size],
                                    [args.batch_size, args.num_input_points, 3],
                                    [args.batch_size, args.num_gt_points, 3]])
     df_valid, self.num_valid = lmdb_dataflow(args.lmdb_valid, args.batch_size,
                                              args.num_input_points, args.num_gt_points, is_training=False)
     batch_valid = get_queued_data(df_valid.get_data(), [tf.string, tf.float32, tf.float32],
                                   [[args.batch_size],
                                    [args.batch_size, args.num_input_points, 3],
                                    [args.batch_size, args.num_gt_points, 3]])
     self.batch_data = tf.cond(is_training, lambda: batch_train, lambda: batch_valid)
 def __init__(self, args, is_training):
     df_test, self.num_test = lmdb_dataflow(args.lmdb_test,
                                            args.batch_size,
                                            args.num_input_points,
                                            args.num_gt_points,
                                            is_training=False)
     batch_test = get_queued_data(
         df_test.get_data(), [tf.string, tf.float32, tf.float32],
         [[args.batch_size], [args.batch_size, args.num_input_points, 3],
          [args.batch_size, args.num_gt_points, 3]])
     self.batch_data = batch_test
Exemple #3
0
    avg_acc, acc_update = tf.metrics.accuracy(labels_pl, prediction)
    avg_iou, iou_update = tf.metrics.mean_iou(labels_pl, prediction, n_parts)

    tf.summary.scalar('train/learning rate', learning_rate, collections=['train'])
    tf.summary.scalar('train/gradient norm', global_norm, collections=['train'])
    tf.summary.scalar('train/bn decay', bn_decay, collections=['train'])
    tf.summary.scalar('valid/loss', avg_loss, collections=['valid'])
    tf.summary.scalar('valid/accuracy', avg_acc, collections=['valid'])
    tf.summary.scalar('valid/mean iou', avg_iou, collections=['valid'])
    update_ops = [loss_update, acc_update, iou_update]
    train_summary = tf.summary.merge_all('train')
    valid_summary = tf.summary.merge_all('valid')

    lmdb_train = os.path.join(args.data_dir, 'train.lmdb')
    lmdb_valid = os.path.join(args.data_dir, 'valid.lmdb')
    df_train, num_train = lmdb_dataflow(lmdb_train, args.batch_size, args.num_points,
                                        shuffle=True, render=True, task='seg')
    df_valid, num_valid = lmdb_dataflow(lmdb_valid, args.batch_size, args.num_points,
                                        shuffle=False, task='seg')
    train_gen = df_train.get_data()
    valid_gen = df_valid.get_data()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    writer = create_log_dir(args, sess)

    saver = tf.train.Saver()
    step = sess.run(global_step)
Exemple #4
0
def train(args):
    is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training')
    global_step = tf.Variable(0, trainable=False, name='global_step')
    alpha = tf.train.piecewise_constant(global_step, [10000, 20000, 50000],
                                        [0.01, 0.1, 0.5, 1.0], 'alpha_op')
    inputs_pl = tf.placeholder(tf.float32, (1, None, 3), 'inputs')
    npts_pl = tf.placeholder(tf.int32, (args.batch_size, ), 'num_points')
    gt_pl = tf.placeholder(tf.float32,
                           (args.batch_size, args.num_gt_points, 3),
                           'ground_truths')

    model_module = importlib.import_module('.%s' % args.model_type, 'models')

    model = model_module.Model(inputs_pl, npts_pl, gt_pl, alpha)
    add_train_summary('alpha', alpha)

    if args.lr_decay:
        learning_rate = tf.train.exponential_decay(args.base_lr,
                                                   global_step,
                                                   args.lr_decay_steps,
                                                   args.lr_decay_rate,
                                                   staircase=True,
                                                   name='lr')
        learning_rate = tf.maximum(learning_rate, args.lr_clip)
        add_train_summary('learning_rate', learning_rate)
    else:
        learning_rate = tf.constant(args.base_lr, name='lr')
    train_summary = tf.summary.merge_all('train_summary')
    valid_summary = tf.summary.merge_all('valid_summary')

    trainer = tf.train.AdamOptimizer(learning_rate)
    train_op = trainer.minimize(model.loss, global_step)

    df_train, num_train = lmdb_dataflow(args.lmdb_train,
                                        args.batch_size,
                                        args.num_input_points,
                                        args.num_gt_points,
                                        is_training=True)
    train_gen = df_train.get_data()
    df_valid, num_valid = lmdb_dataflow(args.lmdb_valid,
                                        args.batch_size,
                                        args.num_input_points,
                                        args.num_gt_points,
                                        is_training=False)
    valid_gen = df_valid.get_data()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    saver = tf.train.Saver()

    if args.restore:
        saver.restore(sess, tf.train.latest_checkpoint(args.log_dir))
        writer = tf.summary.FileWriter(args.log_dir)
    else:
        sess.run(tf.global_variables_initializer())
        if os.path.exists(args.log_dir):
            delete_key = input(
                colored('%s exists. Delete? [y (or enter)/N]' % args.log_dir,
                        'white', 'on_red'))
            if delete_key == 'y' or delete_key == "":
                os.system('rm -rf %s/*' % args.log_dir)
                os.makedirs(os.path.join(args.log_dir, 'plots'))
        else:
            os.makedirs(os.path.join(args.log_dir, 'plots'))
        with open(os.path.join(args.log_dir, 'args.txt'), 'w') as log:
            for arg in sorted(vars(args)):
                log.write(arg + ': ' + str(getattr(args, arg)) +
                          '\n')  # log of arguments
        os.system('cp models/%s.py %s' %
                  (args.model_type, args.log_dir))  # bkp of model def
        os.system('cp train.py %s' % args.log_dir)  # bkp of train procedure
        writer = tf.summary.FileWriter(args.log_dir, sess.graph)

    total_time = 0
    train_start = time.time()
    init_step = sess.run(global_step)
    for step in range(init_step + 1, args.max_step + 1):
        epoch = step * args.batch_size // num_train + 1
        ids, inputs, npts, gt = next(train_gen)
        start = time.time()
        feed_dict = {
            inputs_pl: inputs,
            npts_pl: npts,
            gt_pl: gt,
            is_training_pl: True
        }
        _, loss, summary = sess.run([train_op, model.loss, train_summary],
                                    feed_dict=feed_dict)
        total_time += time.time() - start
        writer.add_summary(summary, step)
        if step % args.steps_per_print == 0:
            print('epoch %d  step %d  loss %.8f - time per batch %.4f' %
                  (epoch, step, loss, total_time / args.steps_per_print))
            total_time = 0
        if step % args.steps_per_eval == 0:
            print(colored('Testing...', 'grey', 'on_green'))
            num_eval_steps = num_valid // args.batch_size
            total_loss = 0
            total_time = 0
            sess.run(tf.local_variables_initializer())
            for i in range(num_eval_steps):
                start = time.time()
                ids, inputs, npts, gt = next(valid_gen)
                feed_dict = {
                    inputs_pl: inputs,
                    npts_pl: npts,
                    gt_pl: gt,
                    is_training_pl: False
                }
                loss, _ = sess.run([model.loss, model.update],
                                   feed_dict=feed_dict)
                total_loss += loss
                total_time += time.time() - start
            summary = sess.run(valid_summary,
                               feed_dict={is_training_pl: False})
            writer.add_summary(summary, step)
            print(
                colored(
                    'epoch %d  step %d  loss %.8f - time per batch %.4f' %
                    (epoch, step, total_loss / num_eval_steps,
                     total_time / num_eval_steps), 'grey', 'on_green'))
            total_time = 0
            if step % args.steps_per_visu == 0:
                all_pcds = sess.run(model.visualize_ops, feed_dict=feed_dict)
                for i in range(0, args.batch_size, args.visu_freq):
                    plot_path = os.path.join(
                        args.log_dir, 'plots',
                        'epoch_%d_step_%d_%s.png' % (epoch, step, ids[i]))
                    pcds = [x[i] for x in all_pcds]
                    plot_pcd_three_views(plot_path, pcds,
                                         model.visualize_titles)
        if step % args.steps_per_save == 0:
            saver.save(sess, os.path.join(args.log_dir, 'model'), step)
            print(
                colored('Model saved at %s' % args.log_dir, 'white',
                        'on_blue'))

    print('Total time', datetime.timedelta(seconds=time.time() - train_start))
    sess.close()
Exemple #5
0
def get_data(args):

    # Speicfy batch size just equals to 1
    #NOTE: specifically put is_training to false, because if I put
    # is_training=True, it will result to double ids in an epoch for some reason idk
    df_train, num_train = lmdb_dataflow(args.lmdb_train,
                                        1,
                                        args.num_input_points,
                                        args.num_gt_points,
                                        is_training=False)
    train_gen = df_train.get_data()
    df_valid, num_valid = lmdb_dataflow(args.lmdb_valid,
                                        1,
                                        args.num_input_points,
                                        args.num_gt_points,
                                        is_training=False)
    valid_gen = df_valid.get_data()

    print('==================================================================')
    print('We have {} train files and {} val files'.format(
        num_train, num_valid))
    print('==================================================================')
    ###########################################################################

    # SAVE TRAINING
    id_train_list = []
    for step in tqdm(range(num_train)):
        ids, inputs, npts, gt = next(train_gen)

        # Sanity check
        assert inputs.shape[1] == npts[0], 'number of points do not match'
        # Save input
        with open(
                os.path.join(SAVE_INPUT_PATH_TRAIN, '{:08d}.npy'.format(step)),
                'wb') as f:
            np.save(f, inputs.reshape(-1, 3))
        # Save gt
        with open(os.path.join(SAVE_GT_PATH_TRAIN, '{:08d}.npy'.format(step)),
                  'wb') as f:
            np.save(f, gt.reshape(-1, 3))

        #assert ids[0] not in id_train_list, 'Double id at step {}, id {}'.format(step, ids[0])
        if ids[0] in id_train_list:
            print('step {} is the same with step {}'.format(
                step, id_train_list.index(ids[0])))
        id_train_list.append(ids[0])
    print('==================================================================')
    print('Finish saving {} train files'.format(num_train))
    print('==================================================================')

    ###########################################################################
    # SAVE VALIDATION
    id_val_list = []
    for step in tqdm(range(num_valid)):
        ids, inputs, npts, gt = next(valid_gen)

        # Sanity check
        assert inputs.shape[1] == npts[0], 'number of points do not match'
        print(inputs.shape, gt.shape)
        # Save input
        with open(os.path.join(SAVE_INPUT_PATH_VAL, '{:08d}.npy'.format(step)),
                  'wb') as f:
            np.save(f, inputs.reshape(-1, 3))
        # Save gt
        with open(os.path.join(SAVE_GT_PATH_VAL, '{:08d}.npy'.format(step)),
                  'wb') as f:
            np.save(f, gt.reshape(-1, 3))

        assert ids[0] not in id_val_list, 'Double id at step {}, id {}'.format(
            step, ids[0])
        id_val_list.append(ids[0])
    print('==================================================================')
    print('Finish saving {} val files'.format(num_valid))
    print('==================================================================')
Exemple #6
0
def get_transformed_test_data(
    args,
    xmax=40.,
    zmax=40.,
):
    """Make and save a test data for PCN: 
        apply random y rotation & random xz translation
        Save the point cloud, and also the pose"""
    # Speicfy batch size just equals to 1

    df_valid, num_valid = lmdb_dataflow(args.lmdb_valid,
                                        1,
                                        args.num_input_points,
                                        args.num_gt_points,
                                        is_training=False)
    valid_gen = df_valid.get_data()

    print('==================================================================')
    print('We have {} val files'.format(num_valid))
    print('==================================================================')
    ###########################################################################
    # SAVE VALIDATION
    id_val_list = []
    for step in tqdm(range(num_valid)):
        ids, inputs, npts, gt = next(valid_gen)
        # inputs, gt ---- (1,N_input,3), (1,N_gt,3)

        # Sanity check
        assert inputs.shape[1] == npts[0], 'number of points do not match'

        # (1,?,3) -> (?,3)
        inputs = inputs.reshape(-1, 3)
        gt = gt.reshape(-1, 3)

        print(inputs.shape, gt.shape)

        # Random y-rotation
        ang = np.random.uniform(low=0., high=2. * np.pi)
        R = np.eye(3)
        # First col (x)
        R[0, 0] = np.cos(ang)
        R[2, 0] = -np.sin(ang)
        # Third col (z)
        R[0, 2] = np.sin(ang)
        R[2, 2] = np.cos(ang)

        # Note: before multiplied by rotation need to be (3,N), then
        # move back to (N,3)
        gt_transformed = (R @ gt.transpose()).transpose()
        inputs_transformed = (R @ inputs.transpose()).transpose()

        # Random x-ztranslation
        xtrans = np.random.uniform(low=-xmax, high=xmax)
        ztrans = np.random.uniform(low=-zmax, high=zmax)

        gt_transformed[:, 0] = gt_transformed[:, 0] + xtrans
        gt_transformed[:, 2] = gt_transformed[:, 2] + ztrans
        inputs_transformed[:, 0] = inputs_transformed[:, 0] + xtrans
        inputs_transformed[:, 2] = inputs_transformed[:, 2] + ztrans

        # The pose
        pose = np.array([xtrans, 0, ztrans, ang])

        # Save input
        with open(
                os.path.join(SAVE_INPUT_PATH_VAL_TRANSFORMED,
                             '{:08d}.npy'.format(step)), 'wb') as f:
            np.save(f, inputs_transformed)
        # Save original gt
        with open(
                os.path.join(SAVE_GT_PATH_VAL_ORIGINAL,
                             '{:08d}.npy'.format(step)), 'wb') as f:
            np.save(f, gt)
        # Save gt
        with open(
                os.path.join(SAVE_GT_PATH_VAL_TRANSFORMED,
                             '{:08d}.npy'.format(step)), 'wb') as f:
            np.save(f, gt_transformed)
        # Save pose
        with open(os.path.join(SAVE_POSE_PATH, '{:08d}.npy'.format(step)),
                  'wb') as f:
            np.save(f, pose)

        assert ids[0] not in id_val_list, 'Double id at step {}, id {}'.format(
            step, ids[0])
        id_val_list.append(ids[0])

    print('==================================================================')
    print('Finish saving {} val files'.format(num_valid))
    print('==================================================================')
Exemple #7
0
                f.close()
                print("=== ModelNet40 %s %s Done ===\n" % (t, res))

    elif args.dataset == 'shapenet8':
        print("\n\n=== ShapeNet8 ===\n")
        for t in ['train', 'valid']:
            sum_dict = json.loads(
                open(os.path.join(SHAPENET8_PATH, 'keys.json')).read())
            for key in sum_dict.keys():
                sum_dict[key] = np.zeros(
                    3)  # num of objects, num of points, average

            # the data stored in the lmdb files is with varying number of points
            df, num = lmdb_dataflow(lmdb_path=os.path.join(
                SHAPENET8_PATH, '%s.lmdb' % t),
                                    batch_size=1,
                                    input_size=1000000,
                                    output_size=1,
                                    is_training=False)

            data_gen = df.get_data()
            for _ in tqdm(range(num)):
                ids, _, npts, _ = next(data_gen)
                model_name = ids[0][:8]
                sum_dict[model_name][1] += npts[0]
                sum_dict[model_name][0] += 1

                sum_dict[model_name][
                    2] = sum_dict[model_name][1] / sum_dict[model_name][0]

            f = open("./dump_sum_points/shapenet8_%s.json" % t, "w+")
            for key in sum_dict.keys():
Exemple #8
0
    parser.add_argument("--hdf5_path",
                        type=str,
                        default=r'../data/modelnet40_pcn/hdf5_partial_1024')
    parser.add_argument("--partial",
                        action='store_true',
                        help='store partial scan or not')
    parser.add_argument('--num_per_obj', type=int, default=1024)
    parser.add_argument('--num_scan', type=int, default=10)

    args = parser.parse_args()

    lmdb_file = os.path.join(args.lmdb_path, args.f_name + '.lmdb')
    os.system('mkdir -p %s' % args.hdf5_path)
    df_train, num_train = lmdb_dataflow(lmdb_path=lmdb_file,
                                        batch_size=1,
                                        input_size=args.num_per_obj,
                                        output_size=args.num_per_obj,
                                        is_training=False)

    if args.partial:
        print('Now we generate point cloud from partial observed objects.')

    file_per_h5 = 2048 * 4  # of objects within each hdf5 file
    data_gen = df_train.get_data()

    idx = 0
    data_np = np.zeros((file_per_h5, args.num_per_obj, 3))
    label_np = np.zeros((file_per_h5, ), dtype=np.int32)
    ids_np = np.chararray((file_per_h5, ), itemsize=32)

    # convert label string to integers
def get_transformed_test_data(
    args,
    xmax=40.,
    zmax=40.,
):
    """Make and save a test data for PCN: 
        apply random y rotation & random xz translation
        Save the point cloud, and also the pose"""
    # Speicfy batch size just equals to 1

    df_valid, num_valid = lmdb_dataflow(args.lmdb_valid,
                                        1,
                                        args.num_input_points,
                                        args.num_gt_points,
                                        is_training=False)
    valid_gen = df_valid.get_data()

    posename = 'pose1'
    posename = 'pose_trans'
    posename = 'pose_rot'
    print('==================================================================')
    print('We have {} val files'.format(num_valid))
    print('==================================================================')
    ###########################################################################
    # SAVE VALIDATION
    id_val_list = []
    for step in tqdm(range(num_valid)):
        ids, inputs, npts, gt = next(valid_gen)
        # inputs, gt ---- (1,N_input,3), (1,N_gt,3)

        # Sanity check
        assert inputs.shape[1] == npts[0], 'number of points do not match'

        # (1,?,3) -> (?,3)
        inputs = inputs.reshape(-1, 3)
        gt = gt.reshape(-1, 3)

        print(inputs.shape, gt.shape)

        # TODO
        save_pose_path = '/home/josephinemonica/Documents/gpu_link/joint_pose_and_shape_estimation/data/data-PCN/my_data'
        save_pose_path = os.path.join(save_pose_path, posename)
        with open(os.path.join(save_pose_path, '{:08d}.npy'.format(step)),
                  'rb') as f:
            pose_random = np.load(f)

        # Random y-rotation
        ang = pose_random[3]
        R = np.eye(3)
        # First col (x)
        R[0, 0] = np.cos(ang)
        R[2, 0] = -np.sin(ang)
        # Third col (z)
        R[0, 2] = np.sin(ang)
        R[2, 2] = np.cos(ang)

        # Note: before multiplied by rotation need to be (3,N), then
        # move back to (N,3)
        gt_transformed = (R @ gt.transpose()).transpose()
        inputs_transformed = (R @ inputs.transpose()).transpose()

        # Random x-ztranslation
        xtrans = pose_random[0]
        ztrans = pose_random[2]

        gt_transformed[:, 0] = gt_transformed[:, 0] + xtrans
        gt_transformed[:, 2] = gt_transformed[:, 2] + ztrans
        inputs_transformed[:, 0] = inputs_transformed[:, 0] + xtrans
        inputs_transformed[:, 2] = inputs_transformed[:, 2] + ztrans

        # Save input
        with open(
                os.path.join(SAVE_INPUT_PATH_VAL_TRANSFORMED,
                             '{:08d}.npy'.format(step)), 'wb') as f:
            np.save(f, inputs_transformed)
        # Save original gt
        with open(
                os.path.join(SAVE_GT_PATH_VAL_ORIGINAL,
                             '{:08d}.npy'.format(step)), 'wb') as f:
            np.save(f, gt)
        # Save gt
        with open(
                os.path.join(SAVE_GT_PATH_VAL_TRANSFORMED,
                             '{:08d}.npy'.format(step)), 'wb') as f:
            np.save(f, gt_transformed)

        assert ids[0] not in id_val_list, 'Double id at step {}, id {}'.format(
            step, ids[0])
        id_val_list.append(ids[0])

    print('==================================================================')
    print('Finish saving {} val files'.format(num_valid))
    print('==================================================================')