def load_df(exp_id, min_shots = None): dirname = variational_model_PATH + "/trained_models/{0}_good/".format(exp_id) data_dict_list = [] for filename in os.listdir(dirname): if "info" not in filename: continue try: info_dict = pickle.load(open(dirname + filename, "rb")) except: print("Cannot load {0}".format(filename)) data_record = info_dict["data_record"] data_dict = parse_dict(filename) if data_dict["exp_mode"] == "meta": model = load_model_dict(info_dict["model_dict"][-1]) data_dict["mse_few_shot"] = plot_few_shot_loss(model, tasks_test, min_shots = min_shots) elif data_dict["exp_mode"] == "baseline": print(filename) model = load_model_dict_net(info_dict["model_dict"][-1]) data_dict["mse_few_shot"] = [None] print("\n{0}".format(filename)) plot_data_record(info_dict["data_record"], is_VAE = is_VAE) data_dict["mean_last"] = np.mean(data_record['mse_mean_test'][-4:-1]) data_dict["median_last"] = np.mean(data_record['mse_median_test'][-4:-1]) data_dict["mean_min"] = np.min(data_record['mse_mean_test']) data_dict["median_min"] = np.min(data_record['mse_median_test']) print("mean_last: {0:6f}\tmean_min: {1:.6f}\tmedian_last: {2:.6f}\tmedian_min: {3:.6f}".format(data_dict["mean_last"], data_dict["mean_min"], data_dict["median_last"], data_dict["median_min"])) if data_dict["exp_mode"] == "meta": statistics_list_test, z_list_test = plot_task_ensembles(tasks_test, model.statistics_Net, model.generative_Net, is_VAE = is_VAE, title = "y_pred_test vs. y_test") print("test statistics vs. z:") data_dict["corr_info"] = plot_statistics_vs_z(z_list_test, statistics_list_test) if "bounce" in task_id_list[0]: if data_dict["exp_mode"] == "meta": plot_individual_tasks_bounce(tasks_test, num_examples_show = 40, num_tasks_show = 9, master_model = model, num_shots = 200) elif data_dict["exp_mode"] == "baseline": plot_individual_tasks_bounce(tasks_test, num_examples_show = 40, num_tasks_show = 9, model = model, num_shots = 200) else: _ = plot_individual_tasks(tasks_test, master_model.statistics_Net, master_model.generative_Net, is_VAE = is_VAE, xlim = task_settings["xlim"]) data_dict_list.append(data_dict) df = pd.DataFrame(data_dict_list) df = df.sort_values(by = "median_last") return df
def load_file(exp_id, filename): dirname = variational_model_PATH + "/trained_models/{0}_good/".format(exp_id) info_dict = pickle.load(open(dirname + filename, "rb")) data_record = info_dict["data_record"] data_dict = parse_dict(filename) master_model = load_model_dict(info_dict["model_dict"][-1]) print("\n{0}".format(filename)) data_dict["mse_few_shot"] = plot_few_shot_loss(master_model, tasks_test) plot_data_record(info_dict["data_record"], is_VAE = is_VAE) data_dict["mean_last"] = np.mean(data_record['mse_mean_test'][-6:-1]) data_dict["median_last"] = np.mean(data_record['mse_median_test'][-6:-1]) data_dict["mean_min"] = np.min(data_record['mse_mean_test']) data_dict["median_min"] = np.min(data_record['mse_median_test']) print("mean_last: {0:6f}\tmean_min: {1:.6f}\tmedian_last: {2:.6f}\tmedian_min: {3:.6f}".format(data_dict["mean_last"], data_dict["mean_min"], data_dict["median_last"], data_dict["median_min"])) statistics_list_test, z_list_test = plot_task_ensembles(tasks_test, master_model.statistics_Net, master_model.generative_Net, is_VAE = is_VAE, title = "y_pred_test vs. y_test") print("test statistics vs. z:") data_dict["corr_info"] = plot_statistics_vs_z(z_list_test, statistics_list_test) if "bounce" in task_id_list[0]: plot_individual_tasks_bounce(tasks_test, num_examples_show = 40, num_tasks_show = 9, master_model = master_model, num_shots = 200) else: _ = plot_individual_tasks(tasks_test, master_model.statistics_Net, master_model.generative_Net, is_VAE = is_VAE, xlim = task_settings["xlim"]) return master_model, info_dict, data_dict
statistics_list_train, z_list_train = plot_task_ensembles(tasks_train, statistics_Net, generative_Net, is_VAE = is_VAE, is_regulated_net = is_regulated_net, title = "y_pred_train vs. y_train", isplot = isplot) statistics_list_test, z_list_test = plot_task_ensembles(tasks_test, statistics_Net, generative_Net, is_VAE = is_VAE, is_regulated_net = is_regulated_net, title = "y_pred_test vs. y_test", isplot = isplot) record_data(data_record, [np.array(z_list_train), np.array(z_list_test), np.array(statistics_list_train), np.array(statistics_list_test)], ["z_list_train_list", "z_list_test_list", "statistics_list_train_list", "statistics_list_test_list"]) if isplot: print("train statistics vs. z:") plot_statistics_vs_z(z_list_train, statistics_list_train) print("test statistics vs. z:") plot_statistics_vs_z(z_list_test, statistics_list_test) # Plotting individual test data: if "bounce" in task_id_list[0]: plot_individual_tasks_bounce(tasks_test, num_examples_show = 40, num_tasks_show = 6, master_model = master_model, num_shots = 200) else: print("train tasks:") plot_individual_tasks(tasks_train, statistics_Net, generative_Net, generative_Net_logstd = generative_Net_logstd, is_VAE = is_VAE, is_regulated_net = is_regulated_net, xlim = task_settings["xlim"]) print("test tasks:") plot_individual_tasks(tasks_test, statistics_Net, generative_Net, generative_Net_logstd = generative_Net_logstd, is_VAE = is_VAE, is_regulated_net = is_regulated_net, xlim = task_settings["xlim"]) print("=" * 50 + "\n\n") try: sys.stdout.flush() except: pass if i % save_interval == 0 or to_stop: record_data(info_dict, [master_model.model_dict, i], ["model_dict", "iter"]) pickle.dump(info_dict, open(filename + "info.p", "wb")) if to_stop: print("The training loss stops decreasing for {0} steps. Early stopping at {1}.".format(patience, i)) break
# In[ ]: # plot_types = ["standard"] plot_types = ["gradient"] # plot_types = ["slider"] for plot_type in plot_types: if plot_type == "standard": plot_data_record(data_record, is_VAE = is_VAE) statistics_list_test, z_list_test = plot_task_ensembles(tasks_test, statistics_Net, generative_Net, is_VAE = is_VAE, title = "y_pred_test vs. y_test") print("test statistics vs. z:") plot_statistics_vs_z(z_list_test, statistics_list_test) _ = plot_individual_tasks(tasks_test, statistics_Net, generative_Net, is_VAE = is_VAE, xlim = task_settings["xlim"]) elif plot_type == "gradient": batch_size = 256 sample_task_id = task_id_list[0] + "_{0}".format(np.random.randint(num_test_tasks)) print("sample_task_id: {0}".format(sample_task_id)) ((X_train, y_train), (X_test, y_test)), _ = tasks_test[sample_task_id] epochs_statistics = 50 lr_statistics = 1e-3 optim_type_statistics = "adam" # epochs = 50 # lr = 1e-3 # optimizer = "adam" epochs = 50
statistics_list_train, z_list_train = plot_task_ensembles(tasks_train, master_model = master_model, model = model, is_time_series = is_time_series, is_VAE = is_VAE, is_uncertainty_net = is_uncertainty_net, is_regulated_net = is_regulated_net, title = "y_pred_train vs. y_train", isplot = isplot) statistics_list_test, z_list_test = plot_task_ensembles(tasks_test, master_model = master_model, model = model, is_time_series = is_time_series, is_VAE = is_VAE, is_uncertainty_net = is_uncertainty_net, is_regulated_net = is_regulated_net, title = "y_pred_test vs. y_test", isplot = isplot) record_data(data_record, [np.array(z_list_train), np.array(z_list_test), np.array(statistics_list_train), np.array(statistics_list_test)], ["z_list_train_list", "z_list_test_list", "statistics_list_train_list", "statistics_list_test_list"]) if isplot: print("train statistics vs. z:") plot_statistics_vs_z(z_list_train, statistics_list_train) print("test statistics vs. z:") plot_statistics_vs_z(z_list_test, statistics_list_test) # Plotting individual test data: if "bounce" in task_id_list[0]: plot_individual_tasks_bounce(tasks_test, num_examples_show = 40, num_tasks_show = 6, master_model = master_model, model = model, num_shots = 200, valid_input_dims = input_size - z_size, target_forward_steps = len(forward_steps), eval_forward_steps = len(forward_steps)) else: print("train tasks:") plot_individual_tasks(tasks_train, master_model = master_model, model = model, is_time_series = is_time_series, is_VAE = is_VAE, is_uncertainty_net = is_uncertainty_net, is_regulated_net = is_regulated_net, is_oracle = is_oracle, xlim = task_settings["xlim"]) print("test tasks:") plot_individual_tasks(tasks_test, master_model = master_model, model = model, is_time_series = is_time_series, is_VAE = is_VAE, is_uncertainty_net = is_uncertainty_net, is_regulated_net = is_regulated_net, is_oracle = is_oracle, xlim = task_settings["xlim"]) print("=" * 50 + "\n\n") try: sys.stdout.flush() except: pass if i % save_interval == 0 or to_stop: if master_model is not None: record_data(info_dict, [master_model.model_dict], ["model_dict"]) else: record_data(info_dict, [model.model_dict], ["model_dict"]) pickle.dump(info_dict, open(filename + "info.p", "wb")) if to_stop: print("The training loss stops decreasing for {0} steps. Early stopping at {1}.".format(patience, i))