예제 #1
0
def living_reward(state, action):
    if torch.is_tensor(state):
        cartpole = np.shape(state)[-1] == 4
    else:
        cartpole = np.shape(state)[0]
    if cartpole:
        # Cartpole
        if torch.is_tensor(state):
            pos = state[:, 0]
            angle = torch.rad2deg(state[:, 2])
            rew = ((torch.lt(torch.abs(pos),2.4) & torch.lt(torch.abs(angle),12))).int()
        else:
            pos = state[0]
            angle = np.rad2deg(state[2])
            rew = int((np.abs(pos) < 2.4 and np.abs(angle) < 12))

    else:
        print("In the wrong place")
        if torch.is_tensor(state):
            pitch = state[:, 0]
            roll = state[:, 1]
            rew = (torch.abs(pitch) < np.deg2rad(5)).float() + (torch.abs(roll) < np.deg2rad(5)).float()
        else:
            pitch = state[0]
            roll = state[1]
            flag1 = np.abs(pitch) < np.deg2rad(5)
            flag2 = np.abs(roll) < np.deg2rad(5)
            rew = int(flag1) + int(flag2)
    return rew
예제 #2
0
def squ_cost(state, action):
    if torch.is_tensor(state):
        cartpole = np.shape(state)[-1] == 4
    else:
        cartpole = np.shape(state)[0]
    if cartpole:
        if torch.is_tensor(state):
            pos = state[:, 0]
            angle = torch.rad2deg(state[:, 2])/2
            cost = pos ** 2 + angle ** 2
        else:
            pos = state[0]
            angle = np.rad2deg(state[2])/2
            cost = pos ** 2 + angle ** 2
        # Cartpole
    else:
        if torch.is_tensor(state):
            pitch = state[:, 0]
            roll = state[:, 1]
            cost = pitch ** 2 + roll ** 2
        else:
            pitch = state[0]
            roll = state[1]
            cost = pitch ** 2 + roll ** 2
    return -cost
예제 #3
0
def get_sda(xyz, M, triple_mask):
    """
    Input: (xyz) in list type, (M) = get_edge_matrix, (triple_mask) = mask of 0's and 1's
    Output: SDA of angles specified in triple_mask in degrees
    """
    xyz = torch.tensor(xyz)
    edges = torch.matmul(M, xyz)
    edges = F.normalize(edges, dim=1)  # [42 x 3]
    gram = torch.matmul(edges, edges.T)  # [42 x 42]
    gram = torch.clamp(gram, -1., 1.)  # [42 x 42]
    angles = torch.masked_select(gram, triple_mask > 0.)
    angles = torch.rad2deg(torch.arccos(angles))
    return torch.std(angles)
예제 #4
0
def get_planarity_values(b, face_mask, gram, ps):
    """
    Inputs: b = number of batches, face_mask = face mask, gram = gram matrix, ps = total internal angles
    of object, usually 1080
    Output: List of deviations from planarity for each object 
    """
    planarity = torch.zeros(b)
    P = torch.sum(face_mask > 0, dim=(1, 2))
    P = torch.cumsum(P, dim=0)
    P = torch.hstack((torch.tensor(0), P)).int()
    cos_angles_P = torch.masked_select(gram, face_mask > 0.)
    angles_P = torch.rad2deg(torch.arccos(cos_angles_P))
    for i in range(b):
        planarity[i] = torch.abs(ps[i] - torch.sum(angles_P[P[i]:P[i + 1]]))
    return planarity
예제 #5
0
def get_symmetry_values(b, sym_mask, gram):
    """
    Inputs: b = number of batches, sym_mask = symmetry_mask, gram = gram matrix
    Output: list of deviations from symmetry, angles selected in sym_mask 
    """
    symmetry = torch.zeros(b)
    for i in range(b):
        sym_max = torch.max(sym_mask[i, :, :]).int()
        for sm in range(sym_max):
            cos_angles_S = torch.masked_select(gram[i, :, :],
                                               sym_mask == sm + 1)
            angles_S = torch.rad2deg(torch.arccos(cos_angles_S))
            #print(angles_S[0] , angles_S[1])
            symmetry[i] += torch.abs(angles_S[0] - angles_S[1])
    return symmetry
 def set_comparison_values(self, ideal_rotation_values,
                           ideal_rotation_matrix):
     self.identity_matrix = torch.Tensor(
         np.identity(self.num_input_dimensions))
     self.ideal_rotation = ideal_rotation_matrix.to(self.device)
     if self.rotation_type == 'qrotate':
         self.ideal_quat = ideal_rotation_values.to(self.device)
         self.ideal_angle = torch.rad2deg(
             self.perspective_taker.qeuler(self.ideal_quat,
                                           'xyz')).to(self.device)
     elif self.rotation_type == 'eulrotate':
         self.ideal_angle = ideal_rotation_values.to(self.device)
     else:
         print(
             f'ERROR: Received unknown rotation type!\n\trotation type: {self.rotation_type}'
         )
         exit()
예제 #7
0
def angular_metric(u, v, metric="cosine"):
    """
    Angular metric: calculates the angle and distance between two
    vectors using different distance metrics: euclidean, cosine,
    Triangle's Area Similarity (TS), Sector's Area Similarity (SS),
    and TS-SS. More details in the paper at
    https://github.com/taki0112/Vector_Similarity
    :param u: 1D Tensor
    :param v: 1D Tensor
    :param metric: choices are: cosine, euclidean, TS, SS, TS-SS
    :return: angle, distance
    """
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    similarity = cos(u, v)
    rad = torch.acos(similarity)
    angle = torch.rad2deg(rad).item()
    if metric == "cosine":
        distance = 1 - similarity.item()
    elif metric == "euclidean":
        distance = torch.cdist(u, v).item()
    elif metric == "TS":
        # Triangle's Area Similarity (TS)
        rad_10 = torch.deg2rad(rad + torch.deg2rad(torch.tensor(10.0)))
        distance = (torch.norm(u) * torch.norm(v)) * torch.sin(rad_10) / 2
        distance = distance.item()
    elif metric == "SS":
        # Sector's Area Similarity (SS)
        ed_md = torch.cdist(u, v) + torch.abs(torch.norm(u) - torch.norm(v))
        rad_10 = rad + torch.deg2rad(torch.tensor(10.0))
        distance = pi * torch.pow(ed_md, 2) * rad_10 / 360
        distance = distance.item()
    elif metric == "TS-SS":
        _, triangle = angular_metric(u, v, metric="TS")
        _, sector = angular_metric(u, v, metric="SS")
        distance = triangle * sector
    else:
        raise Exception(f"Distance metric {metric} unsupported")
    if similarity:
        distance = 1 - distance

    return angle, distance
예제 #8
0
def get_raw_quat_distance(q0, q1):

    # Catching zero data case
    if q0.shape[0] == 0:
        return torch.tensor([float('nan')], device=q0.device)

    # Determine the difference
    q0_minus_q1 = q0 - q1
    q0_plus_q1 = q0 + q1

    # Obtain the norm
    d_minus = q0_minus_q1.norm(dim=-1)
    d_plus = q0_plus_q1.norm(dim=-1)

    # Compare the norms and select the one with the smallest norm
    ds = torch.stack((d_minus, d_plus))
    rad_distance = torch.min(ds, dim=0).values

    # Converting the rad to degree
    degree_distance = torch.rad2deg(rad_distance)

    return degree_distance
예제 #9
0
def angle_loss(a,b):
     return MSE(torch.rad2deg(a), torch.rad2deg(b))
예제 #10
0
    def run_inference(self, observations, grad_calculation):

        at_final_predictions = torch.tensor([]).to(self.device)
        at_final_inputs = torch.tensor([]).to(self.device)

        if self.rotation_type == 'qrotate':
            ## Rotation quaternion
            rq = self.perspective_taker.init_quaternion()
            # print(rq)

            for i in range(self.tuning_length + 1):
                quat = rq.clone().to(self.device)
                quat.requires_grad_()
                self.Rs.append(quat)

        elif self.rotation_type == 'eulrotate':
            ## Rotation euler angles
            # ra = perspective_taker.init_angles_()
            # ra = torch.Tensor([[309.89], [82.234], [95.765]])
            ra = torch.Tensor([[75.0], [6.0], [128.0]])
            print(ra)

            for i in range(self.tuning_length + 1):
                angles = []
                for j in range(self.num_input_dimensions):
                    angle = ra[j].clone()
                    angle.requires_grad_()
                    angles.append(angle)
                self.Rs.append(angles)

        else:
            print('ERROR: Received unknown rotation type!')
            exit()

        ## Core state
        # define scaler
        state_scaler = 0.95

        # init state
        at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device)
        at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device)

        at_h.requires_grad = True
        at_c.requires_grad = True

        init_state = (at_h, at_c)
        state = (init_state[0], init_state[1])

        ############################################################################
        ##########  FORWARD PASS  ##################################################

        for i in range(self.tuning_length):
            o = observations[self.obs_count].to(self.device)
            self.at_inputs = torch.cat((self.at_inputs,
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)
            self.obs_count += 1

            if self.rotation_type == 'qrotate':
                x_R = self.perspective_taker.qrotate(o, self.Rs[i])
            else:
                rotmat = self.perspective_taker.compute_rotation_matrix_(
                    self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                x_R = self.perspective_taker.rotate(o, rotmat)

            x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)
            self.at_states.append(state)
            self.at_predictions = torch.cat(
                (self.at_predictions,
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        ############################################################################
        ##########  ACTIVE TUNING ##################################################

        while self.obs_count < self.num_frames:
            # TODO folgendes evtl in function auslagern
            o = observations[self.obs_count].to(self.device)
            self.obs_count += 1

            if self.rotation_type == 'qrotate':
                x_R = self.perspective_taker.qrotate(o, self.Rs[-1])

            else:
                rotmat = self.perspective_taker.compute_rotation_matrix_(
                    self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                x_R = self.perspective_taker.rotate(o, rotmat)

            x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            ## Generate current prediction
            with torch.no_grad():
                state = self.at_states[-1]
                state = (state[0] * state_scaler, state[1] * state_scaler)
                new_prediction, state = self.core_model(x, state)

            ## For #tuning_cycles
            for cycle in range(self.tuning_cycles):
                print('----------------------------------------------')

                # Get prediction
                p = self.at_predictions[-1]

                # Calculate error
                loss = self.at_loss(p, x[0])

                # Propagate error back through tuning horizon
                loss.backward(retain_graph=True)

                self.at_losses.append(loss.clone().detach().cpu().numpy())
                print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}')

                # Update parameters
                with torch.no_grad():

                    ## get gradients
                    if self.rotation_type == 'qrotate':
                        for i in range(self.tuning_length + 1):
                            # save grads for all parameters in all time steps of tuning horizon
                            self.R_grads[i] = self.Rs[i].grad
                    else:
                        for i in range(self.tuning_length + 1):
                            # save grads for all parameters in all time steps of tuning horizon
                            grad = []
                            for j in range(self.num_input_dimensions):
                                grad.append(self.Rs[i][j].grad)
                            self.R_grads[i] = torch.stack(grad)

                    # print(self.R_grads[self.tuning_length])

                    # Calculate overall gradients
                    if grad_calculation == 'lastOfTunHor':
                        ### version 1
                        grad_R = self.R_grads[0]
                    elif grad_calculation == 'meanOfTunHor':
                        ### version 2 / 3
                        grad_R = torch.mean(torch.stack(self.R_grads), dim=0)
                    elif grad_calculation == 'weightedInTunHor':
                        ### version 4
                        weighted_grads_R = [None] * (self.tuning_length + 1)
                        for i in range(self.tuning_length + 1):
                            weighted_grads_R[i] = np.power(self.grad_bias,
                                                           i) * self.R_grads[i]
                        grad_R = torch.mean(torch.stack(weighted_grads_R),
                                            dim=0)

                    # print(f'grad_R: {grad_R}')

                    grad_R = grad_R.to(self.device)
                    if self.rotation_type == 'qrotate':
                        # Update parameters in time step t-H with saved gradients
                        upd_R = self.perspective_taker.update_quaternion(
                            self.Rs[0], grad_R, self.at_learning_rate,
                            self.r_momentum)
                        print(f'updated quaternion: {upd_R}')

                        # Compare quaternion values
                        # quat_loss = torch.sum(self.perspective_taker.qmul(self.ideal_quat, upd_R))
                        quat_loss = 2 * torch.arccos(
                            torch.abs(
                                torch.sum(torch.mul(self.ideal_quat, upd_R))))
                        quat_loss = torch.rad2deg(quat_loss)
                        print(f'loss of quaternion: {quat_loss}')
                        self.rv_losses.append(quat_loss)

                        # Compare quaternion angles
                        ang = torch.rad2deg(
                            self.perspective_taker.qeuler(upd_R, 'zyx'))
                        ang_diff = ang - self.ideal_angle
                        ang_loss = 2 - (torch.cos(torch.deg2rad(ang_diff)) + 1)
                        print(
                            f'loss of quaternion angles: {ang_loss} \nwith norm: {torch.norm(ang_loss)}'
                        )
                        self.ra_losses.append(torch.norm(ang_loss))

                        # Compute rotation matrix
                        rotmat = self.perspective_taker.quaternion2rotmat(
                            upd_R)

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            self.Rs[i].requires_grad = False
                            self.Rs[i].grad.data.zero_()

                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            quat = upd_R.clone()
                            quat.requires_grad_()
                            self.Rs[i] = quat

                    else:
                        # Update parameters in time step t-H with saved gradients
                        upd_R = self.perspective_taker.update_rotation_angles_(
                            self.Rs[0], grad_R, self.at_learning_rate)
                        print(f'updated angles: {upd_R}')

                        # Save rotation angles
                        rotang = torch.stack(upd_R)
                        # angles:
                        ang_diff = rotang - self.ideal_angle
                        ang_loss = 2 - (torch.cos(torch.deg2rad(ang_diff)) + 1)
                        print(
                            f'loss of rotation angles: \n  {ang_loss}, \n  with norm {torch.norm(ang_loss)}'
                        )
                        self.rv_losses.append(torch.norm(ang_loss))
                        # Compute rotation matrix
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            upd_R[0], upd_R[1], upd_R[2])[0]

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            for j in range(self.num_input_dimensions):
                                self.Rs[i][j].requires_grad = False
                                self.Rs[i][j].grad.data.zero_()

                        # print(Rs[0])
                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            angles = []
                            for j in range(3):
                                angle = upd_R[j].clone()
                                angle.requires_grad_()
                                angles.append(angle)
                            self.Rs[i] = angles
                        # print(Rs[0])

                    # Calculate and save rotation losses
                    # matrix:
                    # mat_loss = self.mse(
                    #     (torch.mm(self.ideal_rotation, torch.transpose(rotmat, 0, 1))),
                    #     self.identity_matrix
                    # )
                    dif_R = torch.mm(self.ideal_rotation,
                                     torch.transpose(rotmat, 0, 1))
                    mat_loss = torch.arccos(0.5 * (torch.trace(dif_R) - 1))
                    mat_loss = torch.rad2deg(mat_loss)

                    print(f'loss of rotation matrix: {mat_loss}')
                    self.rm_losses.append(mat_loss)

                    # print(Rs[0])

                    # Initial state
                    g_h = at_h.grad.to(self.device)
                    g_c = at_c.grad.to(self.device)

                    upd_h = init_state[0] - self.at_learning_rate_state * g_h
                    upd_c = init_state[1] - self.at_learning_rate_state * g_c

                    at_h.data = upd_h.clone().detach().requires_grad_()
                    at_c.data = upd_c.clone().detach().requires_grad_()

                    at_h.grad.data.zero_()
                    at_c.grad.data.zero_()

                    # state_optimizer.step()
                    # print(f'updated init_state: {init_state}')

                ## REORGANIZE FOR MULTIPLE CYCLES!!!!!!!!!!!!!

                # forward pass from t-H to t with new parameters
                # Update init state???
                init_state = (at_h, at_c)
                state = (init_state[0], init_state[1])
                self.at_predictions = torch.tensor([]).to(self.device)
                for i in range(self.tuning_length):

                    if self.rotation_type == 'qrotate':
                        x_R = self.perspective_taker.qrotate(
                            self.at_inputs[i], self.Rs[i])
                    else:
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                        x_R = self.perspective_taker.rotate(
                            self.at_inputs[i], rotmat)

                    x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

                    state = (state[0] * state_scaler, state[1] * state_scaler)
                    upd_prediction, state = self.core_model(x, state)
                    self.at_predictions = torch.cat(
                        (self.at_predictions,
                         upd_prediction.reshape(1, self.input_per_frame)), 0)

                    # for last tuning cycle update initial state to track gradients
                    if cycle == (self.tuning_cycles - 1) and i == 0:
                        with torch.no_grad():
                            final_prediction = self.at_predictions[0].clone(
                            ).detach().to(self.device)
                            final_input = x.clone().detach().to(self.device)

                        at_h = state[0].clone().detach().requires_grad_()
                        at_c = state[1].clone().detach().requires_grad_()
                        init_state = (at_h, at_c)
                        state = (init_state[0], init_state[1])

                    self.at_states[i] = state

                # Update current rotation
                if self.rotation_type == 'qrotate':
                    x_R = self.perspective_taker.qrotate(o, self.Rs[-1])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                    x_R = self.perspective_taker.rotate(o, rotmat)

                x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            # END tuning cycle

            ## Generate updated prediction
            state = self.at_states[-1]
            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)

            ## Reorganize storage variables
            # observations
            at_final_inputs = torch.cat(
                (at_final_inputs, final_input.reshape(
                    1, self.input_per_frame)), 0)
            self.at_inputs = torch.cat((self.at_inputs[1:],
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)

            # predictions
            at_final_predictions = torch.cat(
                (at_final_predictions,
                 final_prediction.reshape(1, self.input_per_frame)), 0)
            self.at_predictions = torch.cat(
                (self.at_predictions[1:],
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        # END active tuning

        # store rest of predictions in at_final_predictions
        for i in range(self.tuning_length):
            at_final_predictions = torch.cat(
                (at_final_predictions, self.at_predictions[i].reshape(
                    1, self.input_per_frame)), 0)
            if self.rotation_type == 'qrotate':
                x_i = self.perspective_taker.qrotate(self.at_inputs[i],
                                                     self.Rs[-1])
            else:
                x_i = self.perspective_taker.rotate(self.at_inputs[i], rotmat)
            at_final_inputs = torch.cat(
                (at_final_inputs, x_i.reshape(1, self.input_per_frame)), 0)

        # get final rotation matrix
        if self.rotation_type == 'qrotate':
            final_rotation_values = self.Rs[0].clone().detach()
            # get final quaternion
            print(f'final quaternion: {final_rotation_values}')
            final_rotation_matrix = self.perspective_taker.quaternion2rotmat(
                final_rotation_values)

        else:
            final_rotation_values = [
                self.Rs[0][i].clone().detach()
                for i in range(self.num_input_dimensions)
            ]
            print(f'final euler angles: {final_rotation_values}')
            final_rotation_matrix = self.perspective_taker.compute_rotation_matrix_(
                final_rotation_values[0], final_rotation_values[1],
                final_rotation_values[2])

        print(f'final rotation matrix: \n{final_rotation_matrix}')

        return at_final_inputs, at_final_predictions, final_rotation_values, final_rotation_matrix
def test(colorImgModel, depthImgModel, regressorModel, maxPool, dataLoader):

    simpleDistanceSE3 = np.empty(0)
    # Check if log directories are available
    if not os.path.exists(config.logsDirsTesting):
        os.makedirs(config.logsDirsTesting)

    for j, data in tqdm(enumerate(dataLoader, 0), total=len(dataLoader)):

        # Expand the data into its respective components
        srcClrT, srcDepthT, ptCldT, __, targetTransformT, optional = data

        # Transpose the tensors such that the no of channels are the 2nd term
        srcClrT = srcClrT.to('cuda')
        srcDepthT = srcDepthT.to('cuda')

        # Cuda 0
        featureMapClrImg = colorImgModel(srcClrT)
        # Cuda 1
        maxPooledDepthImg = maxPool(srcDepthT)
        featureMapDepthImg = depthImgModel(maxPooledDepthImg)

        # Cuda 0
        aggClrDepthFeatureMap = torch.cat(
            [featureMapDepthImg.to('cuda'), featureMapClrImg], dim=1)

        # Cuda 0
        predTransform = regressorModel(aggClrDepthFeatureMap)

        # Simple SE3 distance
        predRot = quaternion_to_matrix(predTransform[:, :4])
        predT = predTransform[:, 4:]

        targetTransformT = moveToDevice(targetTransformT,
                                        predTransform.get_device())

        gtRot = targetTransformT[:, :3, :3].type(torch.float32)
        gtT = targetTransformT[:, :3, 3].type(torch.float32)

        RtR = torch.matmul(calculateInvRTTensor(predRot, predT),
                           targetTransformT.type(torch.float32))

        I = moveToDevice(torch.eye(4, dtype=torch.float32),
                         predTransform.get_device())
        simpleDistanceSE3 = np.append(
            simpleDistanceSE3,
            torch.norm(RtR - I, 'fro').to('cpu').numpy())

        # Caluclate the euler angles from rotation matrix
        predEulerAngles = matrix_to_euler_angles(predRot, "ZXY")
        targetEulerAngles = matrix_to_euler_angles(gtRot, "ZXY")
        errorEulerAngle = torch.abs(targetEulerAngles - predEulerAngles)
        errorEulerAngle = torch.rad2deg(torch.mean(errorEulerAngle, dim=0))
        errorTranslation = torch.abs(gtT - predT)
        errorTranslation = torch.mean(errorTranslation, dim=0)
        """
        print(errorEulerAngle)
        print(errorTranslation)
        
        print("Breakpoint")
        """

    return (np.mean(simpleDistanceSE3), errorEulerAngle.to('cpu').numpy(),
            errorTranslation.to('cpu').numpy())
예제 #12
0
    def run_inference(self, observations, grad_calculations, do_binding,
                      do_rotation, do_translation, order, reorder):
        [grad_calc_binding, grad_calc_rotation,
         grad_calc_translation] = grad_calculations

        if reorder is not None:
            reorder = reorder.to(self.device)

        at_final_predictions = torch.tensor([]).to(self.device)
        at_final_inputs = torch.tensor([]).to(self.device)

        ###########################  BINDING  #################################
        if do_binding:
            ## Binding matrices
            # Init binding entries
            bm = self.binder.init_binding_matrix_det_()
            # bm = binder.init_binding_matrix_rand_()
            # print(bm)
            dummie_line = torch.ones(1, self.num_observations).to(
                self.device) * self.dummie_init

            for i in range(self.tuning_length + 1):
                matrix = bm.clone().to(self.device)
                if self.nxm:
                    matrix = torch.cat([matrix, dummie_line])
                matrix.requires_grad_()
                self.Bs.append(matrix)

        ###########################  ROTATION  ################################
        if do_rotation:
            if self.rotation_type == 'qrotate':
                ## Rotation quaternion
                rq = self.perspective_taker.init_quaternion()
                # print(rq)

                for i in range(self.tuning_length + 1):
                    quat = rq.clone().to(self.device)
                    quat.requires_grad_()
                    self.Rs.append(quat)

            elif self.rotation_type == 'eulrotate':
                ## Rotation euler angles
                # ra = perspective_taker.init_angles_()
                # ra = torch.Tensor([[309.89], [82.234], [95.765]])
                ra = torch.Tensor([[75.0], [6.0], [128.0]])
                # print(ra)

                for i in range(self.tuning_length + 1):
                    angles = []
                    for j in range(self.num_spatial_dimensions):
                        angle = ra[j].clone().to(self.device)
                        angle.requires_grad_()
                        angles.append(angle)
                    self.Rs.append(angles)

            else:
                print('ERROR: Received unknown rotation type!')
                exit()

        ###########################  TRANSLATION  #############################
        if do_translation:
            tb = self.perspective_taker.init_translation_bias_()
            # print(tb)

            for i in range(self.tuning_length + 1):
                transba = tb.clone().to(self.device)
                transba.requires_grad = True
                self.Cs.append(transba)

        #######################################################################

        ## Core state
        # define scaler
        state_scaler = 0.95

        # init state
        at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device)
        at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device)

        at_h.requires_grad = True
        at_c.requires_grad = True

        init_state = (at_h, at_c)
        state = (init_state[0], init_state[1])

        ############################################################################
        ##########  FORWARD PASS  ##################################################

        for i in range(self.tuning_length):
            o = observations[self.obs_count].to(self.device)
            self.at_inputs = torch.cat((self.at_inputs,
                                        o.reshape(1, self.num_observations,
                                                  self.num_input_dimensions)),
                                       0)
            self.obs_count += 1

            ###########################  BINDING  #################################
            if do_binding:
                bm = self.binder.scale_binding_matrix(self.Bs[i],
                                                      self.scale_mode,
                                                      self.scale_combo,
                                                      self.nxm_enhance,
                                                      self.nxm_last_line_scale)
                if self.nxm:
                    bm = bm[:-1]
                x_B = self.binder.bind(o, bm)
            else:
                x_B = o

            if self.gestalten:
                if self.dir_mag_gest:
                    mag = x_B[:, -1].view(self.num_observations, 1)
                    x_B = x_B[:, :-1]
                x_B = torch.cat([
                    x_B[:, :self.num_spatial_dimensions],
                    x_B[:, self.num_spatial_dimensions:]
                ])
            ###########################  ROTATION  ################################
            if do_rotation:
                if self.rotation_type == 'qrotate':
                    x_R = self.perspective_taker.qrotate(x_B, self.Rs[i])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                    x_R = self.perspective_taker.rotate(x_B, rotmat)
            else:
                x_R = x_B

            if self.gestalten:
                dir = x_R[-self.num_observations:, :]
                x_R = x_R[:-self.num_observations, :]
            ###########################  TRANSLATION  #############################
            if do_translation:
                x_C = self.perspective_taker.translate(x_R, self.Cs[i])
            else:
                x_C = x_R

            if self.gestalten:
                if self.dir_mag_gest:
                    x_C = torch.cat([x_C, dir, mag], dim=1)
                else:
                    x_C = torch.cat([x_C, dir], dim=1)
            #######################################################################

            x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)
            self.at_states.append(state)
            self.at_predictions = torch.cat(
                (self.at_predictions,
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        ############################################################################
        ##########  ACTIVE TUNING ##################################################

        while self.obs_count < self.num_frames:
            # TODO folgendes evtl in function auslagern
            o = observations[self.obs_count].to(self.device)
            self.obs_count += 1

            ###########################  BINDING  #################################
            if do_binding:
                bm = self.binder.scale_binding_matrix(self.Bs[-1],
                                                      self.scale_mode,
                                                      self.scale_combo,
                                                      self.nxm_enhance,
                                                      self.nxm_last_line_scale)
                if self.nxm:
                    bm = bm[:-1]
                x_B = self.binder.bind(o, bm)
            else:
                x_B = o

            if self.gestalten:
                if self.dir_mag_gest:
                    mag = x_B[:, -1].view(self.num_observations, 1)
                    x_B = x_B[:, :-1]
                x_B = torch.cat([
                    x_B[:, :self.num_spatial_dimensions],
                    x_B[:, self.num_spatial_dimensions:]
                ])
            ###########################  ROTATION  ################################
            if do_rotation:
                if self.rotation_type == 'qrotate':
                    x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                    x_R = self.perspective_taker.rotate(x_B, rotmat)
            else:
                x_R = x_B

            if self.gestalten:
                dir = x_R[-self.num_observations:, :]
                x_R = x_R[:-self.num_observations, :]
            ###########################  TRANSLATION  #############################
            if do_translation:
                x_C = self.perspective_taker.translate(x_R, self.Cs[-1])
            else:
                x_C = x_R

            if self.gestalten:
                if self.dir_mag_gest:
                    x_C = torch.cat([x_C, dir, mag], dim=1)
                else:
                    x_C = torch.cat([x_C, dir], dim=1)
            #######################################################################

            x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            ## Generate current prediction
            with torch.no_grad():
                state = self.at_states[-1]
                state = (state[0] * state_scaler, state[1] * state_scaler)
                new_prediction, state = self.core_model(x, state)

            ## For #tuning_cycles
            for cycle in range(self.tuning_cycles):
                print('----------------------------------------------')

                # Get prediction
                p = self.at_predictions[-1]

                # Calculate error
                loss = self.at_loss(p, x[0])

                # Propagate error back through tuning horizon
                loss.backward(retain_graph=True)

                # Update parameters
                with torch.no_grad():
                    # self.at_losses.append(loss.clone().detach())
                    self.at_losses.append(loss.clone().detach().cpu())
                    # self.at_losses.append(loss.clone().detach().cpu().numpy())
                    print(
                        f'frame: {self.obs_count} cycle: {cycle} loss: {loss}')

                    ###########################  BINDING  #################################
                    if do_binding:
                        # Calculate gradients with respect to the entires
                        for i in range(self.tuning_length + 1):
                            self.B_grads[i] = self.Bs[i].grad
                        # print(B_grads[tuning_length])

                        # Calculate overall gradients
                        if grad_calc_binding == 'lastOfTunHor':
                            ### version 1
                            grad_B = self.B_grads[0]
                        elif grad_calc_binding == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_B = torch.mean(torch.stack(self.B_grads),
                                                dim=0)
                        elif grad_calc_binding == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_B = [None
                                                ] * (self.tuning_length + 1)
                            for i in range(self.tuning_length + 1):
                                weighted_grads_B[i] = np.power(
                                    self.grad_bias_binding,
                                    i) * self.B_grads[i]
                            grad_B = torch.mean(torch.stack(weighted_grads_B),
                                                dim=0)
                        # print(f'grad_B: {grad_B}')

                        # Update parameters in time step t-H with saved gradients
                        grad_B = grad_B.to(self.device)
                        # upd_B = self.binder.update_binding_matrix_(
                        upd_B = self.binder.decay_update_binding_matrix_(
                            self.Bs[0], grad_B, self.at_learning_rate_binding,
                            self.bm_momentum)

                        # Compare binding matrix to ideal matrix
                        # NOTE: ideal matrix is always identity, bc then the FBE and determinant can be calculated => provide reorder
                        c_bm = self.binder.scale_binding_matrix(
                            upd_B, self.scale_mode, self.scale_combo)
                        if order is not None:
                            c_bm = c_bm.gather(
                                1,
                                reorder.unsqueeze(0).expand(c_bm.shape))

                        if self.nxm:
                            self.oc_grads.append(grad_B[-1])
                            FBE = self.evaluator.FBE_nxm_additional_features(
                                c_bm, self.ideal_binding,
                                self.additional_features)
                            c_bm = self.evaluator.clear_nxm_binding_matrix(
                                c_bm, self.additional_features)

                        mat_loss = self.evaluator.FBE(c_bm, self.ideal_binding)

                        if self.nxm:
                            mat_loss = torch.stack(
                                [mat_loss, FBE, mat_loss + FBE])
                        self.bm_losses.append(mat_loss)
                        print(f'loss of binding matrix (FBE): {mat_loss}')

                        # Compute determinante of binding matrix
                        det = torch.det(c_bm)
                        self.bm_dets.append(det)
                        print(f'determinante of binding matrix: {det}')

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            self.Bs[i].requires_grad = False
                            self.Bs[i].grad.data.zero_()

                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            self.Bs[i].data = upd_B.clone().data
                            self.Bs[i].requires_grad = True

                    ###########################  ROTATION  ################################
                    if do_rotation:
                        ## get gradients
                        if self.rotation_type == 'qrotate':
                            for i in range(self.tuning_length + 1):
                                # save grads for all parameters in all time steps of tuning horizon
                                self.R_grads[i] = self.Rs[i].grad
                        else:
                            for i in range(self.tuning_length + 1):
                                # save grads for all parameters in all time steps of tuning horizon
                                grad = []
                                for j in range(self.num_input_dimensions):
                                    grad.append(self.Rs[i][j].grad)
                                self.R_grads[i] = torch.stack(grad)
                        # print(self.R_grads[self.tuning_length])

                        # Calculate overall gradients
                        if grad_calc_rotation == 'lastOfTunHor':
                            ### version 1
                            grad_R = self.R_grads[0]
                        elif grad_calc_rotation == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_R = torch.mean(torch.stack(self.R_grads),
                                                dim=0)
                        elif grad_calc_rotation == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_R = [None
                                                ] * (self.tuning_length + 1)
                            for i in range(self.tuning_length + 1):
                                weighted_grads_R[i] = np.power(
                                    self.grad_bias_rotation,
                                    i) * self.R_grads[i]
                            grad_R = torch.mean(torch.stack(weighted_grads_R),
                                                dim=0)
                        # print(f'grad_R: {grad_R}')

                        grad_R = grad_R.to(self.device)
                        if self.rotation_type == 'qrotate':
                            # Update parameters in time step t-H with saved gradients
                            upd_R = self.perspective_taker.update_quaternion(
                                self.Rs[0], grad_R,
                                self.at_learning_rate_rotation,
                                self.r_momentum)
                            print(f'updated quaternion: {upd_R}')

                            # Compare quaternion values
                            # quat_loss = torch.sum(self.perspective_taker.qmul(self.ideal_quat, upd_R))
                            quat_loss = 2 * torch.arccos(
                                torch.abs(
                                    torch.sum(torch.mul(
                                        self.ideal_quat, upd_R))))
                            quat_loss = torch.rad2deg(quat_loss)
                            print(f'loss of quaternion: {quat_loss}')
                            self.rv_losses.append(quat_loss)
                            # Compute rotation matrix
                            rotmat = self.perspective_taker.quaternion2rotmat(
                                upd_R)

                            # Zero out gradients for all parameters in all time steps of tuning horizon
                            for i in range(self.tuning_length + 1):
                                self.Rs[i].requires_grad = False
                                self.Rs[i].grad.data.zero_()

                            # Update all parameters for all time steps
                            for i in range(self.tuning_length + 1):
                                quat = upd_R.clone()
                                quat.requires_grad_()
                                self.Rs[i] = quat

                        else:
                            # Update parameters in time step t-H with saved gradients
                            upd_R = self.perspective_taker.update_rotation_angles_(
                                self.Rs[0], grad_R,
                                self.at_learning_rate_rotation)
                            print(f'updated angles: {upd_R}')

                            # Save rotation angles
                            rotang = torch.stack(upd_R)
                            # angles:
                            ang_diff = rotang - self.ideal_angle
                            ang_loss = 2 - (
                                torch.cos(torch.deg2rad(ang_diff)) + 1)
                            print(
                                f'loss of rotation angles: \n  {ang_loss}, \n  with norm {torch.norm(ang_loss)}'
                            )
                            self.rv_losses.append(torch.norm(ang_loss))
                            # Compute rotation matrix
                            rotmat = self.perspective_taker.compute_rotation_matrix_(
                                upd_R[0], upd_R[1], upd_R[2])[0]

                            # Zero out gradients for all parameters in all time steps of tuning horizon
                            for i in range(self.tuning_length + 1):
                                for j in range(self.num_input_dimensions):
                                    self.Rs[i][j].requires_grad = False
                                    self.Rs[i][j].grad.data.zero_()

                            # Update all parameters for all time steps
                            for i in range(self.tuning_length + 1):
                                angles = []
                                for j in range(3):
                                    angle = upd_R[j].clone()
                                    angle.requires_grad_()
                                    angles.append(angle)
                                self.Rs[i] = angles

                        # Calculate and save rotation losses
                        # matrix:
                        # mat_loss = self.mse(
                        #     (torch.mm(self.ideal_rotation, torch.transpose(rotmat, 0, 1))),
                        #     self.identity_matrix
                        # )
                        dif_R = torch.mm(self.ideal_rotation,
                                         torch.transpose(rotmat, 0, 1))
                        mat_loss = torch.arccos(0.5 * (torch.trace(dif_R) - 1))
                        mat_loss = torch.rad2deg(mat_loss)
                        print(f'loss of rotation matrix: {mat_loss}')
                        self.rm_losses.append(mat_loss)

                    ###########################  TRANSLATION  #############################
                    if do_translation:
                        ## Get gradients
                        for i in range(self.tuning_length + 1):
                            # save grads for all parameters in all time steps of tuning horizon
                            self.C_grads[i] = self.Cs[i].grad

                        # print(self.C_grads[self.tuning_length])

                        # Calculate overall gradients
                        if grad_calc_translation == 'lastOfTunHor':
                            ### version 1
                            grad_C = self.C_grads[0]
                        elif grad_calc_translation == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_C = torch.mean(torch.stack(self.C_grads),
                                                dim=0)
                        elif grad_calc_translation == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_C = [None
                                                ] * (self.tuning_length + 1)
                            for i in range(self.tuning_length + 1):
                                weighted_grads_C[i] = np.power(
                                    self.grad_bias_translation,
                                    i) * self.C_grads[i]
                            grad_C = torch.mean(torch.stack(weighted_grads_C),
                                                dim=0)

                        # Update parameters in time step t-H with saved gradients
                        grad_C = grad_C.to(self.device)
                        upd_C = self.perspective_taker.update_translation_bias_(
                            self.Cs[0], grad_C,
                            self.at_learning_rate_translation, self.c_momentum)

                        # Compare translation bias to ideal bias
                        trans_loss = self.mse(self.ideal_translation, upd_C)
                        self.c_losses.append(trans_loss)
                        print(f'loss of translation bias (MSE): {trans_loss}')

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            self.Cs[i].requires_grad = False
                            self.Cs[i].grad.data.zero_()

                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            translation = upd_C.clone()
                            translation.requires_grad_()
                            self.Cs[i] = translation

                    #######################################################################

                    # Initial state
                    g_h = at_h.grad.to(self.device)
                    g_c = at_c.grad.to(self.device)

                    upd_h = init_state[0] - self.at_learning_rate_state * g_h
                    upd_c = init_state[1] - self.at_learning_rate_state * g_c

                    at_h.data = upd_h.clone().detach().requires_grad_()
                    at_c.data = upd_c.clone().detach().requires_grad_()

                    at_h.grad.data.zero_()
                    at_c.grad.data.zero_()

                    # print(f'updated init_state: {init_state}')

                # forward pass from t-H to t with new parameters
                init_state = (at_h, at_c)
                state = (init_state[0], init_state[1])
                self.at_predictions = torch.tensor([]).to(self.device)
                for i in range(self.tuning_length):

                    ###########################  BINDING  #################################
                    if do_binding:
                        bm = self.binder.scale_binding_matrix(
                            self.Bs[i], self.scale_mode, self.scale_combo,
                            self.nxm_enhance, self.nxm_last_line_scale)
                        if self.nxm:
                            bm = bm[:-1]
                        x_B = self.binder.bind(self.at_inputs[i], bm)
                    else:
                        x_B = self.at_inputs[i]

                    if self.gestalten:
                        if self.dir_mag_gest:
                            mag = x_B[:, -1].view(self.num_observations, 1)
                            x_B = x_B[:, :-1]
                        x_B = torch.cat([
                            x_B[:, :self.num_spatial_dimensions],
                            x_B[:, self.num_spatial_dimensions:]
                        ])
                    ###########################  ROTATION  ################################
                    if do_rotation:
                        if self.rotation_type == 'qrotate':
                            x_R = self.perspective_taker.qrotate(
                                x_B, self.Rs[i])
                        else:
                            rotmat = self.perspective_taker.compute_rotation_matrix_(
                                self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                            x_R = self.perspective_taker.rotate(x_B, rotmat)
                    else:
                        x_R = x_B

                    if self.gestalten:
                        dir = x_R[-self.num_observations:, :]
                        x_R = x_R[:-self.num_observations, :]
                    ###########################  TRANSLATION  #############################
                    if do_translation:
                        x_C = self.perspective_taker.translate(x_R, self.Cs[i])
                    else:
                        x_C = x_R

                    if self.gestalten:
                        if self.dir_mag_gest:
                            x_C = torch.cat([x_C, dir, mag], dim=1)
                        else:
                            x_C = torch.cat([x_C, dir], dim=1)
                    #######################################################################

                    x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

                    state = (state[0] * state_scaler, state[1] * state_scaler)
                    upd_prediction, state = self.core_model(x, state)
                    self.at_predictions = torch.cat(
                        (self.at_predictions,
                         upd_prediction.reshape(1, self.input_per_frame)), 0)

                    # for last tuning cycle update initial state to track gradients
                    if cycle == (self.tuning_cycles - 1) and i == 0:
                        with torch.no_grad():
                            final_prediction = self.at_predictions[0].clone(
                            ).detach().to(self.device)
                            final_input = x.clone().detach().to(self.device)

                        at_h = state[0].clone().detach().requires_grad_().to(
                            self.device)
                        at_c = state[1].clone().detach().requires_grad_().to(
                            self.device)
                        init_state = (at_h, at_c)
                        state = (init_state[0], init_state[1])

                    self.at_states[i] = state

                # Update current input
                ###########################  BINDING  #################################
                if do_binding:
                    bm = self.binder.scale_binding_matrix(
                        self.Bs[-1], self.scale_mode, self.scale_combo,
                        self.nxm_enhance, self.nxm_last_line_scale)
                    if self.nxm:
                        bm = bm[:-1]
                    x_B = self.binder.bind(o, bm)
                else:
                    x_B = o

                if self.gestalten:
                    if self.dir_mag_gest:
                        mag = x_B[:, -1].view(self.num_observations, 1)
                        x_B = x_B[:, :-1]
                    x_B = torch.cat([
                        x_B[:, :self.num_spatial_dimensions],
                        x_B[:, self.num_spatial_dimensions:]
                    ])
                ###########################  ROTATION  ################################
                if do_rotation:
                    if self.rotation_type == 'qrotate':
                        x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1])
                    else:
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                        x_R = self.perspective_taker.rotate(x_B, rotmat)
                else:
                    x_R = x_B

                if self.gestalten:
                    dir = x_R[-self.num_observations:, :]
                    x_R = x_R[:-self.num_observations, :]
                ###########################  TRANSLATION  #############################
                if do_translation:
                    x_C = self.perspective_taker.translate(x_R, self.Cs[-1])
                else:
                    x_C = x_R

                if self.gestalten:
                    if self.dir_mag_gest:
                        x_C = torch.cat([x_C, dir, mag], dim=1)
                    else:
                        x_C = torch.cat([x_C, dir], dim=1)
                #######################################################################

                x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            # END tuning cycle

            ## Generate updated prediction
            state = self.at_states[-1]
            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)

            ## Reorganize storage variables
            # observations
            self.at_inputs = torch.cat((self.at_inputs[1:],
                                        o.reshape(1, self.num_observations,
                                                  self.num_input_dimensions)),
                                       0)

            # predictions
            at_final_inputs = torch.cat(
                (at_final_inputs, final_input.reshape(
                    1, self.input_per_frame)), 0)
            at_final_predictions = torch.cat(
                (at_final_predictions,
                 final_prediction.reshape(1, self.input_per_frame)), 0)
            self.at_predictions = torch.cat(
                (self.at_predictions[1:],
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        # END active tuning

        # store rest of predictions in at_final_predictions
        for i in range(self.tuning_length):
            at_final_predictions = torch.cat(
                (at_final_predictions, self.at_predictions[i].reshape(
                    1, self.input_per_frame)), 0)

            inp_i = self.at_inputs[i]
            if do_binding:
                x_B = self.binder.bind(inp_i, bm)
            else:
                x_B = inp_i

            if self.gestalten:
                if self.dir_mag_gest:
                    mag = x_B[:, -1].view(self.num_observations, 1)
                    x_B = x_B[:, :-1]
                x_B = torch.cat([
                    x_B[:, :self.num_spatial_dimensions],
                    x_B[:, self.num_spatial_dimensions:]
                ])
            ###########################  ROTATION  ################################
            if do_rotation:
                if self.rotation_type == 'qrotate':
                    x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1])
                else:
                    x_R = self.perspective_taker.rotate(x_B, rotmat)
            else:
                x_R = x_B

            if self.gestalten:
                dir = x_R[-self.num_observations:, :]
                x_R = x_R[:-self.num_observations, :]
            ###########################  TRANSLATION  #############################
            if do_translation:
                x_C = self.perspective_taker.translate(x_R, self.Cs[-1])
            else:
                x_C = x_R

            if self.gestalten:
                if self.dir_mag_gest:
                    x_i = torch.cat([x_C, dir, mag], dim=1)
                else:
                    x_i = torch.cat([x_C, dir], dim=1)
            #######################################################################

            at_final_inputs = torch.cat(
                (at_final_inputs, x_i.reshape(1, self.input_per_frame)), 0)

        ###########################  BINDING  #################################
        # get final binding matrix
        if do_binding:
            final_binding_matrix = self.binder.scale_binding_matrix(
                self.Bs[-1].clone().detach(), self.scale_mode,
                self.scale_combo)
            print(f'final binding matrix: {final_binding_matrix}')
            final_binding_entries = self.Bs[-1].clone().detach()
            print(f'final binding entires: {final_binding_entries}')

        else:
            final_binding_entries, final_binding_matrix = None, None

        ###########################  ROTATION  ################################
        # get final rotation matrix
        if do_rotation:
            if self.rotation_type == 'qrotate':
                final_rotation_values = self.Rs[0].clone().detach()
                # get final quaternion
                print(f'final quaternion: {final_rotation_values}')
                final_rotation_matrix = self.perspective_taker.quaternion2rotmat(
                    final_rotation_values)
            else:
                final_rotation_values = [
                    self.Rs[0][i].clone().detach()
                    for i in range(self.num_input_dimensions)
                ]
                print(f'final euler angles: {final_rotation_values}')
                final_rotation_matrix = self.perspective_taker.compute_rotation_matrix_(
                    final_rotation_values[0], final_rotation_values[1],
                    final_rotation_values[2])

            print(f'final rotation matrix: \n{final_rotation_matrix}')

        else:
            final_rotation_matrix, final_rotation_values = None, None

        ###########################  TRANSLATION  #############################
        # get final translation bias
        if do_translation:
            final_translation_values = self.Cs[0].clone().detach()
            print(f'final translation bias: {final_translation_values}')

        else:
            final_translation_values = None

        #######################################################################

        return [
            at_final_inputs, at_final_predictions, final_binding_matrix,
            final_binding_entries, final_rotation_values,
            final_rotation_matrix, final_translation_values
        ]
예제 #13
0
    def analyze_angle(self, conf, name):
        '''
        Only works on age labeled vgg dataset, agedb dataset
        '''

        angle_table = [{
            0: set(),
            1: set(),
            2: set(),
            3: set(),
            4: set(),
            5: set(),
            6: set(),
            7: set()
        } for i in range(self.class_num)]
        # batch = 0
        # _angle_table = torch.zeros(self.class_num, 8, len(self.loader)//conf.batch_size).to(conf.device)
        if conf.resume_analysis:
            self.loader = []
        for imgs, labels, ages in tqdm(iter(self.loader)):

            imgs = imgs.to(conf.device)
            labels = labels.to(conf.device)
            ages = ages.to(conf.device)

            embeddings = self.model(imgs)
            if conf.use_dp:
                kernel_norm = l2_norm(self.head.module.kernel, axis=0)
                cos_theta = torch.mm(embeddings, kernel_norm)
                cos_theta = cos_theta.clamp(-1, 1)
            else:
                cos_theta = self.head.get_angle(embeddings)

            thetas = torch.abs(torch.rad2deg(torch.acos(cos_theta)))

            for i in range(len(thetas)):
                age_bin = 7
                if ages[i] < 26:
                    age_bin = 0 if ages[i] < 13 else 1 if ages[i] < 19 else 2
                elif ages[i] < 66:
                    age_bin = int(((ages[i] + 4) // 10).item())
                angle_table[labels[i]][age_bin].add(
                    thetas[i][labels[i]].item())

        if conf.resume_analysis:
            with open('analysis/angle_table.pkl', 'rb') as f:
                angle_table = pickle.load(f)
        else:
            with open('analysis/angle_table.pkl', 'wb') as f:
                pickle.dump(angle_table, f)

        count, avg_angle = [], []
        for i in range(self.class_num):
            count.append(
                [len(single_set) for single_set in angle_table[i].values()])
            avg_angle.append([
                sum(list(single_set)) / len(single_set)
                if len(single_set) else 0  # if set() size is zero, avg is zero
                for single_set in angle_table[i].values()
            ])

        count_df = pd.DataFrame(count)
        avg_angle_df = pd.DataFrame(avg_angle)

        with pd.ExcelWriter('analysis/analyze_angle_{}_{}.xlsx'.format(
                conf.data_mode, name)) as writer:
            count_df.to_excel(writer, sheet_name='count')
            avg_angle_df.to_excel(writer, sheet_name='avg_angle')
예제 #14
0
def compute_metrics(data, pred_transforms):
    gt_transforms = data['transform_gt']
    igt_transforms = torch.eye(4, device=pred_transforms.device).repeat(
        gt_transforms.shape[0], 1, 1)
    igt_transforms[:, :3, :3] = gt_transforms[:, :3, :3].transpose(2, 1)
    igt_transforms[:, :3,
                   3] = -(igt_transforms[:, :3, :3]
                          @ gt_transforms[:, :3, 3].view(-1, 3, 1)).view(
                              -1, 3)
    # points_src = data['points_src'][..., :3]
    # points_ref = data['points_ref'][..., :3]
    # if 'points_raw' in data:
    #     points_raw = data['points_raw'][..., :3]
    # else:
    #     points_raw = points_ref

    # Euler angles, Individual translation errors (Deep Closest Point convention)
    r_gt_euler_deg = np.stack([
        Rotation.from_dcm(r.cpu().numpy()).as_euler('xyz', degrees=True)
        for r in gt_transforms[:, :3, :3]
    ])
    r_pred_euler_deg = np.stack([
        Rotation.from_dcm(r.cpu().numpy()).as_euler('xyz', degrees=True)
        for r in pred_transforms[:, :3, :3]
    ])
    t_gt = gt_transforms[:, :3, 3]
    t_pred = pred_transforms[:, :3, 3]
    r_mae = np.abs(r_gt_euler_deg - r_pred_euler_deg).mean(axis=1)
    t_mae = torch.abs(t_gt - t_pred).mean(dim=1)

    # Rotation, translation errors (isotropic, i.e. doesn't depend on error
    # direction, which is more representative of the actual error)
    concatenated = igt_transforms @ pred_transforms
    rot_trace = concatenated[:, 0, 0] + concatenated[:, 1,
                                                     1] + concatenated[:, 2, 2]
    r_iso = torch.rad2deg(
        torch.acos(torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0)))
    t_iso = concatenated[:, :3, 3].norm(dim=-1)

    # # Modified Chamfer distance
    # src_transformed = (pred_transforms[:, :3, :3] @ points_src.transpose(2, 1)).transpose(2, 1)\
    #                   + pred_transforms[:, :3, 3][:, None, :]
    # ref_clean = points_raw
    # residual_transforms = pred_transforms @ igt_transforms
    # src_clean = (residual_transforms[:, :3, :3] @ points_raw.transpose(2, 1)).transpose(2, 1)\
    #             + residual_transforms[:, :3, 3][:, None, :]
    # dist_src = torch.min(tra.square_distance(src_transformed, ref_clean), dim=-1)[0]
    # dist_ref = torch.min(tra.square_distance(points_ref, src_clean), dim=-1)[0]
    # chamfer_dist = torch.mean(dist_src, dim=1) + torch.mean(dist_ref, dim=1)
    #
    # # ADD/ADI
    # src_diameters = torch.sqrt(tra.square_distance(src_clean, src_clean).max(dim=-1)[0]).max(dim=-1)[0]
    # dist_add = torch.norm(src_clean - ref_clean, p=2, dim=-1).mean(dim=1) / src_diameters
    # dist_adi = torch.sqrt(tra.square_distance(ref_clean, src_clean)).min(dim=-1)[0].mean(dim=-1) / src_diameters

    metrics = {
        'r_mae': r_mae,
        't_mae': t_mae.cpu().numpy(),
        'r_iso': r_iso.cpu().numpy(),
        't_iso': t_iso.cpu().numpy(),
        # 'chamfer_dist': chamfer_dist.cpu().numpy(),
        # 'add': dist_add.cpu().numpy(),
        # 'adi': dist_adi.cpu().numpy()
    }
    return metrics
예제 #15
0
eps = torch.finfo(torch.float32).eps
torch.nextafter(torch.tensor([1, 2]),
                torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps])

# polygamma
torch.polygamma(1, torch.tensor([1, 0.5]))
torch.polygamma(2, torch.tensor([1, 0.5]))
torch.polygamma(3, torch.tensor([1, 0.5]))
torch.polygamma(4, torch.tensor([1, 0.5]))

# pow
torch.pow(a, 2)
torch.pow(torch.arange(1., 5.), torch.arange(1., 5.))

# rad2deg
torch.rad2deg(torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570,
                                                               -1.570]]))

# real
torch.randn(4, dtype=torch.cfloat).real

# reciprocal
torch.reciprocal(a)

# remainder
torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2)
torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5)

# round
torch.round(a)

# rsqrt
예제 #16
0
파일: math_ops.py 프로젝트: malfet/pytorch
 def pointwise_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
     f = torch.zeros(3)
     g = torch.tensor([-1, 0, 1])
     w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
     return (
         torch.abs(torch.tensor([-1, -2, 3])),
         torch.absolute(torch.tensor([-1, -2, 3])),
         torch.acos(a),
         torch.arccos(a),
         torch.acosh(a.uniform_(1.0, 2.0)),
         torch.add(a, 20),
         torch.add(a, torch.randn(4, 1), alpha=10),
         torch.addcdiv(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.addcmul(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.angle(a),
         torch.asin(a),
         torch.arcsin(a),
         torch.asinh(a),
         torch.arcsinh(a),
         torch.atan(a),
         torch.arctan(a),
         torch.atanh(a.uniform_(-1.0, 1.0)),
         torch.arctanh(a.uniform_(-1.0, 1.0)),
         torch.atan2(a, a),
         torch.bitwise_not(t),
         torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.ceil(a),
         torch.clamp(a, min=-0.5, max=0.5),
         torch.clamp(a, min=0.5),
         torch.clamp(a, max=0.5),
         torch.clip(a, min=-0.5, max=0.5),
         torch.conj(a),
         torch.copysign(a, 1),
         torch.copysign(a, b),
         torch.cos(a),
         torch.cosh(a),
         torch.deg2rad(
             torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0,
                                                              -90.0]])),
         torch.div(a, b),
         torch.divide(a, b, rounding_mode="trunc"),
         torch.divide(a, b, rounding_mode="floor"),
         torch.digamma(torch.tensor([1.0, 0.5])),
         torch.erf(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
         torch.exp(torch.tensor([0.0, math.log(2.0)])),
         torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
         torch.expm1(torch.tensor([0.0, math.log(2.0)])),
         torch.fake_quantize_per_channel_affine(
             torch.randn(2, 2, 2),
             (torch.randn(2) + 1) * 0.05,
             torch.zeros(2),
             1,
             0,
             255,
         ),
         torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
         torch.float_power(torch.randint(10, (4, )), 2),
         torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4,
                                                             -5])),
         torch.floor(a),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
         torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
         torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.frac(torch.tensor([1.0, 2.5, -3.2])),
         torch.randn(4, dtype=torch.cfloat).imag,
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
         torch.lerp(torch.arange(1.0, 5.0),
                    torch.empty(4).fill_(10), 0.5),
         torch.lerp(
             torch.arange(1.0, 5.0),
             torch.empty(4).fill_(10),
             torch.full_like(torch.arange(1.0, 5.0), 0.5),
         ),
         torch.lgamma(torch.arange(0.5, 2, 0.5)),
         torch.log(torch.arange(5) + 10),
         torch.log10(torch.rand(5)),
         torch.log1p(torch.randn(5)),
         torch.log2(torch.rand(5)),
         torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logical_and(r, s),
         torch.logical_and(r.double(), s.double()),
         torch.logical_and(r.double(), s),
         torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
         torch.logical_not(
             torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
         torch.logical_not(
             torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
             out=torch.empty(3, dtype=torch.int16),
         ),
         torch.logical_or(r, s),
         torch.logical_or(r.double(), s.double()),
         torch.logical_or(r.double(), s),
         torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_xor(r, s),
         torch.logical_xor(r.double(), s.double()),
         torch.logical_xor(r.double(), s),
         torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logit(torch.rand(5), eps=1e-6),
         torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
         torch.i0(torch.arange(5, dtype=torch.float32)),
         torch.igamma(a, b),
         torch.igammac(a, b),
         torch.mul(torch.randn(3), 100),
         torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
         torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
         torch.tensor([float("nan"),
                       float("inf"), -float("inf"), 3.14]),
         torch.nan_to_num(w),
         torch.nan_to_num(w, nan=2.0),
         torch.nan_to_num(w, nan=2.0, posinf=1.0),
         torch.neg(torch.randn(5)),
         # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
         torch.polygamma(1, torch.tensor([1.0, 0.5])),
         torch.polygamma(2, torch.tensor([1.0, 0.5])),
         torch.polygamma(3, torch.tensor([1.0, 0.5])),
         torch.polygamma(4, torch.tensor([1.0, 0.5])),
         torch.pow(a, 2),
         torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
         torch.rad2deg(
             torch.tensor([[3.142, -3.142], [6.283, -6.283],
                           [1.570, -1.570]])),
         torch.randn(4, dtype=torch.cfloat).real,
         torch.reciprocal(a),
         torch.remainder(torch.tensor([-3.0, -2.0]), 2),
         torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.round(a),
         torch.rsqrt(a),
         torch.sigmoid(a),
         torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sgn(a),
         torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sin(a),
         torch.sinc(a),
         torch.sinh(a),
         torch.sqrt(a),
         torch.square(a),
         torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
         torch.tan(a),
         torch.tanh(a),
         torch.trunc(a),
         torch.xlogy(f, g),
         torch.xlogy(f, g),
         torch.xlogy(f, 4),
         torch.xlogy(2, g),
     )
def evaluate(colorImgModel, depthImgModel, regressorModel, maxPool, dataLoader,
             calibFileRootDir):

    simpleDistanceSE3 = np.empty(0)
    RMSETranslationVec = np.empty((len(dataLoader), 3))
    RMSEEulerAngleVec = np.empty((len(dataLoader), 3))
    MAETranslationVec = np.empty((len(dataLoader), 3))
    MAEEulerAngleVec = np.empty((len(dataLoader), 3))

    for j, data in tqdm(enumerate(dataLoader, 0), total=len(dataLoader)):

        # Expand the data into its respective components
        srcClrT, srcDepthT, ptCldT, __, targetTransformT, optional = data

        # Transpose the tensors such that the no of channels are the 2nd term
        srcClrT = srcClrT.to('cuda')
        srcDepthT = srcDepthT.to('cuda')

        # Cuda 0
        featureMapClrImg = colorImgModel(srcClrT)
        # Cuda 1
        maxPooledDepthImg = maxPool(srcDepthT)
        featureMapDepthImg = depthImgModel(maxPooledDepthImg)

        # Cuda 0
        aggClrDepthFeatureMap = torch.cat(
            [featureMapDepthImg.to('cuda'), featureMapClrImg], dim=1)

        # Cuda 0
        predTransform = regressorModel(aggClrDepthFeatureMap)

        # Simple SE3 distance
        predRot = quaternion_to_matrix(predTransform[:, :4])
        predT = predTransform[:, 4:]

        targetTransformT = moveToDevice(targetTransformT,
                                        predTransform.get_device())

        gtRot = targetTransformT[:, :3, :3].type(torch.float32)
        gtT = targetTransformT[:, :3, 3].type(torch.float32)

        RtR = torch.matmul(calculateInvRTTensor(predRot, predT),
                           targetTransformT.type(torch.float32))
        #RtR = torch.matmul(calculateInvRTTensor(targetTransformT.type(torch.float32)[:,:3,:3], targetTransformT.type(torch.float32)[:,:3,3]), targetTransformT.type(torch.float32))

        I = moveToDevice(torch.eye(4, dtype=torch.float32),
                         predTransform.get_device())
        simpleDistanceSE3 = np.append(
            simpleDistanceSE3,
            torch.norm(RtR - I, 'fro').to('cpu').numpy())

        # Caluclate the euler angles from rotation matrix
        predEulerAngles = matrix_to_euler_angles(predRot, "ZXY")
        targetEulerAngles = matrix_to_euler_angles(gtRot, "ZXY")
        '''
        #####################################################################################
        Root Mean Squared Error
        #####################################################################################
        '''
        RMSEEulerAngleVec = torch.square(
            torch.rad2deg(predEulerAngles -
                          targetEulerAngles)).to('cpu').numpy()
        RMSETranslationVec[j, :] = torch.square(predT - gtT).to('cpu').numpy()
        '''
        #####################################################################################
        Mean Absolute Error
        #####################################################################################
        '''
        MAEEulerAngleVec[j, :] = torch.abs(
            torch.rad2deg(predEulerAngles -
                          targetEulerAngles)).to('cpu').numpy()
        MAETranslationVec[j, :] = torch.abs(predT - gtT).to('cpu').numpy()

    RMSEEulerAngleVec = np.sqrt(np.mean(RMSEEulerAngleVec, axis=0))
    RMSETranslationVec = np.sqrt(np.mean(RMSETranslationVec, axis=0))

    MAEEulerAngleVec = np.sqrt(np.mean(MAEEulerAngleVec, axis=0))
    MAETranslationVec = np.sqrt(np.mean(MAETranslationVec, axis=0))

    return (np.mean(simpleDistanceSE3), RMSEEulerAngleVec, RMSETranslationVec,
            MAEEulerAngleVec, MAETranslationVec)