# predicted_transforms = tf.concat([expected_transforms[:, :3, :3], tf.reshape(predicted_transforms[:, :3, 3], shape=[batch_size, 3, 1])], axis=-1)

# transforms depth maps by the predicted transformation
depth_maps_predicted, cloud_pred = tf.map_fn(lambda x:at3._simple_transformer(X2_pooled[x,:,:,0]*40.0 + 40.0, predicted_transforms[x], K_final, small_transform), elems = tf.range(0, batch_size * time_step, 1), dtype = (tf.float32, tf.float32))

# transforms depth maps by the expected transformation
depth_maps_expected, cloud_exp = tf.map_fn(lambda x:at3._simple_transformer(X2_pooled[x,:,:,0]*40.0 + 40.0, expected_transforms[x], K_final, small_transform), elems = tf.range(0, batch_size * time_step, 1), dtype = (tf.float32, tf.float32))

# photometric loss between predicted and expected transformation
photometric_loss = tf.nn.l2_loss(tf.subtract((depth_maps_expected[:,10:-10,10:-10] - 40.0)/40.0, (depth_maps_predicted[:,10:-10,10:-10] - 40.0)/40.0))

# point cloud distance between point clouds
cloud_loss = model_utils.get_cd_loss(cloud_pred, cloud_exp)
# earth mover's distance between point clouds
emd_loss = model_utils.get_emd_loss(cloud_pred, cloud_exp)
# regression loss
output_vectors_exp = tf.map_fn(lambda x: transform_functions.convert(expected_transforms[x]), elems=tf.range(0, batch_size, 1), dtype=tf.float32)
output_vectors_exp = tf.squeeze(output_vectors_exp)
tr_loss = tf.norm(output_vectors[:, :3] - output_vectors_exp[:, :3], axis=1)
ro_loss = tf.norm(output_vectors[:, 3:] - output_vectors_exp[:, 3:], axis=1)
tr_loss = tf.nn.l2_loss(tr_loss)
ro_loss = tf.nn.l2_loss(ro_loss)

# final loss term
train_loss = _GAMMA_CONST * tr_loss + _THETA_CONST * emd_loss + _ALPHA_CONST * photometric_loss + _EPSILON_CONST * ro_loss

tf.add_to_collection('losses1', train_loss)
loss1 = tf.add_n(tf.get_collection('losses1'))

predicted_loss_validation = tf.nn.l2_loss(tf.subtract((depth_maps_expected[:,10:-10,10:-10] - 40.0)/40.0, (depth_maps_predicted[:,10:-10,10:-10] - 40.0)/40.0))
Beispiel #2
0
def train(assign_model_path=None):
    is_training = True
    bn_decay = 0.95
    step = tf.Variable(0, trainable=False)
    learning_rate = BASE_LEARNING_RATE
    tf.summary.scalar('bn_decay', bn_decay)
    tf.summary.scalar('learning_rate', learning_rate)

    # get placeholder
    pointclouds_pl, pointclouds_gt, pointclouds_gt_normal, pointclouds_radius = MODEL_GEN.placeholder_inputs(
        BATCH_SIZE, NUM_POINT, UP_RATIO)

    #create the generator model
    pred, _ = MODEL_GEN.get_gen_model(pointclouds_pl,
                                      is_training,
                                      scope='generator',
                                      bradius=pointclouds_radius,
                                      reuse=None,
                                      use_normal=False,
                                      use_bn=False,
                                      use_ibn=False,
                                      bn_decay=bn_decay,
                                      up_ratio=UP_RATIO)

    #get emd loss
    gen_loss_emd, matchl_out = model_utils.get_emd_loss(
        pred, pointclouds_gt, pointclouds_radius)

    #get repulsion loss
    if USE_REPULSION_LOSS:
        gen_repulsion_loss = model_utils.get_repulsion_loss4(pred)
        tf.summary.scalar('loss/gen_repulsion_loss', gen_repulsion_loss)
    else:
        gen_repulsion_loss = 0.0

    #get total loss function
    pre_gen_loss = 100 * gen_loss_emd + gen_repulsion_loss + tf.losses.get_regularization_loss(
    )

    # create pre-generator ops
    gen_update_ops = [
        op for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if op.name.startswith("generator")
    ]
    gen_tvars = [
        var for var in tf.trainable_variables()
        if var.name.startswith("generator")
    ]

    with tf.control_dependencies(gen_update_ops):
        pre_gen_train = tf.train.AdamOptimizer(
            learning_rate,
            beta1=0.9).minimize(pre_gen_loss,
                                var_list=gen_tvars,
                                colocate_gradients_with_ops=True,
                                global_step=step)
    # merge summary and add pointclouds summary
    tf.summary.scalar('loss/gen_emd', gen_loss_emd)
    tf.summary.scalar('loss/regularation', tf.losses.get_regularization_loss())
    tf.summary.scalar('loss/pre_gen_total', pre_gen_loss)
    pretrain_merged = tf.summary.merge_all()

    pointclouds_image_input = tf.placeholder(tf.float32,
                                             shape=[None, 500, 1500, 1])
    pointclouds_input_summary = tf.summary.image('pointcloud_input',
                                                 pointclouds_image_input,
                                                 max_outputs=1)
    pointclouds_image_pred = tf.placeholder(tf.float32,
                                            shape=[None, 500, 1500, 1])
    pointclouds_pred_summary = tf.summary.image('pointcloud_pred',
                                                pointclouds_image_pred,
                                                max_outputs=1)
    pointclouds_image_gt = tf.placeholder(tf.float32,
                                          shape=[None, 500, 1500, 1])
    pointclouds_gt_summary = tf.summary.image('pointcloud_gt',
                                              pointclouds_image_gt,
                                              max_outputs=1)
    image_merged = tf.summary.merge([
        pointclouds_input_summary, pointclouds_pred_summary,
        pointclouds_gt_summary
    ])

    # Create a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(os.path.join(MODEL_DIR, 'train'),
                                             sess.graph)
        init = tf.global_variables_initializer()
        sess.run(init)
        ops = {
            'pointclouds_pl': pointclouds_pl,
            'pointclouds_gt': pointclouds_gt,
            'pointclouds_gt_normal': pointclouds_gt_normal,
            'pointclouds_radius': pointclouds_radius,
            'pointclouds_image_input': pointclouds_image_input,
            'pointclouds_image_pred': pointclouds_image_pred,
            'pointclouds_image_gt': pointclouds_image_gt,
            'pretrain_merged': pretrain_merged,
            'image_merged': image_merged,
            'gen_loss_emd': gen_loss_emd,
            'pre_gen_train': pre_gen_train,
            'pred': pred,
            'step': step,
        }
        #restore the model
        saver = tf.train.Saver(max_to_keep=6)
        restore_epoch, checkpoint_path = model_utils.pre_load_checkpoint(
            MODEL_DIR)
        global LOG_FOUT
        if restore_epoch == 0:
            LOG_FOUT = open(os.path.join(MODEL_DIR, 'log_train.txt'), 'w')
            LOG_FOUT.write(str(socket.gethostname()) + '\n')
            LOG_FOUT.write(str(FLAGS) + '\n')
        else:
            LOG_FOUT = open(os.path.join(MODEL_DIR, 'log_train.txt'), 'a')
            saver.restore(sess, checkpoint_path)

        ###assign the generator with another model file
        if assign_model_path is not None:
            print "Load pre-train model from %s" % (assign_model_path)
            assign_saver = tf.train.Saver(var_list=[
                var for var in tf.trainable_variables()
                if var.name.startswith("generator")
            ])
            assign_saver.restore(sess, assign_model_path)

        ##read data
        input_data, gt_data, data_radius, _ = data_provider.load_patch_data(
            skip_rate=1,
            num_point=NUM_POINT,
            norm=USE_DATA_NORM,
            use_randominput=USE_RANDOM_INPUT)

        fetchworker = data_provider.Fetcher(input_data, gt_data, data_radius,
                                            BATCH_SIZE, NUM_POINT,
                                            USE_RANDOM_INPUT, USE_DATA_NORM)
        fetchworker.start()
        for epoch in tqdm(range(restore_epoch, MAX_EPOCH + 1), ncols=55):
            log_string('**** EPOCH %03d ****\t' % (epoch))
            train_one_epoch(sess, ops, fetchworker, train_writer)
            if epoch % 20 == 0:
                saver.save(sess,
                           os.path.join(MODEL_DIR, "model"),
                           global_step=epoch)
        fetchworker.shutdown()
# transforms depth maps by the expected transformation
depth_maps_expected, cloud_exp = tf.map_fn(lambda x: at3._simple_transformer(
    X2_pooled[x, :, :, 0] * 40.0 + 40.0, expected_transforms[x], K_final,
    small_transform),
                                           elems=tf.range(
                                               0, batch_size * time_step, 1),
                                           dtype=(tf.float32, tf.float32))

# photometric loss between predicted and expected transformation
photometric_loss = tf.nn.l2_loss(
    tf.subtract((depth_maps_expected[:, 10:-10, 10:-10] - 40.0) / 40.0,
                (depth_maps_predicted[:, 10:-10, 10:-10] - 40.0) / 40.0))

# earth mover's distance between point clouds
cloud_loss = model_utils.get_cd_loss(cloud_pred, cloud_exp)
emd_loss = model_utils.get_emd_loss(cloud_pred, cloud_exp)

# final loss term
train_loss = _ALPHA_CONST * photometric_loss + _BETA_CONST * cloud_loss + _THETA_CONST * emd_loss

tf.add_to_collection('losses1', train_loss)
loss1 = tf.add_n(tf.get_collection('losses1'))

predicted_loss_test = tf.nn.l2_loss(
    tf.subtract((depth_maps_expected[:, 10:-10, 10:-10] - 40.0) / 40.0,
                (depth_maps_predicted[:, 10:-10, 10:-10] - 40.0) / 40.0))
cloud_loss_test = model_utils.get_cd_loss(cloud_pred, cloud_exp)
emd_loss_test = model_utils.get_emd_loss(cloud_pred, cloud_exp)
output_vectors_exp = tf.map_fn(
    lambda x: transform_functions.convert(expected_transforms[x]),
    elems=tf.range(0, batch_size, 1),