tf.summary.scalar("loss", train_loss.result(), step=(int(ckpt.step)*arg.batch_size)) tf.summary.scalar("cost_per_sequence", train_cost.result(), step=(int(ckpt.step)*arg.batch_size)) train_loss.reset_states() train_cost.reset_states() except KeyboardInterrupt: print("User interrupted") # Visualize the prediction made by the model if arg.test or arg.visualize: x, y = generate_patterns(arg.batch_size, arg.max_sequence, arg.min_sequence, arg.in_bits, arg.out_bits, fixed_seq_len=arg.random_seq_len) y_pred = ntm_model(x) rt, r_wt, at, w_wt, Mt = ntm_model.debug_ntm() cmap_jet = plt.get_cmap('jet') cmap_gray = plt.get_cmap('gray') if arg.visualize: fig_ntm, (ax_at, ax_wwt, ax_mt, ax_rwt, ax_rt) = plt.subplots(5, 1) fig_ntm.subplots_adjust(top=0.85, bottom=0.15, left=0.05, right=0.95, hspace=0.3) ax_at.set_ylabel('Adds') ax_wwt.set_ylabel('Write Weights') ax_mt.set_ylabel("Memory") ax_rwt.set_ylabel("Read Weights") ax_rt.set_ylabel("Reads") ax_at.matshow(np.transpose(at), aspect='equal', cmap=cmap_jet)