weighted_grads_B[i] = np.power(bias, i) * B_grads[i] grad_B = torch.mean(torch.stack(weighted_grads_B), dim=0) # print(f'grad_B: {grad_B}') # print(f'grad_B: {torch.norm(grad_B, 1)}') # exit() # Update parameters in time step t-H with saved gradients upd_B = binder.update_binding_matrix_(Bs[0], grad_B, at_learning_rate, bm_momentum) # Compare binding matrix to ideal matrix c_bm = binder.scale_binding_matrix(upd_B) mat_loss = evaluator.FBE(c_bm, ideal_binding) bm_losses.append(mat_loss) print(f'loss of binding matrix (FBE): {mat_loss}') # Compute determinante of binding matrix det = torch.det(c_bm) bm_dets.append(det) print(f'determinante of binding matrix: {det}') # Zero out gradients for all parameters in all time steps of tuning horizon for i in range(tuning_length + 1): Bs[i].requires_grad = False Bs[i].grad.data.zero_() # Update all parameters for all time steps for i in range(tuning_length + 1):
class SEP_BINDING_GESTALTEN(): def __init__(self): ## General parameters self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') autograd.set_detect_anomaly(True) torch.set_printoptions(precision=8) ## Set default parameters ## -> Can be changed during experiments self.scale_mode = 'rcwSM' self.scale_combo = 'comp_mult' self.grad_bias = 1.5 self.nxm = False self.additional_features = None self.nxm_enhance = 'square', self.nxm_last_line_scale = 0.1 self.dummie_init = 0.1 ############################################################################ ########## PARAMETERS #################################################### def set_scale_mode(self, mode): self.scale_mode = mode print('Reset scale mode: ' + self.scale_mode) def set_scale_combination(self, combination): self.scale_combo = combination print('Reset scale combination: ' + self.scale_combo) def set_additional_features(self, index_addition): self.additional_features = index_addition print(f'Additional features to the LSTM-input at indices {self.additional_features}') def set_nxm_enhancement(self, enhancement): self.nxm_enhance = enhancement print(f'Enhancement for outcast line: {self.nxm_enhance}') def set_nxm_last_line_scale(self, scale_factor): self.nxm_last_line_scale = scale_factor print(f'Scaler for outcast line: {self.nxm_last_line_scale}') def set_dummie_init_value(self, init_value): self.dummie_init = init_value print(f'Initial value for dummie line: {self.dummie_init}') def set_weighted_gradient_bias(self, bias): # bias > 1 => favor recent # bias < 1 => favor earlier self.grad_bias = bias print(f'Reset bias for gradient weighting: {self.grad_bias}') def set_data_parameters_(self, num_frames, num_observations, num_input_features, num_input_dimesions): ## Define data parameters self.num_frames = num_frames self.num_observations = num_observations self.num_input_features = num_input_features self.num_input_dimensions = num_input_dimesions self.input_per_frame = self.num_input_features * self.num_input_dimensions self.nxm = (self.num_observations != self.num_input_features) self.binder = BINDER_NxM( num_observations=self.num_observations, num_features=self.num_input_features, gradient_init=True) self.preprocessor = Preprocessor(self.num_observations, self.num_input_features, self.num_input_dimensions) self.evaluator = BAPTAT_evaluator(self.num_frames, self.num_observations, self.num_input_features, self.preprocessor) def set_tuning_parameters_(self, tuning_length, num_tuning_cycles, loss_function, at_learning_rate_binding, at_learning_rate_state, at_momentum_binding): ## Define tuning parameters self.tuning_length = tuning_length # length of tuning horizon self.tuning_cycles = num_tuning_cycles # number of tuning cycles in each iteration # possible loss functions self.at_loss = loss_function self.mse = nn.MSELoss() self.l1Loss = nn.L1Loss() self.smL1Loss = nn.SmoothL1Loss(reduction='sum') self.l2Loss = lambda x,y: self.mse(x, y) * (self.num_input_dimensions * self.num_input_features) # define learning parameters self.at_learning_rate = at_learning_rate_binding self.at_learning_rate_state = at_learning_rate_state self.bm_momentum = at_momentum_binding self.at_loss_function = self.mse print('Parameters set.') def get_additional_features(self): return self.additional_features def get_oc_grads(self): return self.oc_grads def init_model_(self, model_path): ## Load model self.core_model = CORE_NET() self.core_model.load_state_dict(torch.load(model_path)) self.core_model.eval() self.core_model.to(self.device) print('Model loaded.') def init_inference_tools(self): ## Define tuning variables # general self.obs_count = 0 self.at_inputs = torch.tensor([]).to(self.device) self.at_predictions = torch.tensor([]).to(self.device) self.at_final_predictions = torch.tensor([]).to(self.device) self.at_losses = [] # state self.at_states = [] # state_optimizer = torch.optim.Adam(init_state, at_learning_rate) # binding self.ideal_binding = torch.Tensor(np.identity(self.num_input_features)).to(self.device) self.Bs = [] self.B_grads = [None] * (self.tuning_length+1) self.B_upd = [None] * (self.tuning_length+1) self.bm_losses = [] self.bm_dets = [] self.oc_grads = [] def set_comparison_values(self, ideal_binding): self.ideal_binding = ideal_binding.to(self.device) if self.nxm: self.ideal_binding = self.binder.ideal_nxm_binding( self.additional_features, self.ideal_binding).to(self.device) ############################################################################ ########## INFERENCE ##################################################### def run_inference(self, observations, grad_calculation, order, reorder): if reorder is not None: reorder = reorder.to(self.device) at_final_predictions = torch.tensor([]).to(self.device) at_final_inputs = torch.tensor([]).to(self.device) ## Binding matrices # Init binding entries bm = self.binder.init_binding_matrix_det_() # bm = binder.init_binding_matrix_rand_() # print(bm) dummie_line = torch.ones(1,self.num_observations).to(self.device) * self.dummie_init for i in range(self.tuning_length+1): matrix = bm.clone().to(self.device) if self.nxm: matrix = torch.cat([matrix, dummie_line]) matrix.requires_grad_() self.Bs.append(matrix) # print(f'BMs different in list: {self.Bs[0] is not self.Bs[1]}') ## Core state # define scaler state_scaler = 0.95 # init state at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device) at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device) at_h.requires_grad = True at_c.requires_grad = True init_state = (at_h, at_c) state = (init_state[0], init_state[1]) ############################################################################ ########## FORWARD PASS ################################################## for i in range(self.tuning_length): o = observations[self.obs_count].to(self.device) self.at_inputs = torch.cat((self.at_inputs, o.reshape(1, self.num_observations, self.num_input_dimensions)), 0) self.obs_count += 1 bm = self.binder.scale_binding_matrix( self.Bs[i], self.scale_mode, self.scale_combo, self.nxm_enhance, self.nxm_last_line_scale) if self.nxm: bm = bm[:-1] x_B = self.binder.bind(o, bm) x = self.preprocessor.convert_data_AT_to_LSTM(x_B) state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state = self.core_model(x, state) self.at_states.append(state) self.at_predictions = torch.cat((self.at_predictions, new_prediction.reshape(1,self.input_per_frame)), 0) ############################################################################ ########## ACTIVE TUNING ################################################## while self.obs_count < self.num_frames: # TODO folgendes evtl in function auslagern o = observations[self.obs_count].to(self.device) self.obs_count += 1 bm = self.binder.scale_binding_matrix( self.Bs[-1], self.scale_mode, self.scale_combo, self.nxm_enhance, self.nxm_last_line_scale) if self.nxm: bm = bm[:-1] x_B = self.binder.bind(o, bm) x = self.preprocessor.convert_data_AT_to_LSTM(x_B) ## Generate current prediction with torch.no_grad(): state = self.at_states[-1] state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state = self.core_model(x, state) ## For #tuning_cycles for cycle in range(self.tuning_cycles): print('----------------------------------------------') # Get prediction p = self.at_predictions[-1] # Calculate error loss = self.at_loss(p,x[0]) # Propagate error back through tuning horizon loss.backward(retain_graph = True) self.at_losses.append(loss.clone().detach().cpu().numpy()) print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}') # Update parameters with torch.no_grad(): # Calculate gradients with respect to the entires for i in range(self.tuning_length+1): self.B_grads[i] = self.Bs[i].grad # print(B_grads[tuning_length]) # Calculate overall gradients if grad_calculation == 'lastOfTunHor': ### version 1 grad_B = self.B_grads[0] elif grad_calculation == 'meanOfTunHor': ### version 2 / 3 grad_B = torch.mean(torch.stack(self.B_grads), dim=0) elif grad_calculation == 'weightedInTunHor': ### version 4 weighted_grads_B = [None] * (self.tuning_length+1) for i in range(self.tuning_length+1): weighted_grads_B[i] = np.power(self.grad_bias, i) * self.B_grads[i] grad_B = torch.mean(torch.stack(weighted_grads_B), dim=0) # print(f'grad_B: {grad_B}') # print(f'grad_B: {torch.norm(grad_B, 1)}') # Update parameters in time step t-H with saved gradients grad_B = grad_B.to(self.device) upd_B = self.binder.update_binding_matrix_(self.Bs[0], grad_B, self.at_learning_rate, self.bm_momentum) # Compare binding matrix to ideal matrix # NOTE: ideal matrix is always identity, bc then the FBE and determinant can be calculated => provide reorder c_bm = self.binder.scale_binding_matrix(upd_B, self.scale_mode, self.scale_combo) if order is not None: c_bm = c_bm.gather(1, reorder.unsqueeze(0).expand(c_bm.shape)) if self.nxm: self.oc_grads.append(grad_B[-1]) FBE = self.evaluator.FBE_nxm_additional_features(c_bm, self.ideal_binding, self.additional_features) c_bm = self.evaluator.clear_nxm_binding_matrix(c_bm, self.additional_features) mat_loss = self.evaluator.FBE(c_bm, self.ideal_binding) if self.nxm: mat_loss = torch.stack([mat_loss, FBE, mat_loss+FBE]) self.bm_losses.append(mat_loss) print(f'loss of binding matrix (FBE): {mat_loss}') # Compute determinante of binding matrix det = torch.det(c_bm) self.bm_dets.append(det) print(f'determinante of binding matrix: {det}') # Zero out gradients for all parameters in all time steps of tuning horizon for i in range(self.tuning_length+1): self.Bs[i].requires_grad = False self.Bs[i].grad.data.zero_() # Update all parameters for all time steps for i in range(self.tuning_length+1): self.Bs[i].data = upd_B.clone().data self.Bs[i].requires_grad = True # print(Bs[0]) # Initial state g_h = at_h.grad.to(self.device) g_c = at_c.grad.to(self.device) upd_h = init_state[0] - self.at_learning_rate_state * g_h upd_c = init_state[1] - self.at_learning_rate_state * g_c at_h.data = upd_h.clone().detach().requires_grad_() at_c.data = upd_c.clone().detach().requires_grad_() at_h.grad.data.zero_() at_c.grad.data.zero_() # print(f'updated init_state: {init_state}') ## REORGANIZE FOR MULTIPLE CYCLES!!!!!!!!!!!!! # forward pass from t-H to t with new parameters init_state = (at_h, at_c) state = (init_state[0], init_state[1]) self.at_predictions = torch.tensor([]).to(self.device) for i in range(self.tuning_length): bm = self.binder.scale_binding_matrix( self.Bs[i], self.scale_mode, self.scale_combo, self.nxm_enhance, self.nxm_last_line_scale) if self.nxm: bm = bm[:-1] x_B = self.binder.bind(self.at_inputs[i], bm) x = self.preprocessor.convert_data_AT_to_LSTM(x_B) state = (state[0] * state_scaler, state[1] * state_scaler) upd_prediction, state = self.core_model(x, state) self.at_predictions = torch.cat((self.at_predictions, upd_prediction.reshape(1,self.input_per_frame)), 0) # for last tuning cycle update initial state to track gradients if cycle==(self.tuning_cycles-1) and i==0: with torch.no_grad(): final_prediction = self.at_predictions[0].clone().detach().to(self.device) final_input = x.clone().detach().to(self.device) at_h = state[0].clone().detach().requires_grad_().to(self.device) at_c = state[1].clone().detach().requires_grad_().to(self.device) init_state = (at_h, at_c) state = (init_state[0], init_state[1]) self.at_states[i] = state # Update current binding bm = self.binder.scale_binding_matrix( self.Bs[-1], self.scale_mode, self.scale_combo, self.nxm_enhance, self.nxm_last_line_scale) if self.nxm: bm = bm[:-1] x_B = self.binder.bind(o, bm) x = self.preprocessor.convert_data_AT_to_LSTM(x_B) # END tuning cycle ## Generate updated prediction state = self.at_states[-1] state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state = self.core_model(x, state) ## Reorganize storage variables # observations at_final_inputs = torch.cat( (at_final_inputs, final_input.reshape(1,self.input_per_frame)), 0) self.at_inputs = torch.cat( (self.at_inputs[1:], o.reshape(1, self.num_observations, self.num_input_dimensions)), 0) # predictions at_final_predictions = torch.cat( (at_final_predictions, final_prediction.reshape(1,self.input_per_frame)), 0) self.at_predictions = torch.cat( (self.at_predictions[1:], new_prediction.reshape(1,self.input_per_frame)), 0) # END active tuning # store rest of predictions in at_final_predictions for i in range(self.tuning_length): at_final_predictions = torch.cat( (at_final_predictions, self.at_predictions[i].reshape(1,self.input_per_frame)), 0) at_final_inputs = torch.cat( (at_final_inputs, self.at_inputs[i].reshape(1,self.input_per_frame)), 0) # get final binding matrix final_binding_matrix = self.binder.scale_binding_matrix(self.Bs[-1].clone().detach(), self.scale_mode, self.scale_combo) # print(f'final binding matrix: {final_binding_matrix}') final_binding_entries = self.Bs[-1].clone().detach() # print(f'final binding entires: {final_binding_entries}') return at_final_inputs, at_final_predictions, final_binding_matrix, final_binding_entries ############################################################################ ########## EVALUATION ##################################################### def get_result_history( self, observations, at_final_predictions): if self.nxm: pred_errors = self.evaluator.prediction_errors_nxm( observations, self.additional_features, self.num_observations, at_final_predictions, self.mse ) self.bm_losses = torch.stack(self.bm_losses) else: pred_errors = self.evaluator.prediction_errors( observations, at_final_predictions, self.mse) return [pred_errors, self.at_losses, self.bm_dets, self.bm_losses]
class COMBI_BAPTAT(): def __init__(self): ## General parameters self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') autograd.set_detect_anomaly(True) torch.set_printoptions(precision=8) ## Set default parameters ## -> Can be changed during experiments self.scale_mode = 'rcwSM' self.scale_combo = 'comp_mult' self.rotation_type = 'qrotate' self.grad_bias_binding = 1.5 self.grad_bias_rotation = 1.5 self.grad_bias_translation = 1.5 self.nxm = False self.additional_features = None self.nxm_enhance = 'square', self.nxm_last_line_scale = 0.1 self.dummie_init = 0.1 ############################################################################ ########## PARAMETERS #################################################### def set_scale_mode(self, mode): self.scale_mode = mode print('Reset scale mode: ' + self.scale_mode) def set_scale_combination(self, combination): self.scale_combo = combination print('Reset scale combination: ' + self.scale_combo) def set_additional_features(self, index_addition): self.additional_features = index_addition print(f'Additional features to the LSTM-input at indices {self.additional_features}') def set_nxm_enhancement(self, enhancement): self.nxm_enhance = enhancement print(f'Enhancement for outcast line: {self.nxm_enhance}') def set_nxm_last_line_scale(self, scale_factor): self.nxm_last_line_scale = scale_factor print(f'Scaler for outcast line: {self.nxm_last_line_scale}') def set_dummie_init_value(self, init_value): self.dummie_init = init_value print(f'Initial value for dummie line: {self.dummie_init}') def set_rotation_type(self, rotation): self.rotation_type = rotation print('Reset type of rotation: ' + self.rotation_type) def set_weighted_gradient_biases(self, biases): # bias > 1 => favor recent # bias < 1 => favor earlier print('Reset biases for gradient weighting:') self.grad_bias_binding = biases[0] print(f'\t> binding: {self.grad_bias_binding}') self.grad_bias_rotation = biases[1] print(f'\t> rotation: {self.grad_bias_rotation}') self.grad_bias_translation = biases[2] print(f'\t> translation: {self.grad_bias_translation}') def set_data_parameters_(self, num_frames, num_observations, num_input_features, num_input_dimesions): ## Define data parameters self.num_frames = num_frames self.num_observations = num_observations self.num_input_features = num_input_features self.num_input_dimensions = num_input_dimesions self.input_per_frame = self.num_input_features * self.num_input_dimensions self.nxm = (self.num_observations != self.num_input_features) self.binder = BINDER_NxM( num_observations=self.num_observations, num_features=self.num_input_features, gradient_init=True) self.perspective_taker = Perspective_Taker( self.num_input_features, self.num_input_dimensions, rotation_gradient_init=True, translation_gradient_init=True) self.preprocessor = Preprocessor( self.num_observations, self.num_input_features, self.num_input_dimensions) self.evaluator = BAPTAT_evaluator( self.num_frames, self.num_observations, self.num_input_features, self.preprocessor) def set_tuning_parameters_(self, tuning_length, num_tuning_cycles, loss_function, at_learning_rates_BAPTAT, at_learning_rate_state, at_momenta_BAPTAT): ## Define tuning parameters self.tuning_length = tuning_length # length of tuning horizon self.tuning_cycles = num_tuning_cycles # number of tuning cycles in each iteration # possible loss functions self.at_loss = loss_function self.mse = nn.MSELoss() self.l1Loss = nn.L1Loss() self.smL1Loss = nn.SmoothL1Loss(reduction='sum') self.l2Loss = lambda x,y: self.mse(x, y) * (self.num_input_dimensions * self.num_input_features) # define learning parameters self.at_learning_rate_binding = at_learning_rates_BAPTAT[0] self.at_learning_rate_rotation = at_learning_rates_BAPTAT[1] self.at_learning_rate_translation = at_learning_rates_BAPTAT[2] self.at_learning_rate_state = at_learning_rate_state self.bm_momentum = at_momenta_BAPTAT[0] self.r_momentum = at_momenta_BAPTAT[1] self.c_momentum = at_momenta_BAPTAT[2] self.at_loss_function = self.mse print('Parameters set.') def get_additional_features(self): return self.additional_features def get_oc_grads(self): return self.oc_grads def init_model_(self, model_path): ## Load model self.core_model = CORE_NET() self.core_model.load_state_dict(torch.load(model_path)) self.core_model.eval() self.core_model.to(self.device) print('Model loaded.') def init_inference_tools(self): ## Define tuning variables # general self.obs_count = 0 self.at_inputs = torch.tensor([]).to(self.device) self.at_predictions = torch.tensor([]).to(self.device) self.at_final_predictions = torch.tensor([]).to(self.device) self.at_losses = [] # state self.at_states = [] # binding self.Bs = [] self.B_grads = [None] * (self.tuning_length+1) self.B_upd = [None] * (self.tuning_length+1) self.bm_losses = [] self.bm_dets = [] self.oc_grads = [] # rotation self.Rs = [] self.R_grads = [None] * (self.tuning_length+1) self.R_upd = [None] * (self.tuning_length+1) self.rm_losses = [] self.rv_losses = [] # translation self.Cs = [] self.C_grads = [None] * (self.tuning_length+1) self.C_upd = [None] * (self.tuning_length+1) self.c_losses = [] def set_comparison_values(self, ideal_binding, ideal_rotation, ideal_translation): # binding if ideal_binding is not None: self.ideal_binding = ideal_binding.to(self.device) if self.nxm: self.ideal_binding = self.binder.ideal_nxm_binding( self.additional_features, self.ideal_binding).to(self.device) # rotation self.identity_matrix = torch.Tensor(np.identity(self.num_input_dimensions)) if ideal_rotation is not None: (ideal_rotation_values, self.ideal_rotation) = ideal_rotation self.ideal_rotation = self.ideal_rotation.to(self.device) if self.rotation_type == 'qrotate': self.ideal_quat = ideal_rotation_values.to(self.device) self.ideal_angle = self.perspective_taker.qeuler(self.ideal_quat, 'xyz').to(self.device) elif self.rotation_type == 'eulrotate': self.ideal_angle = ideal_rotation_values.to(self.device) else: print(f'ERROR: Received unknown rotation type!\n\trotation type: {self.rotation_type}') exit() # translation if ideal_translation is not None: self.ideal_translation = ideal_translation.to(self.device) ############################################################################ ########## INFERENCE ##################################################### def run_inference(self, observations, grad_calculations, do_binding, do_rotation, do_translation, order, reorder): [grad_calc_binding, grad_calc_rotation, grad_calc_translation] = grad_calculations if reorder is not None: reorder = reorder.to(self.device) at_final_predictions = torch.tensor([]).to(self.device) at_final_inputs = torch.tensor([]).to(self.device) ########################### BINDING ################################# if do_binding: ## Binding matrices # Init binding entries bm = self.binder.init_binding_matrix_det_() # bm = binder.init_binding_matrix_rand_() # print(bm) dummie_line = torch.ones(1,self.num_observations).to(self.device) * self.dummie_init for i in range(self.tuning_length+1): matrix = bm.clone().to(self.device) if self.nxm: matrix = torch.cat([matrix, dummie_line]) matrix.requires_grad_() self.Bs.append(matrix) ########################### ROTATION ################################ if do_rotation: if self.rotation_type == 'qrotate': ## Rotation quaternion rq = self.perspective_taker.init_quaternion() # print(rq) for i in range(self.tuning_length+1): quat = rq.clone().to(self.device) quat.requires_grad_() self.Rs.append(quat) elif self.rotation_type == 'eulrotate': ## Rotation euler angles # ra = perspective_taker.init_angles_() # ra = torch.Tensor([[309.89], [82.234], [95.765]]) ra = torch.Tensor([[75.0], [6.0], [128.0]]) # print(ra) for i in range(self.tuning_length+1): angles = [] for j in range(self.num_input_dimensions): angle = ra[j].clone().to(self.device) angle.requires_grad_() angles.append(angle) self.Rs.append(angles) else: print('ERROR: Received unknown rotation type!') exit() ########################### TRANSLATION ############################# if do_translation: tb = self.perspective_taker.init_translation_bias_() # print(tb) for i in range(self.tuning_length+1): transba = tb.clone().to(self.device) transba.requires_grad = True self.Cs.append(transba) ####################################################################### ## Core state # define scaler state_scaler = 0.95 # init state at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device) at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device) at_h.requires_grad = True at_c.requires_grad = True init_state = (at_h, at_c) state = (init_state[0], init_state[1]) ############################################################################ ########## FORWARD PASS ################################################## for i in range(self.tuning_length): o = observations[self.obs_count].to(self.device) self.at_inputs = torch.cat(( self.at_inputs, o.reshape(1, self.num_observations, self.num_input_dimensions)), 0) self.obs_count += 1 ########################### BINDING ################################# if do_binding: bm = self.binder.scale_binding_matrix( self.Bs[i], self.scale_mode, self.scale_combo, self.nxm_enhance, self.nxm_last_line_scale) if self.nxm: bm = bm[:-1] x_B = self.binder.bind(o, bm) else: x_B = o ########################### ROTATION ################################ if do_rotation: if self.rotation_type == 'qrotate': x_R = self.perspective_taker.qrotate(x_B, self.Rs[i]) else: rotmat = self.perspective_taker.compute_rotation_matrix_( self.Rs[i][0], self.Rs[i][1], self.Rs[i][2]) x_R = self.perspective_taker.rotate(x_B, rotmat) else: x_R = x_B ########################### TRANSLATION ############################# if do_translation: x_C = self.perspective_taker.translate(x_R, self.Cs[i]) else: x_C = x_R ####################################################################### x = self.preprocessor.convert_data_AT_to_LSTM(x_C) state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state = self.core_model(x, state) self.at_states.append(state) self.at_predictions = torch.cat((self.at_predictions, new_prediction.reshape(1,self.input_per_frame)), 0) ############################################################################ ########## ACTIVE TUNING ################################################## while self.obs_count < self.num_frames: # TODO folgendes evtl in function auslagern o = observations[self.obs_count].to(self.device) self.obs_count += 1 ########################### BINDING ################################# if do_binding: bm = self.binder.scale_binding_matrix( self.Bs[-1], self.scale_mode, self.scale_combo, self.nxm_enhance, self.nxm_last_line_scale) if self.nxm: bm = bm[:-1] x_B = self.binder.bind(o, bm) else: x_B = o ########################### ROTATION ################################ if do_rotation: if self.rotation_type == 'qrotate': x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1]) else: rotmat = self.perspective_taker.compute_rotation_matrix_( self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2]) x_R = self.perspective_taker.rotate(x_B, rotmat) else: x_R = x_B ########################### TRANSLATION ############################# if do_translation: x_C = self.perspective_taker.translate(x_R, self.Cs[-1]) else: x_C = x_R ####################################################################### x = self.preprocessor.convert_data_AT_to_LSTM(x_C) ## Generate current prediction with torch.no_grad(): state = self.at_states[-1] state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state = self.core_model(x, state) ## For #tuning_cycles for cycle in range(self.tuning_cycles): print('----------------------------------------------') # Get prediction p = self.at_predictions[-1] # Calculate error loss = self.at_loss(p,x[0]) # Propagate error back through tuning horizon loss.backward(retain_graph = True) self.at_losses.append(loss.clone().detach().cpu().numpy()) print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}') # Update parameters with torch.no_grad(): ########################### BINDING ################################# if do_binding: # Calculate gradients with respect to the entires for i in range(self.tuning_length+1): self.B_grads[i] = self.Bs[i].grad # print(B_grads[tuning_length]) # Calculate overall gradients if grad_calc_binding == 'lastOfTunHor': ### version 1 grad_B = self.B_grads[0] elif grad_calc_binding == 'meanOfTunHor': ### version 2 / 3 grad_B = torch.mean(torch.stack(self.B_grads), dim=0) elif grad_calc_binding == 'weightedInTunHor': ### version 4 weighted_grads_B = [None] * (self.tuning_length+1) for i in range(self.tuning_length+1): weighted_grads_B[i] = np.power(self.grad_bias_binding, i) * self.B_grads[i] grad_B = torch.mean(torch.stack(weighted_grads_B), dim=0) # print(f'grad_B: {grad_B}') # Update parameters in time step t-H with saved gradients grad_B = grad_B.to(self.device) upd_B = self.binder.decay_update_binding_matrix_( # upd_B = self.binder.update_binding_matrix_( self.Bs[0], grad_B, self.at_learning_rate_binding, self.bm_momentum) # Compare binding matrix to ideal matrix # NOTE: ideal matrix is always identity, bc then the FBE and determinant can be calculated => provide reorder c_bm = self.binder.scale_binding_matrix(upd_B, self.scale_mode, self.scale_combo) if order is not None: c_bm = c_bm.gather(1, reorder.unsqueeze(0).expand(c_bm.shape)) if self.nxm: self.oc_grads.append(grad_B[-1]) FBE = self.evaluator.FBE_nxm_additional_features( c_bm, self.ideal_binding, self.additional_features) c_bm = self.evaluator.clear_nxm_binding_matrix(c_bm, self.additional_features) mat_loss = self.evaluator.FBE(c_bm, self.ideal_binding) if self.nxm: mat_loss = torch.stack([mat_loss, FBE, mat_loss+FBE]) self.bm_losses.append(mat_loss) print(f'loss of binding matrix (FBE): {mat_loss}') # Compute determinante of binding matrix det = torch.det(c_bm) self.bm_dets.append(det) print(f'determinante of binding matrix: {det}') # Zero out gradients for all parameters in all time steps of tuning horizon for i in range(self.tuning_length+1): self.Bs[i].requires_grad = False self.Bs[i].grad.data.zero_() # Update all parameters for all time steps for i in range(self.tuning_length+1): self.Bs[i].data = upd_B.clone().data self.Bs[i].requires_grad = True ########################### ROTATION ################################ if do_rotation: ## get gradients if self.rotation_type == 'qrotate': for i in range(self.tuning_length+1): # save grads for all parameters in all time steps of tuning horizon self.R_grads[i] = self.Rs[i].grad else: for i in range(self.tuning_length+1): # save grads for all parameters in all time steps of tuning horizon grad = [] for j in range(self.num_input_dimensions): grad.append(self.Rs[i][j].grad) self.R_grads[i] = torch.stack(grad) # print(self.R_grads[self.tuning_length]) # Calculate overall gradients if grad_calc_rotation == 'lastOfTunHor': ### version 1 grad_R = self.R_grads[0] elif grad_calc_rotation == 'meanOfTunHor': ### version 2 / 3 grad_R = torch.mean(torch.stack(self.R_grads), dim=0) elif grad_calc_rotation == 'weightedInTunHor': ### version 4 weighted_grads_R = [None] * (self.tuning_length+1) for i in range(self.tuning_length+1): weighted_grads_R[i] = np.power(self.grad_bias_rotation, i) * self.R_grads[i] grad_R = torch.mean(torch.stack(weighted_grads_R), dim=0) # print(f'grad_R: {grad_R}') grad_R = grad_R.to(self.device) if self.rotation_type == 'qrotate': # Update parameters in time step t-H with saved gradients upd_R = self.perspective_taker.update_quaternion( self.Rs[0], grad_R, self.at_learning_rate_rotation, self.r_momentum) print(f'updated quaternion: {upd_R}') # Compare quaternion values quat_loss = torch.sum(self.perspective_taker.qmul(self.ideal_quat, upd_R)) print(f'loss of quaternion: {quat_loss}') self.rv_losses.append(quat_loss) # Compute rotation matrix rotmat = self.perspective_taker.quaternion2rotmat(upd_R) # Zero out gradients for all parameters in all time steps of tuning horizon for i in range(self.tuning_length+1): self.Rs[i].requires_grad = False self.Rs[i].grad.data.zero_() # Update all parameters for all time steps for i in range(self.tuning_length+1): quat = upd_R.clone() quat.requires_grad_() self.Rs[i] = quat else: # Update parameters in time step t-H with saved gradients upd_R = self.perspective_taker.update_rotation_angles_( self.Rs[0], grad_R, self.at_learning_rate_rotation) print(f'updated angles: {upd_R}') # Save rotation angles rotang = torch.stack(upd_R) # angles: ang_diff = rotang - self.ideal_angle ang_loss = 2 - (torch.cos(torch.deg2rad(ang_diff)) + 1) print(f'loss of rotation angles: \n {ang_loss}, \n with norm {torch.norm(ang_loss)}') self.rv_losses.append(torch.norm(ang_loss)) # Compute rotation matrix rotmat = self.perspective_taker.compute_rotation_matrix_( upd_R[0], upd_R[1], upd_R[2])[0] # Zero out gradients for all parameters in all time steps of tuning horizon for i in range(self.tuning_length+1): for j in range(self.num_input_dimensions): self.Rs[i][j].requires_grad = False self.Rs[i][j].grad.data.zero_() # Update all parameters for all time steps for i in range(self.tuning_length+1): angles = [] for j in range(3): angle = upd_R[j].clone() angle.requires_grad_() angles.append(angle) self.Rs[i] = angles # Calculate and save rotation losses # matrix: mat_loss = self.mse( (torch.mm(self.ideal_rotation, torch.transpose(rotmat, 0, 1))), self.identity_matrix ) print(f'loss of rotation matrix: {mat_loss}') self.rm_losses.append(mat_loss) ########################### TRANSLATION ############################# if do_translation: ## Get gradients for i in range(self.tuning_length+1): # save grads for all parameters in all time steps of tuning horizon self.C_grads[i] = self.Cs[i].grad # print(self.C_grads[self.tuning_length]) # Calculate overall gradients if grad_calc_translation == 'lastOfTunHor': ### version 1 grad_C = self.C_grads[0] elif grad_calc_translation == 'meanOfTunHor': ### version 2 / 3 grad_C = torch.mean(torch.stack(self.C_grads), dim=0) elif grad_calc_translation == 'weightedInTunHor': ### version 4 weighted_grads_C = [None] * (self.tuning_length+1) for i in range(self.tuning_length+1): weighted_grads_C[i] = np.power(self.grad_bias_translation, i) * self.C_grads[i] grad_C = torch.mean(torch.stack(weighted_grads_C), dim=0) # Update parameters in time step t-H with saved gradients grad_C = grad_C.to(self.device) upd_C = self.perspective_taker.update_translation_bias_( self.Cs[0], grad_C, self.at_learning_rate_translation, self.c_momentum) # Compare translation bias to ideal bias trans_loss = self.mse(self.ideal_translation, upd_C) self.c_losses.append(trans_loss) print(f'loss of translation bias (MSE): {trans_loss}') # Zero out gradients for all parameters in all time steps of tuning horizon for i in range(self.tuning_length+1): self.Cs[i].requires_grad = False self.Cs[i].grad.data.zero_() # Update all parameters for all time steps for i in range(self.tuning_length+1): translation = upd_C.clone() translation.requires_grad_() self.Cs[i] = translation ####################################################################### # Initial state g_h = at_h.grad.to(self.device) g_c = at_c.grad.to(self.device) upd_h = init_state[0] - self.at_learning_rate_state * g_h upd_c = init_state[1] - self.at_learning_rate_state * g_c at_h.data = upd_h.clone().detach().requires_grad_() at_c.data = upd_c.clone().detach().requires_grad_() at_h.grad.data.zero_() at_c.grad.data.zero_() # print(f'updated init_state: {init_state}') # forward pass from t-H to t with new parameters init_state = (at_h, at_c) state = (init_state[0], init_state[1]) self.at_predictions = torch.tensor([]).to(self.device) for i in range(self.tuning_length): ########################### BINDING ################################# if do_binding: bm = self.binder.scale_binding_matrix( self.Bs[i], self.scale_mode, self.scale_combo, self.nxm_enhance, self.nxm_last_line_scale) if self.nxm: bm = bm[:-1] x_B = self.binder.bind(self.at_inputs[i], bm) else: x_B = self.at_inputs[i] ########################### ROTATION ################################ if do_rotation: if self.rotation_type == 'qrotate': x_R = self.perspective_taker.qrotate(x_B, self.Rs[i]) else: rotmat = self.perspective_taker.compute_rotation_matrix_( self.Rs[i][0], self.Rs[i][1], self.Rs[i][2]) x_R = self.perspective_taker.rotate(x_B, rotmat) else: x_R = x_B ########################### TRANSLATION ############################# if do_translation: x_C = self.perspective_taker.translate(x_R, self.Cs[i]) else: x_C = x_R ####################################################################### x = self.preprocessor.convert_data_AT_to_LSTM(x_C) state = (state[0] * state_scaler, state[1] * state_scaler) upd_prediction, state = self.core_model(x, state) self.at_predictions = torch.cat((self.at_predictions, upd_prediction.reshape(1,self.input_per_frame)), 0) # for last tuning cycle update initial state to track gradients if cycle==(self.tuning_cycles-1) and i==0: with torch.no_grad(): final_prediction = self.at_predictions[0].clone().detach().to(self.device) final_input = x.clone().detach().to(self.device) at_h = state[0].clone().detach().requires_grad_().to(self.device) at_c = state[1].clone().detach().requires_grad_().to(self.device) init_state = (at_h, at_c) state = (init_state[0], init_state[1]) self.at_states[i] = state # Update current input ########################### BINDING ################################# if do_binding: bm = self.binder.scale_binding_matrix( self.Bs[-1], self.scale_mode, self.scale_combo, self.nxm_enhance, self.nxm_last_line_scale) if self.nxm: bm = bm[:-1] x_B = self.binder.bind(o, bm) else: x_B = o ########################### ROTATION ################################ if do_rotation: if self.rotation_type == 'qrotate': x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1]) else: rotmat = self.perspective_taker.compute_rotation_matrix_( self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2]) x_R = self.perspective_taker.rotate(x_B, rotmat) else: x_R = x_B ########################### TRANSLATION ############################# if do_translation: x_C = self.perspective_taker.translate(x_R, self.Cs[-1]) else: x_C = x_R ####################################################################### x = self.preprocessor.convert_data_AT_to_LSTM(x_C) # END tuning cycle ## Generate updated prediction state = self.at_states[-1] state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state = self.core_model(x, state) ## Reorganize storage variables # observations at_final_inputs = torch.cat( (at_final_inputs, final_input.reshape(1,self.input_per_frame)), 0) self.at_inputs = torch.cat( (self.at_inputs[1:], o.reshape(1, self.num_observations, self.num_input_dimensions)), 0) # predictions at_final_predictions = torch.cat( (at_final_predictions, final_prediction.reshape(1,self.input_per_frame)), 0) self.at_predictions = torch.cat( (self.at_predictions[1:], new_prediction.reshape(1,self.input_per_frame)), 0) # END active tuning # store rest of predictions in at_final_predictions for i in range(self.tuning_length): at_final_predictions = torch.cat( (at_final_predictions, self.at_predictions[i].reshape(1,self.input_per_frame)), 0) at_final_inputs = torch.cat( (at_final_inputs, self.at_inputs[i].reshape(1,self.input_per_frame)), 0) ########################### BINDING ################################# # get final binding matrix if do_binding: final_binding_matrix = self.binder.scale_binding_matrix( self.Bs[-1].clone().detach(), self.scale_mode, self.scale_combo) print(f'final binding matrix: {final_binding_matrix}') final_binding_entries = self.Bs[-1].clone().detach() print(f'final binding entires: {final_binding_entries}') else: final_binding_entries, final_binding_matrix = None, None ########################### ROTATION ################################ # get final rotation matrix if do_rotation: if self.rotation_type == 'qrotate': final_rotation_values = self.Rs[0].clone().detach() # get final quaternion print(f'final quaternion: {final_rotation_values}') final_rotation_matrix = self.perspective_taker.quaternion2rotmat(final_rotation_values) else: final_rotation_values = [ self.Rs[0][i].clone().detach() for i in range(self.num_input_dimensions)] print(f'final euler angles: {final_rotation_values}') final_rotation_matrix = self.perspective_taker.compute_rotation_matrix_( final_rotation_values[0], final_rotation_values[1], final_rotation_values[2]) print(f'final rotation matrix: \n{final_rotation_matrix}') else: final_rotation_matrix, final_rotation_values = None, None ########################### TRANSLATION ############################# # get final translation bias if do_translation: final_translation_values = self.Cs[0].clone().detach() print(f'final translation bias: {final_translation_values}') else: final_translation_values = None ####################################################################### return [at_final_inputs, at_final_predictions, final_binding_matrix, final_binding_entries, final_rotation_values, final_rotation_matrix, final_translation_values] ############################################################################ ########## EVALUATION ##################################################### def get_result_history( self, observations, at_final_predictions): if self.nxm: pred_errors = self.evaluator.prediction_errors_nxm( observations, self.additional_features, self.num_observations, at_final_predictions, self.mse ) self.bm_losses = torch.stack(self.bm_losses) else: pred_errors = self.evaluator.prediction_errors( observations, at_final_predictions, self.mse) return [pred_errors, self.at_losses, self.bm_dets, self.bm_losses, self.rm_losses, self.rv_losses, self.c_losses]
class SEP_BINDING(): def __init__(self): ## General parameters self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') autograd.set_detect_anomaly(True) torch.set_printoptions(precision=8) ############################################################################ ########## PARAMETERS #################################################### def set_data_parameters(self): ## Define data parameters self.num_frames = 20 self.num_input_features = 15 self.num_input_dimensions = 3 self.preprocessor = Preprocessor(self.num_input_features, self.num_input_dimensions) self.evaluator = BAPTAT_evaluator(self.num_frames, self.num_input_features, self.preprocessor) self.data_at_unlike_train = False ## Note: sample needs to be changed in the future # data paths self.data_asf_path = 'Data_Compiler/S35T07.asf' self.data_amc_path = 'Data_Compiler/S35T07.amc' def set_data_parameters_(self, num_frames, num_input_features, num_input_dimesions): ## Define data parameters self.num_frames = num_frames self.num_input_features = num_input_features self.num_input_dimensions = num_input_dimesions self.preprocessor = Preprocessor(self.num_input_features, self.num_input_dimensions) self.evaluator = BAPTAT_evaluator(self.num_frames, self.num_input_features, self.preprocessor) def set_model_parameters(self): ## Define model parameters self.model_path = 'CoreLSTM/models/LSTM_46_cell.pt' def set_tuning_parameters(self): ## Define tuning parameters self.tuning_length = 10 # length of tuning horizon self.tuning_cycles = 3 # number of tuning cycles in each iteration # possible loss functions self.mse = nn.MSELoss() self.l1Loss = nn.L1Loss() # smL1Loss = nn.SmoothL1Loss() self.smL1Loss = nn.SmoothL1Loss(reduction='sum') # smL1Loss = nn.SmoothL1Loss(beta=2) # smL1Loss = nn.SmoothL1Loss(beta=0.5) # smL1Loss = nn.SmoothL1Loss(reduction='sum', beta=0.5) self.l2Loss = lambda x, y: self.mse(x, y) * (self.num_input_dimensions * self.num_input_features) self.at_loss = self.smL1Loss # define learning parameters self.at_learning_rate = 1 self.at_learning_rate_state = 0.0 self.bm_momentum = 0.0 self.at_loss_function = self.mse def set_tuning_parameters_(self, tuning_length, num_tuning_cycles, loss_function, at_learning_rate_binding, at_learning_rate_state, at_momentum_binding): ## Define tuning parameters self.tuning_length = tuning_length # length of tuning horizon self.tuning_cycles = num_tuning_cycles # number of tuning cycles in each iteration # possible loss functions self.at_loss = loss_function self.mse = nn.MSELoss() self.l1Loss = nn.L1Loss() self.smL1Loss = nn.SmoothL1Loss(reduction='sum') self.l2Loss = lambda x, y: self.mse(x, y) * (self.num_input_dimensions * self.num_input_features) # define learning parameters self.at_learning_rate = at_learning_rate_binding self.at_learning_rate_state = at_learning_rate_state self.bm_momentum = at_momentum_binding self.at_loss_function = self.mse print('Parameters set.') def init_inference_tools(self): ## Define tuning variables # general self.obs_count = 0 self.at_inputs = torch.tensor([]).to(self.device) self.at_predictions = torch.tensor([]).to(self.device) self.at_final_predictions = torch.tensor([]).to(self.device) self.at_losses = [] # state self.at_states = [] # state_optimizer = torch.optim.Adam(init_state, at_learning_rate) # binding self.binder = BinderExMat(num_features=self.num_input_features, gradient_init=True) self.ideal_binding = torch.Tensor(np.identity( self.num_input_features)).to(self.device) self.Bs = [] self.B_grads = [None] * (self.tuning_length + 1) self.B_upd = [None] * (self.tuning_length + 1) self.bm_losses = [] self.bm_dets = [] def set_comparison_values(self, ideal_binding): self.ideal_binding = ideal_binding ############################################################################ ########## INITIALIZATIONS ############################################### def load_data(self): ## Load data observations, feature_names = self.preprocessor.get_AT_data( self.data_asf_path, self.data_amc_path, self.num_frames) return observations, feature_names def init_model(self): ## Load model self.core_model = CORE_NET() self.core_model.load_state_dict(torch.load(self.model_path)) self.core_model.eval() def init_model_(self, model_path): ## Load model self.core_model = CORE_NET() self.core_model.load_state_dict(torch.load(model_path)) self.core_model.eval() # print('\nModel loaded:') # print(self.core_model) def prepare_inference(self): self.set_data_parameters() self.set_model_parameters() self.set_tuning_parameters() self.init_inference_tools() self.init_model() print( 'Ready to run AT inference for binding task! \nInitialized parameters with: \n' + f' - number of features: \t\t{self.num_input_features}\n' + f' - number of dimensions: \t{self.num_input_dimensions}\n' + f' - number of tuning cycles: \t{self.tuning_cycles}\n' + f' - size of tuning horizon: \t{self.tuning_length}\n' + f' - learning rate: \t\t{self.at_learning_rate}\n' + f' - learning rate (state): \t{self.at_learning_rate_state}\n' + f' - momentum: \t\t\t{self.bm_momentum}\n' + f' - model: \t\t\t{self.model_path}\n' + f' - number of features: \t\t{self.num_input_features}\n') def run_inference(self, observations, order, reorder): at_final_predictions = torch.tensor([]).to(self.device) ## Binding matrices # Init binding entries bm = self.binder.init_binding_matrix_det_() # bm = binder.init_binding_matrix_rand_() print(bm) for i in range(self.tuning_length + 1): matrix = bm.clone() matrix.requires_grad_() self.Bs.append(matrix) print(f'BMs different in list: {self.Bs[0] is not self.Bs[1]}') ## Core state # define scaler state_scaler = 0.95 # init state at_h = torch.zeros(1, self.core_model.hidden_size).requires_grad_().to( self.device) at_c = torch.zeros(1, self.core_model.hidden_size).requires_grad_().to( self.device) init_state = (at_h, at_c) self.at_states.append(init_state) state = (init_state[0], init_state[1]) ############################################################################ ########## FORWARD PASS ################################################## for i in range(self.tuning_length): o = observations[self.obs_count] self.at_inputs = torch.cat((self.at_inputs, o.reshape(1, self.num_input_features, self.num_input_dimensions)), 0) self.obs_count += 1 bm = self.binder.scale_binding_matrix(self.Bs[i]) x_B = self.binder.bind(o, bm) x = self.preprocessor.convert_data_AT_to_LSTM(x_B) state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state = self.core_model(x, state) self.at_states.append(state) self.at_predictions = torch.cat( (self.at_predictions, new_prediction.reshape(1, 45)), 0) ############################################################################ ########## ACTIVE TUNING ################################################## while self.obs_count < self.num_frames: # TODO folgendes evtl in function auslagern o = observations[self.obs_count] self.obs_count += 1 bm = self.binder.scale_binding_matrix(self.Bs[-1]) x_B = self.binder.bind(o, bm) x = self.preprocessor.convert_data_AT_to_LSTM(x_B) ## Generate current prediction with torch.no_grad(): state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state_new = self.core_model( x, self.at_states[-1]) ## For #tuning_cycles for cycle in range(self.tuning_cycles): print('----------------------------------------------') # Get prediction p = self.at_predictions[-1] # Calculate error # lam = 10 # loss = at_loss_function(p, x[0]) + l1Loss(p,x[0]) + lam / torch.norm(torch.Tensor(Bs[0].copy())) # loss = at_loss_function(p, x[0]) + mse(p, x[0]) # loss = l1Loss(p,x[0]) + l2Loss(p,x[0]) # loss_scale = torch.square(torch.mean(torch.norm(torch.tensor(Bs[-1]), dim=1, keepdim=True)) -1.) ##COPY????? # loss_scale = torch.square(torch.mean(torch.norm(bm.clone().detach(), dim=1, keepdim=True)) -1.) ##COPY????? # -> länge der Vektoren # print(f'loss scale: {loss_scale}') # loss_scale_factor = 0.9 # l1scale = loss_scale_factor * loss_scale # l2scale = loss_scale_factor / loss_scale # loss = l1Loss(p,x[0]) + l2scale * l2Loss(p,x[0]) # loss = l1scale * mse(p,x[0]) + l2scale * l2Loss(p,x[0]) # loss = l2Loss(p,x[0]) + mse(p,x[0]) # loss = l2Loss(p,x[0]) + loss_scale * mse(p,x[0]) # loss = loss_scale_factor * loss_scale * l2Loss(p,x[0]) + mse(p,x[0]) # loss = loss_scale_factor * loss_scale * l2Loss(p,x[0]) # loss = loss_scale_factor * loss_scale * mse(p,x[0]) # loss = self.smL1Loss(p, x[0]) loss = self.at_loss(p, x[0]) self.at_losses.append(loss.clone().detach().numpy()) print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}') # Propagate error back through tuning horizon loss.backward(retain_graph=True) # Update parameters with torch.no_grad(): # Calculate gradients with respect to the entires for i in range(self.tuning_length + 1): self.B_grads[i] = self.Bs[i].grad # print(B_grads[tuning_length]) # Calculate overall gradients ### version 1 # grad_B = B_grads[0] ### version 2 / 3 # grad_B = torch.mean(torch.stack(B_grads), 0) ### version 4 # # # # bias > 1 => favor recent # # # # bias < 1 => favor earlier bias = 1.5 weighted_grads_B = [None] * (self.tuning_length + 1) for i in range(self.tuning_length + 1): weighted_grads_B[i] = np.power(bias, i) * self.B_grads[i] grad_B = torch.mean(torch.stack(weighted_grads_B), dim=0) # print(f'grad_B: {grad_B}') # print(f'grad_B: {torch.norm(grad_B, 1)}') # Update parameters in time step t-H with saved gradients upd_B = self.binder.update_binding_matrix_( self.Bs[0], grad_B, self.at_learning_rate, self.bm_momentum) # Compare binding matrix to ideal matrix c_bm = self.binder.scale_binding_matrix(upd_B) if order is not None: c_bm = c_bm.gather( 1, reorder.unsqueeze(0).expand(c_bm.shape)) mat_loss = self.evaluator.FBE(c_bm, self.ideal_binding) self.bm_losses.append(mat_loss) print(f'loss of binding matrix (FBE): {mat_loss}') # Compute determinante of binding matrix det = torch.det(c_bm) self.bm_dets.append(det) print(f'determinante of binding matrix: {det}') # Zero out gradients for all parameters in all time steps of tuning horizon for i in range(self.tuning_length + 1): self.Bs[i].requires_grad = False self.Bs[i].grad.data.zero_() # Update all parameters for all time steps for i in range(self.tuning_length + 1): self.Bs[i].data = upd_B.clone().data self.Bs[i].requires_grad = True # print(Bs[0]) # Initial state g_h = at_h.grad g_c = at_c.grad upd_h = self.at_states[0][ 0] - self.at_learning_rate_state * g_h upd_c = self.at_states[0][ 1] - self.at_learning_rate_state * g_c at_h.data = upd_h.clone().detach().requires_grad_() at_c.data = upd_c.clone().detach().requires_grad_() at_h.grad.data.zero_() at_c.grad.data.zero_() # state_optimizer.step() # print(f'updated init_state: {init_state}') ## REORGANIZE FOR MULTIPLE CYCLES!!!!!!!!!!!!! # forward pass from t-H to t with new parameters init_state = (at_h, at_c) state = (init_state[0], init_state[1]) for i in range(self.tuning_length): bm = self.binder.scale_binding_matrix(self.Bs[i]) x_B = self.binder.bind(self.at_inputs[i], bm) x = self.preprocessor.convert_data_AT_to_LSTM(x_B) # print(f'x_B :{x_B}') state = (state[0] * state_scaler, state[1] * state_scaler) self.at_predictions[i], state = self.core_model(x, state) # for last tuning cycle update initial state to track gradients if cycle == (self.tuning_cycles - 1) and i == 0: at_h = state[0].clone().detach().requires_grad_().to( self.device) at_c = state[1].clone().detach().requires_grad_().to( self.device) init_state = (at_h, at_c) state = (init_state[0], init_state[1]) self.at_states[i + 1] = state # Update current binding bm = self.binder.scale_binding_matrix(self.Bs[-1]) x_B = self.binder.bind(o, bm) x = self.preprocessor.convert_data_AT_to_LSTM(x_B) # END tuning cycle ## Generate updated prediction state = self.at_states[-1] state = (state[0] * state_scaler, state[1] * state_scaler) new_prediction, state = self.core_model(x, state) ## Reorganize storage variables # states self.at_states.append(state) self.at_states[0][0].requires_grad = False self.at_states[0][1].requires_grad = False self.at_states = self.at_states[1:] # observations self.at_inputs = torch.cat((self.at_inputs[1:], o.reshape(1, self.num_input_features, self.num_input_dimensions)), 0) # predictions at_final_predictions = torch.cat( (at_final_predictions, self.at_predictions[0].detach().reshape( 1, 45)), 0) self.at_predictions = torch.cat( (self.at_predictions[1:], new_prediction.reshape(1, 45)), 0) # END active tuning # store rest of predictions in at_final_predictions for i in range(self.tuning_length): at_final_predictions = torch.cat( (at_final_predictions, self.at_predictions[1].reshape(1, 45)), 0) # get final binding matrix final_binding_matrix = self.binder.scale_binding_matrix( self.Bs[-1].clone().detach()) print(f'final binding matrix: {final_binding_matrix}') final_binding_entries = self.Bs[-1].clone().detach() print(f'final binding entires: {final_binding_entries}') return at_final_predictions, final_binding_matrix, final_binding_entries ############################################################################ ########## EVALUATION ##################################################### def evaluate(self, observations, at_final_predictions, feature_names, final_binding_matrix, final_binding_entries): pred_errors = self.evaluator.prediction_errors(observations, at_final_predictions, self.at_loss_function) self.evaluator.plot_prediction_errors(pred_errors) self.evaluator.plot_at_losses( self.at_losses, 'History of overall losses during active tuning') self.evaluator.plot_at_losses(self.bm_losses, 'History of binding matrix loss (FBE)') self.evaluator.plot_at_losses( self.bm_dets, 'History of binding matrix determinante') self.evaluator.plot_binding_matrix( final_binding_matrix, feature_names, 'Binding matrix showing relative contribution of observed feature to input feature' ) self.evaluator.plot_binding_matrix( final_binding_entries, feature_names, 'Binding matrix entries showing contribution of observed feature to input feature' ) # evaluator.help_visualize_devel(observations, at_final_predictions) def get_result_history(self, observations, at_final_predictions): pred_errors = self.evaluator.prediction_errors(observations, at_final_predictions, self.mse) return [pred_errors, self.at_losses, self.bm_losses, self.bm_dets]