class COMBI_BAPTAT():

    def __init__(self):
        ## General parameters 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        autograd.set_detect_anomaly(True)

        torch.set_printoptions(precision=8)

        
        ## Set default parameters
        ##   -> Can be changed during experiments 

        self.scale_mode = 'rcwSM'
        self.scale_combo = 'comp_mult'
        self.rotation_type = 'qrotate'
        self.grad_bias_binding = 1.5
        self.grad_bias_rotation = 1.5
        self.grad_bias_translation = 1.5
        self.nxm = False
        self.additional_features = None
        self.nxm_enhance = 'square', 
        self.nxm_last_line_scale = 0.1
        self.dummie_init = 0.1



    ############################################################################
    ##########  PARAMETERS  ####################################################

    def set_scale_mode(self, mode): 
        self.scale_mode = mode
        print('Reset scale mode: ' + self.scale_mode)


    def set_scale_combination(self, combination): 
        self.scale_combo = combination
        print('Reset scale combination: ' + self.scale_combo)


    def set_additional_features(self, index_addition): 
        self.additional_features = index_addition
        print(f'Additional features to the LSTM-input at indices {self.additional_features}')


    def set_nxm_enhancement(self, enhancement): 
        self.nxm_enhance = enhancement
        print(f'Enhancement for outcast line: {self.nxm_enhance}')


    def set_nxm_last_line_scale(self, scale_factor): 
        self.nxm_last_line_scale = scale_factor
        print(f'Scaler for outcast line: {self.nxm_last_line_scale}')


    def set_dummie_init_value(self, init_value):
        self.dummie_init = init_value
        print(f'Initial value for dummie line: {self.dummie_init}')

    
    def set_rotation_type(self, rotation): 
        self.rotation_type = rotation
        print('Reset type of rotation: ' + self.rotation_type)


    def set_weighted_gradient_biases(self, biases):
        # bias > 1 => favor recent
        # bias < 1 => favor earlier
        print('Reset biases for gradient weighting:')
        self.grad_bias_binding = biases[0]
        print(f'\t> binding: {self.grad_bias_binding}')
        self.grad_bias_rotation = biases[1]
        print(f'\t> rotation: {self.grad_bias_rotation}')
        self.grad_bias_translation = biases[2]
        print(f'\t> translation: {self.grad_bias_translation}')

    
    def set_data_parameters_(self, 
        num_frames, 
        num_observations, 
        num_input_features, 
        num_input_dimesions): 

        ## Define data parameters
        self.num_frames = num_frames
        self.num_observations = num_observations
        self.num_input_features = num_input_features
        self.num_input_dimensions = num_input_dimesions
        self.input_per_frame = self.num_input_features * self.num_input_dimensions
        self.nxm = (self.num_observations != self.num_input_features)
        
        self.binder = BINDER_NxM(
            num_observations=self.num_observations, 
            num_features=self.num_input_features, 
            gradient_init=True)

        self.perspective_taker = Perspective_Taker(
            self.num_input_features, 
            self.num_input_dimensions,
            rotation_gradient_init=True, 
            translation_gradient_init=True)

        self.preprocessor = Preprocessor(
            self.num_observations, 
            self.num_input_features, 
            self.num_input_dimensions)

        self.evaluator = BAPTAT_evaluator(
            self.num_frames, 
            self.num_observations, 
            self.num_input_features, 
            self.preprocessor)
        
    
    def set_tuning_parameters_(self, 
        tuning_length, 
        num_tuning_cycles, 
        loss_function, 
        at_learning_rates_BAPTAT, 
        at_learning_rate_state, 
        at_momenta_BAPTAT): 

        ## Define tuning parameters 
        self.tuning_length = tuning_length          # length of tuning horizon 
        self.tuning_cycles = num_tuning_cycles      # number of tuning cycles in each iteration 

        # possible loss functions
        self.at_loss = loss_function
        self.mse = nn.MSELoss()
        self.l1Loss = nn.L1Loss()
        self.smL1Loss = nn.SmoothL1Loss(reduction='sum')
        self.l2Loss = lambda x,y: self.mse(x, y) * (self.num_input_dimensions * self.num_input_features)

        # define learning parameters 
        self.at_learning_rate_binding = at_learning_rates_BAPTAT[0]
        self.at_learning_rate_rotation = at_learning_rates_BAPTAT[1]
        self.at_learning_rate_translation = at_learning_rates_BAPTAT[2]
        self.at_learning_rate_state = at_learning_rate_state
        self.bm_momentum = at_momenta_BAPTAT[0]
        self.r_momentum = at_momenta_BAPTAT[1]
        self.c_momentum = at_momenta_BAPTAT[2]
        self.at_loss_function = self.mse

        print('Parameters set.')


    def get_additional_features(self): 
        return self.additional_features 

    
    def get_oc_grads(self): 
        return self.oc_grads 


    def init_model_(self, model_path): 
        ## Load model
        self.core_model = CORE_NET()
        self.core_model.load_state_dict(torch.load(model_path))
        self.core_model.eval()
        self.core_model.to(self.device)

        print('Model loaded.')


    def init_inference_tools(self):
        ## Define tuning variables
        # general
        self.obs_count = 0
        self.at_inputs = torch.tensor([]).to(self.device)
        self.at_predictions = torch.tensor([]).to(self.device)
        self.at_final_predictions = torch.tensor([]).to(self.device)
        self.at_losses = []

        # state 
        self.at_states = []

        # binding
        self.Bs = []
        self.B_grads = [None] * (self.tuning_length+1)
        self.B_upd = [None] * (self.tuning_length+1)
        self.bm_losses = []
        self.bm_dets = []
        self.oc_grads = []

        # rotation
        self.Rs = []
        self.R_grads = [None] * (self.tuning_length+1)
        self.R_upd = [None] * (self.tuning_length+1)
        self.rm_losses = []
        self.rv_losses = []

        # translation
        self.Cs = []
        self.C_grads = [None] * (self.tuning_length+1)
        self.C_upd = [None] * (self.tuning_length+1)
        self.c_losses = []

    
    def set_comparison_values(self, ideal_binding, ideal_rotation, ideal_translation):
        # binding
        if ideal_binding is not None:
            self.ideal_binding = ideal_binding.to(self.device)
            if self.nxm:
                self.ideal_binding = self.binder.ideal_nxm_binding(
                    self.additional_features, self.ideal_binding).to(self.device)

        # rotation
        self.identity_matrix = torch.Tensor(np.identity(self.num_input_dimensions))
        if ideal_rotation is not None:
            (ideal_rotation_values, self.ideal_rotation) = ideal_rotation
            self.ideal_rotation = self.ideal_rotation.to(self.device)
            if self.rotation_type == 'qrotate': 
                self.ideal_quat = ideal_rotation_values.to(self.device)
                self.ideal_angle = self.perspective_taker.qeuler(self.ideal_quat, 'xyz').to(self.device)
            elif self.rotation_type == 'eulrotate': 
                self.ideal_angle = ideal_rotation_values.to(self.device)
            else: 
                print(f'ERROR: Received unknown rotation type!\n\trotation type: {self.rotation_type}')
                exit()

        # translation
        if ideal_translation is not None:
            self.ideal_translation = ideal_translation.to(self.device)


    ############################################################################
    ##########  INFERENCE  #####################################################
    
    def run_inference(self, 
        observations, 
        grad_calculations, 
        do_binding, 
        do_rotation,
        do_translation,
        order, 
        reorder):
        [grad_calc_binding, grad_calc_rotation, grad_calc_translation] = grad_calculations

        if reorder is not None:
            reorder = reorder.to(self.device)
        
        at_final_predictions = torch.tensor([]).to(self.device)
        at_final_inputs = torch.tensor([]).to(self.device)

        ###########################  BINDING  #################################
        if do_binding:
            ## Binding matrices 
            # Init binding entries 
            bm = self.binder.init_binding_matrix_det_()
            # bm = binder.init_binding_matrix_rand_()
            # print(bm)
            dummie_line = torch.ones(1,self.num_observations).to(self.device) * self.dummie_init

            for i in range(self.tuning_length+1):
                matrix = bm.clone().to(self.device)
                if self.nxm:
                    matrix = torch.cat([matrix, dummie_line])
                matrix.requires_grad_()
                self.Bs.append(matrix)
            
        ###########################  ROTATION  ################################
        if do_rotation:
            if self.rotation_type == 'qrotate': 
                ## Rotation quaternion 
                rq = self.perspective_taker.init_quaternion()
                # print(rq)

                for i in range(self.tuning_length+1):
                    quat = rq.clone().to(self.device)
                    quat.requires_grad_()
                    self.Rs.append(quat)

            elif self.rotation_type == 'eulrotate':
                ## Rotation euler angles 
                # ra = perspective_taker.init_angles_()
                # ra = torch.Tensor([[309.89], [82.234], [95.765]])
                ra = torch.Tensor([[75.0], [6.0], [128.0]])
                # print(ra)

                for i in range(self.tuning_length+1):
                    angles = []
                    for j in range(self.num_input_dimensions):
                        angle = ra[j].clone().to(self.device)
                        angle.requires_grad_()
                        angles.append(angle)
                    self.Rs.append(angles)

            else: 
                print('ERROR: Received unknown rotation type!')
                exit()

        ###########################  TRANSLATION  #############################
        if do_translation:
            tb = self.perspective_taker.init_translation_bias_()
            # print(tb)

            for i in range(self.tuning_length+1):
                transba = tb.clone().to(self.device)
                transba.requires_grad = True
                self.Cs.append(transba)

        #######################################################################

        ## Core state
        # define scaler
        state_scaler = 0.95

        # init state
        at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device)
        at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device)

        at_h.requires_grad = True
        at_c.requires_grad = True

        init_state = (at_h, at_c)
        state = (init_state[0], init_state[1])

        ############################################################################
        ##########  FORWARD PASS  ##################################################

        for i in range(self.tuning_length):
            o = observations[self.obs_count].to(self.device)
            self.at_inputs = torch.cat((
                self.at_inputs, 
                o.reshape(1, self.num_observations, self.num_input_dimensions)), 0)
            self.obs_count += 1

            ###########################  BINDING  #################################
            if do_binding:
                bm = self.binder.scale_binding_matrix(
                    self.Bs[i], 
                    self.scale_mode, 
                    self.scale_combo, 
                    self.nxm_enhance, 
                    self.nxm_last_line_scale)
                if self.nxm:
                    bm = bm[:-1]
                x_B = self.binder.bind(o, bm)
            else: 
                x_B = o
            ###########################  ROTATION  ################################
            if do_rotation:
                if self.rotation_type == 'qrotate': 
                    x_R = self.perspective_taker.qrotate(x_B, self.Rs[i])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                    x_R = self.perspective_taker.rotate(x_B, rotmat)
            else: 
                x_R = x_B
            ###########################  TRANSLATION  #############################
            if do_translation:
                x_C = self.perspective_taker.translate(x_R, self.Cs[i])
            else:
                x_C = x_R
            #######################################################################
            
            x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)  
            self.at_states.append(state)
            self.at_predictions = torch.cat((self.at_predictions, new_prediction.reshape(1,self.input_per_frame)), 0)

        ############################################################################
        ##########  ACTIVE TUNING ##################################################

        while self.obs_count < self.num_frames:
            # TODO folgendes evtl in function auslagern
            o = observations[self.obs_count].to(self.device)
            self.obs_count += 1

            ###########################  BINDING  #################################
            if do_binding:
                bm = self.binder.scale_binding_matrix(
                    self.Bs[-1], 
                    self.scale_mode, 
                    self.scale_combo,
                    self.nxm_enhance, 
                    self.nxm_last_line_scale)
                if self.nxm:
                    bm = bm[:-1]
                x_B = self.binder.bind(o, bm)    
            else: 
                x_B = o        
            ###########################  ROTATION  ################################
            if do_rotation:
                if self.rotation_type == 'qrotate': 
                    x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                    x_R = self.perspective_taker.rotate(x_B, rotmat)
            else: 
                x_R = x_B
            ###########################  TRANSLATION  #############################
            if do_translation:
                x_C = self.perspective_taker.translate(x_R, self.Cs[-1])
            else: 
                x_C = x_R
            #######################################################################

            x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            ## Generate current prediction 
            with torch.no_grad():
                state = self.at_states[-1]
                state = (state[0] * state_scaler, state[1] * state_scaler)
                new_prediction, state = self.core_model(x, state)

            ## For #tuning_cycles 
            for cycle in range(self.tuning_cycles):
                print('----------------------------------------------')

                # Get prediction
                p = self.at_predictions[-1]

                # Calculate error 
                loss = self.at_loss(p,x[0])

                # Propagate error back through tuning horizon 
                loss.backward(retain_graph = True)

                self.at_losses.append(loss.clone().detach().cpu().numpy())
                print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}')

                # Update parameters 
                with torch.no_grad():
                    ###########################  BINDING  #################################
                    if do_binding:
                        # Calculate gradients with respect to the entires 
                        for i in range(self.tuning_length+1):
                            self.B_grads[i] = self.Bs[i].grad
                        # print(B_grads[tuning_length])                    
                        
                        # Calculate overall gradients 
                        if grad_calc_binding == 'lastOfTunHor':
                            ### version 1
                            grad_B = self.B_grads[0]
                        elif grad_calc_binding == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_B = torch.mean(torch.stack(self.B_grads), dim=0)
                        elif grad_calc_binding == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_B = [None] * (self.tuning_length+1)
                            for i in range(self.tuning_length+1):
                                weighted_grads_B[i] = np.power(self.grad_bias_binding, i) * self.B_grads[i]
                            grad_B = torch.mean(torch.stack(weighted_grads_B), dim=0)
                        # print(f'grad_B: {grad_B}')
                        
                        # Update parameters in time step t-H with saved gradients 
                        grad_B = grad_B.to(self.device)
                        upd_B = self.binder.decay_update_binding_matrix_(
                        # upd_B = self.binder.update_binding_matrix_(
                            self.Bs[0], grad_B, self.at_learning_rate_binding, self.bm_momentum)

                        # Compare binding matrix to ideal matrix
                        # NOTE: ideal matrix is always identity, bc then the FBE and determinant can be calculated => provide reorder
                        c_bm = self.binder.scale_binding_matrix(upd_B, self.scale_mode, self.scale_combo)
                        if order is not None: 
                            c_bm = c_bm.gather(1, reorder.unsqueeze(0).expand(c_bm.shape))

                        if self.nxm:
                            self.oc_grads.append(grad_B[-1])
                            FBE = self.evaluator.FBE_nxm_additional_features(
                                c_bm, self.ideal_binding, self.additional_features)
                            c_bm = self.evaluator.clear_nxm_binding_matrix(c_bm, self.additional_features)
                        
                        mat_loss = self.evaluator.FBE(c_bm, self.ideal_binding)

                        if self.nxm:
                            mat_loss = torch.stack([mat_loss, FBE, mat_loss+FBE])
                        self.bm_losses.append(mat_loss)
                        print(f'loss of binding matrix (FBE): {mat_loss}')

                        # Compute determinante of binding matrix
                        det = torch.det(c_bm)
                        self.bm_dets.append(det)
                        print(f'determinante of binding matrix: {det}')


                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length+1):
                            self.Bs[i].requires_grad = False
                            self.Bs[i].grad.data.zero_()

                        # Update all parameters for all time steps 
                        for i in range(self.tuning_length+1):
                            self.Bs[i].data = upd_B.clone().data
                            self.Bs[i].requires_grad = True

                    ###########################  ROTATION  ################################
                    if do_rotation:
                        ## get gradients
                        if self.rotation_type == 'qrotate': 
                            for i in range(self.tuning_length+1):
                                # save grads for all parameters in all time steps of tuning horizon
                                self.R_grads[i] = self.Rs[i].grad
                        else: 
                            for i in range(self.tuning_length+1):
                                # save grads for all parameters in all time steps of tuning horizon
                                grad = []
                                for j in range(self.num_input_dimensions):
                                    grad.append(self.Rs[i][j].grad) 
                                self.R_grads[i] = torch.stack(grad)
                        # print(self.R_grads[self.tuning_length])
                        
                        # Calculate overall gradients 
                        if grad_calc_rotation == 'lastOfTunHor':
                            ### version 1
                            grad_R = self.R_grads[0]
                        elif grad_calc_rotation == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_R = torch.mean(torch.stack(self.R_grads), dim=0)
                        elif grad_calc_rotation == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_R = [None] * (self.tuning_length+1)
                            for i in range(self.tuning_length+1):
                                weighted_grads_R[i] = np.power(self.grad_bias_rotation, i) * self.R_grads[i]
                            grad_R = torch.mean(torch.stack(weighted_grads_R), dim=0)
                        # print(f'grad_R: {grad_R}')

                        grad_R = grad_R.to(self.device)
                        if self.rotation_type == 'qrotate': 
                            # Update parameters in time step t-H with saved gradients 
                            upd_R = self.perspective_taker.update_quaternion(
                                self.Rs[0], grad_R, self.at_learning_rate_rotation, self.r_momentum)
                            print(f'updated quaternion: {upd_R}')

                            # Compare quaternion values
                            quat_loss = torch.sum(self.perspective_taker.qmul(self.ideal_quat, upd_R))
                            print(f'loss of quaternion: {quat_loss}')
                            self.rv_losses.append(quat_loss)
                            # Compute rotation matrix
                            rotmat = self.perspective_taker.quaternion2rotmat(upd_R)

                            # Zero out gradients for all parameters in all time steps of tuning horizon
                            for i in range(self.tuning_length+1):
                                self.Rs[i].requires_grad = False
                                self.Rs[i].grad.data.zero_()

                            # Update all parameters for all time steps 
                            for i in range(self.tuning_length+1):
                                quat = upd_R.clone()
                                quat.requires_grad_()
                                self.Rs[i] = quat

                        else: 
                            # Update parameters in time step t-H with saved gradients 
                            upd_R = self.perspective_taker.update_rotation_angles_(
                                self.Rs[0], grad_R, self.at_learning_rate_rotation)
                            print(f'updated angles: {upd_R}')

                            # Save rotation angles
                            rotang = torch.stack(upd_R)
                            # angles:
                            ang_diff = rotang - self.ideal_angle
                            ang_loss = 2 - (torch.cos(torch.deg2rad(ang_diff)) + 1)
                            print(f'loss of rotation angles: \n  {ang_loss}, \n  with norm {torch.norm(ang_loss)}')
                            self.rv_losses.append(torch.norm(ang_loss))
                            # Compute rotation matrix
                            rotmat = self.perspective_taker.compute_rotation_matrix_(
                                upd_R[0], upd_R[1], upd_R[2])[0]
                            
                            # Zero out gradients for all parameters in all time steps of tuning horizon
                            for i in range(self.tuning_length+1):
                                for j in range(self.num_input_dimensions):
                                    self.Rs[i][j].requires_grad = False
                                    self.Rs[i][j].grad.data.zero_()

                            # Update all parameters for all time steps 
                            for i in range(self.tuning_length+1):
                                angles = []
                                for j in range(3):
                                    angle = upd_R[j].clone()
                                    angle.requires_grad_()
                                    angles.append(angle)
                                self.Rs[i] = angles

                        # Calculate and save rotation losses
                        # matrix: 
                        mat_loss = self.mse(
                            (torch.mm(self.ideal_rotation, torch.transpose(rotmat, 0, 1))), 
                            self.identity_matrix
                        )
                        print(f'loss of rotation matrix: {mat_loss}')
                        self.rm_losses.append(mat_loss)
                    
                    ###########################  TRANSLATION  #############################
                    if do_translation:
                        ## Get gradients 
                        for i in range(self.tuning_length+1):
                            # save grads for all parameters in all time steps of tuning horizon 
                            self.C_grads[i] = self.Cs[i].grad

                        # print(self.C_grads[self.tuning_length])
                        
                        # Calculate overall gradients 
                        if grad_calc_translation == 'lastOfTunHor':
                            ### version 1
                            grad_C = self.C_grads[0]
                        elif grad_calc_translation == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_C = torch.mean(torch.stack(self.C_grads), dim=0)
                        elif grad_calc_translation == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_C = [None] * (self.tuning_length+1)
                            for i in range(self.tuning_length+1):
                                weighted_grads_C[i] = np.power(self.grad_bias_translation, i) * self.C_grads[i]
                            grad_C = torch.mean(torch.stack(weighted_grads_C), dim=0)
                        
                        # Update parameters in time step t-H with saved gradients 
                        grad_C = grad_C.to(self.device)
                        upd_C = self.perspective_taker.update_translation_bias_(
                            self.Cs[0], grad_C, self.at_learning_rate_translation, self.c_momentum)
                        
                        # Compare translation bias to ideal bias
                        trans_loss = self.mse(self.ideal_translation, upd_C)
                        self.c_losses.append(trans_loss)
                        print(f'loss of translation bias (MSE): {trans_loss}')
                        
                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length+1):
                            self.Cs[i].requires_grad = False
                            self.Cs[i].grad.data.zero_()
                        
                        # Update all parameters for all time steps 
                        for i in range(self.tuning_length+1):
                            translation = upd_C.clone()
                            translation.requires_grad_()
                            self.Cs[i] = translation

                    #######################################################################

                    # Initial state
                    g_h = at_h.grad.to(self.device)
                    g_c = at_c.grad.to(self.device)

                    upd_h = init_state[0] - self.at_learning_rate_state * g_h
                    upd_c = init_state[1] - self.at_learning_rate_state * g_c

                    at_h.data = upd_h.clone().detach().requires_grad_()
                    at_c.data = upd_c.clone().detach().requires_grad_()

                    at_h.grad.data.zero_()
                    at_c.grad.data.zero_()

                    # print(f'updated init_state: {init_state}')
                

                # forward pass from t-H to t with new parameters 
                init_state = (at_h, at_c)
                state = (init_state[0], init_state[1])
                self.at_predictions = torch.tensor([]).to(self.device)
                for i in range(self.tuning_length):

                    ###########################  BINDING  #################################
                    if do_binding:
                        bm = self.binder.scale_binding_matrix(
                            self.Bs[i], 
                            self.scale_mode, 
                            self.scale_combo, 
                            self.nxm_enhance, 
                            self.nxm_last_line_scale)
                        if self.nxm:
                            bm = bm[:-1]
                        x_B = self.binder.bind(self.at_inputs[i], bm)
                    else:
                        x_B = self.at_inputs[i]
                    ###########################  ROTATION  ################################
                    if do_rotation:
                        if self.rotation_type == 'qrotate': 
                            x_R = self.perspective_taker.qrotate(x_B, self.Rs[i])
                        else:
                            rotmat = self.perspective_taker.compute_rotation_matrix_(
                                self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                            x_R = self.perspective_taker.rotate(x_B, rotmat)
                    else: 
                        x_R = x_B
                    ###########################  TRANSLATION  #############################
                    if do_translation:
                        x_C = self.perspective_taker.translate(x_R, self.Cs[i])
                    else:
                        x_C = x_R
                    #######################################################################

                    x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

                    state = (state[0] * state_scaler, state[1] * state_scaler)
                    upd_prediction, state = self.core_model(x, state)
                    self.at_predictions = torch.cat((self.at_predictions, upd_prediction.reshape(1,self.input_per_frame)), 0)
                    
                    # for last tuning cycle update initial state to track gradients 
                    if cycle==(self.tuning_cycles-1) and i==0: 
                        with torch.no_grad():
                            final_prediction = self.at_predictions[0].clone().detach().to(self.device)
                            final_input = x.clone().detach().to(self.device)

                        at_h = state[0].clone().detach().requires_grad_().to(self.device)
                        at_c = state[1].clone().detach().requires_grad_().to(self.device)
                        init_state = (at_h, at_c)
                        state = (init_state[0], init_state[1])

                    self.at_states[i] = state 

                # Update current input
                ###########################  BINDING  #################################
                if do_binding:
                    bm = self.binder.scale_binding_matrix(
                        self.Bs[-1], 
                        self.scale_mode, 
                        self.scale_combo, 
                        self.nxm_enhance, 
                        self.nxm_last_line_scale)
                    if self.nxm:
                        bm = bm[:-1]
                    x_B = self.binder.bind(o, bm)
                else:
                    x_B = o
                ###########################  ROTATION  ################################
                if do_rotation: 
                    if self.rotation_type == 'qrotate': 
                        x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1])
                    else:
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                        x_R = self.perspective_taker.rotate(x_B, rotmat)
                else: 
                    x_R = x_B
                ###########################  TRANSLATION  #############################
                if do_translation: 
                    x_C = self.perspective_taker.translate(x_R, self.Cs[-1])
                else: 
                    x_C = x_R
                #######################################################################

                x = self.preprocessor.convert_data_AT_to_LSTM(x_C)


            # END tuning cycle        

            ## Generate updated prediction 
            state = self.at_states[-1]
            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)

            ## Reorganize storage variables            
            # observations
            at_final_inputs = torch.cat(
                (at_final_inputs, 
                final_input.reshape(1,self.input_per_frame)), 0)
            self.at_inputs = torch.cat(
                (self.at_inputs[1:], 
                o.reshape(1, self.num_observations, self.num_input_dimensions)), 0)
            
            # predictions
            at_final_predictions = torch.cat(
                (at_final_predictions, 
                final_prediction.reshape(1,self.input_per_frame)), 0)
            self.at_predictions = torch.cat(
                (self.at_predictions[1:], 
                new_prediction.reshape(1,self.input_per_frame)), 0)

        # END active tuning
        
        # store rest of predictions in at_final_predictions
        for i in range(self.tuning_length): 
            at_final_predictions = torch.cat(
                (at_final_predictions, 
                self.at_predictions[i].reshape(1,self.input_per_frame)), 0)
            at_final_inputs = torch.cat(
                (at_final_inputs, 
                self.at_inputs[i].reshape(1,self.input_per_frame)), 0)

        ###########################  BINDING  #################################
        # get final binding matrix
        if do_binding:
            final_binding_matrix = self.binder.scale_binding_matrix(
                self.Bs[-1].clone().detach(), self.scale_mode, self.scale_combo)
            print(f'final binding matrix: {final_binding_matrix}')
            final_binding_entries = self.Bs[-1].clone().detach()
            print(f'final binding entires: {final_binding_entries}')

        else: 
            final_binding_entries, final_binding_matrix = None, None

        ###########################  ROTATION  ################################
        # get final rotation matrix
        if do_rotation:
            if self.rotation_type == 'qrotate': 
                final_rotation_values = self.Rs[0].clone().detach()
                # get final quaternion
                print(f'final quaternion: {final_rotation_values}')
                final_rotation_matrix = self.perspective_taker.quaternion2rotmat(final_rotation_values)
            else:
                final_rotation_values = [
                    self.Rs[0][i].clone().detach() 
                    for i in range(self.num_input_dimensions)]
                print(f'final euler angles: {final_rotation_values}')
                final_rotation_matrix = self.perspective_taker.compute_rotation_matrix_(
                    final_rotation_values[0], 
                    final_rotation_values[1], 
                    final_rotation_values[2])
            
            print(f'final rotation matrix: \n{final_rotation_matrix}')

        else: 
            final_rotation_matrix, final_rotation_values = None, None

        ###########################  TRANSLATION  #############################
        # get final translation bias
        if do_translation:
            final_translation_values = self.Cs[0].clone().detach()
            print(f'final translation bias: {final_translation_values}')

        else: 
            final_translation_values = None

        #######################################################################


        return [at_final_inputs,
                at_final_predictions, 
                final_binding_matrix, 
                final_binding_entries,
                final_rotation_values, 
                final_rotation_matrix, 
                final_translation_values]




    ############################################################################
    ##########  EVALUATION #####################################################

    def get_result_history(
        self, 
        observations, 
        at_final_predictions):

        if self.nxm:
            pred_errors = self.evaluator.prediction_errors_nxm(
                observations, 
                self.additional_features, 
                self.num_observations, 
                at_final_predictions, 
                self.mse
            )
            self.bm_losses = torch.stack(self.bm_losses)

        else:
            pred_errors = self.evaluator.prediction_errors(
                observations, 
                at_final_predictions, 
                self.mse)

        return [pred_errors, 
                self.at_losses, 
                self.bm_dets, 
                self.bm_losses, 
                self.rm_losses, 
                self.rv_losses,
                self.c_losses]
예제 #2
0
C_upd = [None] * (tuning_length + 1)
c_losses = []
c_norm = 1

############################################################################
##########  INITIALIZATIONS  ###############################################

## Load data
observations, feature_names = preprocessor.get_AT_data(data_asf_path,
                                                       data_amc_path,
                                                       num_frames)

## Load model
core_model = CORE_NET()
core_model.load_state_dict(torch.load(model_path))
core_model.eval()

## Translation biases
tb = perspective_taker.init_translation_bias_()

for i in range(tuning_length + 1):
    transba = copy.deepcopy(tb)
    transba.requires_grad = True
    Cs.append(transba)

print(f'BMs different in list: {Cs[0] is not Cs[1]}')

## Core state
at_h = torch.zeros(core_model.hidden_num, 1,
                   core_model.hidden_size).requires_grad_()
at_c = torch.zeros(core_model.hidden_num, 1,
예제 #3
0
class BAPTAT_BASICS(ABC):
    def __init__(self, num_features, num_observations, num_dimensions):
        self.num_features = num_features
        self.num_observations = num_observations
        self.num_dimensions = num_dimensions

        print('Initialized BAPTAT basic module.')

    def set_data_parameters_(self, num_frames, num_observations,
                             num_input_features, num_input_dimesions):

        ## Define data parameters
        self.num_frames = num_frames
        self.num_observations = num_observations
        self.num_input_features = num_input_features
        self.num_input_dimensions = num_input_dimesions
        self.input_per_frame = self.num_input_features * self.num_input_dimensions
        self.nxm = (self.num_observations != self.num_input_features)

        self.binder = BINDER_NxM(num_observations=self.num_observations,
                                 num_features=self.num_input_features,
                                 gradient_init=True)

        self.perspective_taker = Perspective_Taker(
            self.num_input_features,
            self.num_input_dimensions,
            rotation_gradient_init=True,
            translation_gradient_init=True)

        self.preprocessor = Preprocessor(self.num_observations,
                                         self.num_input_features,
                                         self.num_input_dimensions)

        self.evaluator = BAPTAT_evaluator(self.num_frames,
                                          self.num_observations,
                                          self.num_input_features,
                                          self.preprocessor)

    def init_model_(self, model_path):
        ## Load model
        self.core_model = CORE_NET()
        self.core_model.load_state_dict(torch.load(model_path))
        self.core_model.eval()

        print('Model loaded.')

    def init_general_inference_tools(self):
        # general
        self.obs_count = 0
        self.at_inputs = torch.tensor([]).to(self.device)
        self.at_predictions = torch.tensor([]).to(self.device)
        self.at_final_predictions = torch.tensor([]).to(self.device)
        self.at_losses = []

        # state
        self.at_states = []

    def init_binding_inference_tools(self):
        self.Bs = []
        self.B_grads = [None] * (self.tuning_length + 1)
        self.B_upd = [None] * (self.tuning_length + 1)
        self.bm_losses = []
        self.bm_dets = []
        self.oc_grads = []

    def init_rotation_inference_tools(self):
        self.Rs = []
        self.R_grads = [None] * (self.tuning_length + 1)
        self.R_upd = [None] * (self.tuning_length + 1)
        self.rm_losses = []
        self.rv_losses = []

    def init_translation_inference_tools(self):
        self.Cs = []
        self.C_grads = [None] * (self.tuning_length + 1)
        self.C_upd = [None] * (self.tuning_length + 1)
        self.c_losses = []

    def set_comparison_values_binding(self, ideal_binding):
        self.ideal_binding = ideal_binding.to(self.device)
        if self.nxm:
            self.ideal_binding = self.binder.ideal_nxm_binding(
                self.additional_features, self.ideal_binding).to(self.device)

    def set_comparison_values(self, ideal_rotation_values,
                              ideal_rotation_matrix):
        self.ideal_rotation = ideal_rotation_matrix
        if self.rotation_type == 'qrotate':
            self.ideal_quat = ideal_rotation_values
            self.ideal_angle = self.perspective_taker.qeuler(
                self.ideal_quat, 'xyz')
        elif self.rotation_type == 'eulrotate':
            self.ideal_angle = ideal_rotation_values
        else:
            print(
                f'ERROR: Received unknown rotation type!\n\trotation type: {self.rotation_type}'
            )
            exit()

    def set_comparison_values_translation(self, ideal_translation):
        self.ideal_translation = ideal_translation

    def perform_binding(self, data, bm_activations):
        bm = self.binder.scale_binding_matrix(bm_activations, self.scale_mode,
                                              self.scale_combo,
                                              self.nxm_enhance,
                                              self.nxm_last_line_scale)
        if self.nxm:
            bm = bm[:-1]

        return self.binder.bind(data, bm)

    def perform_rotation(self, data, rotation_vals):
        if self.rotation_type == 'qrotate':
            return self.perspective_taker.qrotate(data, rotation_vals)
        else:
            rotmat = self.perspective_taker.compute_rotation_matrix_(
                rotation_vals[0], rotation_vals[1], rotation_vals[2])
            return self.perspective_taker.rotate(data, rotmat)

    def perform_translation(self, data, translation_vals):
        return self.perspective_taker.translate(data, translation_vals)
class SEP_BINDING_GESTALTEN():

    def __init__(self):
        ## General parameters 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        autograd.set_detect_anomaly(True)

        torch.set_printoptions(precision=8)

        
        ## Set default parameters
        ##   -> Can be changed during experiments 

        self.scale_mode = 'rcwSM'
        self.scale_combo = 'comp_mult'
        self.grad_bias = 1.5
        self.nxm = False
        self.additional_features = None
        self.nxm_enhance = 'square', 
        self.nxm_last_line_scale = 0.1
        self.dummie_init = 0.1



    ############################################################################
    ##########  PARAMETERS  ####################################################

    def set_scale_mode(self, mode): 
        self.scale_mode = mode
        print('Reset scale mode: ' + self.scale_mode)


    def set_scale_combination(self, combination): 
        self.scale_combo = combination
        print('Reset scale combination: ' + self.scale_combo)


    def set_additional_features(self, index_addition): 
        self.additional_features = index_addition
        print(f'Additional features to the LSTM-input at indices {self.additional_features}')


    def set_nxm_enhancement(self, enhancement): 
        self.nxm_enhance = enhancement
        print(f'Enhancement for outcast line: {self.nxm_enhance}')


    def set_nxm_last_line_scale(self, scale_factor): 
        self.nxm_last_line_scale = scale_factor
        print(f'Scaler for outcast line: {self.nxm_last_line_scale}')


    def set_dummie_init_value(self, init_value):
        self.dummie_init = init_value
        print(f'Initial value for dummie line: {self.dummie_init}')


    def set_weighted_gradient_bias(self, bias):
        # bias > 1 => favor recent
        # bias < 1 => favor earlier
        self.grad_bias = bias
        print(f'Reset bias for gradient weighting: {self.grad_bias}')

    
    def set_data_parameters_(self, 
        num_frames, 
        num_observations, 
        num_input_features, 
        num_input_dimesions): 

        ## Define data parameters
        self.num_frames = num_frames
        self.num_observations = num_observations
        self.num_input_features = num_input_features
        self.num_input_dimensions = num_input_dimesions
        self.input_per_frame = self.num_input_features * self.num_input_dimensions
        self.nxm = (self.num_observations != self.num_input_features)
        
        self.binder = BINDER_NxM(
            num_observations=self.num_observations, 
            num_features=self.num_input_features, 
            gradient_init=True)

        self.preprocessor = Preprocessor(self.num_observations, self.num_input_features, self.num_input_dimensions)
        self.evaluator = BAPTAT_evaluator(self.num_frames, self.num_observations, self.num_input_features, self.preprocessor)
        
    
    def set_tuning_parameters_(self, 
        tuning_length, 
        num_tuning_cycles, 
        loss_function, 
        at_learning_rate_binding, 
        at_learning_rate_state, 
        at_momentum_binding):
         
        ## Define tuning parameters 
        self.tuning_length = tuning_length          # length of tuning horizon 
        self.tuning_cycles = num_tuning_cycles      # number of tuning cycles in each iteration 

        # possible loss functions
        self.at_loss = loss_function
        self.mse = nn.MSELoss()
        self.l1Loss = nn.L1Loss()
        self.smL1Loss = nn.SmoothL1Loss(reduction='sum')
        self.l2Loss = lambda x,y: self.mse(x, y) * (self.num_input_dimensions * self.num_input_features)

        # define learning parameters 
        self.at_learning_rate = at_learning_rate_binding
        self.at_learning_rate_state = at_learning_rate_state
        self.bm_momentum = at_momentum_binding
        self.at_loss_function = self.mse

        print('Parameters set.')


    def get_additional_features(self): 
        return self.additional_features 

    
    def get_oc_grads(self): 
        return self.oc_grads 


    def init_model_(self, model_path): 
        ## Load model
        self.core_model = CORE_NET()
        self.core_model.load_state_dict(torch.load(model_path))
        self.core_model.eval()
        self.core_model.to(self.device)

        print('Model loaded.')


    def init_inference_tools(self):
        ## Define tuning variables
        # general
        self.obs_count = 0
        self.at_inputs = torch.tensor([]).to(self.device)
        self.at_predictions = torch.tensor([]).to(self.device)
        self.at_final_predictions = torch.tensor([]).to(self.device)
        self.at_losses = []

        # state 
        self.at_states = []
        # state_optimizer = torch.optim.Adam(init_state, at_learning_rate)

        # binding
        self.ideal_binding = torch.Tensor(np.identity(self.num_input_features)).to(self.device)

        self.Bs = []
        self.B_grads = [None] * (self.tuning_length+1)
        self.B_upd = [None] * (self.tuning_length+1)
        self.bm_losses = []
        self.bm_dets = []

        self.oc_grads = []

    
    def set_comparison_values(self, ideal_binding):
        self.ideal_binding = ideal_binding.to(self.device)
        if self.nxm:
            self.ideal_binding = self.binder.ideal_nxm_binding(
                self.additional_features, self.ideal_binding).to(self.device)



    ############################################################################
    ##########  INFERENCE  #####################################################
    
    def run_inference(self, observations, grad_calculation, order, reorder):

        if reorder is not None:
            reorder = reorder.to(self.device)
        
        at_final_predictions = torch.tensor([]).to(self.device)
        at_final_inputs = torch.tensor([]).to(self.device)

        ## Binding matrices 
        # Init binding entries 
        bm = self.binder.init_binding_matrix_det_()
        # bm = binder.init_binding_matrix_rand_()
        # print(bm)
        dummie_line = torch.ones(1,self.num_observations).to(self.device) * self.dummie_init

        for i in range(self.tuning_length+1):
            matrix = bm.clone().to(self.device)
            if self.nxm:
                matrix = torch.cat([matrix, dummie_line])
            matrix.requires_grad_()
            self.Bs.append(matrix)
            
        # print(f'BMs different in list: {self.Bs[0] is not self.Bs[1]}')

        ## Core state
        # define scaler
        state_scaler = 0.95

        # init state
        at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device)
        at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device)

        at_h.requires_grad = True
        at_c.requires_grad = True

        init_state = (at_h, at_c)
        state = (init_state[0], init_state[1])

        ############################################################################
        ##########  FORWARD PASS  ##################################################

        for i in range(self.tuning_length):
            o = observations[self.obs_count].to(self.device)
            self.at_inputs = torch.cat((self.at_inputs, o.reshape(1, self.num_observations, self.num_input_dimensions)), 0)
            self.obs_count += 1

            bm = self.binder.scale_binding_matrix(
                self.Bs[i], 
                self.scale_mode, 
                self.scale_combo, 
                self.nxm_enhance, 
                self.nxm_last_line_scale)
            if self.nxm:
                bm = bm[:-1]
            x_B = self.binder.bind(o, bm)
            
            x = self.preprocessor.convert_data_AT_to_LSTM(x_B)

            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)  
            self.at_states.append(state)
            self.at_predictions = torch.cat((self.at_predictions, new_prediction.reshape(1,self.input_per_frame)), 0)

        ############################################################################
        ##########  ACTIVE TUNING ##################################################

        while self.obs_count < self.num_frames:
            # TODO folgendes evtl in function auslagern
            o = observations[self.obs_count].to(self.device)
            self.obs_count += 1

            bm = self.binder.scale_binding_matrix(
                self.Bs[-1], 
                self.scale_mode, 
                self.scale_combo,
                self.nxm_enhance, 
                self.nxm_last_line_scale)
            if self.nxm:
                bm = bm[:-1]
            x_B = self.binder.bind(o, bm)
            
            x = self.preprocessor.convert_data_AT_to_LSTM(x_B)

            ## Generate current prediction 
            with torch.no_grad():
                state = self.at_states[-1]
                state = (state[0] * state_scaler, state[1] * state_scaler)
                new_prediction, state = self.core_model(x, state)

            ## For #tuning_cycles 
            for cycle in range(self.tuning_cycles):
                print('----------------------------------------------')

                # Get prediction
                p = self.at_predictions[-1]

                # Calculate error 
                loss = self.at_loss(p,x[0])

                # Propagate error back through tuning horizon 
                loss.backward(retain_graph = True)

                self.at_losses.append(loss.clone().detach().cpu().numpy())
                print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}')

                # Update parameters 
                with torch.no_grad():
                    
                    # Calculate gradients with respect to the entires 
                    for i in range(self.tuning_length+1):
                        self.B_grads[i] = self.Bs[i].grad

                    # print(B_grads[tuning_length])
                    
                    
                    # Calculate overall gradients 
                    if grad_calculation == 'lastOfTunHor':
                        ### version 1
                        grad_B = self.B_grads[0]
                    elif grad_calculation == 'meanOfTunHor':
                        ### version 2 / 3
                        grad_B = torch.mean(torch.stack(self.B_grads), dim=0)
                    elif grad_calculation == 'weightedInTunHor':
                        ### version 4
                        weighted_grads_B = [None] * (self.tuning_length+1)
                        for i in range(self.tuning_length+1):
                            weighted_grads_B[i] = np.power(self.grad_bias, i) * self.B_grads[i]
                        grad_B = torch.mean(torch.stack(weighted_grads_B), dim=0)
                    
                    # print(f'grad_B: {grad_B}')
                    # print(f'grad_B: {torch.norm(grad_B, 1)}')
                    

                    # Update parameters in time step t-H with saved gradients 
                    grad_B = grad_B.to(self.device)
                    upd_B = self.binder.update_binding_matrix_(self.Bs[0], grad_B, self.at_learning_rate, self.bm_momentum)

                    # Compare binding matrix to ideal matrix
                    # NOTE: ideal matrix is always identity, bc then the FBE and determinant can be calculated => provide reorder
                    c_bm = self.binder.scale_binding_matrix(upd_B, self.scale_mode, self.scale_combo)
                    if order is not None: 
                        c_bm = c_bm.gather(1, reorder.unsqueeze(0).expand(c_bm.shape))

                    if self.nxm:
                        self.oc_grads.append(grad_B[-1])
                        FBE = self.evaluator.FBE_nxm_additional_features(c_bm, self.ideal_binding, self.additional_features)
                        c_bm = self.evaluator.clear_nxm_binding_matrix(c_bm, self.additional_features)
                    
                    mat_loss = self.evaluator.FBE(c_bm, self.ideal_binding)

                    if self.nxm:
                        mat_loss = torch.stack([mat_loss, FBE, mat_loss+FBE])
                    self.bm_losses.append(mat_loss)
                    print(f'loss of binding matrix (FBE): {mat_loss}')

                    # Compute determinante of binding matrix
                    det = torch.det(c_bm)
                    self.bm_dets.append(det)
                    print(f'determinante of binding matrix: {det}')


                    # Zero out gradients for all parameters in all time steps of tuning horizon
                    for i in range(self.tuning_length+1):
                        self.Bs[i].requires_grad = False
                        self.Bs[i].grad.data.zero_()

                    # Update all parameters for all time steps 
                    for i in range(self.tuning_length+1):
                        self.Bs[i].data = upd_B.clone().data
                        self.Bs[i].requires_grad = True
                    
                    # print(Bs[0])


                    # Initial state
                    g_h = at_h.grad.to(self.device)
                    g_c = at_c.grad.to(self.device)

                    upd_h = init_state[0] - self.at_learning_rate_state * g_h
                    upd_c = init_state[1] - self.at_learning_rate_state * g_c

                    at_h.data = upd_h.clone().detach().requires_grad_()
                    at_c.data = upd_c.clone().detach().requires_grad_()

                    at_h.grad.data.zero_()
                    at_c.grad.data.zero_()

                    # print(f'updated init_state: {init_state}')
                
                ## REORGANIZE FOR MULTIPLE CYCLES!!!!!!!!!!!!!

                # forward pass from t-H to t with new parameters 
                init_state = (at_h, at_c)
                state = (init_state[0], init_state[1])
                self.at_predictions = torch.tensor([]).to(self.device)
                for i in range(self.tuning_length):

                    bm = self.binder.scale_binding_matrix(
                        self.Bs[i], 
                        self.scale_mode, 
                        self.scale_combo, 
                        self.nxm_enhance, 
                        self.nxm_last_line_scale)
                    if self.nxm:
                        bm = bm[:-1]
                    x_B = self.binder.bind(self.at_inputs[i], bm)
                    x = self.preprocessor.convert_data_AT_to_LSTM(x_B)

                    state = (state[0] * state_scaler, state[1] * state_scaler)
                    upd_prediction, state = self.core_model(x, state)
                    self.at_predictions = torch.cat((self.at_predictions, upd_prediction.reshape(1,self.input_per_frame)), 0)
                    
                    # for last tuning cycle update initial state to track gradients 
                    if cycle==(self.tuning_cycles-1) and i==0: 
                        with torch.no_grad():
                            final_prediction = self.at_predictions[0].clone().detach().to(self.device)
                            final_input = x.clone().detach().to(self.device)

                        at_h = state[0].clone().detach().requires_grad_().to(self.device)
                        at_c = state[1].clone().detach().requires_grad_().to(self.device)
                        init_state = (at_h, at_c)
                        state = (init_state[0], init_state[1])

                    self.at_states[i] = state 

                # Update current binding
                bm = self.binder.scale_binding_matrix(
                    self.Bs[-1], 
                    self.scale_mode, 
                    self.scale_combo, 
                    self.nxm_enhance, 
                    self.nxm_last_line_scale)
                if self.nxm:
                    bm = bm[:-1]
                x_B = self.binder.bind(o, bm)
                x = self.preprocessor.convert_data_AT_to_LSTM(x_B)


            # END tuning cycle        

            ## Generate updated prediction 
            state = self.at_states[-1]
            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)

            ## Reorganize storage variables            
            # observations
            at_final_inputs = torch.cat(
                (at_final_inputs, 
                final_input.reshape(1,self.input_per_frame)), 0)
            self.at_inputs = torch.cat(
                (self.at_inputs[1:], 
                o.reshape(1, self.num_observations, self.num_input_dimensions)), 0)
            
            # predictions
            at_final_predictions = torch.cat(
                (at_final_predictions, 
                final_prediction.reshape(1,self.input_per_frame)), 0)
            self.at_predictions = torch.cat(
                (self.at_predictions[1:], 
                new_prediction.reshape(1,self.input_per_frame)), 0)

        # END active tuning
        
        # store rest of predictions in at_final_predictions
        for i in range(self.tuning_length): 
            at_final_predictions = torch.cat(
                (at_final_predictions, 
                self.at_predictions[i].reshape(1,self.input_per_frame)), 0)
            at_final_inputs = torch.cat(
                (at_final_inputs, 
                self.at_inputs[i].reshape(1,self.input_per_frame)), 0)

        # get final binding matrix
        final_binding_matrix = self.binder.scale_binding_matrix(self.Bs[-1].clone().detach(), self.scale_mode, self.scale_combo)
        # print(f'final binding matrix: {final_binding_matrix}')
        final_binding_entries = self.Bs[-1].clone().detach()
        # print(f'final binding entires: {final_binding_entries}')

        return at_final_inputs, at_final_predictions, final_binding_matrix, final_binding_entries




    ############################################################################
    ##########  EVALUATION #####################################################

    def get_result_history(
        self, 
        observations, 
        at_final_predictions):

        if self.nxm:
            pred_errors = self.evaluator.prediction_errors_nxm(
                observations, 
                self.additional_features, 
                self.num_observations, 
                at_final_predictions, 
                self.mse
            )
            self.bm_losses = torch.stack(self.bm_losses)

        else:
            pred_errors = self.evaluator.prediction_errors(
                observations, 
                at_final_predictions, 
                self.mse)

        return [pred_errors, self.at_losses, self.bm_dets, self.bm_losses]
class SEP_ROTATION():
    def __init__(self):
        ## General parameters
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        autograd.set_detect_anomaly(True)

        torch.set_printoptions(precision=8)

        ## Set default parameters
        ##   -> Can be changed during experiments

        self.rotation_type = 'qrotate'
        self.grad_bias = 1.5

    ############################################################################
    ##########  PARAMETERS  ####################################################

    def set_rotation_type(self, rotation):
        self.rotation_type = rotation
        print('Reset type of rotation: ' + self.rotation_type)

    def set_weighted_gradient_bias(self, bias):
        # bias > 1 => favor recent
        # bias < 1 => favor earlier
        self.grad_bias = bias
        print(f'Reset bias for gradient weighting: {self.grad_bias}')

    def set_data_parameters_(self, num_frames, num_observations,
                             num_input_features, num_input_dimesions):

        ## Define data parameters
        self.num_frames = num_frames
        self.num_observations = num_observations
        self.num_input_features = num_input_features
        self.num_input_dimensions = num_input_dimesions
        self.input_per_frame = self.num_input_features * self.num_input_dimensions

        self.perspective_taker = Perspective_Taker(
            self.num_observations,
            self.num_input_dimensions,
            rotation_gradient_init=True,
            translation_gradient_init=True)

        self.preprocessor = Preprocessor(self.num_observations,
                                         self.num_input_features,
                                         self.num_input_dimensions)

        self.evaluator = BAPTAT_evaluator(self.num_frames,
                                          self.num_observations,
                                          self.num_input_features,
                                          self.preprocessor)

    def set_tuning_parameters_(self, tuning_length, num_tuning_cycles,
                               loss_function, at_learning_rate_rotation,
                               at_learning_rate_state, at_momentum_rotation):

        ## Define tuning parameters
        self.tuning_length = tuning_length  # length of tuning horizon
        self.tuning_cycles = num_tuning_cycles  # number of tuning cycles in each iteration

        # possible loss functions
        self.at_loss = loss_function
        self.mse = nn.MSELoss()
        self.l1Loss = nn.L1Loss()
        self.smL1Loss = nn.SmoothL1Loss(reduction='sum')
        self.l2Loss = lambda x, y: self.mse(x, y) * (self.num_input_dimensions
                                                     * self.num_input_features)

        # define learning parameters
        self.at_learning_rate = at_learning_rate_rotation
        self.at_learning_rate_state = at_learning_rate_state
        self.r_momentum = at_momentum_rotation
        self.at_loss_function = self.mse

        print('Parameters set.')

    def init_model_(self, model_path):
        ## Load model
        self.core_model = CORE_NET(45, 150)
        self.core_model.load_state_dict(torch.load(model_path))
        self.core_model.eval()
        self.core_model.to(self.device)

        print('Model loaded.')

    def init_inference_tools(self):
        ## Define tuning variables
        # general
        self.obs_count = 0
        self.at_inputs = torch.tensor([]).to(self.device)
        self.at_predictions = torch.tensor([]).to(self.device)
        self.at_final_inputs = torch.tensor([]).to(self.device)
        self.at_final_predictions = torch.tensor([]).to(self.device)
        self.at_losses = []

        # state
        self.at_states = []
        # state_optimizer = torch.optim.Adam(init_state, at_learning_rate)

        # rotation
        self.Rs = []
        self.R_grads = [None] * (self.tuning_length + 1)
        self.R_upd = [None] * (self.tuning_length + 1)
        self.rm_losses = []
        self.rv_losses = []
        self.ra_losses = []

    def set_comparison_values(self, ideal_rotation_values,
                              ideal_rotation_matrix):
        self.identity_matrix = torch.Tensor(
            np.identity(self.num_input_dimensions))
        self.ideal_rotation = ideal_rotation_matrix.to(self.device)
        if self.rotation_type == 'qrotate':
            self.ideal_quat = ideal_rotation_values.to(self.device)
            self.ideal_angle = torch.rad2deg(
                self.perspective_taker.qeuler(self.ideal_quat,
                                              'xyz')).to(self.device)
        elif self.rotation_type == 'eulrotate':
            self.ideal_angle = ideal_rotation_values.to(self.device)
        else:
            print(
                f'ERROR: Received unknown rotation type!\n\trotation type: {self.rotation_type}'
            )
            exit()

    def set_ideal_euler_angles(self, angles):
        self.ideal_angles = angles.to(self.device)

    def set_ideal_quaternion(self, quaternion):
        self.ideal_quat = quaternion.to(self.device)

    def set_ideal_roation_matrix(self, matrix):
        self.ideal_rotation = matrix.to(self.device)

    ############################################################################
    ##########  INFERENCE  #####################################################

    def run_inference(self, observations, grad_calculation):

        at_final_predictions = torch.tensor([]).to(self.device)
        at_final_inputs = torch.tensor([]).to(self.device)

        if self.rotation_type == 'qrotate':
            ## Rotation quaternion
            rq = self.perspective_taker.init_quaternion()
            # print(rq)

            for i in range(self.tuning_length + 1):
                quat = rq.clone().to(self.device)
                quat.requires_grad_()
                self.Rs.append(quat)

        elif self.rotation_type == 'eulrotate':
            ## Rotation euler angles
            # ra = perspective_taker.init_angles_()
            # ra = torch.Tensor([[309.89], [82.234], [95.765]])
            ra = torch.Tensor([[75.0], [6.0], [128.0]])
            print(ra)

            for i in range(self.tuning_length + 1):
                angles = []
                for j in range(self.num_input_dimensions):
                    angle = ra[j].clone()
                    angle.requires_grad_()
                    angles.append(angle)
                self.Rs.append(angles)

        else:
            print('ERROR: Received unknown rotation type!')
            exit()

        ## Core state
        # define scaler
        state_scaler = 0.95

        # init state
        at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device)
        at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device)

        at_h.requires_grad = True
        at_c.requires_grad = True

        init_state = (at_h, at_c)
        state = (init_state[0], init_state[1])

        ############################################################################
        ##########  FORWARD PASS  ##################################################

        for i in range(self.tuning_length):
            o = observations[self.obs_count].to(self.device)
            self.at_inputs = torch.cat((self.at_inputs,
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)
            self.obs_count += 1

            if self.rotation_type == 'qrotate':
                x_R = self.perspective_taker.qrotate(o, self.Rs[i])
            else:
                rotmat = self.perspective_taker.compute_rotation_matrix_(
                    self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                x_R = self.perspective_taker.rotate(o, rotmat)

            x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)
            self.at_states.append(state)
            self.at_predictions = torch.cat(
                (self.at_predictions,
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        ############################################################################
        ##########  ACTIVE TUNING ##################################################

        while self.obs_count < self.num_frames:
            # TODO folgendes evtl in function auslagern
            o = observations[self.obs_count].to(self.device)
            self.obs_count += 1

            if self.rotation_type == 'qrotate':
                x_R = self.perspective_taker.qrotate(o, self.Rs[-1])

            else:
                rotmat = self.perspective_taker.compute_rotation_matrix_(
                    self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                x_R = self.perspective_taker.rotate(o, rotmat)

            x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            ## Generate current prediction
            with torch.no_grad():
                state = self.at_states[-1]
                state = (state[0] * state_scaler, state[1] * state_scaler)
                new_prediction, state = self.core_model(x, state)

            ## For #tuning_cycles
            for cycle in range(self.tuning_cycles):
                print('----------------------------------------------')

                # Get prediction
                p = self.at_predictions[-1]

                # Calculate error
                loss = self.at_loss(p, x[0])

                # Propagate error back through tuning horizon
                loss.backward(retain_graph=True)

                self.at_losses.append(loss.clone().detach().cpu().numpy())
                print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}')

                # Update parameters
                with torch.no_grad():

                    ## get gradients
                    if self.rotation_type == 'qrotate':
                        for i in range(self.tuning_length + 1):
                            # save grads for all parameters in all time steps of tuning horizon
                            self.R_grads[i] = self.Rs[i].grad
                    else:
                        for i in range(self.tuning_length + 1):
                            # save grads for all parameters in all time steps of tuning horizon
                            grad = []
                            for j in range(self.num_input_dimensions):
                                grad.append(self.Rs[i][j].grad)
                            self.R_grads[i] = torch.stack(grad)

                    # print(self.R_grads[self.tuning_length])

                    # Calculate overall gradients
                    if grad_calculation == 'lastOfTunHor':
                        ### version 1
                        grad_R = self.R_grads[0]
                    elif grad_calculation == 'meanOfTunHor':
                        ### version 2 / 3
                        grad_R = torch.mean(torch.stack(self.R_grads), dim=0)
                    elif grad_calculation == 'weightedInTunHor':
                        ### version 4
                        weighted_grads_R = [None] * (self.tuning_length + 1)
                        for i in range(self.tuning_length + 1):
                            weighted_grads_R[i] = np.power(self.grad_bias,
                                                           i) * self.R_grads[i]
                        grad_R = torch.mean(torch.stack(weighted_grads_R),
                                            dim=0)

                    # print(f'grad_R: {grad_R}')

                    grad_R = grad_R.to(self.device)
                    if self.rotation_type == 'qrotate':
                        # Update parameters in time step t-H with saved gradients
                        upd_R = self.perspective_taker.update_quaternion(
                            self.Rs[0], grad_R, self.at_learning_rate,
                            self.r_momentum)
                        print(f'updated quaternion: {upd_R}')

                        # Compare quaternion values
                        # quat_loss = torch.sum(self.perspective_taker.qmul(self.ideal_quat, upd_R))
                        quat_loss = 2 * torch.arccos(
                            torch.abs(
                                torch.sum(torch.mul(self.ideal_quat, upd_R))))
                        quat_loss = torch.rad2deg(quat_loss)
                        print(f'loss of quaternion: {quat_loss}')
                        self.rv_losses.append(quat_loss)

                        # Compare quaternion angles
                        ang = torch.rad2deg(
                            self.perspective_taker.qeuler(upd_R, 'zyx'))
                        ang_diff = ang - self.ideal_angle
                        ang_loss = 2 - (torch.cos(torch.deg2rad(ang_diff)) + 1)
                        print(
                            f'loss of quaternion angles: {ang_loss} \nwith norm: {torch.norm(ang_loss)}'
                        )
                        self.ra_losses.append(torch.norm(ang_loss))

                        # Compute rotation matrix
                        rotmat = self.perspective_taker.quaternion2rotmat(
                            upd_R)

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            self.Rs[i].requires_grad = False
                            self.Rs[i].grad.data.zero_()

                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            quat = upd_R.clone()
                            quat.requires_grad_()
                            self.Rs[i] = quat

                    else:
                        # Update parameters in time step t-H with saved gradients
                        upd_R = self.perspective_taker.update_rotation_angles_(
                            self.Rs[0], grad_R, self.at_learning_rate)
                        print(f'updated angles: {upd_R}')

                        # Save rotation angles
                        rotang = torch.stack(upd_R)
                        # angles:
                        ang_diff = rotang - self.ideal_angle
                        ang_loss = 2 - (torch.cos(torch.deg2rad(ang_diff)) + 1)
                        print(
                            f'loss of rotation angles: \n  {ang_loss}, \n  with norm {torch.norm(ang_loss)}'
                        )
                        self.rv_losses.append(torch.norm(ang_loss))
                        # Compute rotation matrix
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            upd_R[0], upd_R[1], upd_R[2])[0]

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            for j in range(self.num_input_dimensions):
                                self.Rs[i][j].requires_grad = False
                                self.Rs[i][j].grad.data.zero_()

                        # print(Rs[0])
                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            angles = []
                            for j in range(3):
                                angle = upd_R[j].clone()
                                angle.requires_grad_()
                                angles.append(angle)
                            self.Rs[i] = angles
                        # print(Rs[0])

                    # Calculate and save rotation losses
                    # matrix:
                    # mat_loss = self.mse(
                    #     (torch.mm(self.ideal_rotation, torch.transpose(rotmat, 0, 1))),
                    #     self.identity_matrix
                    # )
                    dif_R = torch.mm(self.ideal_rotation,
                                     torch.transpose(rotmat, 0, 1))
                    mat_loss = torch.arccos(0.5 * (torch.trace(dif_R) - 1))
                    mat_loss = torch.rad2deg(mat_loss)

                    print(f'loss of rotation matrix: {mat_loss}')
                    self.rm_losses.append(mat_loss)

                    # print(Rs[0])

                    # Initial state
                    g_h = at_h.grad.to(self.device)
                    g_c = at_c.grad.to(self.device)

                    upd_h = init_state[0] - self.at_learning_rate_state * g_h
                    upd_c = init_state[1] - self.at_learning_rate_state * g_c

                    at_h.data = upd_h.clone().detach().requires_grad_()
                    at_c.data = upd_c.clone().detach().requires_grad_()

                    at_h.grad.data.zero_()
                    at_c.grad.data.zero_()

                    # state_optimizer.step()
                    # print(f'updated init_state: {init_state}')

                ## REORGANIZE FOR MULTIPLE CYCLES!!!!!!!!!!!!!

                # forward pass from t-H to t with new parameters
                # Update init state???
                init_state = (at_h, at_c)
                state = (init_state[0], init_state[1])
                self.at_predictions = torch.tensor([]).to(self.device)
                for i in range(self.tuning_length):

                    if self.rotation_type == 'qrotate':
                        x_R = self.perspective_taker.qrotate(
                            self.at_inputs[i], self.Rs[i])
                    else:
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                        x_R = self.perspective_taker.rotate(
                            self.at_inputs[i], rotmat)

                    x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

                    state = (state[0] * state_scaler, state[1] * state_scaler)
                    upd_prediction, state = self.core_model(x, state)
                    self.at_predictions = torch.cat(
                        (self.at_predictions,
                         upd_prediction.reshape(1, self.input_per_frame)), 0)

                    # for last tuning cycle update initial state to track gradients
                    if cycle == (self.tuning_cycles - 1) and i == 0:
                        with torch.no_grad():
                            final_prediction = self.at_predictions[0].clone(
                            ).detach().to(self.device)
                            final_input = x.clone().detach().to(self.device)

                        at_h = state[0].clone().detach().requires_grad_()
                        at_c = state[1].clone().detach().requires_grad_()
                        init_state = (at_h, at_c)
                        state = (init_state[0], init_state[1])

                    self.at_states[i] = state

                # Update current rotation
                if self.rotation_type == 'qrotate':
                    x_R = self.perspective_taker.qrotate(o, self.Rs[-1])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                    x_R = self.perspective_taker.rotate(o, rotmat)

                x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            # END tuning cycle

            ## Generate updated prediction
            state = self.at_states[-1]
            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)

            ## Reorganize storage variables
            # observations
            at_final_inputs = torch.cat(
                (at_final_inputs, final_input.reshape(
                    1, self.input_per_frame)), 0)
            self.at_inputs = torch.cat((self.at_inputs[1:],
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)

            # predictions
            at_final_predictions = torch.cat(
                (at_final_predictions,
                 final_prediction.reshape(1, self.input_per_frame)), 0)
            self.at_predictions = torch.cat(
                (self.at_predictions[1:],
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        # END active tuning

        # store rest of predictions in at_final_predictions
        for i in range(self.tuning_length):
            at_final_predictions = torch.cat(
                (at_final_predictions, self.at_predictions[i].reshape(
                    1, self.input_per_frame)), 0)
            if self.rotation_type == 'qrotate':
                x_i = self.perspective_taker.qrotate(self.at_inputs[i],
                                                     self.Rs[-1])
            else:
                x_i = self.perspective_taker.rotate(self.at_inputs[i], rotmat)
            at_final_inputs = torch.cat(
                (at_final_inputs, x_i.reshape(1, self.input_per_frame)), 0)

        # get final rotation matrix
        if self.rotation_type == 'qrotate':
            final_rotation_values = self.Rs[0].clone().detach()
            # get final quaternion
            print(f'final quaternion: {final_rotation_values}')
            final_rotation_matrix = self.perspective_taker.quaternion2rotmat(
                final_rotation_values)

        else:
            final_rotation_values = [
                self.Rs[0][i].clone().detach()
                for i in range(self.num_input_dimensions)
            ]
            print(f'final euler angles: {final_rotation_values}')
            final_rotation_matrix = self.perspective_taker.compute_rotation_matrix_(
                final_rotation_values[0], final_rotation_values[1],
                final_rotation_values[2])

        print(f'final rotation matrix: \n{final_rotation_matrix}')

        return at_final_inputs, at_final_predictions, final_rotation_values, final_rotation_matrix

    ############################################################################
    ##########  EVALUATION #####################################################

    def get_result_history(self, observations, at_final_predictions):

        pred_errors = self.evaluator.prediction_errors(observations,
                                                       at_final_predictions,
                                                       self.mse)
        print([
            pred_errors, self.at_losses, self.rm_losses, self.rv_losses,
            self.ra_losses
        ])

        return [
            pred_errors, self.at_losses, self.rm_losses, self.rv_losses,
            self.ra_losses
        ]
예제 #6
0
class SEP_BINDING():
    def __init__(self):
        ## General parameters
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        autograd.set_detect_anomaly(True)

        torch.set_printoptions(precision=8)

    ############################################################################
    ##########  PARAMETERS  ####################################################

    def set_data_parameters(self):
        ## Define data parameters
        self.num_frames = 20
        self.num_input_features = 15
        self.num_input_dimensions = 3
        self.preprocessor = Preprocessor(self.num_input_features,
                                         self.num_input_dimensions)
        self.evaluator = BAPTAT_evaluator(self.num_frames,
                                          self.num_input_features,
                                          self.preprocessor)
        self.data_at_unlike_train = False  ## Note: sample needs to be changed in the future

        # data paths
        self.data_asf_path = 'Data_Compiler/S35T07.asf'
        self.data_amc_path = 'Data_Compiler/S35T07.amc'

    def set_data_parameters_(self, num_frames, num_input_features,
                             num_input_dimesions):
        ## Define data parameters
        self.num_frames = num_frames
        self.num_input_features = num_input_features
        self.num_input_dimensions = num_input_dimesions
        self.preprocessor = Preprocessor(self.num_input_features,
                                         self.num_input_dimensions)
        self.evaluator = BAPTAT_evaluator(self.num_frames,
                                          self.num_input_features,
                                          self.preprocessor)

    def set_model_parameters(self):
        ## Define model parameters
        self.model_path = 'CoreLSTM/models/LSTM_46_cell.pt'

    def set_tuning_parameters(self):
        ## Define tuning parameters
        self.tuning_length = 10  # length of tuning horizon
        self.tuning_cycles = 3  # number of tuning cycles in each iteration

        # possible loss functions
        self.mse = nn.MSELoss()
        self.l1Loss = nn.L1Loss()
        # smL1Loss = nn.SmoothL1Loss()
        self.smL1Loss = nn.SmoothL1Loss(reduction='sum')
        # smL1Loss = nn.SmoothL1Loss(beta=2)
        # smL1Loss = nn.SmoothL1Loss(beta=0.5)
        # smL1Loss = nn.SmoothL1Loss(reduction='sum', beta=0.5)
        self.l2Loss = lambda x, y: self.mse(x, y) * (self.num_input_dimensions
                                                     * self.num_input_features)

        self.at_loss = self.smL1Loss

        # define learning parameters
        self.at_learning_rate = 1
        self.at_learning_rate_state = 0.0
        self.bm_momentum = 0.0
        self.at_loss_function = self.mse

    def set_tuning_parameters_(self, tuning_length, num_tuning_cycles,
                               loss_function, at_learning_rate_binding,
                               at_learning_rate_state, at_momentum_binding):
        ## Define tuning parameters
        self.tuning_length = tuning_length  # length of tuning horizon
        self.tuning_cycles = num_tuning_cycles  # number of tuning cycles in each iteration

        # possible loss functions
        self.at_loss = loss_function
        self.mse = nn.MSELoss()
        self.l1Loss = nn.L1Loss()
        self.smL1Loss = nn.SmoothL1Loss(reduction='sum')
        self.l2Loss = lambda x, y: self.mse(x, y) * (self.num_input_dimensions
                                                     * self.num_input_features)

        # define learning parameters
        self.at_learning_rate = at_learning_rate_binding
        self.at_learning_rate_state = at_learning_rate_state
        self.bm_momentum = at_momentum_binding
        self.at_loss_function = self.mse

        print('Parameters set.')

    def init_inference_tools(self):
        ## Define tuning variables
        # general
        self.obs_count = 0
        self.at_inputs = torch.tensor([]).to(self.device)
        self.at_predictions = torch.tensor([]).to(self.device)
        self.at_final_predictions = torch.tensor([]).to(self.device)
        self.at_losses = []

        # state
        self.at_states = []
        # state_optimizer = torch.optim.Adam(init_state, at_learning_rate)

        # binding
        self.binder = BinderExMat(num_features=self.num_input_features,
                                  gradient_init=True)
        self.ideal_binding = torch.Tensor(np.identity(
            self.num_input_features)).to(self.device)

        self.Bs = []
        self.B_grads = [None] * (self.tuning_length + 1)
        self.B_upd = [None] * (self.tuning_length + 1)
        self.bm_losses = []
        self.bm_dets = []

    def set_comparison_values(self, ideal_binding):
        self.ideal_binding = ideal_binding

    ############################################################################
    ##########  INITIALIZATIONS  ###############################################

    def load_data(self):
        ## Load data
        observations, feature_names = self.preprocessor.get_AT_data(
            self.data_asf_path, self.data_amc_path, self.num_frames)
        return observations, feature_names

    def init_model(self):
        ## Load model
        self.core_model = CORE_NET()
        self.core_model.load_state_dict(torch.load(self.model_path))
        self.core_model.eval()

    def init_model_(self, model_path):
        ## Load model
        self.core_model = CORE_NET()
        self.core_model.load_state_dict(torch.load(model_path))
        self.core_model.eval()
        # print('\nModel loaded:')
        # print(self.core_model)

    def prepare_inference(self):
        self.set_data_parameters()
        self.set_model_parameters()
        self.set_tuning_parameters()
        self.init_inference_tools()
        self.init_model()
        print(
            'Ready to run AT inference for binding task! \nInitialized parameters with: \n'
            + f' - number of features: \t\t{self.num_input_features}\n' +
            f' - number of dimensions: \t{self.num_input_dimensions}\n' +
            f' - number of tuning cycles: \t{self.tuning_cycles}\n' +
            f' - size of tuning horizon: \t{self.tuning_length}\n' +
            f' - learning rate: \t\t{self.at_learning_rate}\n' +
            f' - learning rate (state): \t{self.at_learning_rate_state}\n' +
            f' - momentum: \t\t\t{self.bm_momentum}\n' +
            f' - model: \t\t\t{self.model_path}\n' +
            f' - number of features: \t\t{self.num_input_features}\n')

    def run_inference(self, observations, order, reorder):

        at_final_predictions = torch.tensor([]).to(self.device)

        ## Binding matrices
        # Init binding entries
        bm = self.binder.init_binding_matrix_det_()
        # bm = binder.init_binding_matrix_rand_()
        print(bm)

        for i in range(self.tuning_length + 1):
            matrix = bm.clone()
            matrix.requires_grad_()
            self.Bs.append(matrix)

        print(f'BMs different in list: {self.Bs[0] is not self.Bs[1]}')

        ## Core state
        # define scaler
        state_scaler = 0.95

        # init state
        at_h = torch.zeros(1, self.core_model.hidden_size).requires_grad_().to(
            self.device)
        at_c = torch.zeros(1, self.core_model.hidden_size).requires_grad_().to(
            self.device)

        init_state = (at_h, at_c)
        self.at_states.append(init_state)
        state = (init_state[0], init_state[1])

        ############################################################################
        ##########  FORWARD PASS  ##################################################

        for i in range(self.tuning_length):
            o = observations[self.obs_count]
            self.at_inputs = torch.cat((self.at_inputs,
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)
            self.obs_count += 1

            bm = self.binder.scale_binding_matrix(self.Bs[i])
            x_B = self.binder.bind(o, bm)
            x = self.preprocessor.convert_data_AT_to_LSTM(x_B)

            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)
            self.at_states.append(state)
            self.at_predictions = torch.cat(
                (self.at_predictions, new_prediction.reshape(1, 45)), 0)

        ############################################################################
        ##########  ACTIVE TUNING ##################################################

        while self.obs_count < self.num_frames:
            # TODO folgendes evtl in function auslagern
            o = observations[self.obs_count]
            self.obs_count += 1

            bm = self.binder.scale_binding_matrix(self.Bs[-1])
            x_B = self.binder.bind(o, bm)

            x = self.preprocessor.convert_data_AT_to_LSTM(x_B)

            ## Generate current prediction
            with torch.no_grad():
                state = (state[0] * state_scaler, state[1] * state_scaler)
                new_prediction, state_new = self.core_model(
                    x, self.at_states[-1])

            ## For #tuning_cycles
            for cycle in range(self.tuning_cycles):
                print('----------------------------------------------')

                # Get prediction
                p = self.at_predictions[-1]

                # Calculate error
                # lam = 10
                # loss = at_loss_function(p, x[0]) + l1Loss(p,x[0]) + lam / torch.norm(torch.Tensor(Bs[0].copy()))
                # loss = at_loss_function(p, x[0]) + mse(p, x[0])
                # loss = l1Loss(p,x[0]) + l2Loss(p,x[0])
                # loss_scale = torch.square(torch.mean(torch.norm(torch.tensor(Bs[-1]), dim=1, keepdim=True)) -1.) ##COPY?????
                # loss_scale = torch.square(torch.mean(torch.norm(bm.clone().detach(), dim=1, keepdim=True)) -1.) ##COPY?????
                # -> länge der Vektoren
                # print(f'loss scale: {loss_scale}')
                # loss_scale_factor = 0.9
                # l1scale = loss_scale_factor * loss_scale
                # l2scale = loss_scale_factor / loss_scale
                # loss = l1Loss(p,x[0]) + l2scale * l2Loss(p,x[0])
                # loss = l1scale * mse(p,x[0]) + l2scale * l2Loss(p,x[0])
                # loss = l2Loss(p,x[0]) + mse(p,x[0])
                # loss = l2Loss(p,x[0]) + loss_scale * mse(p,x[0])
                # loss = loss_scale_factor * loss_scale * l2Loss(p,x[0]) + mse(p,x[0])
                # loss = loss_scale_factor * loss_scale * l2Loss(p,x[0])
                # loss = loss_scale_factor * loss_scale * mse(p,x[0])
                # loss = self.smL1Loss(p, x[0])
                loss = self.at_loss(p, x[0])

                self.at_losses.append(loss.clone().detach().numpy())
                print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}')

                # Propagate error back through tuning horizon
                loss.backward(retain_graph=True)

                # Update parameters
                with torch.no_grad():

                    # Calculate gradients with respect to the entires
                    for i in range(self.tuning_length + 1):
                        self.B_grads[i] = self.Bs[i].grad

                    # print(B_grads[tuning_length])

                    # Calculate overall gradients
                    ### version 1
                    # grad_B = B_grads[0]
                    ### version 2 / 3
                    # grad_B = torch.mean(torch.stack(B_grads), 0)
                    ### version 4
                    # # # # bias > 1 => favor recent
                    # # # # bias < 1 => favor earlier
                    bias = 1.5
                    weighted_grads_B = [None] * (self.tuning_length + 1)
                    for i in range(self.tuning_length + 1):
                        weighted_grads_B[i] = np.power(bias,
                                                       i) * self.B_grads[i]
                    grad_B = torch.mean(torch.stack(weighted_grads_B), dim=0)

                    # print(f'grad_B: {grad_B}')
                    # print(f'grad_B: {torch.norm(grad_B, 1)}')

                    # Update parameters in time step t-H with saved gradients
                    upd_B = self.binder.update_binding_matrix_(
                        self.Bs[0], grad_B, self.at_learning_rate,
                        self.bm_momentum)

                    # Compare binding matrix to ideal matrix
                    c_bm = self.binder.scale_binding_matrix(upd_B)
                    if order is not None:
                        c_bm = c_bm.gather(
                            1,
                            reorder.unsqueeze(0).expand(c_bm.shape))

                    mat_loss = self.evaluator.FBE(c_bm, self.ideal_binding)
                    self.bm_losses.append(mat_loss)
                    print(f'loss of binding matrix (FBE): {mat_loss}')

                    # Compute determinante of binding matrix
                    det = torch.det(c_bm)
                    self.bm_dets.append(det)
                    print(f'determinante of binding matrix: {det}')

                    # Zero out gradients for all parameters in all time steps of tuning horizon
                    for i in range(self.tuning_length + 1):
                        self.Bs[i].requires_grad = False
                        self.Bs[i].grad.data.zero_()

                    # Update all parameters for all time steps
                    for i in range(self.tuning_length + 1):
                        self.Bs[i].data = upd_B.clone().data
                        self.Bs[i].requires_grad = True

                    # print(Bs[0])

                    # Initial state
                    g_h = at_h.grad
                    g_c = at_c.grad

                    upd_h = self.at_states[0][
                        0] - self.at_learning_rate_state * g_h
                    upd_c = self.at_states[0][
                        1] - self.at_learning_rate_state * g_c

                    at_h.data = upd_h.clone().detach().requires_grad_()
                    at_c.data = upd_c.clone().detach().requires_grad_()

                    at_h.grad.data.zero_()
                    at_c.grad.data.zero_()

                    # state_optimizer.step()
                    # print(f'updated init_state: {init_state}')

                ## REORGANIZE FOR MULTIPLE CYCLES!!!!!!!!!!!!!

                # forward pass from t-H to t with new parameters
                init_state = (at_h, at_c)
                state = (init_state[0], init_state[1])
                for i in range(self.tuning_length):

                    bm = self.binder.scale_binding_matrix(self.Bs[i])
                    x_B = self.binder.bind(self.at_inputs[i], bm)
                    x = self.preprocessor.convert_data_AT_to_LSTM(x_B)

                    # print(f'x_B :{x_B}')

                    state = (state[0] * state_scaler, state[1] * state_scaler)
                    self.at_predictions[i], state = self.core_model(x, state)

                    # for last tuning cycle update initial state to track gradients
                    if cycle == (self.tuning_cycles - 1) and i == 0:
                        at_h = state[0].clone().detach().requires_grad_().to(
                            self.device)
                        at_c = state[1].clone().detach().requires_grad_().to(
                            self.device)
                        init_state = (at_h, at_c)
                        state = (init_state[0], init_state[1])

                    self.at_states[i + 1] = state

                # Update current binding
                bm = self.binder.scale_binding_matrix(self.Bs[-1])
                x_B = self.binder.bind(o, bm)
                x = self.preprocessor.convert_data_AT_to_LSTM(x_B)

            # END tuning cycle

            ## Generate updated prediction
            state = self.at_states[-1]
            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)

            ## Reorganize storage variables
            # states
            self.at_states.append(state)
            self.at_states[0][0].requires_grad = False
            self.at_states[0][1].requires_grad = False
            self.at_states = self.at_states[1:]

            # observations
            self.at_inputs = torch.cat((self.at_inputs[1:],
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)

            # predictions
            at_final_predictions = torch.cat(
                (at_final_predictions, self.at_predictions[0].detach().reshape(
                    1, 45)), 0)
            self.at_predictions = torch.cat(
                (self.at_predictions[1:], new_prediction.reshape(1, 45)), 0)

        # END active tuning

        # store rest of predictions in at_final_predictions
        for i in range(self.tuning_length):
            at_final_predictions = torch.cat(
                (at_final_predictions, self.at_predictions[1].reshape(1, 45)),
                0)

        # get final binding matrix
        final_binding_matrix = self.binder.scale_binding_matrix(
            self.Bs[-1].clone().detach())
        print(f'final binding matrix: {final_binding_matrix}')
        final_binding_entries = self.Bs[-1].clone().detach()
        print(f'final binding entires: {final_binding_entries}')

        return at_final_predictions, final_binding_matrix, final_binding_entries

    ############################################################################
    ##########  EVALUATION #####################################################

    def evaluate(self, observations, at_final_predictions, feature_names,
                 final_binding_matrix, final_binding_entries):

        pred_errors = self.evaluator.prediction_errors(observations,
                                                       at_final_predictions,
                                                       self.at_loss_function)

        self.evaluator.plot_prediction_errors(pred_errors)
        self.evaluator.plot_at_losses(
            self.at_losses, 'History of overall losses during active tuning')
        self.evaluator.plot_at_losses(self.bm_losses,
                                      'History of binding matrix loss (FBE)')
        self.evaluator.plot_at_losses(
            self.bm_dets, 'History of binding matrix determinante')

        self.evaluator.plot_binding_matrix(
            final_binding_matrix, feature_names,
            'Binding matrix showing relative contribution of observed feature to input feature'
        )
        self.evaluator.plot_binding_matrix(
            final_binding_entries, feature_names,
            'Binding matrix entries showing contribution of observed feature to input feature'
        )

        # evaluator.help_visualize_devel(observations, at_final_predictions)

    def get_result_history(self, observations, at_final_predictions):

        pred_errors = self.evaluator.prediction_errors(observations,
                                                       at_final_predictions,
                                                       self.mse)

        return [pred_errors, self.at_losses, self.bm_losses, self.bm_dets]
class SEP_TRANSLATION():
    def __init__(self):
        ## General parameters
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        autograd.set_detect_anomaly(True)

        torch.set_printoptions(precision=8)

        ## Set default parameters
        ##   -> Can be changed during experiments

        self.grad_bias = 1.5

    ############################################################################
    ##########  PARAMETERS  ####################################################

    def set_weighted_gradient_bias(self, bias):
        # bias > 1 => favor recent
        # bias < 1 => favor earlier
        self.grad_bias = bias
        print(f'Reset bias for gradient weighting: {self.grad_bias}')

    def set_data_parameters_(self, num_frames, num_observations,
                             num_input_features, num_input_dimesions):

        ## Define data parameters
        self.num_frames = num_frames
        self.num_observations = num_observations
        self.num_input_features = num_input_features
        self.num_input_dimensions = num_input_dimesions
        self.input_per_frame = self.num_input_features * self.num_input_dimensions

        self.perspective_taker = Perspective_Taker(
            self.num_observations,
            self.num_input_dimensions,
            rotation_gradient_init=True,
            translation_gradient_init=True)

        self.preprocessor = Preprocessor(self.num_observations,
                                         self.num_input_features,
                                         self.num_input_dimensions)

        self.evaluator = BAPTAT_evaluator(self.num_frames,
                                          self.num_observations,
                                          self.num_input_features,
                                          self.preprocessor)

    def set_tuning_parameters_(self, tuning_length, num_tuning_cycles,
                               loss_function, at_learning_rate_translation,
                               at_learning_rate_state,
                               at_momentum_translation):

        ## Define tuning parameters
        self.tuning_length = tuning_length  # length of tuning horizon
        self.tuning_cycles = num_tuning_cycles  # number of tuning cycles in each iteration

        # possible loss functions
        self.at_loss = loss_function
        self.mse = nn.MSELoss()
        self.l1Loss = nn.L1Loss()
        self.smL1Loss = nn.SmoothL1Loss(reduction='sum')
        self.l2Loss = lambda x, y: self.mse(x, y) * (self.num_input_dimensions
                                                     * self.num_input_features)

        # define learning parameters
        self.at_learning_rate = at_learning_rate_translation
        self.at_learning_rate_state = at_learning_rate_state
        self.c_momentum = at_momentum_translation
        self.at_loss_function = self.mse

        print('Parameters set.')

    def init_model_(self, model_path):
        ## Load model
        self.core_model = CORE_NET()
        self.core_model.load_state_dict(torch.load(model_path))
        self.core_model.eval()
        self.core_model.to(self.device)

        print('Model loaded.')

    def init_inference_tools(self):
        ## Define tuning variables
        # general
        self.obs_count = 0
        self.at_inputs = torch.tensor([]).to(self.device)
        self.at_predictions = torch.tensor([]).to(self.device)
        self.at_final_predictions = torch.tensor([]).to(self.device)
        self.at_losses = []

        # state
        self.at_states = []

        # translation
        self.Cs = []
        self.C_grads = [None] * (self.tuning_length + 1)
        self.C_upd = [None] * (self.tuning_length + 1)
        self.c_losses = []

    def set_comparison_values(self, ideal_translation):
        self.ideal_translation = ideal_translation.to(self.device)

    ############################################################################
    ##########  INFERENCE  #####################################################

    def run_inference(self, observations, grad_calculation):

        at_final_predictions = torch.tensor([]).to(self.device)
        at_final_inputs = torch.tensor([]).to(self.device)

        tb = self.perspective_taker.init_translation_bias_()
        # print(tb)

        for i in range(self.tuning_length + 1):
            transba = tb.clone().to(self.device)
            transba.requires_grad = True
            self.Cs.append(transba)

        ## Core state
        # define scaler
        state_scaler = 0.95

        # init state
        at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device)
        at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device)

        at_h.requires_grad = True
        at_c.requires_grad = True

        init_state = (at_h, at_c)
        state = (init_state[0], init_state[1])

        ############################################################################
        ##########  FORWARD PASS  ##################################################

        for i in range(self.tuning_length):
            o = observations[self.obs_count].to(self.device)
            self.at_inputs = torch.cat((self.at_inputs,
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)
            self.obs_count += 1

            x_C = self.perspective_taker.translate(o, self.Cs[i])
            x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)
            self.at_states.append(state)
            self.at_predictions = torch.cat(
                (self.at_predictions,
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        ############################################################################
        ##########  ACTIVE TUNING ##################################################

        while self.obs_count < self.num_frames:
            # TODO folgendes evtl in function auslagern
            o = observations[self.obs_count].to(self.device)
            self.obs_count += 1

            x_C = self.perspective_taker.translate(o, self.Cs[-1])
            x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            ## Generate current prediction
            with torch.no_grad():
                state = self.at_states[-1]
                state = (state[0] * state_scaler, state[1] * state_scaler)
                new_prediction, state = self.core_model(x, state)

            ## For #tuning_cycles
            for cycle in range(self.tuning_cycles):
                print('----------------------------------------------')

                # Get prediction
                p = self.at_predictions[-1]

                # Calculate error
                loss = self.at_loss(p, x[0])

                # Propagate error back through tuning horizon
                loss.backward(retain_graph=True)

                self.at_losses.append(loss.clone().detach().cpu().numpy())
                print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}')

                # Update parameters
                with torch.no_grad():

                    ## Get gradients
                    for i in range(self.tuning_length + 1):
                        # save grads for all parameters in all time steps of tuning horizon
                        self.C_grads[i] = self.Cs[i].grad

                    # print(self.C_grads[self.tuning_length])

                    # Calculate overall gradients
                    if grad_calculation == 'lastOfTunHor':
                        ### version 1
                        grad_C = self.C_grads[0]
                    elif grad_calculation == 'meanOfTunHor':
                        ### version 2 / 3
                        grad_C = torch.mean(torch.stack(self.C_grads), dim=0)
                    elif grad_calculation == 'weightedInTunHor':
                        ### version 4
                        weighted_grads_C = [None] * (self.tuning_length + 1)
                        for i in range(self.tuning_length + 1):
                            weighted_grads_C[i] = np.power(self.grad_bias,
                                                           i) * self.C_grads[i]
                        grad_C = torch.mean(torch.stack(weighted_grads_C),
                                            dim=0)

                    # print(f'grad_C: {grad_C}')

                    # Update parameters in time step t-H with saved gradients
                    grad_C = grad_C.to(self.device)
                    upd_C = self.perspective_taker.update_translation_bias_(
                        self.Cs[0], grad_C, self.at_learning_rate,
                        self.c_momentum)

                    # print(upd_C)
                    # Compare translation bias to ideal bias
                    trans_loss = self.mse(self.ideal_translation, upd_C)
                    self.c_losses.append(trans_loss)
                    print(f'loss of translation bias (MSE): {trans_loss}')

                    # Zero out gradients for all parameters in all time steps of tuning horizon
                    for i in range(self.tuning_length + 1):
                        self.Cs[i].requires_grad = False
                        self.Cs[i].grad.data.zero_()

                    # Update all parameters for all time steps
                    for i in range(self.tuning_length + 1):
                        translation = upd_C.clone()
                        translation.requires_grad_()
                        self.Cs[i] = translation

                    # print(self.Cs[0])

                    # Initial state
                    g_h = at_h.grad.to(self.device)
                    g_c = at_c.grad.to(self.device)

                    upd_h = init_state[0] - self.at_learning_rate_state * g_h
                    upd_c = init_state[1] - self.at_learning_rate_state * g_c

                    at_h.data = upd_h.clone().detach().requires_grad_()
                    at_c.data = upd_c.clone().detach().requires_grad_()

                    at_h.grad.data.zero_()
                    at_c.grad.data.zero_()

                    # state_optimizer.step()
                    # print(f'updated init_state: {init_state}')

                ## REORGANIZE FOR MULTIPLE CYCLES!!!!!!!!!!!!!

                # forward pass from t-H to t with new parameters
                # Update init state???
                init_state = (at_h, at_c)
                state = (init_state[0], init_state[1])
                self.at_predictions = torch.tensor([]).to(self.device)
                for i in range(self.tuning_length):

                    x_C = self.perspective_taker.translate(
                        self.at_inputs[i], self.Cs[i])
                    x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

                    state = (state[0] * state_scaler, state[1] * state_scaler)
                    upd_prediction, state = self.core_model(x, state)
                    self.at_predictions = torch.cat(
                        (self.at_predictions,
                         upd_prediction.reshape(1, self.input_per_frame)), 0)

                    # for last tuning cycle update initial state to track gradients
                    if cycle == (self.tuning_cycles - 1) and i == 0:
                        with torch.no_grad():
                            final_prediction = self.at_predictions[0].clone(
                            ).detach().to(self.device)
                            final_input = x.clone().detach().to(self.device)

                        at_h = state[0].clone().detach().requires_grad_()
                        at_c = state[1].clone().detach().requires_grad_()
                        init_state = (at_h, at_c)
                        state = (init_state[0], init_state[1])

                    self.at_states[i] = state

                # Update current rotation
                x_C = self.perspective_taker.translate(o, self.Cs[-1])
                x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            # END tuning cycle

            ## Generate updated prediction
            state = self.at_states[-1]
            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)

            ## Reorganize storage variables
            # observations
            at_final_inputs = torch.cat(
                (at_final_inputs, final_input.reshape(
                    1, self.input_per_frame)), 0)
            self.at_inputs = torch.cat((self.at_inputs[1:],
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)

            # predictions
            at_final_predictions = torch.cat(
                (at_final_predictions,
                 final_prediction.reshape(1, self.input_per_frame)), 0)
            self.at_predictions = torch.cat(
                (self.at_predictions[1:],
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        # END active tuning

        # store rest of predictions in at_final_predictions
        for i in range(self.tuning_length):
            at_final_predictions = torch.cat(
                (at_final_predictions, self.at_predictions[i].reshape(
                    1, self.input_per_frame)), 0)
            at_final_inputs = torch.cat(
                (at_final_inputs, self.at_inputs[i].reshape(
                    1, self.input_per_frame)), 0)

        # get final translation bias
        final_translation_bias = self.Cs[0].clone().detach()
        print(f'final translation bias: {final_translation_bias}')

        return at_final_inputs, at_final_predictions, final_translation_bias

    ############################################################################
    ##########  EVALUATION #####################################################

    def get_result_history(self, observations, at_final_predictions):

        pred_errors = self.evaluator.prediction_errors(observations,
                                                       at_final_predictions,
                                                       self.mse)

        return [pred_errors, self.at_losses, self.c_losses]