Example #1
0
def rotation_tensor(angle):
    angle = torch.deg2rad(-angle)
    cos = torch.tensor([torch.cos(angle)], dtype=torch.float32, device=0)
    sin = torch.tensor([torch.sin(angle)], dtype=torch.float32, device=0)
    return torch.stack([torch.stack([cos, -sin]),
                        torch.stack([sin, cos])]).reshape(2, 2)
Example #2
0
    def forward(self,
                img,
                sol,
                steps,
                reset_threshold=None,
                max_steps=None,
                disturb_sol=True,
                height_disturbance=0.5,
                angle_disturbance=30,
                translate_disturbance=10):

        img.cuda()

        # tensor([tsa, channels, width, height])
        input = torch.zeros(
            (self.tsa_size, 3, self.patch_size, self.patch_size)).cuda()

        steps_ran = 0

        sol = {
            "upper_point": sol[0],
            "base_point": sol[1],
            "angle": sol[3][0],
        }

        if disturb_sol:
            x = random.uniform(0, translate_disturbance)
            y = random.uniform(0, translate_disturbance)
            sol["upper_point"][0] += x
            sol["upper_point"][1] += y
            sol["base_point"][0] += x
            sol["base_point"][1] += y
            sol["angle"] += random.uniform(0, angle_disturbance)

        current_height = torch.dist(sol["upper_point"].clone(),
                                    sol["base_point"].clone()).cuda()
        current_height = current_height * (1 if not disturb_sol else (
            1 + random.uniform(0, height_disturbance)))
        current_angle = sol["angle"].clone().cuda()
        current_base = sol["base_point"].clone().cuda()

        results = []
        tsa_sequence = []

        while (max_steps is None or steps_ran < max_steps):

            current_scale = (self.patch_ratio * current_height /
                             self.patch_size).cuda()

            img_width = img.shape[2]
            img_height = img.shape[3]

            # current_angle = torch.mul(current_angle, -1)

            if img_width < current_base[0].item() or current_base[0].item(
            ) < 0:
                break
            if img_height < current_base[1].item() or current_base[0].item(
            ) < 0:
                break
            patch_parameters = torch.stack([
                current_base[0],  # x
                current_base[1],  # y
                torch.mul(torch.deg2rad(current_angle), -1),  # angle
                current_height
            ]).unsqueeze(0)
            patch_parameters = patch_parameters.cuda()
            try:
                patch = extract_tensor_patch(img,
                                             patch_parameters,
                                             size=self.patch_size)  # size
            except:
                break
            # patch = interpolate(patch, (self.patch_size, self.patch_size, 3))

            # Shift input left
            input = torch.stack([pic
                                 for pic in input[1:]] + [patch.squeeze(0)])

            y = input.cuda().unsqueeze(0)
            y = self.tsa(y)
            after_tsa_copy = y.detach().cpu().clone()
            tsa_sequence.append(after_tsa_copy)
            y = y.squeeze(0)
            y = self.initial_convolutions(y)
            y = y.unsqueeze(0)
            y = self.memory_layer(y)
            y = y.unsqueeze(0)
            y = self.final_convolutions(y)
            y = y.unsqueeze(0)
            y = torch.flatten(y, 1)
            y = torch.flatten(y, 0)
            y = self.fully_connected(y)

            size = input[-1, :, :, :].shape[1] / self.patch_ratio
            upper_prior_x = size
            upper_prior_y = -size
            base_prior_x = size
            base_prior_y = 0
            lower_prior_x = size
            lower_prior_y = 0
            y[0] = torch.add(y[0], upper_prior_x)
            y[1] = torch.add(y[1], upper_prior_y)
            y[2] = torch.add(y[2], base_prior_x)
            y[3] = torch.add(y[3], base_prior_y)
            y[4] = torch.add(y[4], lower_prior_x)
            y[5] = torch.add(y[5], lower_prior_y)

            upper_point = torch.stack([y[0], y[1]])
            base_point = torch.stack([y[2], y[3]])
            lower_point = torch.stack([y[4], y[5]])
            current_angle = torch.add(current_angle, y[6])
            stop_confidence = torch.sigmoid(y[7])

            rotation = torch.tensor(
                [[
                    torch.cos(torch.deg2rad(current_angle)),
                    -1.0 * torch.sin(torch.deg2rad(current_angle))
                ],
                 [
                     1.0 * torch.sin(torch.deg2rad(current_angle)),
                     torch.cos(torch.deg2rad(current_angle))
                 ]]).cuda()
            upper_point = torch.matmul(upper_point, rotation.t())
            base_point = torch.matmul(base_point, rotation.t())
            lower_point = torch.matmul(lower_point, rotation.t())

            scaling = torch.tensor([[current_scale, 0.],
                                    [0., current_scale]]).cuda()

            upper_point = torch.matmul(upper_point, scaling)
            base_point = torch.matmul(base_point, scaling)
            lower_point = torch.matmul(lower_point, scaling)

            upper_point = upper_point + current_base
            base_point = base_point + current_base
            lower_point = lower_point + current_base

            current_base = base_point
            current_height = torch.dist(base_point, upper_point)
            current_height = torch.max(current_height, torch.tensor(16).cuda())
            # current_height = torch.min(current_height, torch.tensor(80).cuda())

            results.append(
                torch.stack([
                    upper_point,
                    current_base.clone(), lower_point,
                    torch.stack(
                        [current_angle.clone(),
                         torch.tensor(0).cuda()]),
                    torch.stack([stop_confidence,
                                 torch.tensor(0).cuda()])
                ],
                            dim=0))

            steps_ran += 1

            if max_steps is None and steps_ran >= len(steps) - 1:
                break
            elif reset_threshold is not None:
                upper_distance = torch.dist(upper_point.clone().detach().cpu(),
                                            steps[steps_ran][0])
                base_distance = torch.dist(base_point.clone().detach().cpu(),
                                           steps[steps_ran][1])
                current_threshold = reset_threshold
                if upper_distance > current_threshold \
                        or base_distance > current_threshold:
                    break

        return None if len(results) == 0 else torch.stack(
            results), steps_ran, tsa_sequence
Example #3
0
 def pointwise_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
     f = torch.zeros(3)
     g = torch.tensor([-1, 0, 1])
     w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
     return len(
         torch.abs(torch.tensor([-1, -2, 3])),
         torch.absolute(torch.tensor([-1, -2, 3])),
         torch.acos(a),
         torch.arccos(a),
         torch.acosh(a.uniform_(1.0, 2.0)),
         torch.add(a, 20),
         torch.add(a, b, out=a),
         b.add(a),
         b.add(a, out=b),
         b.add_(a),
         b.add(1),
         torch.add(a, torch.randn(4, 1), alpha=10),
         torch.addcdiv(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.addcmul(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.angle(a),
         torch.asin(a),
         torch.arcsin(a),
         torch.asinh(a),
         torch.arcsinh(a),
         torch.atan(a),
         torch.arctan(a),
         torch.atanh(a.uniform_(-1.0, 1.0)),
         torch.arctanh(a.uniform_(-1.0, 1.0)),
         torch.atan2(a, a),
         torch.bitwise_not(t),
         torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.ceil(a),
         torch.ceil(float(torch.tensor(0.5))),
         torch.ceil(torch.tensor(0.5).item()),
         torch.clamp(a, min=-0.5, max=0.5),
         torch.clamp(a, min=0.5),
         torch.clamp(a, max=0.5),
         torch.clip(a, min=-0.5, max=0.5),
         torch.conj(a),
         torch.copysign(a, 1),
         torch.copysign(a, b),
         torch.cos(a),
         torch.cosh(a),
         torch.deg2rad(
             torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0,
                                                              -90.0]])),
         torch.div(a, b),
         a.div(b),
         a.div(1),
         a.div_(b),
         torch.divide(a, b, rounding_mode="trunc"),
         torch.divide(a, b, rounding_mode="floor"),
         torch.digamma(torch.tensor([1.0, 0.5])),
         torch.erf(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
         torch.exp(torch.tensor([0.0, math.log(2.0)])),
         torch.exp(float(torch.tensor(1))),
         torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
         torch.expm1(torch.tensor([0.0, math.log(2.0)])),
         torch.fake_quantize_per_channel_affine(
             torch.randn(2, 2, 2),
             (torch.randn(2) + 1) * 0.05,
             torch.zeros(2),
             1,
             0,
             255,
         ),
         torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
         torch.float_power(torch.randint(10, (4, )), 2),
         torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4,
                                                             -5])),
         torch.floor(a),
         torch.floor(float(torch.tensor(1))),
         torch.floor_divide(torch.tensor([4.0, 3.0]),
                            torch.tensor([2.0, 2.0])),
         torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
         torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
         torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.frac(torch.tensor([1.0, 2.5, -3.2])),
         torch.randn(4, dtype=torch.cfloat).imag,
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
         torch.lerp(torch.arange(1.0, 5.0),
                    torch.empty(4).fill_(10), 0.5),
         torch.lerp(
             torch.arange(1.0, 5.0),
             torch.empty(4).fill_(10),
             torch.full_like(torch.arange(1.0, 5.0), 0.5),
         ),
         torch.lgamma(torch.arange(0.5, 2, 0.5)),
         torch.log(torch.arange(5) + 10),
         torch.log10(torch.rand(5)),
         torch.log1p(torch.randn(5)),
         torch.log2(torch.rand(5)),
         torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logical_and(r, s),
         torch.logical_and(r.double(), s.double()),
         torch.logical_and(r.double(), s),
         torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
         torch.logical_not(
             torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
         torch.logical_not(
             torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
             out=torch.empty(3, dtype=torch.int16),
         ),
         torch.logical_or(r, s),
         torch.logical_or(r.double(), s.double()),
         torch.logical_or(r.double(), s),
         torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_xor(r, s),
         torch.logical_xor(r.double(), s.double()),
         torch.logical_xor(r.double(), s),
         torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logit(torch.rand(5), eps=1e-6),
         torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
         torch.i0(torch.arange(5, dtype=torch.float32)),
         torch.igamma(a, b),
         torch.igammac(a, b),
         torch.mul(torch.randn(3), 100),
         b.mul(a),
         b.mul(5),
         b.mul(a, out=b),
         b.mul_(a),
         b.mul_(5),
         torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
         torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
         torch.tensor([float("nan"),
                       float("inf"), -float("inf"), 3.14]),
         torch.nan_to_num(w),
         torch.nan_to_num_(w),
         torch.nan_to_num(w, nan=2.0),
         torch.nan_to_num(w, nan=2.0, posinf=1.0),
         torch.neg(torch.randn(5)),
         # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
         torch.polygamma(1, torch.tensor([1.0, 0.5])),
         torch.polygamma(2, torch.tensor([1.0, 0.5])),
         torch.polygamma(3, torch.tensor([1.0, 0.5])),
         torch.polygamma(4, torch.tensor([1.0, 0.5])),
         torch.pow(a, 2),
         torch.pow(2, float(torch.tensor(0.5))),
         torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
         torch.rad2deg(
             torch.tensor([[3.142, -3.142], [6.283, -6.283],
                           [1.570, -1.570]])),
         torch.randn(4, dtype=torch.cfloat).real,
         torch.reciprocal(a),
         torch.remainder(torch.tensor([-3.0, -2.0]), 2),
         torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.round(a),
         torch.round(torch.tensor(0.5).item()),
         torch.rsqrt(a),
         torch.sigmoid(a),
         torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sgn(a),
         torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sin(a),
         torch.sinc(a),
         torch.sinh(a),
         torch.sqrt(a),
         torch.square(a),
         torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
         b.sub(a),
         b.sub_(a),
         b.sub(5),
         torch.sum(5),
         torch.tan(a),
         torch.tanh(a),
         torch.true_divide(a, a),
         torch.trunc(a),
         torch.trunc_(a),
         torch.xlogy(f, g),
         torch.xlogy(f, g),
         torch.xlogy(f, 4),
         torch.xlogy(2, g),
     )
    def forward(self,
                img,
                sol_tensor,
                steps,
                reset_threshold=None,
                max_steps=None,
                min_run_tsa=False,
                disturb_sol=True,
                height_disturbance=0.5,
                angle_disturbance=30,
                translate_disturbance=10):

        desired_polygon_steps = torch.stack([sol_tensor.cuda()] +
                                            [p.cuda() for p in steps])

        img.cuda()

        # tensor([tsa, channels, width, height])
        input = ((255 / 128) - 1) * torch.ones(
            (1, 3, self.patch_size, self.patch_size)).cuda()

        steps_ran = 0

        sol = {
            "upper_point": sol_tensor[0],
            "base_point": sol_tensor[1],
            "angle": sol_tensor[3][0],
        }

        if disturb_sol:
            x = random.uniform(0, translate_disturbance)
            y = random.uniform(0, translate_disturbance)
            sol["upper_point"][0] += x
            sol["upper_point"][1] += y
            sol["base_point"][0] += x
            sol["base_point"][1] += y
            sol["angle"] += random.uniform(-angle_disturbance,
                                           angle_disturbance)

        current_height = torch.dist(sol["upper_point"].clone(),
                                    sol["base_point"].clone()).cuda()
        current_height = current_height * (1 if not disturb_sol else (
            1 + random.uniform(0, height_disturbance)))

        current_angle = sol["angle"].clone().cuda()
        current_base = sol["base_point"].clone().cuda()

        results = []
        tsa_sequence = []

        while (max_steps is None or steps_ran < max_steps):

            current_scale = (self.patch_ratio * current_height /
                             self.patch_size).cuda()

            img_width = img.shape[2]
            img_height = img.shape[3]

            # current_angle = torch.mul(current_angle, -1)

            if img_width < current_base[0].item() or current_base[0].item(
            ) < 0:
                break
            if img_height < current_base[1].item() or current_base[0].item(
            ) < 0:
                break
            patch_parameters = torch.stack([
                current_base[0],  # x
                current_base[1],  # y
                torch.mul(torch.deg2rad(current_angle), -1),  # angle
                current_height
            ]).unsqueeze(0)
            patch_parameters = patch_parameters.cuda()
            try:
                patch = extract_tensor_patch(img,
                                             patch_parameters,
                                             size=self.patch_size)  # size
            except:
                break

            # Shift input left
            # input = torch.stack([pic for pic in input[1:]] + [patch.squeeze(0)])

            input = torch.stack([pic for pic in input[-self.tsa_size:]] +
                                [patch.squeeze(0)])

            y = input.cuda().unsqueeze(0)
            y = self.tsa(y)
            y = y[:, 1:, :, :, :]
            after_tsa_copy = y.detach().cpu().clone()
            tsa_sequence.append(after_tsa_copy)
            y = y.squeeze(0)
            y = self.initial_convolutions(y)
            y = y.unsqueeze(0)
            y = self.memory_layer(y)
            y = y.unsqueeze(0)
            y = self.final_convolutions(y)
            y = y.unsqueeze(0)
            y = torch.flatten(y, 1)
            y = torch.flatten(y, 0)
            y = self.fully_connected(y)

            # Biases
            # size = input[-1, :, :, :].shape[1] / self.patch_ratio
            # 0 first_angle
            # 1 next_angle
            # 2 upper_height
            # 3 lower_height
            # y[0] = torch.add(y[1], -size)
            # y[1] = torch.add(y[2], base_prior_x)

            # Predicts variations in size, not absolute sizes
            y[2] = torch.add(torch.sigmoid(y[2]), 0.5)
            y[3] = torch.sigmoid(y[3])

            # y[3] = torch.add(y[5], 5)

            scale_matrix = torch.stack([
                torch.stack([current_scale,
                             torch.tensor(0.).cuda()]),
                torch.stack([torch.tensor(0.).cuda(), current_scale])
            ]).cuda()

            # Create a vector to represent the new base

            # Finds the next base point
            angle_to_base = torch.add(current_angle, y[0])
            base_rotation_matrix = torch.stack([
                torch.stack([
                    torch.cos(torch.deg2rad(angle_to_base)),
                    -1.0 * torch.sin(torch.deg2rad(angle_to_base))
                ]),
                torch.stack([
                    1.0 * torch.sin(torch.deg2rad(angle_to_base)),
                    torch.cos(torch.deg2rad(angle_to_base))
                ])
            ]).cuda()

            base_unity = torch.stack([current_height, torch.tensor(0.).cuda()])
            base_point = torch.matmul(base_unity, base_rotation_matrix.t())
            # base_point = torch.matmul(base_point, scale_matrix)
            current_base = torch.add(base_point, current_base)
            current_angle = torch.add(current_angle, y[1])

            points_angle = torch.mean(
                torch.stack([angle_to_base, current_angle]))
            point_rotation_matrix = torch.stack([
                torch.stack([
                    torch.cos(torch.deg2rad(points_angle)),
                    -1.0 * torch.sin(torch.deg2rad(points_angle))
                ]),
                torch.stack([
                    1.0 * torch.sin(torch.deg2rad(points_angle)),
                    torch.cos(torch.deg2rad(points_angle))
                ])
            ]).cuda()

            current_height = torch.mul(current_height, y[2])
            upper_unity = torch.stack(
                [torch.tensor(0.).cuda(), -current_height])
            upper_point = torch.matmul(upper_unity, point_rotation_matrix.t())
            # upper_point = torch.matmul(upper_point, scale_matrix)
            upper_point = torch.add(upper_point, current_base)

            current_lower_height = torch.mul(current_height, y[3])
            lower_unity = torch.stack(
                [torch.tensor(0.).cuda(), current_lower_height])
            lower_point = torch.matmul(lower_unity, point_rotation_matrix.t())
            # lower_point = torch.matmul(lower_point, scale_matrix)
            lower_point = torch.add(lower_point, current_base)

            current_height = torch.max(
                torch.stack([
                    torch.dist(upper_point, current_base),
                    torch.tensor(8.).cuda()
                ]))
            # Rotate outlines
            # upper_point = torch.matmul(upper_point, scale_matrix)
            # lower_point = torch.matmul(lower_point, scale_matrix)
            # upper_point = torch.matmul(upper_point, rotation_matrix.t())
            # lower_point = torch.matmul(lower_point, rotation_matrix.t())
            # upper_point = torch.add(upper_point, current_base)
            # lower_point = torch.add(lower_point, current_base)

            # stop_confidence = torch.sigmoid(y[5])

            look_ahead_ratio = 3
            look_ahead_base = current_base.clone().detach()
            look_ahead_angle = current_angle.clone().detach()
            look_ahead_height = current_height.clone().detach()
            extraction_params = []
            for i in range(look_ahead_ratio):
                look_ahead_params = torch.stack([
                    look_ahead_base[0],  # x
                    look_ahead_base[1],  # y
                    torch.mul(torch.deg2rad(look_ahead_angle), -1),  # angle
                    current_height
                ]).unsqueeze(0)
                extraction_params.append(look_ahead_params.cuda())
                base_point = Point(look_ahead_base[0].item(),
                                   look_ahead_base[1].item())
                next_point = get_new_point(base_point, look_ahead_angle.item(),
                                           look_ahead_height.item())
                look_ahead_base = torch.tensor([next_point.x,
                                                next_point.y]).cuda()
            patches = [
                extract_tensor_patch(img, p, size=self.patch_size)
                for p in extraction_params
            ]
            concatenated_patches = torch.cat(patches, dim=2)

            results.append(
                torch.stack(
                    [
                        upper_point,  # .clone(),
                        current_base,  # .clone(),
                        lower_point,  # .clone(),
                        torch.stack(
                            [current_angle.clone(),
                             torch.tensor(0).cuda()]),
                        # torch.stack([stop_confidence.clone(), torch.tensor(0).cuda()])
                    ],
                    dim=0))

            # Decide whether to stop based on last step DICE AFFINITY
            steps_ran += 1

            # Minimum steps to fill TSA
            # if min_run_tsa and steps_ran < self.tsa_size:
            #    continue

            if reset_threshold is not None:
                upper_as_point = Point(upper_point[0].item(),
                                       upper_point[1].item())
                base_as_point = Point(current_base[0].item(),
                                      current_base[1].item())
                lower_as_point = Point(lower_point[0].item(),
                                       lower_point[1].item())
                gt_step = steps[steps_ran]
                gt_upper_point = Point(gt_step[0][0].item(),
                                       gt_step[0][1].item())
                gt_base_point = Point(gt_step[1][0].item(),
                                      gt_step[1][1].item())
                gt_lower_point = Point(gt_step[2][0].item(),
                                       gt_step[2][1].item())

                if base_as_point.distance(
                        gt_base_point
                ) > reset_threshold or upper_as_point.distance(
                        gt_upper_point) > reset_threshold:
                    break

            if max_steps is None and steps_ran >= len(steps) - 1:
                break

        if len(results) == 0:
            return torch.zeros((0, 5, 2)).cuda(), 0, []
        else:
            return torch.stack(results), steps_ran, tsa_sequence
def get_object_rotation_matrix(rot):
    #######
    # Construct the Left-hand coordinate system rotation matrix. Ref: ???
    #######

    r_y = torch.Tensor([[torch.cos(torch.deg2rad(rot[1])), 0, torch.sin(torch.deg2rad(rot[1]))],
                        [0, 1, 0],
                        [-torch.sin(torch.deg2rad(rot[1])), 0, torch.cos(torch.deg2rad(rot[1]))]])
    r_x = torch.Tensor([[1, 0, 0],
                        [0, torch.cos(torch.deg2rad(rot[0])), -torch.sin(torch.deg2rad(rot[0]))],
                        [0, torch.sin(torch.deg2rad(rot[0])), torch.cos(torch.deg2rad(rot[0]))]])
    r_z = torch.Tensor([[torch.cos(torch.deg2rad(rot[2])), -torch.sin(torch.deg2rad(rot[2])), 0],
                        [torch.sin(torch.deg2rad(rot[2])), torch.cos(torch.deg2rad(rot[2])), 0],
                        [0, 0, 1]])
    r = torch.mm(torch.mm(r_y, r_x), r_z)
    return r.to(rot.device)
Example #6
0
            weighted_grads_R = [None] * (tuning_length + 1)
            for i in range(tuning_length + 1):
                weighted_grads_R[i] = np.power(bias, i) * R_grads[i]
            grad_R = torch.mean(torch.stack(weighted_grads_R), dim=0)

            print(f'grad_R: {grad_R}')

            # Update parameters in time step t-H with saved gradients
            upd_R = perspective_taker.update_rotation_angles_(
                Rs[0], grad_R, at_learning_rate)
            # for i in range(tuning_length+1):
            #     B_upd[i] = binder.update_binding_matrix_(Bs[i], grad_B, at_learning_rate)
            print(f'updated angles: {upd_R}')

            # Compare binding matrix to ideal matrix
            ang_loss = 2 - (torch.cos(torch.deg2rad(torch.stack(upd_R))) + 1)
            # ang_loss = at_loss_function(ideal_angle, torch.stack(upd_R))
            # mat_loss = at_loss_function(ideal_binding, B_upd[0])
            print(
                f'loss of rotation angles: \n  {ang_loss}, \n  with norm {torch.norm(ang_loss)}'
            )
            ra_losses.append(torch.norm(ang_loss))

            rotmat = perspective_taker.compute_rotation_matrix_(
                upd_R[0], upd_R[1], upd_R[2])
            mat_loss = mse(ideal_rotation, rotmat[0])
            print(f'loss of rotation matrix: {mat_loss}')
            rm_losses.append(mat_loss)

            # print(Rs[0][0].grad)
            # print(Rs[0][1].grad)
    def run_inference(self, 
        observations, 
        grad_calculations, 
        do_binding, 
        do_rotation,
        do_translation,
        order, 
        reorder):
        [grad_calc_binding, grad_calc_rotation, grad_calc_translation] = grad_calculations

        if reorder is not None:
            reorder = reorder.to(self.device)
        
        at_final_predictions = torch.tensor([]).to(self.device)
        at_final_inputs = torch.tensor([]).to(self.device)

        ###########################  BINDING  #################################
        if do_binding:
            ## Binding matrices 
            # Init binding entries 
            bm = self.binder.init_binding_matrix_det_()
            # bm = binder.init_binding_matrix_rand_()
            # print(bm)
            dummie_line = torch.ones(1,self.num_observations).to(self.device) * self.dummie_init

            for i in range(self.tuning_length+1):
                matrix = bm.clone().to(self.device)
                if self.nxm:
                    matrix = torch.cat([matrix, dummie_line])
                matrix.requires_grad_()
                self.Bs.append(matrix)
            
        ###########################  ROTATION  ################################
        if do_rotation:
            if self.rotation_type == 'qrotate': 
                ## Rotation quaternion 
                rq = self.perspective_taker.init_quaternion()
                # print(rq)

                for i in range(self.tuning_length+1):
                    quat = rq.clone().to(self.device)
                    quat.requires_grad_()
                    self.Rs.append(quat)

            elif self.rotation_type == 'eulrotate':
                ## Rotation euler angles 
                # ra = perspective_taker.init_angles_()
                # ra = torch.Tensor([[309.89], [82.234], [95.765]])
                ra = torch.Tensor([[75.0], [6.0], [128.0]])
                # print(ra)

                for i in range(self.tuning_length+1):
                    angles = []
                    for j in range(self.num_spatial_dimensions):
                        angle = ra[j].clone().to(self.device)
                        angle.requires_grad_()
                        angles.append(angle)
                    self.Rs.append(angles)

            else: 
                print('ERROR: Received unknown rotation type!')
                exit()

        ###########################  TRANSLATION  #############################
        if do_translation:
            tb = self.perspective_taker.init_translation_bias_()
            # print(tb)

            for i in range(self.tuning_length+1):
                transba = tb.clone().to(self.device)
                transba.requires_grad = True
                self.Cs.append(transba)

        #######################################################################

        ## Core state
        # define scaler
        state_scaler = 0.95

        # init state
        at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device)
        at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device)

        at_h.requires_grad = True
        at_c.requires_grad = True

        init_state = (at_h, at_c)
        state = (init_state[0], init_state[1])

        ############################################################################
        ##########  FORWARD PASS  ##################################################

        for i in range(self.tuning_length):
            o = observations[self.obs_count].to(self.device)
            self.at_inputs = torch.cat((
                self.at_inputs, 
                o.reshape(1, self.num_observations, self.num_input_dimensions)), 0)
            self.obs_count += 1

            ###########################  BINDING  #################################
            if do_binding:
                bm = self.binder.scale_binding_matrix(
                    self.Bs[i], 
                    self.scale_mode, 
                    self.scale_combo, 
                    self.nxm_enhance, 
                    self.nxm_last_line_scale)
                if self.nxm:
                    bm = bm[:-1]
                x_B = self.binder.bind(o, bm)
            else: 
                x_B = o

            if self.gestalten:
                mag = x_B[:, -1].view(self.num_observations, 1)
                x_B = x_B[:, :-1]
                x_B = torch.cat([
                    x_B[:, :self.num_spatial_dimensions], 
                    x_B[:, self.num_spatial_dimensions:]])
            ###########################  ROTATION  ################################
            if do_rotation:
                if self.rotation_type == 'qrotate': 
                    x_R = self.perspective_taker.qrotate(x_B, self.Rs[i])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                    x_R = self.perspective_taker.rotate(x_B, rotmat)
            else: 
                x_R = x_B

            if self.gestalten:
                dir = x_R[-self.num_observations:, :]
                x_R = x_R[:-self.num_observations, :]
            ###########################  TRANSLATION  #############################
            if do_translation:
                x_C = self.perspective_taker.translate(x_R, self.Cs[i])
            else:
                x_C = x_R

            if self.gestalten: 
                x_C = torch.cat([x_C, dir, mag], dim=1)
            #######################################################################
            
            x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)  
            self.at_states.append(state)
            self.at_predictions = torch.cat((self.at_predictions, new_prediction.reshape(1,self.input_per_frame)), 0)
        
        ############################################################################
        ##########  ACTIVE TUNING ##################################################

        while self.obs_count < self.num_frames:
            # TODO folgendes evtl in function auslagern
            o = observations[self.obs_count].to(self.device)
            self.obs_count += 1

            ###########################  BINDING  #################################
            if do_binding:
                bm = self.binder.scale_binding_matrix(
                    self.Bs[-1], 
                    self.scale_mode, 
                    self.scale_combo,
                    self.nxm_enhance, 
                    self.nxm_last_line_scale)
                if self.nxm:
                    bm = bm[:-1]
                x_B = self.binder.bind(o, bm)    
            else: 
                x_B = o   

            if self.gestalten:
                mag = x_B[:, -1].view(self.num_observations, 1)
                x_B = x_B[:, :-1]
                x_B = torch.cat([
                    x_B[:, :self.num_spatial_dimensions], 
                    x_B[:, self.num_spatial_dimensions:]])
            ###########################  ROTATION  ################################
            if do_rotation:
                if self.rotation_type == 'qrotate': 
                    x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                    x_R = self.perspective_taker.rotate(x_B, rotmat)
            else: 
                x_R = x_B
    
            if self.gestalten:
                dir = x_R[-self.num_observations:, :]
                x_R = x_R[:-self.num_observations, :]
            ###########################  TRANSLATION  #############################
            if do_translation:
                x_C = self.perspective_taker.translate(x_R, self.Cs[-1])
            else: 
                x_C = x_R

            if self.gestalten: 
                x_C = torch.cat([x_C, dir, mag], dim=1)
            #######################################################################

            x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            ## Generate current prediction 
            with torch.no_grad():
                state = self.at_states[-1]
                state = (state[0] * state_scaler, state[1] * state_scaler)
                new_prediction, state = self.core_model(x, state)

            ## For #tuning_cycles 
            for cycle in range(self.tuning_cycles):
                print('----------------------------------------------')

                # Get prediction
                p = self.at_predictions[-1]

                # Calculate error 
                loss = self.at_loss(p,x[0])

                # Propagate error back through tuning horizon 
                loss.backward(retain_graph = True)

                self.at_losses.append(loss.clone().detach().cpu().numpy())
                print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}')

                # Update parameters 
                with torch.no_grad():
                    ###########################  BINDING  #################################
                    if do_binding:
                        # Calculate gradients with respect to the entires 
                        for i in range(self.tuning_length+1):
                            self.B_grads[i] = self.Bs[i].grad
                        # print(B_grads[tuning_length])                    
                        
                        # Calculate overall gradients 
                        if grad_calc_binding == 'lastOfTunHor':
                            ### version 1
                            grad_B = self.B_grads[0]
                        elif grad_calc_binding == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_B = torch.mean(torch.stack(self.B_grads), dim=0)
                        elif grad_calc_binding == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_B = [None] * (self.tuning_length+1)
                            for i in range(self.tuning_length+1):
                                weighted_grads_B[i] = np.power(self.grad_bias_binding, i) * self.B_grads[i]
                            grad_B = torch.mean(torch.stack(weighted_grads_B), dim=0)
                        # print(f'grad_B: {grad_B}')
                        
                        # Update parameters in time step t-H with saved gradients 
                        grad_B = grad_B.to(self.device)
                        upd_B = self.binder.update_binding_matrix_(
                            self.Bs[0], grad_B, self.at_learning_rate_binding, self.bm_momentum)

                        upd_B = self.binder.clampBM(upd_B)

                        # Compare binding matrix to ideal matrix
                        # NOTE: ideal matrix is always identity, bc then the FBE and determinant can be calculated => provide reorder
                        c_bm = self.binder.scale_binding_matrix(upd_B, self.scale_mode, self.scale_combo)
                        if order is not None: 
                            c_bm = c_bm.gather(1, reorder.unsqueeze(0).expand(c_bm.shape))

                        if self.nxm:
                            self.oc_grads.append(grad_B[-1])
                            FBE = self.evaluator.FBE_nxm_additional_features(
                                c_bm, self.ideal_binding, self.additional_features)
                            c_bm = self.evaluator.clear_nxm_binding_matrix(c_bm, self.additional_features)
                        
                        mat_loss = self.evaluator.FBE(c_bm, self.ideal_binding)

                        if self.nxm:
                            mat_loss = torch.stack([mat_loss, FBE, mat_loss+FBE])
                        self.bm_losses.append(mat_loss)
                        print(f'loss of binding matrix (FBE): {mat_loss}')

                        # Compute determinante of binding matrix
                        det = torch.det(c_bm)
                        self.bm_dets.append(det)
                        print(f'determinante of binding matrix: {det}')


                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length+1):
                            self.Bs[i].requires_grad = False
                            self.Bs[i].grad.data.zero_()

                        # Update all parameters for all time steps 
                        for i in range(self.tuning_length+1):
                            self.Bs[i].data = upd_B.clone().data
                            self.Bs[i].requires_grad = True

                    ###########################  ROTATION  ################################
                    if do_rotation:
                        ## get gradients
                        if self.rotation_type == 'qrotate': 
                            for i in range(self.tuning_length+1):
                                # save grads for all parameters in all time steps of tuning horizon
                                self.R_grads[i] = self.Rs[i].grad
                        else: 
                            for i in range(self.tuning_length+1):
                                # save grads for all parameters in all time steps of tuning horizon
                                grad = []
                                for j in range(self.num_input_dimensions):
                                    grad.append(self.Rs[i][j].grad) 
                                self.R_grads[i] = torch.stack(grad)
                        # print(self.R_grads[self.tuning_length])
                        
                        # Calculate overall gradients 
                        if grad_calc_rotation == 'lastOfTunHor':
                            ### version 1
                            grad_R = self.R_grads[0]
                        elif grad_calc_rotation == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_R = torch.mean(torch.stack(self.R_grads), dim=0)
                        elif grad_calc_rotation == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_R = [None] * (self.tuning_length+1)
                            for i in range(self.tuning_length+1):
                                weighted_grads_R[i] = np.power(self.grad_bias_rotation, i) * self.R_grads[i]
                            grad_R = torch.mean(torch.stack(weighted_grads_R), dim=0)
                        # print(f'grad_R: {grad_R}')

                        grad_R = grad_R.to(self.device)
                        if self.rotation_type == 'qrotate': 
                            # Update parameters in time step t-H with saved gradients 
                            upd_R = self.perspective_taker.update_quaternion(
                                self.Rs[0], grad_R, self.at_learning_rate_rotation, self.r_momentum)
                            print(f'updated quaternion: {upd_R}')

                            # Compare quaternion values
                            quat_loss = torch.sum(self.perspective_taker.qmul(self.ideal_quat, upd_R))
                            print(f'loss of quaternion: {quat_loss}')
                            self.rv_losses.append(quat_loss)
                            # Compute rotation matrix
                            rotmat = self.perspective_taker.quaternion2rotmat(upd_R)

                            # Zero out gradients for all parameters in all time steps of tuning horizon
                            for i in range(self.tuning_length+1):
                                self.Rs[i].requires_grad = False
                                self.Rs[i].grad.data.zero_()

                            # Update all parameters for all time steps 
                            for i in range(self.tuning_length+1):
                                quat = upd_R.clone()
                                quat.requires_grad_()
                                self.Rs[i] = quat

                        else: 
                            # Update parameters in time step t-H with saved gradients 
                            upd_R = self.perspective_taker.update_rotation_angles_(
                                self.Rs[0], grad_R, self.at_learning_rate_rotation)
                            print(f'updated angles: {upd_R}')

                            # Save rotation angles
                            rotang = torch.stack(upd_R)
                            # angles:
                            ang_diff = rotang - self.ideal_angle
                            ang_loss = 2 - (torch.cos(torch.deg2rad(ang_diff)) + 1)
                            print(f'loss of rotation angles: \n  {ang_loss}, \n  with norm {torch.norm(ang_loss)}')
                            self.rv_losses.append(torch.norm(ang_loss))
                            # Compute rotation matrix
                            rotmat = self.perspective_taker.compute_rotation_matrix_(
                                upd_R[0], upd_R[1], upd_R[2])[0]
                            
                            # Zero out gradients for all parameters in all time steps of tuning horizon
                            for i in range(self.tuning_length+1):
                                for j in range(self.num_input_dimensions):
                                    self.Rs[i][j].requires_grad = False
                                    self.Rs[i][j].grad.data.zero_()

                            # Update all parameters for all time steps 
                            for i in range(self.tuning_length+1):
                                angles = []
                                for j in range(3):
                                    angle = upd_R[j].clone()
                                    angle.requires_grad_()
                                    angles.append(angle)
                                self.Rs[i] = angles

                        # Calculate and save rotation losses
                        # matrix: 
                        mat_loss = self.mse(
                            (torch.mm(self.ideal_rotation, torch.transpose(rotmat, 0, 1))), 
                            self.identity_matrix
                        )
                        print(f'loss of rotation matrix: {mat_loss}')
                        self.rm_losses.append(mat_loss)
                    
                    ###########################  TRANSLATION  #############################
                    if do_translation:
                        ## Get gradients 
                        for i in range(self.tuning_length+1):
                            # save grads for all parameters in all time steps of tuning horizon 
                            self.C_grads[i] = self.Cs[i].grad

                        # print(self.C_grads[self.tuning_length])
                        
                        # Calculate overall gradients 
                        if grad_calc_translation == 'lastOfTunHor':
                            ### version 1
                            grad_C = self.C_grads[0]
                        elif grad_calc_translation == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_C = torch.mean(torch.stack(self.C_grads), dim=0)
                        elif grad_calc_translation == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_C = [None] * (self.tuning_length+1)
                            for i in range(self.tuning_length+1):
                                weighted_grads_C[i] = np.power(self.grad_bias_translation, i) * self.C_grads[i]
                            grad_C = torch.mean(torch.stack(weighted_grads_C), dim=0)
                        
                        # Update parameters in time step t-H with saved gradients 
                        grad_C = grad_C.to(self.device)
                        upd_C = self.perspective_taker.update_translation_bias_(
                            self.Cs[0], grad_C, self.at_learning_rate_translation, self.c_momentum)
                        
                        # Compare translation bias to ideal bias
                        trans_loss = self.mse(self.ideal_translation, upd_C)
                        self.c_losses.append(trans_loss)
                        print(f'loss of translation bias (MSE): {trans_loss}')
                        
                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length+1):
                            self.Cs[i].requires_grad = False
                            self.Cs[i].grad.data.zero_()
                        
                        # Update all parameters for all time steps 
                        for i in range(self.tuning_length+1):
                            translation = upd_C.clone()
                            translation.requires_grad_()
                            self.Cs[i] = translation

                    #######################################################################

                    # Initial state
                    g_h = at_h.grad.to(self.device)
                    g_c = at_c.grad.to(self.device)

                    upd_h = init_state[0] - self.at_learning_rate_state * g_h
                    upd_c = init_state[1] - self.at_learning_rate_state * g_c

                    at_h.data = upd_h.clone().detach().requires_grad_()
                    at_c.data = upd_c.clone().detach().requires_grad_()

                    at_h.grad.data.zero_()
                    at_c.grad.data.zero_()

                    # print(f'updated init_state: {init_state}')
                

                # forward pass from t-H to t with new parameters 
                init_state = (at_h, at_c)
                state = (init_state[0], init_state[1])
                self.at_predictions = torch.tensor([]).to(self.device)
                for i in range(self.tuning_length):

                    ###########################  BINDING  #################################
                    if do_binding:
                        bm = self.binder.scale_binding_matrix(
                            self.Bs[i], 
                            self.scale_mode, 
                            self.scale_combo, 
                            self.nxm_enhance, 
                            self.nxm_last_line_scale)
                        if self.nxm:
                            bm = bm[:-1]
                        x_B = self.binder.bind(self.at_inputs[i], bm)
                    else:
                        x_B = self.at_inputs[i]

                    if self.gestalten:
                        mag = x_B[:, -1].view(self.num_observations, 1)
                        x_B = x_B[:, :-1]
                        x_B = torch.cat([
                            x_B[:, :self.num_spatial_dimensions], 
                            x_B[:, self.num_spatial_dimensions:]])
                    ###########################  ROTATION  ################################
                    if do_rotation:
                        if self.rotation_type == 'qrotate': 
                            x_R = self.perspective_taker.qrotate(x_B, self.Rs[i])
                        else:
                            rotmat = self.perspective_taker.compute_rotation_matrix_(
                                self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                            x_R = self.perspective_taker.rotate(x_B, rotmat)
                    else: 
                        x_R = x_B
                
                    if self.gestalten:
                        dir = x_R[-self.num_observations:, :]
                        x_R = x_R[:-self.num_observations, :]
                    ###########################  TRANSLATION  #############################
                    if do_translation:
                        x_C = self.perspective_taker.translate(x_R, self.Cs[i])
                    else:
                        x_C = x_R

                    if self.gestalten: 
                        x_C = torch.cat([x_C, dir, mag], dim=1)
                    #######################################################################

                    x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

                    state = (state[0] * state_scaler, state[1] * state_scaler)
                    upd_prediction, state = self.core_model(x, state)
                    self.at_predictions = torch.cat((self.at_predictions, upd_prediction.reshape(1,self.input_per_frame)), 0)
                    
                    # for last tuning cycle update initial state to track gradients 
                    if cycle==(self.tuning_cycles-1) and i==0: 
                        with torch.no_grad():
                            final_prediction = self.at_predictions[0].clone().detach().to(self.device)
                            final_input = x.clone().detach().to(self.device)
                        
                        at_h = state[0].clone().detach().requires_grad_().to(self.device)
                        at_c = state[1].clone().detach().requires_grad_().to(self.device)
                        
                        init_state = (at_h, at_c)
                        state = (init_state[0], init_state[1])

                    self.at_states[i] = state 

                # Update current input
                ###########################  BINDING  #################################
                if do_binding:
                    bm = self.binder.scale_binding_matrix(
                        self.Bs[-1], 
                        self.scale_mode, 
                        self.scale_combo, 
                        self.nxm_enhance, 
                        self.nxm_last_line_scale)
                    if self.nxm:
                        bm = bm[:-1]
                    x_B = self.binder.bind(o, bm)
                else:
                    x_B = o

                if self.gestalten:
                    mag = x_B[:, -1].view(self.num_observations, 1)
                    x_B = x_B[:, :-1]
                    x_B = torch.cat([
                        x_B[:, :self.num_spatial_dimensions], 
                        x_B[:, self.num_spatial_dimensions:]])
                ###########################  ROTATION  ################################
                if do_rotation: 
                    if self.rotation_type == 'qrotate': 
                        x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1])
                    else:
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                        x_R = self.perspective_taker.rotate(x_B, rotmat)
                else: 
                    x_R = x_B

                if self.gestalten:
                    dir = x_R[-self.num_observations:, :]
                    x_R = x_R[:-self.num_observations, :]
                ###########################  TRANSLATION  #############################
                if do_translation: 
                    x_C = self.perspective_taker.translate(x_R, self.Cs[-1])
                else: 
                    x_C = x_R

                if self.gestalten: 
                    x_C = torch.cat([x_C, dir, mag], dim=1)
                #######################################################################

                x = self.preprocessor.convert_data_AT_to_LSTM(x_C)


            # END tuning cycle        

            ## Generate updated prediction 
            state = self.at_states[-1]
            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)

            ## Reorganize storage variables            
            # observations
            at_final_inputs = torch.cat(
                (at_final_inputs, 
                final_input.reshape(1,self.input_per_frame)), 0)
            self.at_inputs = torch.cat(
                (self.at_inputs[1:], 
                o.reshape(1, self.num_observations, self.num_input_dimensions)), 0)
            
            # predictions
            at_final_predictions = torch.cat(
                (at_final_predictions, 
                final_prediction.reshape(1,self.input_per_frame)), 0)
            self.at_predictions = torch.cat(
                (self.at_predictions[1:], 
                new_prediction.reshape(1,self.input_per_frame)), 0)

        # END active tuning
        
        # store rest of predictions in at_final_predictions
        for i in range(self.tuning_length): 
            at_final_predictions = torch.cat(
                (at_final_predictions, 
                self.at_predictions[i].reshape(1,self.input_per_frame)), 0)
            at_final_inputs = torch.cat(
                (at_final_inputs, 
                self.at_inputs[i].reshape(1,self.input_per_frame)), 0)

        ###########################  BINDING  #################################
        # get final binding matrix
        if do_binding:
            final_binding_matrix = self.binder.scale_binding_matrix(
                self.Bs[-1].clone().detach(), self.scale_mode, self.scale_combo)
            print(f'final binding matrix: {final_binding_matrix}')
            final_binding_entries = self.Bs[-1].clone().detach()
            print(f'final binding entires: {final_binding_entries}')

        else: 
            final_binding_entries, final_binding_matrix = None, None

        ###########################  ROTATION  ################################
        # get final rotation matrix
        if do_rotation:
            if self.rotation_type == 'qrotate': 
                final_rotation_values = self.Rs[0].clone().detach()
                # get final quaternion
                print(f'final quaternion: {final_rotation_values}')
                final_rotation_matrix = self.perspective_taker.quaternion2rotmat(final_rotation_values)
            else:
                final_rotation_values = [
                    self.Rs[0][i].clone().detach() 
                    for i in range(self.num_input_dimensions)]
                print(f'final euler angles: {final_rotation_values}')
                final_rotation_matrix = self.perspective_taker.compute_rotation_matrix_(
                    final_rotation_values[0], 
                    final_rotation_values[1], 
                    final_rotation_values[2])
            
            print(f'final rotation matrix: \n{final_rotation_matrix}')

        else: 
            final_rotation_matrix, final_rotation_values = None, None

        ###########################  TRANSLATION  #############################
        # get final translation bias
        if do_translation:
            final_translation_values = self.Cs[0].clone().detach()
            print(f'final translation bias: {final_translation_values}')

        else: 
            final_translation_values = None

        #######################################################################


        return [at_final_inputs,
                at_final_predictions, 
                final_binding_matrix, 
                final_binding_entries,
                final_rotation_values, 
                final_rotation_matrix, 
                final_translation_values]
Example #8
0
# conj
torch.conj(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))

# copysign
torch.copysign(a, 1)
torch.copysign(a, b)

# cos
torch.cos(a)

# cosh
torch.cosh(a)

# deg2rad
torch.deg2rad(torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]]))

# div/divide/true_divide
x = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
torch.div(x, 0.5)
p = torch.tensor([[-0.3711, -1.9353, -0.4605, -0.2917],
                  [0.1815, -1.0111, 0.9805, -1.5923],
                  [0.1062, 1.4581, 0.7759, -1.2344],
                  [-0.1830, -0.0313, 1.1908, -1.4757]])
q = torch.tensor([0.8032, 0.2930, -0.8113, -0.2308])
torch.div(p, q)
torch.divide(p, q, rounding_mode='trunc')
torch.divide(p, q, rounding_mode='floor')

# digamma
torch.digamma(torch.tensor([1, 0.5]))
    def run_inference(self, observations, grad_calculation):

        at_final_predictions = torch.tensor([]).to(self.device)
        at_final_inputs = torch.tensor([]).to(self.device)

        if self.rotation_type == 'qrotate':
            ## Rotation quaternion
            rq = self.perspective_taker.init_quaternion()
            # print(rq)

            for i in range(self.tuning_length + 1):
                quat = rq.clone().to(self.device)
                quat.requires_grad_()
                self.Rs.append(quat)

        elif self.rotation_type == 'eulrotate':
            ## Rotation euler angles
            # ra = perspective_taker.init_angles_()
            # ra = torch.Tensor([[309.89], [82.234], [95.765]])
            ra = torch.Tensor([[75.0], [6.0], [128.0]])
            print(ra)

            for i in range(self.tuning_length + 1):
                angles = []
                for j in range(self.num_input_dimensions):
                    angle = ra[j].clone()
                    angle.requires_grad_()
                    angles.append(angle)
                self.Rs.append(angles)

        else:
            print('ERROR: Received unknown rotation type!')
            exit()

        ## Core state
        # define scaler
        state_scaler = 0.95

        # init state
        at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device)
        at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device)

        at_h.requires_grad = True
        at_c.requires_grad = True

        init_state = (at_h, at_c)
        state = (init_state[0], init_state[1])

        ############################################################################
        ##########  FORWARD PASS  ##################################################

        for i in range(self.tuning_length):
            o = observations[self.obs_count].to(self.device)
            self.at_inputs = torch.cat((self.at_inputs,
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)
            self.obs_count += 1

            if self.rotation_type == 'qrotate':
                x_R = self.perspective_taker.qrotate(o, self.Rs[i])
            else:
                rotmat = self.perspective_taker.compute_rotation_matrix_(
                    self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                x_R = self.perspective_taker.rotate(o, rotmat)

            x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)
            self.at_states.append(state)
            self.at_predictions = torch.cat(
                (self.at_predictions,
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        ############################################################################
        ##########  ACTIVE TUNING ##################################################

        while self.obs_count < self.num_frames:
            # TODO folgendes evtl in function auslagern
            o = observations[self.obs_count].to(self.device)
            self.obs_count += 1

            if self.rotation_type == 'qrotate':
                x_R = self.perspective_taker.qrotate(o, self.Rs[-1])

            else:
                rotmat = self.perspective_taker.compute_rotation_matrix_(
                    self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                x_R = self.perspective_taker.rotate(o, rotmat)

            x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            ## Generate current prediction
            with torch.no_grad():
                state = self.at_states[-1]
                state = (state[0] * state_scaler, state[1] * state_scaler)
                new_prediction, state = self.core_model(x, state)

            ## For #tuning_cycles
            for cycle in range(self.tuning_cycles):
                print('----------------------------------------------')

                # Get prediction
                p = self.at_predictions[-1]

                # Calculate error
                loss = self.at_loss(p, x[0])

                # Propagate error back through tuning horizon
                loss.backward(retain_graph=True)

                self.at_losses.append(loss.clone().detach().cpu().numpy())
                print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}')

                # Update parameters
                with torch.no_grad():

                    ## get gradients
                    if self.rotation_type == 'qrotate':
                        for i in range(self.tuning_length + 1):
                            # save grads for all parameters in all time steps of tuning horizon
                            self.R_grads[i] = self.Rs[i].grad
                    else:
                        for i in range(self.tuning_length + 1):
                            # save grads for all parameters in all time steps of tuning horizon
                            grad = []
                            for j in range(self.num_input_dimensions):
                                grad.append(self.Rs[i][j].grad)
                            self.R_grads[i] = torch.stack(grad)

                    # print(self.R_grads[self.tuning_length])

                    # Calculate overall gradients
                    if grad_calculation == 'lastOfTunHor':
                        ### version 1
                        grad_R = self.R_grads[0]
                    elif grad_calculation == 'meanOfTunHor':
                        ### version 2 / 3
                        grad_R = torch.mean(torch.stack(self.R_grads), dim=0)
                    elif grad_calculation == 'weightedInTunHor':
                        ### version 4
                        weighted_grads_R = [None] * (self.tuning_length + 1)
                        for i in range(self.tuning_length + 1):
                            weighted_grads_R[i] = np.power(self.grad_bias,
                                                           i) * self.R_grads[i]
                        grad_R = torch.mean(torch.stack(weighted_grads_R),
                                            dim=0)

                    # print(f'grad_R: {grad_R}')

                    grad_R = grad_R.to(self.device)
                    if self.rotation_type == 'qrotate':
                        # Update parameters in time step t-H with saved gradients
                        upd_R = self.perspective_taker.update_quaternion(
                            self.Rs[0], grad_R, self.at_learning_rate,
                            self.r_momentum)
                        print(f'updated quaternion: {upd_R}')

                        # Compare quaternion values
                        # quat_loss = torch.sum(self.perspective_taker.qmul(self.ideal_quat, upd_R))
                        quat_loss = 2 * torch.arccos(
                            torch.abs(
                                torch.sum(torch.mul(self.ideal_quat, upd_R))))
                        quat_loss = torch.rad2deg(quat_loss)
                        print(f'loss of quaternion: {quat_loss}')
                        self.rv_losses.append(quat_loss)

                        # Compare quaternion angles
                        ang = torch.rad2deg(
                            self.perspective_taker.qeuler(upd_R, 'zyx'))
                        ang_diff = ang - self.ideal_angle
                        ang_loss = 2 - (torch.cos(torch.deg2rad(ang_diff)) + 1)
                        print(
                            f'loss of quaternion angles: {ang_loss} \nwith norm: {torch.norm(ang_loss)}'
                        )
                        self.ra_losses.append(torch.norm(ang_loss))

                        # Compute rotation matrix
                        rotmat = self.perspective_taker.quaternion2rotmat(
                            upd_R)

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            self.Rs[i].requires_grad = False
                            self.Rs[i].grad.data.zero_()

                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            quat = upd_R.clone()
                            quat.requires_grad_()
                            self.Rs[i] = quat

                    else:
                        # Update parameters in time step t-H with saved gradients
                        upd_R = self.perspective_taker.update_rotation_angles_(
                            self.Rs[0], grad_R, self.at_learning_rate)
                        print(f'updated angles: {upd_R}')

                        # Save rotation angles
                        rotang = torch.stack(upd_R)
                        # angles:
                        ang_diff = rotang - self.ideal_angle
                        ang_loss = 2 - (torch.cos(torch.deg2rad(ang_diff)) + 1)
                        print(
                            f'loss of rotation angles: \n  {ang_loss}, \n  with norm {torch.norm(ang_loss)}'
                        )
                        self.rv_losses.append(torch.norm(ang_loss))
                        # Compute rotation matrix
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            upd_R[0], upd_R[1], upd_R[2])[0]

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            for j in range(self.num_input_dimensions):
                                self.Rs[i][j].requires_grad = False
                                self.Rs[i][j].grad.data.zero_()

                        # print(Rs[0])
                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            angles = []
                            for j in range(3):
                                angle = upd_R[j].clone()
                                angle.requires_grad_()
                                angles.append(angle)
                            self.Rs[i] = angles
                        # print(Rs[0])

                    # Calculate and save rotation losses
                    # matrix:
                    # mat_loss = self.mse(
                    #     (torch.mm(self.ideal_rotation, torch.transpose(rotmat, 0, 1))),
                    #     self.identity_matrix
                    # )
                    dif_R = torch.mm(self.ideal_rotation,
                                     torch.transpose(rotmat, 0, 1))
                    mat_loss = torch.arccos(0.5 * (torch.trace(dif_R) - 1))
                    mat_loss = torch.rad2deg(mat_loss)

                    print(f'loss of rotation matrix: {mat_loss}')
                    self.rm_losses.append(mat_loss)

                    # print(Rs[0])

                    # Initial state
                    g_h = at_h.grad.to(self.device)
                    g_c = at_c.grad.to(self.device)

                    upd_h = init_state[0] - self.at_learning_rate_state * g_h
                    upd_c = init_state[1] - self.at_learning_rate_state * g_c

                    at_h.data = upd_h.clone().detach().requires_grad_()
                    at_c.data = upd_c.clone().detach().requires_grad_()

                    at_h.grad.data.zero_()
                    at_c.grad.data.zero_()

                    # state_optimizer.step()
                    # print(f'updated init_state: {init_state}')

                ## REORGANIZE FOR MULTIPLE CYCLES!!!!!!!!!!!!!

                # forward pass from t-H to t with new parameters
                # Update init state???
                init_state = (at_h, at_c)
                state = (init_state[0], init_state[1])
                self.at_predictions = torch.tensor([]).to(self.device)
                for i in range(self.tuning_length):

                    if self.rotation_type == 'qrotate':
                        x_R = self.perspective_taker.qrotate(
                            self.at_inputs[i], self.Rs[i])
                    else:
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                        x_R = self.perspective_taker.rotate(
                            self.at_inputs[i], rotmat)

                    x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

                    state = (state[0] * state_scaler, state[1] * state_scaler)
                    upd_prediction, state = self.core_model(x, state)
                    self.at_predictions = torch.cat(
                        (self.at_predictions,
                         upd_prediction.reshape(1, self.input_per_frame)), 0)

                    # for last tuning cycle update initial state to track gradients
                    if cycle == (self.tuning_cycles - 1) and i == 0:
                        with torch.no_grad():
                            final_prediction = self.at_predictions[0].clone(
                            ).detach().to(self.device)
                            final_input = x.clone().detach().to(self.device)

                        at_h = state[0].clone().detach().requires_grad_()
                        at_c = state[1].clone().detach().requires_grad_()
                        init_state = (at_h, at_c)
                        state = (init_state[0], init_state[1])

                    self.at_states[i] = state

                # Update current rotation
                if self.rotation_type == 'qrotate':
                    x_R = self.perspective_taker.qrotate(o, self.Rs[-1])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                    x_R = self.perspective_taker.rotate(o, rotmat)

                x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            # END tuning cycle

            ## Generate updated prediction
            state = self.at_states[-1]
            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)

            ## Reorganize storage variables
            # observations
            at_final_inputs = torch.cat(
                (at_final_inputs, final_input.reshape(
                    1, self.input_per_frame)), 0)
            self.at_inputs = torch.cat((self.at_inputs[1:],
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)

            # predictions
            at_final_predictions = torch.cat(
                (at_final_predictions,
                 final_prediction.reshape(1, self.input_per_frame)), 0)
            self.at_predictions = torch.cat(
                (self.at_predictions[1:],
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        # END active tuning

        # store rest of predictions in at_final_predictions
        for i in range(self.tuning_length):
            at_final_predictions = torch.cat(
                (at_final_predictions, self.at_predictions[i].reshape(
                    1, self.input_per_frame)), 0)
            if self.rotation_type == 'qrotate':
                x_i = self.perspective_taker.qrotate(self.at_inputs[i],
                                                     self.Rs[-1])
            else:
                x_i = self.perspective_taker.rotate(self.at_inputs[i], rotmat)
            at_final_inputs = torch.cat(
                (at_final_inputs, x_i.reshape(1, self.input_per_frame)), 0)

        # get final rotation matrix
        if self.rotation_type == 'qrotate':
            final_rotation_values = self.Rs[0].clone().detach()
            # get final quaternion
            print(f'final quaternion: {final_rotation_values}')
            final_rotation_matrix = self.perspective_taker.quaternion2rotmat(
                final_rotation_values)

        else:
            final_rotation_values = [
                self.Rs[0][i].clone().detach()
                for i in range(self.num_input_dimensions)
            ]
            print(f'final euler angles: {final_rotation_values}')
            final_rotation_matrix = self.perspective_taker.compute_rotation_matrix_(
                final_rotation_values[0], final_rotation_values[1],
                final_rotation_values[2])

        print(f'final rotation matrix: \n{final_rotation_matrix}')

        return at_final_inputs, at_final_predictions, final_rotation_values, final_rotation_matrix
Example #10
0
    def forward(self,
                img,
                steps,
                sol_index=0,
                disturb_sol=True,
                min_height=64,
                height_disturbance=0.5,
                angle_disturbance=45,
                translate_disturbance=15):

        self.visualization = None
        img = Variable(img, requires_grad=False).cuda()

        # Take last 5 positions before current step

        steps_used = steps[max(0, sol_index - self.tsa_size):sol_index +
                           1].cuda()
        heights = [
            Point(s[0][0].item(), s[0][1].item()).distance(
                Point(s[1][0].item(), s[1][1].item())) for s in steps_used
        ]
        heights = [max(min_height, h) for h in heights]
        heights = [
            torch.tensor(h, dtype=torch.float32).cuda() for h in heights
        ]

        if any([h == 0 for h in heights]):
            return None

        if disturb_sol:
            base_disturbance = torch.tensor([
                0.0, 0.0
            ] if not disturb_sol else [
                random.uniform(-translate_disturbance, translate_disturbance),
                random.uniform(-translate_disturbance, translate_disturbance)
            ],
                                            dtype=torch.float32).cuda()
            angle_disturbance = torch.tensor(
                0.0 if not disturb_sol else random.uniform(
                    -angle_disturbance, angle_disturbance),
                dtype=torch.float32).cuda()
            steps_used[-1][3][0] = torch.add(steps_used[-1][3][0],
                                             angle_disturbance)
            steps_used[-1][1] = torch.add(steps_used[-1][1], base_disturbance)
            heights[-1] = heights[-1] * (
                1 + random.uniform(-height_disturbance, height_disturbance))

        current_angle = steps_used[-1][3][0]
        current_scale = (self.patch_ratio * heights[-1] /
                         self.patch_size).cuda()
        current_base = steps_used[-1][1]

        patches = [
            extract_tensor_patch(
                img,
                torch.stack([
                    step[1][0],  # x
                    step[1][1],  # y
                    torch.mul(torch.deg2rad(step[3][0]), -1).cuda(),
                    heights[index]
                ]).unsqueeze(0).cuda(),
                size=self.patch_size).squeeze(0)
            for index, step in enumerate(steps_used)
        ]

        y = self.stepper(patches)

        # size = input[-1, :, :, :].shape[1] / self.patch_ratio

        base_rotation_matrix = torch.stack([
            torch.stack([
                torch.cos(torch.deg2rad(current_angle)),
                -1.0 * torch.sin(torch.deg2rad(current_angle))
            ]),
            torch.stack([
                1.0 * torch.sin(torch.deg2rad(current_angle)),
                torch.cos(torch.deg2rad(current_angle))
            ])
        ]).cuda()

        scale_matrix = torch.stack([
            torch.stack([
                current_scale,
                torch.tensor(0., dtype=torch.float32,
                             requires_grad=True).cuda()
            ]),
            torch.stack([
                torch.tensor(0., dtype=torch.float32,
                             requires_grad=True).cuda(), current_scale
            ])
        ]).cuda()

        base_point = torch.stack([y[0], y[1]])
        base_point = torch.matmul(base_point, base_rotation_matrix.t())
        base_point = torch.matmul(base_point, scale_matrix)

        # Assuming a sigmoid indicating how much % of the patch is the upper height
        #
        upper_point = torch.stack([y[3], y[4]])
        upper_point = torch.matmul(upper_point, base_rotation_matrix.t())
        upper_point = torch.matmul(upper_point, scale_matrix)
        upper_point = torch.add(upper_point, current_base)

        lower_point = torch.stack([y[5], y[6]])
        lower_point = torch.matmul(lower_point, base_rotation_matrix.t())
        lower_point = torch.matmul(lower_point, scale_matrix)
        lower_point = torch.add(lower_point, current_base)

        # current_height = torch.mul(current_height, y[3])
        current_base = torch.add(current_base, base_point)
        current_angle = torch.add(current_angle, y[2])

        # up = self.up_measurer(img, current_base, current_angle, base_height)
        # down = self.down_measurer(img, current_base, current_angle, base_height)
        # stop = self.stop_measurer(img, current_base, current_angle, base_height)

        del patches

        return torch.stack([
            upper_point,
            current_base,
            lower_point,
            torch.stack(
                [current_angle,
                 torch.tensor(0.0, dtype=torch.float32).cuda()]).cuda(),
            torch.stack([y[7],
                         torch.tensor(0.0,
                                      dtype=torch.float32).cuda()]).cuda(),
        ])
Example #11
0
 def create_output_past(self, index):
     start_deg = self.indicies[index]
     start_deg = start_deg - self.output_len
     degrees = torch.arange(start_deg, start_deg + self.output_len)
     rads = torch.deg2rad(degrees)
     return (degrees, self.F(rads))
Example #12
0
training_set_list = load_file_list_direct(training_set_list_path)
train_dataset = LolDataset(training_set_list, augmentation=True)
train_dataloader = DataLoader(train_dataset,
                              batch_size=1,
                              shuffle=True,
                              num_workers=0,
                              collate_fn=lol_dataset.collate)
batches_per_epoch = int(args.images_per_epoch / args.batch_size)
train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

dtype = torch.cuda.FloatTensor

initial = torch.tensor([10.0, 10.0]).cuda()
current_angle = torch.tensor(45.0).cuda()
rotation = torch.tensor(
    [[torch.cos(torch.deg2rad(current_angle)), 1.0 * torch.sin(torch.deg2rad(current_angle))],
     [-1.0 * torch.sin(torch.deg2rad(current_angle)), torch.cos(torch.deg2rad(current_angle))]]).cuda()
rotated = torch.matmul(initial, rotation.t())

current_scale = 1 / (sqrt(2))
scaling = torch.tensor([[current_scale, 0.], [0., current_scale]]).cuda()

scaled = torch.matmul(rotated, scaling)

assert scaled[0].item() == 10

t1 = torch.tensor([[[10, 10], [20, 20], [30, 30], [10, 0], [0.5, 0]]])
t2 = torch.tensor([[[20, 10], [20, 20], [30, 30], [5, 0], [0.5, 0]]])

loss = torch.nn.MSELoss(reduction="sum")(t1, t2)
import numpy as np
import torch

# import matplotlib.pyplot as plt

IIWA_JOINT_MIN_LIMITS = torch.deg2rad(
    torch.tensor([-170., -120., -170., -120., -170., -120., -175.]))
IIWA_JOINT_MAX_LIMITS = torch.deg2rad(
    torch.tensor([170., 120., 170., 120., 170., 120., 175.]))


class Rotation:
    def __init__(self, quat, normalize=True):
        self._quat = quat if isinstance(
            quat, torch.Tensor) else torch.tensor(quat).double()
        if normalize:
            self._quat /= torch.norm(quat)

    @classmethod
    def from_quat(cls, quat):
        return cls(quat, normalize=True)

    @classmethod
    def from_matrix(cls, matrix):
        trace = torch.trace(matrix)
        quat = torch.zeros(4).double()
        if trace > 1e-5:
            s = torch.sqrt(trace + 1) * 2
            quat[0] = s / 4
            quat[1] = (matrix[2, 1] - matrix[1, 2]) / s
            quat[2] = (matrix[0, 2] - matrix[2, 0]) / s
Example #14
0
    def forward(self,
                img,
                sol_tensor,
                steps,
                reset_threshold=None,
                max_steps=None,
                disturb_sol=True,
                confidence_threshold=None,
                height_disturbance=0.5,
                angle_disturbance=30,
                translate_disturbance=10):

        desired_polygon_steps = torch.stack([sol_tensor.cuda()] +
                                            [p.cuda() for p in steps])

        img.cuda()

        # tensor([tsa, channels, width, height])
        input = ((255 / 128) - 1) * torch.ones(
            (1, 3, self.patch_size, self.patch_size)).cuda()

        steps_ran = 0

        sol = {
            "upper_point": sol_tensor[0],
            "base_point": sol_tensor[1],
            "angle": sol_tensor[3][0],
        }

        if disturb_sol:
            x = random.uniform(0, translate_disturbance)
            y = random.uniform(0, translate_disturbance)
            sol["upper_point"][0] += x
            sol["upper_point"][1] += y
            sol["base_point"][0] += x
            sol["base_point"][1] += y
            sol["angle"] += random.uniform(-angle_disturbance,
                                           angle_disturbance)

        current_height = torch.dist(sol["upper_point"].clone(),
                                    sol["base_point"].clone()).cuda()
        current_height = current_height * (1 if not disturb_sol else (
            1 + random.uniform(0, height_disturbance)))

        current_angle = sol["angle"].clone().cuda()
        current_base = sol["base_point"].clone().cuda()

        results = []
        tsa_sequence = []

        upper_height_stack = []
        lower_height_stack = []
        baseline_stack = []  # [sol_tensor[1].cuda()]
        angle_stack = []
        while (max_steps is None or steps_ran < max_steps):

            current_scale = (self.patch_ratio * current_height /
                             self.patch_size).cuda()

            img_width = img.shape[2]
            img_height = img.shape[3]

            # current_angle = torch.mul(current_angle, -1)

            if img_width < current_base[0].item() or current_base[0].item(
            ) < 0:
                break
            if img_height < current_base[1].item() or current_base[0].item(
            ) < 0:
                break
            patch_parameters = torch.stack([
                current_base[0],  # x
                current_base[1],  # y
                torch.mul(torch.deg2rad(current_angle), -1),  # angle
                current_height
            ]).unsqueeze(0)
            patch_parameters = patch_parameters.cuda()
            try:
                patch = extract_tensor_patch(img,
                                             patch_parameters,
                                             size=self.patch_size)  # size
            except:
                break

            # Shift input left
            # input = torch.stack([pic for pic in input[1:]] + [patch.squeeze(0)])

            input = torch.stack([pic for pic in input[-self.tsa_size:]] +
                                [patch.squeeze(0)])

            y = input.cuda().unsqueeze(0)
            y = self.tsa(y)
            y = y[:, 1:, :, :, :]
            after_tsa_copy = y.detach().cpu().clone()
            tsa_sequence.append(after_tsa_copy)
            y = y.squeeze(0)
            y = self.initial_convolutions(y)
            y = y.unsqueeze(0)
            y = self.memory_layer(y)
            y = y.unsqueeze(0)
            y = self.final_convolutions(y)
            y = y.unsqueeze(0)
            y = torch.flatten(y, 1)
            y = torch.flatten(y, 0)
            y = self.fully_connected(y)

            size = input[0, :, :, :].shape[1] / self.patch_ratio

            y[0] = torch.add(y[0], size)

            y[2] = torch.sigmoid(y[2])
            y[3] = torch.sigmoid(y[3])
            # y[2] = torch.add(y[2], size)
            # y[3] = torch.add(y[3], -size)
            # y[4] = torch.add(y[4], size)

            scale_matrix = torch.stack([
                torch.stack([current_scale,
                             torch.tensor(0.).cuda()]),
                torch.stack([torch.tensor(0.).cuda(), current_scale])
            ]).cuda()

            # Finds the next base point
            base_rotation_matrix = torch.stack([
                torch.stack([
                    torch.cos(torch.deg2rad(current_angle)),
                    -1.0 * torch.sin(torch.deg2rad(current_angle))
                ]),
                torch.stack([
                    1.0 * torch.sin(torch.deg2rad(current_angle)),
                    torch.cos(torch.deg2rad(current_angle))
                ])
            ]).cuda()

            # Create a vector to represent the new base

            base_point = torch.stack([y[0], y[1]])

            # upper_point = torch.stack([torch.tensor(0, dtype=torch.float32).cuda(), -upper_height])
            # lower_point = torch.stack([torch.tensor(0, dtype=torch.float32).cuda(), lower_height])

            base_point = torch.matmul(base_point, base_rotation_matrix.t())
            base_point = torch.matmul(base_point, scale_matrix)
            current_base = torch.add(base_point, current_base)
            upper_height = torch.mul(y[2], self.patch_size / 2)
            lower_height = torch.mul(y[3], self.patch_size / 2)
            current_angle = torch.add(current_angle, y[4])

            angle_stack.append(current_angle.clone().detach())
            baseline_stack.append(current_base)
            upper_height_stack.append(torch.mul(upper_height, current_scale))
            lower_height_stack.append(torch.mul(lower_height, current_scale))

            # Finds the next base point
            # point_rotation_matrix = torch.stack(
            #    [torch.stack([torch.cos(torch.deg2rad(current_angle)), -1.0 * torch.sin(torch.deg2rad(current_angle))]),
            #     torch.stack(
            #         [1.0 * torch.sin(torch.deg2rad(current_angle)), torch.cos(torch.deg2rad(current_angle))])]).cuda()

            # lower_point = torch.matmul(lower_point, point_rotation_matrix.t())
            # lower_point = torch.matmul(lower_point, scale_matrix)
            # lower_point = torch.add(lower_point, current_base)

            # upper_point = torch.matmul(upper_point, point_rotation_matrix.t())
            # upper_point = torch.matmul(upper_point, scale_matrix)
            # upper_point = torch.add(upper_point, current_base)

            # look_ahead_ratio = 3
            # look_ahead_base = current_base.clone().detach()
            # look_ahead_angle = current_angle.clone().detach()
            # look_ahead_height = current_height.clone().detach()
            # extraction_params = []
            # for i in range(look_ahead_ratio):
            #     look_ahead_params = torch.stack([look_ahead_base[0],  # x
            #                                      look_ahead_base[1],  # y
            #                                      torch.mul(torch.deg2rad(look_ahead_angle), -1),  # angle
            #                                      current_height]).unsqueeze(0)
            #     extraction_params.append(look_ahead_params.cuda())
            #     base_point = Point(look_ahead_base[0].item(), look_ahead_base[1].item())
            #     next_point = get_new_point(base_point, look_ahead_angle.item(), look_ahead_height.item())
            #     look_ahead_base = torch.tensor([next_point.x, next_point.y]).cuda()
            # patches = [extract_tensor_patch(img, p, size=self.patch_size) for p in extraction_params]
            # concatenated_patches = torch.cat(patches, dim=2)
            # stop_result = self.stop(concatenated_patches.clone().detach())
            #
            # gt_polygon = to_polygon(torch.stack([s.cuda() for s in steps]))
            # upper_p = Point(upper_point[0].item(), upper_point[1].item())
            # lower_p = Point(lower_point[0].item(), lower_point[1].item())
            # step_line = LineString([upper_p, lower_p])
            # total_length = upper_p.distance(lower_p)
            # intersection = step_line.intersection(gt_polygon)
            # desired_confidence = 1
            # if intersection is not None and isinstance(intersection, LineString):
            #     desired_confidence = 1 - (intersection.length / total_length)
            #
            # confidence_loss = torch.nn.MSELoss()(stop_result, torch.tensor(desired_confidence, dtype=torch.float32).cuda())
            # confidence_loss.backward(retain_graph=True)

            # Decide whether to stop based on last step DICE AFFINITY
            steps_ran += 1

            # if confidence_threshold is not None and stop_result.item() > confidence_threshold:
            #    break

            # Minimum steps to fill TSA
            # if min_run_tsa and steps_ran < self.tsa_size:
            #    continue

            if reset_threshold is not None:
                base_as_point = Point(current_base[0].item(),
                                      current_base[1].item())
                # upper_as_point = Point(upper_point[0].item(), upper_point[1].item())
                gt_step = steps[steps_ran]
                # gt_upper = Point(gt_step[0][0].item(), gt_step[0][1].item())
                gt_base_point = Point(gt_step[1][0].item(),
                                      gt_step[1][1].item())
                # is_upper_violated = upper_as_point.distance(gt_upper) > reset_threshold

                if base_as_point.distance(gt_base_point) > reset_threshold:
                    break

            if max_steps is None and steps_ran >= len(steps) - 1:
                break

        for i in range(len(baseline_stack)):

            if (i == 0 and len(baseline_stack)
                    == 1) or i == len(baseline_stack) - 1:
                angle_to_next = torch.deg2rad(angle_stack[i])
            else:
                difference = baseline_stack[i + 1] - baseline_stack[i]
                angle_to_next = torch.atan2(difference[1], difference[0])

            point_rotation_matrix = torch.stack([
                torch.stack([
                    torch.cos(angle_to_next), -1.0 * torch.sin(angle_to_next)
                ]),
                torch.stack(
                    [1.0 * torch.sin(angle_to_next),
                     torch.cos(angle_to_next)])
            ]).cuda()

            upper_point = torch.stack(
                [torch.tensor(0.).cuda(), -upper_height_stack[i]])
            upper_point = torch.matmul(upper_point, point_rotation_matrix.t())
            upper_point = torch.add(upper_point,
                                    baseline_stack[i].clone().detach())

            lower_point = torch.stack(
                [torch.tensor(0.).cuda(), lower_height_stack[i]])
            lower_point = torch.matmul(lower_point, point_rotation_matrix.t())
            lower_point = torch.add(lower_point,
                                    baseline_stack[i].clone().detach())

            results.append(
                torch.stack([
                    upper_point, baseline_stack[i], lower_point,
                    torch.tensor([0.0, 0.0]).cuda()
                ]))
            # point_rotation_matrix = torch.stack(
            #    [torch.stack([torch.cos(torch.deg2rad(current_angle)), -1.0 * torch.sin(torch.deg2rad(current_angle))]),
            #     torch.stack(
            #         [1.0 * torch.sin(torch.deg2rad(current_angle)), torch.cos(torch.deg2rad(current_angle))])]).cuda()

        if len(results) == 0:
            return torch.zeros((0, 5, 2)).cuda(), 0, []
        else:
            return torch.stack(results), steps_ran, tsa_sequence