class SEP_TRANSLATION(): def __init__(self): ## General parameters self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') autograd.set_detect_anomaly(True) torch.set_printoptions(precision=8) ## Set default parameters ## -> Can be changed during experiments self.grad_bias = 1.5 ############################################################################ ########## PARAMETERS #################################################### def set_weighted_gradient_bias(self, bias): # bias > 1 => favor recent # bias < 1 => favor earlier self.grad_bias = bias print(f'Reset bias for gradient weighting: {self.grad_bias}') def set_data_parameters_(self, num_frames, num_observations, num_input_features, num_input_dimesions): ## Define data parameters self.num_frames = num_frames self.num_observations = num_observations self.num_input_features = num_input_features self.num_input_dimensions = num_input_dimesions self.input_per_frame = self.num_input_features * self.num_input_dimensions self.perspective_taker = Perspective_Taker( self.num_observations, self.num_input_dimensions, rotation_gradient_init=True, translation_gradient_init=True) self.preprocessor = Preprocessor(self.num_observations, self.num_input_features, self.num_input_dimensions) self.evaluator = BAPTAT_evaluator(self.num_frames, self.num_observations, self.num_input_features, self.preprocessor) def set_tuning_parameters_(self, tuning_length, num_tuning_cycles, loss_function, at_learning_rate_translation, at_learning_rate_state, at_momentum_translation): ## Define tuning parameters self.tuning_length = tuning_length # length of tuning horizon self.tuning_cycles = num_tuning_cycles # number of tuning cycles in each iteration # possible loss functions self.at_loss = loss_function self.mse = nn.MSELoss() self.l1Loss = nn.L1Loss() self.smL1Loss = nn.SmoothL1Loss(reduction='sum') self.l2Loss = lambda x, y: self.mse(x, y) * (self.num_input_dimensions * self.num_input_features) # define learning parameters self.at_learning_rate = at_learning_rate_translation self.at_learning_rate_state = at_learning_rate_state self.c_momentum = at_momentum_translation self.at_loss_function = self.mse print('Parameters set.') def init_model_(self, model_path): ## Load model self.core_model = CORE_NET() self.core_model.load_state_dict(torch.load(model_path)) self.core_model.eval() self.core_model.to(self.device) print('Model loaded.') def init_inference_tools(self): ## Define tuning variables # general self.obs_count = 0 self.at_inputs = torch.tensor([]).to(self.device) self.at_predictions = torch.tensor([]).to(self.device) self.at_final_predictions = torch.tensor([]).to(self.device) self.at_losses = [] # state self.at_states = [] # translation self.Cs = [] self.C_grads = [None] * (self.tuning_length + 1) self.C_upd = [None] * (self.tuning_length + 1) self.c_losses = [] def set_comparison_values(self, ideal_translation): self.ideal_translation = ideal_translation.to(self.device) ############################################################################ ########## INFERENCE ##################################################### def run_inference(self, observations, grad_calculation): at_final_predictions = torch.tensor([]).to(self.device) at_final_inputs = torch.tensor([]).to(self.device) tb = self.perspective_taker.init_translation_bias_() # print(tb) for i in range(self.tuning_length + 1): transba = tb.clone().to(self.device) transba.requires_grad = True self.Cs.append(transba) ## Core state # define scaler state_scaler = 0.95 # init state at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device) at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device) at_h.requires_grad = True at_c.requires_grad = True init_state = (at_h, at_c) state = (init_state[0], init_state[1]) ############################################################################ ########## FORWARD PASS ################################################## for i in range(self.tuning_length): o = observations[self.obs_count].to(self.device) self.at_inputs = torch.cat((self.at_inputs, o.reshape(1, self.num_input_features, self.num_input_dimensions)), 0) self.obs_count += 1 x_C = self.perspective_taker.translate(o, self.Cs[i]) x = self.preprocessor.convert_data_AT_to_LSTM(x_C) state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state = self.core_model(x, state) self.at_states.append(state) self.at_predictions = torch.cat( (self.at_predictions, new_prediction.reshape(1, self.input_per_frame)), 0) ############################################################################ ########## ACTIVE TUNING ################################################## while self.obs_count < self.num_frames: # TODO folgendes evtl in function auslagern o = observations[self.obs_count].to(self.device) self.obs_count += 1 x_C = self.perspective_taker.translate(o, self.Cs[-1]) x = self.preprocessor.convert_data_AT_to_LSTM(x_C) ## Generate current prediction with torch.no_grad(): state = self.at_states[-1] state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state = self.core_model(x, state) ## For #tuning_cycles for cycle in range(self.tuning_cycles): print('----------------------------------------------') # Get prediction p = self.at_predictions[-1] # Calculate error loss = self.at_loss(p, x[0]) # Propagate error back through tuning horizon loss.backward(retain_graph=True) self.at_losses.append(loss.clone().detach().cpu().numpy()) print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}') # Update parameters with torch.no_grad(): ## Get gradients for i in range(self.tuning_length + 1): # save grads for all parameters in all time steps of tuning horizon self.C_grads[i] = self.Cs[i].grad # print(self.C_grads[self.tuning_length]) # Calculate overall gradients if grad_calculation == 'lastOfTunHor': ### version 1 grad_C = self.C_grads[0] elif grad_calculation == 'meanOfTunHor': ### version 2 / 3 grad_C = torch.mean(torch.stack(self.C_grads), dim=0) elif grad_calculation == 'weightedInTunHor': ### version 4 weighted_grads_C = [None] * (self.tuning_length + 1) for i in range(self.tuning_length + 1): weighted_grads_C[i] = np.power(self.grad_bias, i) * self.C_grads[i] grad_C = torch.mean(torch.stack(weighted_grads_C), dim=0) # print(f'grad_C: {grad_C}') # Update parameters in time step t-H with saved gradients grad_C = grad_C.to(self.device) upd_C = self.perspective_taker.update_translation_bias_( self.Cs[0], grad_C, self.at_learning_rate, self.c_momentum) # print(upd_C) # Compare translation bias to ideal bias trans_loss = self.mse(self.ideal_translation, upd_C) self.c_losses.append(trans_loss) print(f'loss of translation bias (MSE): {trans_loss}') # Zero out gradients for all parameters in all time steps of tuning horizon for i in range(self.tuning_length + 1): self.Cs[i].requires_grad = False self.Cs[i].grad.data.zero_() # Update all parameters for all time steps for i in range(self.tuning_length + 1): translation = upd_C.clone() translation.requires_grad_() self.Cs[i] = translation # print(self.Cs[0]) # Initial state g_h = at_h.grad.to(self.device) g_c = at_c.grad.to(self.device) upd_h = init_state[0] - self.at_learning_rate_state * g_h upd_c = init_state[1] - self.at_learning_rate_state * g_c at_h.data = upd_h.clone().detach().requires_grad_() at_c.data = upd_c.clone().detach().requires_grad_() at_h.grad.data.zero_() at_c.grad.data.zero_() # state_optimizer.step() # print(f'updated init_state: {init_state}') ## REORGANIZE FOR MULTIPLE CYCLES!!!!!!!!!!!!! # forward pass from t-H to t with new parameters # Update init state??? init_state = (at_h, at_c) state = (init_state[0], init_state[1]) self.at_predictions = torch.tensor([]).to(self.device) for i in range(self.tuning_length): x_C = self.perspective_taker.translate( self.at_inputs[i], self.Cs[i]) x = self.preprocessor.convert_data_AT_to_LSTM(x_C) state = (state[0] * state_scaler, state[1] * state_scaler) upd_prediction, state = self.core_model(x, state) self.at_predictions = torch.cat( (self.at_predictions, upd_prediction.reshape(1, self.input_per_frame)), 0) # for last tuning cycle update initial state to track gradients if cycle == (self.tuning_cycles - 1) and i == 0: with torch.no_grad(): final_prediction = self.at_predictions[0].clone( ).detach().to(self.device) final_input = x.clone().detach().to(self.device) at_h = state[0].clone().detach().requires_grad_() at_c = state[1].clone().detach().requires_grad_() init_state = (at_h, at_c) state = (init_state[0], init_state[1]) self.at_states[i] = state # Update current rotation x_C = self.perspective_taker.translate(o, self.Cs[-1]) x = self.preprocessor.convert_data_AT_to_LSTM(x_C) # END tuning cycle ## Generate updated prediction state = self.at_states[-1] state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state = self.core_model(x, state) ## Reorganize storage variables # observations at_final_inputs = torch.cat( (at_final_inputs, final_input.reshape( 1, self.input_per_frame)), 0) self.at_inputs = torch.cat((self.at_inputs[1:], o.reshape(1, self.num_input_features, self.num_input_dimensions)), 0) # predictions at_final_predictions = torch.cat( (at_final_predictions, final_prediction.reshape(1, self.input_per_frame)), 0) self.at_predictions = torch.cat( (self.at_predictions[1:], new_prediction.reshape(1, self.input_per_frame)), 0) # END active tuning # store rest of predictions in at_final_predictions for i in range(self.tuning_length): at_final_predictions = torch.cat( (at_final_predictions, self.at_predictions[i].reshape( 1, self.input_per_frame)), 0) at_final_inputs = torch.cat( (at_final_inputs, self.at_inputs[i].reshape( 1, self.input_per_frame)), 0) # get final translation bias final_translation_bias = self.Cs[0].clone().detach() print(f'final translation bias: {final_translation_bias}') return at_final_inputs, at_final_predictions, final_translation_bias ############################################################################ ########## EVALUATION ##################################################### def get_result_history(self, observations, at_final_predictions): pred_errors = self.evaluator.prediction_errors(observations, at_final_predictions, self.mse) return [pred_errors, self.at_losses, self.c_losses]