def compareFixed(): t = Tasks() x_test, y_test = t.sequence_type_1(100) add_params, mul_params = torch.load('program_memory/add.pt'), torch.load( 'program_memory/mul.pt') hnm = HNM(10, 20, add_params, mul_params) hnm.load_state_dict(torch.load("learned_params/hnm_arch_2.pt")) ntm = NTM(10, 20) ntm.load_state_dict(torch.load("learned_params/ntm.pt")) lstm = LSTM(14, 256, 325, 1) lstm.load_state_dict(torch.load("learned_params/lstm.pt")) hnm_diff, lstm_diff, ntm_diff = 0, 0, 0 for i in range(len(x_test)): hnm_out = hnm.recurrent_forward(x_test[i:i + 1]) ntm_out = ntm.recurrent_forward(x_test[i:i + 1]) lstm_out = lstm.recurrent_forward(x_test[i:i + 1]) answer = np.argmax(y_test[i:i + 1].detach().numpy()) hnm_diff += abs(answer - np.argmax(hnm_out.detach().numpy())) ntm_diff += abs(answer - np.argmax(ntm_out.detach().numpy())) lstm_diff += abs(answer - np.argmax(lstm_out.detach().numpy())) print(hnm_diff / len(y_test), ntm_diff / len(y_test), lstm_diff / len(y_test))
def compare(): obstacle, wall_cw, wall_awc = Obstacle(), WallCW(), WallACW() obstacle_params, wall_cw_params, wall_acw_params = torch.load( 'program_memory/move.pt'), torch.load( 'program_memory/cw.pt'), torch.load('program_memory/acw.pt') networks = [obstacle, wall_cw, wall_awc] params = [obstacle_params, wall_cw_params, wall_acw_params] hnm = HNM(10, 14, networks, params) hnm.load_state_dict(torch.load('learned_params/hnm.pt')) ntm = NTM(10, 14) ntm.load_state_dict(torch.load('learned_params/ntm.pt')) lstm = LSTM(14, 64, 3, 1) lstm.load_state_dict(torch.load('learned_params/lstm.pt')) testX, testY = getTestData() hnm_correct, ntm_correct, lstm_correct = 0, 0, 0 totSamples = 0 for i in range(0, 25): s = torch.from_numpy(np.array(testX[i:i + 1][0])).float().unsqueeze(0) s_lstm = s.view(s.size()[0], s.size()[2], -1) l = np.array(testY[i:i + 1][0]) print(i) (hnm_read_weights, hnm_write_weights) = hnm._initialise() (ntm_read_weights, ntm_write_weights) = ntm._initialise() lstm_h = lstm.h0.expand(s_lstm.size()[0], 64) lstm_c = lstm.c0.expand(s_lstm.size()[0], 64) for j in range(s.size()[1]): (hnm_out, hnm_read_weights, hnm_write_weights) = hnm.forward(s[:, j, :], hnm_read_weights, hnm_write_weights) (ntm_out, ntm_read_weights, ntm_write_weights) = ntm.forward(s[:, j, :], ntm_read_weights, ntm_write_weights) lstm_h, lstm_c, lstm_out = lstm.forward(s_lstm[:, :, j], lstm_h, lstm_c) if np.argmax(hnm_out.detach().numpy()) == np.argmax(l[j]): hnm_correct += 1 if np.argmax(ntm_out.detach().numpy()) == np.argmax(l[j]): ntm_correct += 1 if np.argmax(lstm_out.detach().numpy()) == np.argmax(l[j]): lstm_correct += 1 totSamples += 1 print(hnm_correct, ntm_correct, lstm_correct) print(totSamples)
def generate_target_original_plots(iteration, task_params, model_path, image_output): dataset = PrioritySort(task_params) criterion = nn.BCELoss() ntm = NTM(input_size=task_params['seq_width'] + 1, output_size=task_params['seq_width'], controller_size=task_params['controller_size'], memory_units=task_params['memory_units'], memory_unit_size=task_params['memory_unit_size'], num_heads=task_params['num_heads'], save_weigths=True, multi_layer_controller=task_params['multi_layer_controller']) ntm.load_state_dict(torch.load(model_path)) # ----------------------------------------------------------------------------- # --- evaluation # ----------------------------------------------------------------------------- ntm.reset() data = dataset[0] # 0 is a dummy index input, target = data['input'], data['target'] out = torch.zeros(target.size()) # ----------------------------------------------------------------------------- # loop for other tasks # ----------------------------------------------------------------------------- for i in range(input.size()[0]): # to maintain consistency in dimensions as torch.cat was throwing error in_data = torch.unsqueeze(input[i], 0) ntm(in_data) # passing zero vector as the input while generating target sequence in_data = torch.unsqueeze(torch.zeros(input.size()[1]), 0) for i in range(target.size()[0]): out[i] = ntm(in_data) loss = criterion(out, target) binary_output = out.clone() binary_output = binary_output.detach().apply_(lambda x: 0 if x < 0.5 else 1) # sequence prediction error is calculted in bits per sequence error = torch.sum(torch.abs(binary_output - target)) fig = plt.figure() ax1 = fig.add_subplot(211) ax2 = fig.add_subplot(221) ax1.set_title("Result") ax2.set_title("Target") sns.heatmap(binary_output, ax=ax1, vmin=0, vmax=1, linewidths=.5, cbar=False, square=True) sns.heatmap(target, ax=ax2, vmin=0, vmax=1, linewidths=.5, cbar=False, square=True) plt.savefig( image_output + "/priority_sort_{}_{}_{}_{}_{}_{}_{}_{}_{}_image_{}.png".format( task_params['seq_width'] + 1, task_params['seq_width'], task_params['controller_size'], task_params['memory_units'], task_params['memory_unit_size'], task_params['num_heads'], task_params['uniform'], task_params['random_distr'], task_params['multi_layer_controller'], iteration)) fig = plt.figure(figsize=(15, 6)) ax1_2 = fig.add_subplot(211) ax2_2 = fig.add_subplot(212) ax1_2.set_title("Read Weigths") ax2_2.set_title("Write Weights") sns.heatmap(ntm.all_read_w, ax=ax1_2, linewidths=.01, square=True) sns.heatmap(ntm.all_write_w, ax=ax2_2, linewidths=.01, square=True) plt.tight_layout() plt.savefig( image_output + "/priority_sort_{}_{}_{}_{}_{}_{}_{}_{}_{}_weigths_{}.png".format( task_params['seq_width'] + 1, task_params['seq_width'], task_params['controller_size'], task_params['memory_units'], task_params['memory_unit_size'], task_params['num_heads'], task_params['uniform'], task_params['random_distr'], task_params['multi_layer_controller'], iteration), dpi=250) # ---logging--- print('[*] Checkpoint Loss: %.2f\tError in bits per sequence: %.2f' % (loss, error))
controller_hid_dim=args.controller_hidden_dim, ) print(model) criterion = torch.nn.BCELoss() optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) print("--------- Number of parameters -----------") print(model.calculate_num_params()) print("--------- Start training -----------") losses = [] if args.loadmodel != '': model.load_state_dict(torch.load(args.loadmodel)) for e, (X, Y) in enumerate(dataloader): tmp = time() model.initalize_state() optimizer.zero_grad() inp_seq_len = args.sequence_length + 2 out_seq_len = args.sequence_length X.requires_grad = True # Input rete: sequenza for t in range(0, inp_seq_len): model(X[:, t])
For the Copy task, input_size: seq_width + 2, output_size: seq_width For the RepeatCopy task, input_size: seq_width + 2, output_size: seq_width + 1 For the Associative task, input_size: seq_width + 2, output_size: seq_width For the NGram task, input_size: 1, output_size: 1 For the Priority Sort task, input_size: seq_width + 1, output_size: seq_width """ ntm = NTM(input_size=task_params['seq_width'] + 1, output_size=task_params['seq_width'], controller_size=task_params['controller_size'], memory_units=task_params['memory_units'], memory_unit_size=task_params['memory_unit_size'], num_heads=task_params['num_heads'], multi_layer_controller=task_params['multi_layer_controller']) if args.load_model != "": ntm.load_state_dict(torch.load(args.load_model)) criterion = nn.BCELoss() # As the learning rate is task specific, the argument can be moved to json file optimizer = optim.RMSprop(ntm.parameters(), lr=args.lr, alpha=args.alpha, momentum=args.momentum) ''' optimizer = optim.Adam(ntm.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) ''' ''' args.saved_model = 'saved_model_copy.pt' args.saved_model = 'saved_model_repeatcopy.pt' args.saved_model = 'saved_model_associative.pt'
""" For the Copy task, input_size: seq_width + 2, output_size: seq_width For the RepeatCopy task, input_size: seq_width + 2, output_size: seq_width + 1 For the Associative task, input_size: seq_width + 2, output_size: seq_width For the NGram task, input_size: 1, output_size: 1 For the Priority Sort task, input_size: seq_width + 1, output_size: seq_width """ ntm = NTM(input_size=task_params['seq_width'] + 1, output_size=task_params['seq_width'], controller_size=task_params['controller_size'], memory_units=task_params['memory_units'], memory_unit_size=task_params['memory_unit_size'], num_heads=task_params['num_heads']) ntm.load_state_dict(torch.load(PATH)) # ----------------------------------------------------------------------------- # --- evaluation # ----------------------------------------------------------------------------- ntm.reset() data = dataset[0] # 0 is a dummy index input, target = data['input'], data['target'] out = torch.zeros(target.size()) # ----------------------------------------------------------------------------- # loop for other tasks # ----------------------------------------------------------------------------- for i in range(input.size()[0]): # to maintain consistency in dimensions as torch.cat was throwing error in_data = torch.unsqueeze(input[i], 0)