def generate_validation_data(): global val_inputs, val_shifts, val_targets global tst_inputs, tst_shifts, tst_targets print "Generating validation data..." val_inputs, val_shifts, val_targets = generate_data_2d(cfg.width, cfg.height, cfg.n_val_samples, with_hot=with_hot, force_id=use_id_data) tst_inputs, tst_shifts, tst_targets = generate_data_2d(cfg.width, cfg.height, cfg.n_val_samples, binary=True, with_hot=with_hot, force_id=use_id_data) print "Done."
def generate_new_data(): global trn_inputs, trn_shifts, trn_targets #print "Generating new data..." trn_inputs, trn_shifts, trn_targets = \ generate_data_2d(cfg.width, cfg.height, cfg.n_batch, with_hot=with_hot, force_id=use_id_data) if use_id_data: trn_targets = trn_inputs.copy()
# final weights plt.clf() plot_all_weights(ps) plt.suptitle("final iter=%d" % iteration) plt.draw() if ml.common.plot.headless: plt.savefig(plot_dir + "/weights_final.pdf") # final loss plt.clf() his.plot() plt.savefig(plot_dir + "/loss.pdf") # check with simple patterns sim_inputs, sim_shifts, sim_targets = generate_data_2d(cfg.width, cfg.height, 3, binary=True, with_hot=with_hot, force_id=use_id_data) if use_id_data: sim_targets = sim_inputs.copy() sim_results = gather(f_output(ps.data, post(sim_inputs), post(sim_shifts))) threshold = None print "input: " print format_sample(sim_inputs[:, 0], cfg.width, cfg.height, threshold=threshold) print "shift: " print sim_shifts[:, 0] print "targets: " print format_sample(sim_targets[:, 0], cfg.width, cfg.height, threshold=threshold) print "results: " print format_sample(sim_results[:, 0], cfg.width, cfg.height, threshold=threshold) # test on global test set