def compare_kpts(data_dir, loss_to_use, num_keypoints, latent_dim_size, env, img_input, img_size, colour_input, patch_sizes, lsp_layers, noise_type, batch_size, eval_split, tp_fname, pkey_fname, tp_epoch, pkey_epoch, save_base_dir, ablation, _run): # Input params tp_ckpt_load_dir = "transporter_exp/" + img_input + "/" + noise_type \ + "/" + env + "/" + str(num_keypoints) + "/" + \ tp_fname + "/ckpt_" pkey_ckpt_load_dir = "permakey_exp/" + img_input + "/" + noise_type \ + "/" + env + "/" + str(num_keypoints) + "/" + \ pkey_fname + "/ckpt_" model_id = "" if not ablation: # numerical string after '.' as unique model_id model_id = pkey_ckpt_load_dir.split(".")[1][0:6] + "_" \ + tp_ckpt_load_dir.split(".")[1][0:6] elif ablation: model_id = pkey_ckpt_load_dir.split(".")[1][0:6] save_dir = save_base_dir + img_input + "/" + noise_type + "/" \ + env + "/" + str(num_keypoints) + "/" + model_id # setup data pipeline if img_input == "dm_atari": eval_dataset = preprocess.deepmind_atari(data_dir, env, eval_split, loss_to_use, batch_size, noise_type, colour_input) else: raise ValueError("Eval data %s does not exist" % img_input) # load best pkey ckpt models pkey_model_list = create_model() tp_kp_model_list = transporter_train.create_model(colour_input, num_keypoints, 0.1, "transporter") # unpacking models from model list encoder, decoder, lsp_models, pnet = pkey_model_list[0], pkey_model_list[1], \ pkey_model_list[2], pkey_model_list[3] # FIX: run 1 forward pass over models to make it do weight init if colour_input: test_inputs = tf.zeros((batch_size, img_size, img_size, 3)) if not colour_input: test_inputs = tf.zeros((batch_size, img_size, img_size, 1)) _ = ul_loss.pkey_loss(pkey_model_list, test_inputs, latent_dim_size, patch_sizes, batch_size, img_size, lsp_layers, loss_to_use, training=True) # restore best model weights encoder.load_weights(pkey_ckpt_load_dir + 'encoder-' + str(pkey_epoch) + '.h5') decoder.load_weights(pkey_ckpt_load_dir + 'decoder-' + str(pkey_epoch) + '.h5') pnet.load_weights(pkey_ckpt_load_dir + 'pnet-' + str(pkey_epoch) + '.h5') for m in range(len(lsp_models)): lsp_models[m].load_weights(pkey_ckpt_load_dir + 'lsp_model-layer-' + str(lsp_layers[m]) + '-' + str(pkey_epoch) + '.h5') pkey_model_list = [encoder, decoder, lsp_models, pnet] # unpacking models from tp_model_list tp_encoder, keypointer, decoder = tp_kp_model_list[0], tp_kp_model_list[ 1], tp_kp_model_list[2] if colour_input: test_inputs = tf.zeros((batch_size, img_size, img_size, 3, 2)) if not colour_input: test_inputs = tf.zeros((batch_size, img_size, img_size, 1, 2)) _ = ul_loss.transporter_loss(test_inputs, tp_encoder, keypointer, decoder, training=True) # restore best model weights tp_encoder.load_weights(tp_ckpt_load_dir + 'encoder-' + str(tp_epoch) + '.h5') decoder.load_weights(tp_ckpt_load_dir + 'decoder-' + str(tp_epoch) + '.h5') keypointer.load_weights(tp_ckpt_load_dir + 'keypointer-' + str(tp_epoch) + '.h5') batch_num = 0 for x_test in eval_dataset: batch_num = batch_num + 1 # inference using pkey model x_pred, kpts, gauss_mask, error_mask, _ = ul_loss.pkey_loss( pkey_model_list, x_test, latent_dim_size, patch_sizes, batch_size, img_size, lsp_layers, loss_to_use, training=False) # inference using tp_model tp_x_test = tf.stack([x_test, x_test], axis=4) kpts_tp, gauss_mask_tp, features, x_pred_tp, _ = ul_loss.transporter_loss( tp_x_test, tp_encoder, keypointer, decoder, training=False) # logging results for viz if not (os.path.exists(save_dir)): # create the directory you want to save to os.makedirs(save_dir) # saving data from pkey model np.savez( save_dir + "/" + "batch_" + str(batch_num) + "_preds_masks.npz", x_pred, x_test.numpy(), kpts, gauss_mask, error_mask) # saving data from tp_model np.savez(save_dir + "/" + "batch_" + str(batch_num) + "_keypoints.npz", x_pred_tp, x_test.numpy(), kpts_tp, gauss_mask_tp) return 0
def train(img_input, data_dir, env, batch_size, loss_to_use, img_size, lsp_layers, latent_dim_size, max_patience, colour_input, noise_type, patch_sizes, learning_rate, decay_rate, decay_steps, epochs, checkpoint_prefix, _run): # setup data pipeline if img_input == "dm_atari": train_dataset = preprocess.deepmind_atari(data_dir, env, "train", loss_to_use, batch_size, noise_type, colour_input) valid_dataset = preprocess.deepmind_atari(data_dir, env, "valid", loss_to_use, batch_size, noise_type, colour_input) else: raise ValueError("Input data %s does not exist" % img_input) # create models model_list = create_model() # unpacking models encoder, decoder, lsp_models, pnet = model_list[0], model_list[ 1], model_list[2], model_list[3] # setting up optimizer and decay params lr_decay = tf.keras.optimizers.schedules.InverseTimeDecay(learning_rate, decay_steps, decay_rate, staircase=True) optimizer = tf.keras.optimizers.Adam(learning_rate=lr_decay) def train_step(inputs, loss_type): with tf.GradientTape() as vae_tape, tf.GradientTape( persistent=True) as lsp_tape, tf.GradientTape() as pnet_tape: loss_list = ul_loss.pkey_loss(model_list, inputs, latent_dim_size, patch_sizes, batch_size, img_size, lsp_layers, loss_type, training=True) nll_loss, kl_loss, lsp_loss, pnet_loss = loss_list[0], loss_list[ 1], loss_list[2], loss_list[3] vae_loss = nll_loss + kl_loss # vae update # opt_start = time.time() vae_params = encoder.trainable_variables + decoder.trainable_variables vae_grads = vae_tape.gradient(vae_loss, vae_params) optimizer.apply_gradients(zip(vae_grads, vae_params)) # lsp update for l in range(len(lsp_models)): lsp_params = lsp_models[l].trainable_variables lsp_grads = lsp_tape.gradient(lsp_loss, lsp_params) optimizer.apply_gradients(zip(lsp_grads, lsp_params)) pnet_params = pnet.trainable_variables pnet_grads = pnet_tape.gradient(pnet_loss, pnet_params) optimizer.apply_gradients(zip(pnet_grads, pnet_params)) # print("weight updates took %4.5f" % (time.time() - opt_start)) return loss_list def test_step(inputs): x_hat, kpts, gauss_mask, error_mask, loss = ul_loss.pkey_loss( model_list, inputs, latent_dim_size, patch_sizes, batch_size, img_size, lsp_layers, loss_to_use, training=False) return x_hat, kpts, gauss_mask, error_mask, loss # TRAINING LOOP best_validation_loss, best_validation_epoch = float("inf"), -1 patience = 0 step = 0 for epoch in range(epochs): total_nll_loss = 0.0 total_kl_loss = 0.0 total_lsp_loss = 0.0 total_pnet_loss = 0.0 num_batches = 0 # TRAIN LOOP start_time_epoch = time.time() i = 0 for x_train in train_dataset: start_time = time.time() loss = train_step(x_train, loss_to_use) nll_loss, kl_loss, lsp_loss, pnet_loss = loss[0], loss[1], loss[ 2], loss[3] print( "batch number: %4d nll_loss: %4.5f kl_loss: %4.5f lsp_loss: %4.5f pnet_loss: %4.5f took %4.5f s" % (num_batches, nll_loss.numpy(), kl_loss.numpy(), lsp_loss.numpy(), pnet_loss.numpy(), time.time() - start_time)) total_nll_loss = total_nll_loss + nll_loss total_kl_loss = total_kl_loss + kl_loss total_lsp_loss = total_lsp_loss + lsp_loss total_pnet_loss = total_pnet_loss + pnet_loss num_batches += 1 step += 1 # logging losses to Sacred add_sacred_log("train.step_nll_loss", float(nll_loss.numpy()), _run) add_sacred_log("train.step_kl_loss", float(kl_loss.numpy()), _run) add_sacred_log("train.step_lsp_loss", float(lsp_loss.numpy()), _run) add_sacred_log("train.step_pnet_loss", float(pnet_loss.numpy()), _run) add_sacred_log("train.step", step, _run) epoch_nll_loss = total_nll_loss / num_batches epoch_kl_loss = total_kl_loss / num_batches epoch_lsp_loss = total_lsp_loss / num_batches epoch_pnet_loss = total_pnet_loss / num_batches batch_per_second = num_batches / (time.time() - start_time_epoch) # logging avg. epoch losses to Sacred add_sacred_log("train.epoch_nll_loss", float(epoch_nll_loss.numpy()), _run) add_sacred_log("train.epoch_kl_loss", float(epoch_kl_loss.numpy()), _run) add_sacred_log("train.epoch_lsp_loss", float(epoch_lsp_loss.numpy()), _run) add_sacred_log("train.epoch_pnet_loss", float(epoch_pnet_loss.numpy()), _run) add_sacred_log("train.epoch", epoch, _run) # VALIDATION LOOP # end of every epoch compute validation loss and checkpoint models based on that total_valid_nll_loss = 0.0 total_valid_kl_loss = 0.0 total_valid_lsp_loss = 0.0 total_valid_pnet_loss = 0.0 valid_num_batch = 0 for x_valid in valid_dataset: valid_num_batch = valid_num_batch + 1 x_pred, kpts, gauss_mask, error_mask, valid_batch_loss = test_step( x_valid) val_nll_loss, val_kl_loss, val_lsp_loss, val_pnet_loss = valid_batch_loss[0], \ valid_batch_loss[1], valid_batch_loss[2], valid_batch_loss[3] total_valid_nll_loss = total_valid_nll_loss + val_nll_loss total_valid_kl_loss = total_valid_kl_loss + val_kl_loss total_valid_lsp_loss = total_valid_lsp_loss + val_lsp_loss total_valid_pnet_loss = total_valid_pnet_loss + val_pnet_loss epoch_val_nll_loss = total_valid_nll_loss / valid_num_batch epoch_val_kl_loss = total_valid_kl_loss / valid_num_batch epoch_val_lsp_loss = total_valid_lsp_loss / valid_num_batch epoch_val_pnet_loss = total_valid_pnet_loss / valid_num_batch # printing out avg. train end of every epoch print( "end of epoch: %2d avg. train_nll_loss: %3.4f avg. train_kl_loss: %3.4f " " avg. train_lsp_loss: %3.4f avg. pnet_loss: %3.4f batch/s: %3.4f" % (epoch, epoch_nll_loss.numpy(), epoch_kl_loss.numpy(), epoch_val_lsp_loss.numpy(), epoch_val_pnet_loss.numpy(), batch_per_second)) # logging validation_losses to Sacred add_sacred_log("validation.epoch_nll_loss", float(epoch_val_nll_loss.numpy()), _run) add_sacred_log("validation.epoch_kl_loss", float(epoch_val_kl_loss.numpy()), _run) add_sacred_log("validation.epoch_lsp_loss", float(epoch_val_lsp_loss.numpy()), _run) add_sacred_log("validation.epoch_pnet_loss", float(epoch_val_pnet_loss.numpy()), _run) add_sacred_log("validation.epoch", epoch, _run) # checkpointing models based on validation loss validation_loss = epoch_val_pnet_loss + epoch_val_lsp_loss if validation_loss < best_validation_loss: # update best_validation loss best_validation_loss, best_validation_epoch = validation_loss, epoch encoder.save_weights(checkpoint_prefix + '_encoder-' + str(best_validation_epoch) + '.h5') decoder.save_weights(checkpoint_prefix + '_decoder-' + str(best_validation_epoch) + '.h5') for m in range(len(lsp_models)): lsp_models[m].save_weights(checkpoint_prefix + '_lsp_model-layer-' + str(lsp_layers[m]) + '-' + str(best_validation_epoch) + '.h5') pnet.save_weights(checkpoint_prefix + '_pnet-' + str(best_validation_epoch) + '.h5') # early_stopping param resets patience = 0 # early stopping check elif validation_loss >= best_validation_loss: patience = patience + 1 # break out if max_patience is reached if patience == max_patience: break print("Training complete!! Best validation loss : %3.4f achieved at epoch" ": %2d" % (best_validation_loss, best_validation_epoch)) add_sacred_log("validation.best_val_loss", float(best_validation_loss), _run) add_sacred_log("validation.best_val_epoch", best_validation_epoch, _run) return best_validation_loss, best_validation_epoch
def evaluate(data_dir, env, ckpt_load_dir, test_logs_prefix, loss_to_use, noise_type, eval_split, img_input, colour_input, num_keypoints, gauss_std, batch_size, epoch, _run): test_inputs, keypoints, heatmaps, x_pred = 0.0, 0.0, 0.0, 0.0 encoder, keypointer, decoder = None, None, None # setup data pipeline if img_input == "dm_atari": eval_dataset = preprocess.deepmind_atari(data_dir, env, eval_split, loss_to_use, batch_size, noise_type, colour_input) else: raise ValueError("Eval data %s does not exist" % img_input) # load best ckpt models if loss_to_use == "transporter": encoder = TransporterEncoder() keypointer = TransporterKeypointer(num_keypoints=num_keypoints, gauss_std=gauss_std) decoder = TransporterDecoder(colour_input) # FIX: run 1 forward pass over models to make it do weight init if colour_input: test_inputs = tf.zeros((batch_size, 84, 84, 3, 2)) if not colour_input: test_inputs = tf.zeros((batch_size, 84, 84, 1, 2)) _ = ul_loss.transporter_loss(test_inputs, encoder, keypointer, decoder, training=True) # restore best model weights encoder.load_weights(ckpt_load_dir + 'encoder-' + str(epoch) + '.h5') decoder.load_weights(ckpt_load_dir + 'decoder-' + str(epoch) + '.h5') keypointer.load_weights(ckpt_load_dir + 'keypointer-' + str(epoch) + '.h5') batch_num = 0 test_recon_loss = 0.0 for x_test in eval_dataset: batch_num = batch_num + 1 if loss_to_use == "transporter": keypoints, heatmaps, features, x_pred, loss = ul_loss.transporter_loss( x_test, encoder, keypointer, decoder, training=False) test_recon_loss = test_recon_loss + loss # saving data if not (os.path.exists(test_logs_prefix)): # create the directory you want to save to os.makedirs(test_logs_prefix) np.savez( test_logs_prefix + "/" + "epoch_" + str(epoch) + "_batch_" + str(batch_num) + "_keypoints.npz", x_pred, x_test.numpy(), keypoints, heatmaps) # log test loss test_recon_loss = test_recon_loss / batch_num # logging avg. test epoch losses to Sacred add_sacred_log("test.epoch_recon_loss", float(test_recon_loss.numpy()), _run) print(" avg. test_nll_loss: %3.4f " % (test_recon_loss.numpy())) return 0.
def evaluate(ckpt_load_dir, test_logs_prefix, data_dir, loss_to_use, img_size, latent_dim_size, env, img_input, colour_input, patch_sizes, lsp_layers, noise_type, batch_size, epoch, eval_split, _run): # setup data pipeline if img_input == "dm_atari": eval_dataset = preprocess.deepmind_atari(data_dir, env, eval_split, loss_to_use, batch_size, noise_type, colour_input) else: raise ValueError("Eval data %s does not exist" % img_input) # load best ckpt models model_list = create_model() # unpacking models from model list encoder, decoder, lsp_models, pnet = model_list[0], model_list[1], model_list[2],\ model_list[3] # FIX: run 1 forward pass over models to make it do weight init if colour_input: test_inputs = tf.zeros((batch_size, img_size, img_size, 3)) if not colour_input: test_inputs = tf.zeros((batch_size, img_size, img_size, 1)) _ = ul_loss.pkey_loss(model_list, test_inputs, latent_dim_size, patch_sizes, batch_size, img_size, lsp_layers, loss_to_use, training=True) # restore best model weights encoder.load_weights(ckpt_load_dir + 'encoder-' + str(epoch) + '.h5') decoder.load_weights(ckpt_load_dir + 'decoder-' + str(epoch) + '.h5') pnet.load_weights(ckpt_load_dir + 'pnet-' + str(epoch) + '.h5') for m in range(len(lsp_models)): lsp_models[m].load_weights(ckpt_load_dir + 'lsp_model-layer-' + str(lsp_layers[m]) + '-' + str(epoch) + '.h5') model_list = [encoder, decoder, lsp_models, pnet] batch_num = 0 test_nll_loss = 0.0 test_kl_loss = 0.0 test_lsp_loss = 0.0 test_pnet_loss = 0.0 for x_test in eval_dataset: batch_num = batch_num + 1 x_pred, kpts, gauss_mask, error_mask, loss = ul_loss.pkey_loss( model_list, x_test, latent_dim_size, patch_sizes, batch_size, img_size, lsp_layers, loss_to_use, training=False) nll_loss, kl_loss, lsp_loss, pnet_loss = loss[0], loss[1], loss[ 2], loss[3] test_nll_loss = test_nll_loss + nll_loss test_kl_loss = test_kl_loss + kl_loss test_lsp_loss = test_lsp_loss + lsp_loss test_pnet_loss = test_pnet_loss + pnet_loss # saving data if not (os.path.exists(test_logs_prefix)): # create the directory you want to save to os.makedirs(test_logs_prefix) np.savez( test_logs_prefix + "/" + "epoch_" + str(epoch) + "_batch_" + str(batch_num) + "_preds_masks.npz", x_pred, x_test.numpy(), kpts, gauss_mask, error_mask) # log test loss test_nll_loss = test_nll_loss / batch_num test_kl_loss = test_kl_loss / batch_num test_lsp_loss = test_lsp_loss / batch_num test_pnet_loss = test_pnet_loss / batch_num # logging avg. test epoch losses to Sacred add_sacred_log("test.epoch_nll_loss", float(test_nll_loss.numpy()), _run) add_sacred_log("test.epoch_kl_loss", float(test_kl_loss.numpy()), _run) add_sacred_log("test.epoch_lsp_loss", float(test_lsp_loss.numpy()), _run) add_sacred_log("test.epoch_pnet_loss", float(test_pnet_loss.numpy()), _run) print( "%s: avg._nll_loss: %3.4f avg. kl_loss: %3.4f avg. lsp_loss: %3.4f avg. pnet_loss: %3.4f" % (eval_split, test_nll_loss.numpy(), test_kl_loss.numpy(), test_lsp_loss.numpy(), test_pnet_loss.numpy())) # checkpointing models based on validation loss loss = test_pnet_loss + test_lsp_loss return loss
def train(img_input, data_dir, env, batch_size, loss_to_use, decay_steps, decay_rate, max_patience, colour_input, noise_type, learning_rate, epochs, checkpoint_prefix, _run): # setup data pipeline if img_input == "dm_atari": train_dataset = preprocess.deepmind_atari(data_dir, env, "train", loss_to_use, batch_size, noise_type, colour_input) valid_dataset = preprocess.deepmind_atari(data_dir, env, "valid", loss_to_use, batch_size, noise_type, colour_input) else: raise ValueError("Input data %s does not exist" % img_input) # create models if loss_to_use == "transporter": encoder, keypointer, decoder = create_model() # setting up checkpointing and summaries lr_decay = tf.keras.optimizers.schedules.InverseTimeDecay(learning_rate, decay_steps, decay_rate, staircase=True) optimizer = tf.keras.optimizers.Adam(learning_rate=lr_decay) def train_step(images, loss_type): if loss_type == "transporter": with tf.GradientTape() as tape: reconstruction_loss = ul_loss.transporter_loss(images, encoder, keypointer, decoder, training=True) # update params model_params = encoder.trainable_variables + keypointer.trainable_variables + decoder.trainable_variables grads = tape.gradient(reconstruction_loss, model_params) optimizer.apply_gradients(zip(grads, model_params)) return reconstruction_loss def test_step(images, loss_type): if loss_type == "transporter": keypoints, heatmaps, features, x_hat, loss = ul_loss.transporter_loss( images, encoder, keypointer, decoder, training=False) return keypoints, heatmaps, x_hat, loss # training best_validation_loss, best_validation_epoch = float( "inf"), -1 # val_loss, val_epoch patience = 0 step = 0 for epoch in range(epochs): total_recon_loss = 0.0 num_batches = 0 # TRAIN LOOP start_time_epoch = time.time() i = 0 for x_train in train_dataset: start_time = time.time() loss = train_step(x_train, loss_to_use) print("batch number: %4d reconstruction_loss: %4.5f took %4.5f s" % (num_batches, loss.numpy(), time.time() - start_time)) total_recon_loss = total_recon_loss + loss num_batches += 1 step += 1 # logging train vae and pm losses to Sacred add_sacred_log("train.step_reconstruction_loss", float(loss.numpy()), _run) add_sacred_log("train.step", step, _run) epoch_recon_loss = total_recon_loss / num_batches batch_per_second = num_batches / (time.time() - start_time_epoch) # logging avg. epoch losses to Sacred add_sacred_log("train.epoch_reconstruction_loss", float(epoch_recon_loss.numpy()), _run) add_sacred_log("train.epoch", epoch, _run) # VALIDATION LOOP # end of every epoch compute validation loss and checkpoint models based on that total_valid_recon_loss = 0.0 valid_num_batch = 0 for x_valid in valid_dataset: keypoints, heatmaps, x_hat, valid_batch_loss = test_step( x_valid, loss_to_use) total_valid_recon_loss = total_valid_recon_loss + valid_batch_loss valid_num_batch = valid_num_batch + 1 epoch_val_recon_loss = total_valid_recon_loss / valid_num_batch # printing out avg. train end of every epoch print( "end of epoch: %2d avg. train_recon_loss: %3.4f avg. batch/s: %3.4f" % (epoch, epoch_recon_loss.numpy(), batch_per_second)) # printing out avg.validation losses print("end of epoch: %2d avg. val_recon_loss: %3.4f batch/s: %3.4f" % (epoch, epoch_val_recon_loss.numpy(), batch_per_second)) # logging validation_losses to Sacred add_sacred_log("validation.epoch_nll_loss", float(epoch_val_recon_loss.numpy()), _run) add_sacred_log("validation.epoch", epoch, _run) # checkpointing models based on validation loss validation_loss = epoch_val_recon_loss # epoch_val_nll_loss + epoch_val_kl_loss + if validation_loss.numpy() < best_validation_loss: # update best_validation loss best_validation_loss, best_validation_epoch = validation_loss.numpy( ), epoch encoder.save_weights(checkpoint_prefix + '_encoder-' + str(best_validation_epoch) + '.h5') decoder.save_weights(checkpoint_prefix + '_decoder-' + str(best_validation_epoch) + '.h5') keypointer.save_weights(checkpoint_prefix + '_keypointer-' + str(best_validation_epoch) + '.h5') # early_stopping param resets patience = 0 # early stopping check elif validation_loss.numpy() >= best_validation_loss: patience = patience + 1 # break out if max_patience is reached if patience == max_patience: break print( "Training complete!! Best validation loss : %3.4f achieved at epoch: %2d" % (best_validation_loss, best_validation_epoch)) add_sacred_log("validation.best_val_loss", float(best_validation_loss), _run) add_sacred_log("validation.best_val_epoch", best_validation_epoch, _run) return best_validation_loss, best_validation_epoch