def network_simulation(self, networkpath): # load model torch.set_default_dtype(torch.float64) state_dict = torch.load(networkpath) #hidden_layer_size = state_dict["lstm.weight_hh_l0"][1] model = LSTM_multi_modal() if output_model == "multi_modal" else LSTM_fixed() model.load_state_dict(state_dict) model.eval() # init hidden #hidden_state = [model.init_hidden(1, num_layers) for agent in range(self.num_guppys)] states = [[model.init_hidden(1, 1, hidden_layer_size) for i in range(num_layers * 2)] for j in range(self.num_guppys)] for i in range(1, len(self.agent_data) - 1): for agent in range(self.num_guppys): with torch.no_grad(): # get input data for this frame sensory = self.craft_vector(i, agent) data = torch.from_numpy(numpy.concatenate((self.loc_vec, sensory))) data = data.view(1, 1, -1) # predict the new ang_turn, lin_speed #out, hidden_state[agent] = model.predict(data, hidden_state[agent]) out, states[agent] = model.predict(data, states[agent]) ang_turn = out[0].item() if output_model == "multi_modal" else out[0][0][0].item() lin_speed = out[1].item() if output_model == "multi_modal" else out[0][0][1].item() # rotate agent position by angle calculated by network cos_a = cos(ang_turn) sin_a = sin(ang_turn) agent_pos = self.data[agent][i][0], self.data[agent][i][1] agent_ori = self.data[agent][i][2], self.data[agent][i][3] new_ori = [cos_a * agent_ori[0] - sin_a * agent_ori[1], \ sin_a * agent_ori[0] + cos_a * agent_ori[1]] # normally the rotation of a normalized vector by a normalized vector should again be a # normalized vector, but it seems there are some numerical errors, so normalize the orientation # again normalize_ori(new_ori) # multiply new orientation by linear speed and add to old position translation_vec = scalar_mul(lin_speed, new_ori) new_pos = vec_add(agent_pos, translation_vec) # network does not learn the tank walls properly sometimes, let fish bump against the wall normalize_pos(new_pos) # update the position for the next timestep self.data[agent][i + 1][0], self.data[agent][i + 1][1] = new_pos self.data[agent][i + 1][2], self.data[agent][i + 1][3] = new_ori self.plot_guppy_bins(bins=False)
print("###################################") print(f'epoch: {i:3} loss: {loss.item():10.10f}') print("###################################") # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25) #loss.backward() # optimizer.step() except KeyboardInterrupt: if input("Do you want to save the model trained so far? y/n") == "y": torch.save(model.state_dict(), network_path + f".epochs{i}") print("network saved at " + network_path + f".epochs{i}") sys.exit(0) ########validation####### model.eval() for inputs, targets in valloader: if output_model == "multi_modal": targets = targets.type(torch.LongTensor) for s in range(0, inputs.size()[1] - seq_len, seq_len): states = [tuple([each.data for each in s]) for s in states] if arch == "ey" else \ tuple([each.data for each in states]) angle_pred, speed_pred, _ = model.forward( inputs[:, s:s + seq_len, :], states) # angle_pred, speed_pred, states = model.forward(inputs[:, s:s + seq_len, :], states) angle_pred = angle_pred.view( angle_pred.shape[0] * angle_pred.shape[1], -1) speed_pred = speed_pred.view(