Пример #1
0
def train(assign_model_path=None):
    is_training = True
    bn_decay = 0.95
    step = tf.Variable(0, trainable=False)
    learning_rate = BASE_LEARNING_RATE
    tf.summary.scalar('bn_decay', bn_decay)
    tf.summary.scalar('learning_rate', learning_rate)

    # get placeholder
    pointclouds_pl, pointclouds_gt, pointclouds_gt_normal, pointclouds_radius = MODEL_GEN.placeholder_inputs(
        BATCH_SIZE, NUM_POINT, UP_RATIO)

    #create the generator model
    pred, _ = MODEL_GEN.get_gen_model(pointclouds_pl,
                                      is_training,
                                      scope='generator',
                                      bradius=pointclouds_radius,
                                      reuse=None,
                                      use_normal=False,
                                      use_bn=False,
                                      use_ibn=False,
                                      bn_decay=bn_decay,
                                      up_ratio=UP_RATIO)

    #get emd loss
    gen_loss_emd, matchl_out = model_utils.get_emd_loss(
        pred, pointclouds_gt, pointclouds_radius)

    #get repulsion loss
    if USE_REPULSION_LOSS:
        gen_repulsion_loss = model_utils.get_repulsion_loss4(pred)
        tf.summary.scalar('loss/gen_repulsion_loss', gen_repulsion_loss)
    else:
        gen_repulsion_loss = 0.0

    #get total loss function
    pre_gen_loss = 100 * gen_loss_emd + gen_repulsion_loss + tf.losses.get_regularization_loss(
    )

    # create pre-generator ops
    gen_update_ops = [
        op for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if op.name.startswith("generator")
    ]
    gen_tvars = [
        var for var in tf.trainable_variables()
        if var.name.startswith("generator")
    ]

    with tf.control_dependencies(gen_update_ops):
        pre_gen_train = tf.train.AdamOptimizer(
            learning_rate,
            beta1=0.9).minimize(pre_gen_loss,
                                var_list=gen_tvars,
                                colocate_gradients_with_ops=True,
                                global_step=step)
    # merge summary and add pointclouds summary
    tf.summary.scalar('loss/gen_emd', gen_loss_emd)
    tf.summary.scalar('loss/regularation', tf.losses.get_regularization_loss())
    tf.summary.scalar('loss/pre_gen_total', pre_gen_loss)
    pretrain_merged = tf.summary.merge_all()

    pointclouds_image_input = tf.placeholder(tf.float32,
                                             shape=[None, 500, 1500, 1])
    pointclouds_input_summary = tf.summary.image('pointcloud_input',
                                                 pointclouds_image_input,
                                                 max_outputs=1)
    pointclouds_image_pred = tf.placeholder(tf.float32,
                                            shape=[None, 500, 1500, 1])
    pointclouds_pred_summary = tf.summary.image('pointcloud_pred',
                                                pointclouds_image_pred,
                                                max_outputs=1)
    pointclouds_image_gt = tf.placeholder(tf.float32,
                                          shape=[None, 500, 1500, 1])
    pointclouds_gt_summary = tf.summary.image('pointcloud_gt',
                                              pointclouds_image_gt,
                                              max_outputs=1)
    image_merged = tf.summary.merge([
        pointclouds_input_summary, pointclouds_pred_summary,
        pointclouds_gt_summary
    ])

    # Create a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(os.path.join(MODEL_DIR, 'train'),
                                             sess.graph)
        init = tf.global_variables_initializer()
        sess.run(init)
        ops = {
            'pointclouds_pl': pointclouds_pl,
            'pointclouds_gt': pointclouds_gt,
            'pointclouds_gt_normal': pointclouds_gt_normal,
            'pointclouds_radius': pointclouds_radius,
            'pointclouds_image_input': pointclouds_image_input,
            'pointclouds_image_pred': pointclouds_image_pred,
            'pointclouds_image_gt': pointclouds_image_gt,
            'pretrain_merged': pretrain_merged,
            'image_merged': image_merged,
            'gen_loss_emd': gen_loss_emd,
            'pre_gen_train': pre_gen_train,
            'pred': pred,
            'step': step,
        }
        #restore the model
        saver = tf.train.Saver(max_to_keep=6)
        restore_epoch, checkpoint_path = model_utils.pre_load_checkpoint(
            MODEL_DIR)
        global LOG_FOUT
        if restore_epoch == 0:
            LOG_FOUT = open(os.path.join(MODEL_DIR, 'log_train.txt'), 'w')
            LOG_FOUT.write(str(socket.gethostname()) + '\n')
            LOG_FOUT.write(str(FLAGS) + '\n')
        else:
            LOG_FOUT = open(os.path.join(MODEL_DIR, 'log_train.txt'), 'a')
            saver.restore(sess, checkpoint_path)

        ###assign the generator with another model file
        if assign_model_path is not None:
            print "Load pre-train model from %s" % (assign_model_path)
            assign_saver = tf.train.Saver(var_list=[
                var for var in tf.trainable_variables()
                if var.name.startswith("generator")
            ])
            assign_saver.restore(sess, assign_model_path)

        ##read data
        input_data, gt_data, data_radius, _ = data_provider.load_patch_data(
            skip_rate=1,
            num_point=NUM_POINT,
            norm=USE_DATA_NORM,
            use_randominput=USE_RANDOM_INPUT)

        fetchworker = data_provider.Fetcher(input_data, gt_data, data_radius,
                                            BATCH_SIZE, NUM_POINT,
                                            USE_RANDOM_INPUT, USE_DATA_NORM)
        fetchworker.start()
        for epoch in tqdm(range(restore_epoch, MAX_EPOCH + 1), ncols=55):
            log_string('**** EPOCH %03d ****\t' % (epoch))
            train_one_epoch(sess, ops, fetchworker, train_writer)
            if epoch % 20 == 0:
                saver.save(sess,
                           os.path.join(MODEL_DIR, "model"),
                           global_step=epoch)
        fetchworker.shutdown()
Пример #2
0
def prediction_whole_model(data_folder=None, show=False, use_normal=False):
    data_folder = '../data/test_data/our_collected_data/MC_5k'
    phase = data_folder.split('/')[-2] + data_folder.split('/')[-1]
    save_path = os.path.join(MODEL_DIR, 'result/' + phase)

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    samples = glob(data_folder + "/*.xyz")
    samples.sort(reverse=True)
    input = np.loadtxt(samples[0])

    if use_normal:
        pointclouds_ipt = tf.placeholder(tf.float32,
                                         shape=(1, input.shape[0], 6))
    else:
        pointclouds_ipt = tf.placeholder(tf.float32,
                                         shape=(1, input.shape[0], 3))
    pred, _ = MODEL_GEN.get_gen_model(pointclouds_ipt,
                                      is_training=False,
                                      scope='generator',
                                      bradius=1.0,
                                      reuse=None,
                                      use_normal=use_normal,
                                      use_bn=False,
                                      use_ibn=False,
                                      bn_decay=0.95,
                                      up_ratio=UP_RATIO)
    saver = tf.train.Saver()
    _, restore_model_path = model_utils.pre_load_checkpoint(MODEL_DIR)
    print restore_model_path

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    with tf.Session(config=config) as sess:
        saver.restore(sess, restore_model_path)
        samples = glob(data_folder + "/*.xyz")
        samples.sort()
        total_time = 0
        for i, item in enumerate(samples):
            input = np.loadtxt(item)
            gt = input

            # input = data_provider.jitter_perturbation_point_cloud(np.expand_dims(input,axis=0),sigma=0.003,clip=0.006)
            input = np.expand_dims(input, axis=0)

            if not use_normal:
                input = input[:, :, 0:3]
                gt = gt[:, 0:3]
            print item, input.shape

            start_time = time.time()
            pred_pl = sess.run(pred, feed_dict={pointclouds_ipt: input})
            total_time += time.time() - start_time
            norm_pl = np.zeros_like(pred_pl)

            ##--------------visualize predicted point cloud----------------------
            path = os.path.join(save_path, item.split('/')[-1])
            if show:
                f, axis = plt.subplots(3)
                axis[0].imshow(
                    pc_util.point_cloud_three_views(input[0, :, 0:3],
                                                    diameter=5))
                axis[1].imshow(
                    pc_util.point_cloud_three_views(pred_pl[0, :, :],
                                                    diameter=5))
                axis[2].imshow(
                    pc_util.point_cloud_three_views(gt[:, 0:3], diameter=5))
                plt.show()
            data_provider.save_pl(
                path, np.hstack((pred_pl[0, ...], norm_pl[0, ...])))
            path = path[:-4] + '_input.xyz'
            data_provider.save_pl(path, input[0])
        print total_time / 20