# 全部的训练iteration
        total_iterations_train = int(current_epoch*config.net_params['total_frames_train']/(batch_size * time_step))

        # 全部的验证iteration
        total_iterations_validate = int(current_epoch*config.net_params['total_frames_validation']/(batch_size * time_step))

        # 全部测试iteration
        # total_iterations_test = int(current_epoch / 10 * config.net_params['total_frames_test']/(batch_size))
        total_iterations_test = 0

    for epoch in range(current_epoch, n_epochs):
        total_partitions_train = config.net_params['total_frames_train']/config.net_params['partition_limit']
        total_partitions_validation = config.net_params['total_frames_validation']/config.net_params['partition_limit']
        total_partitions_test = config.net_params['total_frames_test']/config.net_params['partition_limit']

        ldr.dataset_train, ldr.dataset_validation = ldr.shuffle(ldr.dataset_train, ldr.dataset_validation)

        for part in range(int(total_partitions_train)):
            source_container, target_container, source_img_container, target_img_container, transforms_container = ldr.load(part, mode = "train")
            for source_b, target_b, source_img_b, target_img_b, transforms_b in zip(source_container, target_container, source_img_container, target_img_container, transforms_container):

                outputs= sess.run([depth_maps_predicted, depth_maps_expected, train_loss, X2_pooled, train_step, merge_train, predicted_transforms, cloud_loss, photometric_loss, loss1, emd_loss, tr_loss, ro_loss],
                                  feed_dict={X1: source_img_b,
                                             X2: source_b,
                                             depth_maps_target: target_b,
                                             expected_transforms: transforms_b,
                                             phase: True,
                                             fc_keep_prob: 0.7,
                                             phase_rgb: True})

                dmaps_pred = outputs[0]
示例#2
0
    writer = tf.summary.FileWriter("./logs_simple_transformer/")

    total_iterations_train = 0
    total_iterations_validate = 0

    writer.add_graph(sess.graph)

    checkpoint_path = config.paths['checkpoint_path']

    print("Restoring Checkpoint")

    saver.restore(sess, checkpoint_path + "/model-%d"%current_epoch)

    total_partitions_train = config.net_params['total_frames_train']/config.net_params['partition_limit']
    total_partitions_validation = config.net_params['total_frames_validation']/config.net_params['partition_limit']
    ldr.shuffle()

    source_container, target_container, source_img_container, target_img_container, transforms_container = ldr.load(0, mode = "inference")

    outputs = sess.run([depth_maps_predicted, depth_maps_expected, predicted_loss_train, predicted_transforms], feed_dict={X1: source_img_container[0], X2: source_container[0], depth_maps_target: target_container[0], expected_transforms: transforms_container[0] ,phase:True, keep_prob:0.5, phase_rgb: False})

    dmaps_pred = outputs[0]
    dmaps_exp = outputs[1]
    loss = outputs[2]

    print(dmaps_pred.shape)
    print(dmaps_exp.shape)
    print(outputs[3])
    print(source_img_container.shape)

# cv2.imwrite('result.png', dmaps_pred[0, :, :])