def full_bootstrap(batch_size): bootstrap_flow = bootstrap.bootstrap_net(batch_size) bootstrap_flow._make_predict_function() bootstrap_out = bootstrap_flow.outputs image_pair = bootstrap_flow.inputs[0] flow_out = bootstrap_out[2] conf_out = bootstrap_out[3] bootstrap_motion_input = [image_pair, flow_out, conf_out] bootstrap_motion_outputs = bootstrap.bootstrap_net_depth(batch_size)( bootstrap_motion_input) bootstrap_full = Model(inputs=image_pair, outputs=bootstrap_motion_outputs) bootstrap_full._make_predict_function() return bootstrap_full
def train_bootstrap_depth_motion(data_loader): checkpoint_path = TRAINING_SAVE_FOLDER + "cp_depth_motion_bootstrap.ckpt" flow_checkpoint_path = TRAINING_SAVE_FOLDER + "cp.ckpt" # Create checkpoint callback cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True, verbose=1, period=1000) # create terminate on nan callback nan_callback = tf.keras.callbacks.TerminateOnNaN() bootstrap_flow = bootstrap.bootstrap_net(32) # Load weights bootstrap_flow.load_weights(flow_checkpoint_path) bootstrap_flow._make_predict_function() # Fix weights for layer in bootstrap_flow.layers: layer.trainable = False print(bootstrap_flow.summary()) bootstrap_out = bootstrap_flow.outputs image_pair = bootstrap_flow.inputs[0] flow_out = bootstrap_out[2] conf_out = bootstrap_out[3] bootstrap_motion_input = [image_pair, flow_out, conf_out] bootstrap_motion_outputs = bootstrap.bootstrap_net_depth(32)( bootstrap_motion_input) bootstrap_full = Model(inputs=image_pair, outputs=bootstrap_motion_outputs) bootstrap_full._make_predict_function() # Compile model adam = Adam(decay=0.0004) loss_list = [ custom_losses.mean_abs_error, custom_losses.euclidean_with_gradient_loss, custom_losses.normals_loss_from_depth_gt ] weights = [15, 300, 100] bootstrap_full.compile(loss=loss_list, optimizer=adam, loss_weights=weights) print(bootstrap_full.summary()) # plot_model(bootstrap_net, to_file='bootstrap.png', show_shapes=True) # Train model for 250k epochs history = bootstrap_full.fit_generator( generate_motion_depth_normals(data_loader), steps_per_epoch=1, epochs=250000, callbacks=[cp_callback, nan_callback]) print("Training complete!") with open( TRAINING_SAVE_FOLDER + 'train_history_bootstrap_depth_motion.pickle', 'wb+') as file_pi: pickle.dump(history.history, file_pi) print("Saved weights to file")