def full_bootstrap(batch_size):
    bootstrap_flow = bootstrap.bootstrap_net(batch_size)

    bootstrap_flow._make_predict_function()

    bootstrap_out = bootstrap_flow.outputs

    image_pair = bootstrap_flow.inputs[0]
    flow_out = bootstrap_out[2]
    conf_out = bootstrap_out[3]

    bootstrap_motion_input = [image_pair, flow_out, conf_out]

    bootstrap_motion_outputs = bootstrap.bootstrap_net_depth(batch_size)(
        bootstrap_motion_input)

    bootstrap_full = Model(inputs=image_pair, outputs=bootstrap_motion_outputs)
    bootstrap_full._make_predict_function()

    return bootstrap_full
def train_bootstrap_depth_motion(data_loader):
    checkpoint_path = TRAINING_SAVE_FOLDER + "cp_depth_motion_bootstrap.ckpt"
    flow_checkpoint_path = TRAINING_SAVE_FOLDER + "cp.ckpt"

    # Create checkpoint callback
    cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                     save_weights_only=True,
                                                     verbose=1,
                                                     period=1000)
    # create terminate on nan callback
    nan_callback = tf.keras.callbacks.TerminateOnNaN()

    bootstrap_flow = bootstrap.bootstrap_net(32)

    # Load weights
    bootstrap_flow.load_weights(flow_checkpoint_path)
    bootstrap_flow._make_predict_function()

    # Fix weights
    for layer in bootstrap_flow.layers:
        layer.trainable = False

    print(bootstrap_flow.summary())

    bootstrap_out = bootstrap_flow.outputs

    image_pair = bootstrap_flow.inputs[0]
    flow_out = bootstrap_out[2]
    conf_out = bootstrap_out[3]

    bootstrap_motion_input = [image_pair, flow_out, conf_out]

    bootstrap_motion_outputs = bootstrap.bootstrap_net_depth(32)(
        bootstrap_motion_input)

    bootstrap_full = Model(inputs=image_pair, outputs=bootstrap_motion_outputs)
    bootstrap_full._make_predict_function()

    # Compile model
    adam = Adam(decay=0.0004)
    loss_list = [
        custom_losses.mean_abs_error,
        custom_losses.euclidean_with_gradient_loss,
        custom_losses.normals_loss_from_depth_gt
    ]
    weights = [15, 300, 100]
    bootstrap_full.compile(loss=loss_list,
                           optimizer=adam,
                           loss_weights=weights)
    print(bootstrap_full.summary())

    # plot_model(bootstrap_net, to_file='bootstrap.png', show_shapes=True)

    # Train model for 250k epochs
    history = bootstrap_full.fit_generator(
        generate_motion_depth_normals(data_loader),
        steps_per_epoch=1,
        epochs=250000,
        callbacks=[cp_callback, nan_callback])

    print("Training complete!")

    with open(
            TRAINING_SAVE_FOLDER +
            'train_history_bootstrap_depth_motion.pickle', 'wb+') as file_pi:
        pickle.dump(history.history, file_pi)
        print("Saved weights to file")