示例#1
0
    def forward(self,
                search,
                template,
                samp_idx=None,
                labels=None,
                settings=None):
        """ The forward expects a NestedTensor, which consists of:
               - search.tensors: batched images, of shape [batch_size x 3 x H_search x W_search]
               - search.mask: a binary mask of shape [batch_size x H_search x W_search], containing 1 on padded pixels
               - template.tensors: batched images, of shape [batch_size x 3 x H_template x W_template]
               - template.mask: a binary mask of shape [batch_size x H_template x W_template], containing 1 on padded pixels

            It returns a dict with the following elements:
               - "pred_logits": the classification logits for all feature vectors.
                                Shape= [batch_size x num_vectors x (num_classes + 1)]
               - "pred_boxes": The normalized boxes coordinates for all feature vectors, represented as
                               (center_x, center_y, height, width). These values are normalized in [0, 1],
                               relative to the size of each individual image.

        """
        if not isinstance(search, NestedTensor):
            search = nested_tensor_from_tensor(search)
        if not isinstance(template, NestedTensor):
            template = nested_tensor_from_tensor(template)
        feature_search, pos_search = self.backbone(search)
        feature_template, pos_template = self.backbone(template)
        src_search, mask_search = feature_search[-1].decompose()
        assert mask_search is not None
        src_template, mask_template = feature_template[-1].decompose()
        assert mask_template is not None
        hs = self.featurefusion_network(self.input_proj(src_template),
                                        mask_template,
                                        self.input_proj(src_search),
                                        mask_search, pos_template[-1],
                                        pos_search[-1])

        outputs_class = self.new_class_embed(hs)
        outputs_coord = self.new_bbox_embed(hs).sigmoid()
        out = {
            'pred_logits': outputs_class[-1],
            'pred_boxes': outputs_coord[-1]
        }
        return out
示例#2
0
    def forward(self, search, template, samp_idx, labels, settings, boxes):
        """ The forward expects a NestedTensor, which consists of:
               - search.tensors: batched images, of shape [batch_size x 3 x H_search x W_search]
               - search.mask: a binary mask of shape [batch_size x H_search x W_search], containing 1 on padded pixels
               - template.tensors: batched images, of shape [batch_size x 3 x H_template x W_template]
               - template.mask: a binary mask of shape [batch_size x H_template x W_template], containing 1 on padded pixels

            It returns a dict with the following elements:
               - "pred_logits": the classification logits for all feature vectors.
                                Shape= [batch_size x num_vectors x (num_classes + 1)]
               - "pred_boxes": The normalized boxes coordinates for all feature vectors, represented as
                               (center_x, center_y, height, width). These values are normalized in [0, 1],
                               relative to the size of each individual image.

        """
        # Reshape search into a 4D tensor
        search_shape = [int(x) for x in search.shape]
        searcht = search.view([search_shape[0] * search_shape[1]] +
                              search_shape[2:])
        if not isinstance(searcht, NestedTensor):
            searcht = nested_tensor_from_tensor(searcht)
        if not isinstance(template, NestedTensor):
            templatet = nested_tensor_from_tensor(template)
        with torch.no_grad():
            feature_search, pos_search = self.backbone(searcht)
            feature_template, pos_template = self.backbone(templatet)
        src_search, mask_search = feature_search[-1].decompose()

        assert mask_search is not None
        src_template, mask_template = feature_template[-1].decompose()
        assert mask_template is not None
        circuit_input, _ = feature_search[-2].decompose()

        post_src_search_shape = [int(x) for x in src_search.shape]
        post_mask_search_shape = [int(x) for x in mask_search.shape]
        src_search = src_search.view(search_shape[:2] +
                                     post_src_search_shape[1:])
        src_template = self.input_proj(src_template)
        mask_search = mask_search.view(search_shape[:2] +
                                       post_mask_search_shape[1:])

        circuit_input = circuit_input.view(search_shape[0], search_shape[1],
                                           -1, circuit_input.shape[2],
                                           circuit_input.shape[3]).permute(
                                               0, 2, 1, 3, 4)
        proc_labels = F.interpolate(labels, circuit_input.shape[-2:])
        proc_label_shape = [int(x) for x in proc_labels.shape]
        proc_labels = proc_labels.view([search_shape[0]] +
                                       proc_label_shape[1:]).to(
                                           src_search.device)  # .mean(2)
        inh_1 = self.cnl(self.circuit_inh_1_init(1 - proc_labels))
        exc_1 = self.cnl(self.circuit_exc_1_init(proc_labels))

        nl_src_search = self.cnl(self.circuit_proj(circuit_input))
        pos_search_shape = [int(x) for x in pos_search[-1].shape]
        pos_search = pos_search[-1].view(search_shape[:2] +
                                         pos_search_shape[1:])
        mask_search = mask_search.view(search_shape[:2] +
                                       post_mask_search_shape[1:])
        excs = []
        for t in range(nl_src_search.shape[2]):

            # Step 1, saliency
            pre_exc_1, inh_1 = self.circuit_1(nl_src_search[:, :, t],
                                              excitation=exc_1,
                                              inhibition=inh_1,
                                              activ=self.cnl)

            # Step 2, tracking
            post_exc_1 = F.max_pool2d(self.cnl(
                self.circuit_step2_trans(pre_exc_1)),
                                      kernel_size=(2, 2),
                                      stride=(2, 2),
                                      padding=(0, 0))
            if t == 0:
                exc_2 = self.cnl(self.circuit_exc_2_init(post_exc_1))
                inh_2 = self.cnl(self.circuit_inh_2_init(post_exc_1))
            exc_2, inh_2 = self.circuit_2(
                post_exc_1, excitation=exc_2, inhibition=inh_2,
                activ=self.cnl)  # , label=proc_labels)

            # Split off a pair of TransT features and then use the circuit to gate the Qs in the decoder.
            interp_exc_2 = F.interpolate(exc_2,
                                         pre_exc_1.shape[2:],
                                         mode="bilinear",
                                         align_corners=False)
            dec_rnn = self.circuit_rnn_decode_1(interp_exc_2)
            prev_hs = self.circuit_rnn_decode_2(
                self.nl(dec_rnn))  # * torch.sigmoid(self.rnn_gate(dec_rnn))
            # src_search = src_search * proc_rnn.sigmoid()

            # Pass activities through transformer
            # hs, _ = self.featurefusion_network(src_template, mask_template, self.input_proj(src_search[:, t]), mask_search[:, t], pos_template[-1], pos_search[:, t], exc=prev_hs)
            src_input = self.input_proj(src_search[:, t]) * prev_hs.sigmoid()
            hs = self.featurefusion_network(src_template, mask_template,
                                            src_input, mask_search[:, t],
                                            pos_template[-1], pos_search[:, t])

            # Step 3, TD-FB incorporating hs
            res_hs = hs.squeeze().view(search_shape[0], self.height,
                                       self.height,
                                       self.hidden_dim).permute(0, 3, 1, 2)

            # TD from Trans to Circuit
            res_hs = F.max_pool2d(self.cnl(self.circuit_tf_codeswitch(res_hs)),
                                  kernel_size=(2, 2),
                                  stride=(2, 2),
                                  padding=(0, 0))
            ##### This pooling is for technical reasons. Ideally dont have to do this but whatever...
            if t == 0:
                td_inh_2 = self.cnl(self.circuit_td_inh_2_init(res_hs))
            exc_2, td_inh_2 = self.circuit_td_2(exc_2,
                                                excitation=res_hs,
                                                inhibition=td_inh_2,
                                                activ=self.cnl)

            # TD from circuit2 to Circuit 1
            interp_exc_2 = self.cnl(
                self.circuit_step3_trans(interp_exc_2)
            )  # F.interpolate(exc_2, pre_exc_1.shape[2:], mode="bilinear", align_corners=False)))
            if t == 0:
                td_inh_1 = self.cnl(self.circuit_td_inh_1_init(interp_exc_2))
            exc_1, td_inh_1 = self.circuit_td_1(pre_exc_1,
                                                excitation=interp_exc_2,
                                                inhibition=td_inh_1,
                                                activ=self.cnl)
            # from matplotlib import pyplot as plt
            # plt.subplot(141);plt.imshow(search[0, t].squeeze().permute(1, 2, 0).cpu());plt.subplot(142);
            # plt.imshow((self.input_proj(src_search[:, t])[0].squeeze() ** 2).mean(0).detach().cpu());
            # plt.subplot(143);plt.imshow((prev_hs.sigmoid().squeeze() ** 2)[0].mean(0).detach().cpu());
            # plt.subplot(144);plt.imshow(((self.input_proj(src_search[:, t]) * prev_hs.sigmoid())[0].squeeze() ** 2).mean(0).detach().cpu());plt.show()
            # if t > src_search.shape[1] - 5:
            #     plt.subplot(141);plt.title("Resnet");plt.imshow(search[0, t].squeeze().permute(1, 2, 0).cpu());plt.title("L1");plt.subplot(142);plt.imshow((pre_exc_1[0] ** 2).squeeze().mean(0).detach().cpu());plt.subplot(143);plt.title("L2");plt.imshow((exc_2[0] ** 2).squeeze().mean(0).detach().cpu());plt.subplot(144);plt.title("L1 after TD");plt.imshow((exc_1[0] ** 2).squeeze().mean(0).detach().cpu());
            #     plt.show()
            excs.append(exc_1)  # Also try exc_1?

        # Concat exc to hs too
        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        excs = torch.stack(excs, 2)
        proc_rnn = self.circuit_exc_bbox_3(
            self.cnl(self.circuit_exc_bbox_1(excs)).view(
                excs.shape[0], excs.shape[2], -1)).sigmoid()

        if self.vj_pen:
            out = {
                'pred_logits': outputs_class[-1],
                'pred_boxes': outputs_coord[-1],
                'hgru_boxes': exc_bbox,
                "vj_penalty": vj_penalty
            }
        else:
            out = {
                'pred_logits': outputs_class[-1],
                'pred_boxes': outputs_coord[-1],
                'hgru_boxes': proc_rnn
            }
        return out
示例#3
0
    def forward(self, search, template, samp_idx, labels, settings):
        """ The forward expects a NestedTensor, which consists of:
               - search.tensors: batched images, of shape [batch_size x 3 x H_search x W_search]
               - search.mask: a binary mask of shape [batch_size x H_search x W_search], containing 1 on padded pixels
               - template.tensors: batched images, of shape [batch_size x 3 x H_template x W_template]
               - template.mask: a binary mask of shape [batch_size x H_template x W_template], containing 1 on padded pixels

            It returns a dict with the following elements:
               - "pred_logits": the classification logits for all feature vectors.
                                Shape= [batch_size x num_vectors x (num_classes + 1)]
               - "pred_boxes": The normalized boxes coordinates for all feature vectors, represented as
                               (center_x, center_y, height, width). These values are normalized in [0, 1],
                               relative to the size of each individual image.

        """
        # Reshape search into a 4D tensor
        search_shape = [int(x) for x in search.shape]
        search = search.view([search_shape[0] * search_shape[1]] + search_shape[2:])
        if not isinstance(search, NestedTensor):
            search = nested_tensor_from_tensor(search)
        if not isinstance(template, NestedTensor):
            template = nested_tensor_from_tensor(template)
        with torch.no_grad():
            feature_search, pos_search = self.backbone(search)
            feature_template, pos_template = self.backbone(template)
        src_search, mask_search = feature_search[-1].decompose()
        assert mask_search is not None
        src_template, mask_template = feature_template[-1].decompose()
        assert mask_template is not None

        # Use the circuit to track through the features
        src_search = self.input_proj(src_search)
        post_src_search_shape = [int(x) for x in src_search.shape]
        post_mask_search_shape = [int(x) for x in mask_search.shape]
        src_search = src_search.view(search_shape[:2] + post_src_search_shape[1:])
        mask_search = mask_search.view(search_shape[:2] + post_mask_search_shape[1:])
        proc_labels = self._generate_label_function(
                labels,
                settings.sigma,
                settings.kernel,
                settings.feature,  # post_src_search_shape[-1],  # settings.feature,
                settings.output_sz,  # post_src_search_shape[-1],  # settings.output_sz,
                settings.end_pad_if_even)
        exc, inh = None, None
        # pos_search_shape = [int(x) for x in pos_search[-1].shape]
        proc_label_shape = [int(x) for x in proc_labels.shape]
        # proc_labels = pos_search[-1].view(search_shape[:2] + pos_search_shape[1:]).mean(2)
        proc_labels = proc_labels.view(search_shape[:2] + [1] + proc_label_shape[1:]).to(src_search.device)  # .mean(2)
        # from matplotlib import pyplot as plt
        # im=2;ti=6;plt.subplot(121);plt.imshow(proc_labels[im, ti].cpu());plt.subplot(122);plt.imshow(src_search[im, ti].mean(0).detach().cpu());plt.show()
        for t in range(samp_idx):
            exc, inh = self.circuit(src_search[:, t], excitation=exc, inhibition=inh, label=proc_labels[:, t])
        # Reshape exc for the transformers
        exc = exc.flatten(2).permute(2, 0, 1)

        # Reshape pos_search
        pos_search_shape = [int(x) for x in pos_search[-1].shape]
        pos_search = pos_search[-1].view(search_shape[:2] + pos_search_shape[1:])

        # Split off a pair of TransT features and then use the circuit to gate the Qs in the decoder.
        src_search = src_search[:, samp_idx]
        mask_search = mask_search[:, samp_idx]
        pos_search = pos_search[:, samp_idx]
        hs = self.featurefusion_network(self.input_proj(src_template), mask_template, src_search, mask_search, pos_template[-1], pos_search, exc=exc)

        # Concat exc to hs too
        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        return out
示例#4
0
    def forward(self, search, template, samp_idx, labels, settings, boxes):
        """ The forward expects a NestedTensor, which consists of:
               - search.tensors: batched images, of shape [batch_size x 3 x H_search x W_search]
               - search.mask: a binary mask of shape [batch_size x H_search x W_search], containing 1 on padded pixels
               - template.tensors: batched images, of shape [batch_size x 3 x H_template x W_template]
               - template.mask: a binary mask of shape [batch_size x H_template x W_template], containing 1 on padded pixels

            It returns a dict with the following elements:
               - "pred_logits": the classification logits for all feature vectors.
                                Shape= [batch_size x num_vectors x (num_classes + 1)]
               - "pred_boxes": The normalized boxes coordinates for all feature vectors, represented as
                               (center_x, center_y, height, width). These values are normalized in [0, 1],
                               relative to the size of each individual image.

        """
        # Reshape search into a 4D tensor
        search_shape = [int(x) for x in search.shape]
        searcht = search.view([search_shape[0] * search_shape[1]] + search_shape[2:])
        if not isinstance(searcht, NestedTensor):
            searcht = nested_tensor_from_tensor(searcht)
        if not isinstance(template, NestedTensor):
            templatet = nested_tensor_from_tensor(template)
        with torch.no_grad():
            feature_search, pos_search = self.backbone(searcht)
            feature_template, pos_template = self.backbone(templatet)
        src_search, mask_search = feature_search[-1].decompose()

        assert mask_search is not None
        src_template, mask_template = feature_template[-1].decompose()
        assert mask_template is not None
        circuit_input, _ = feature_search[-1].decompose()

        post_src_search_shape = [int(x) for x in src_search.shape]
        post_mask_search_shape = [int(x) for x in mask_search.shape]
        src_search = src_search.view(search_shape[:2] + post_src_search_shape[1:])
        src_template = self.input_proj(src_template)
        mask_search = mask_search.view(search_shape[:2] + post_mask_search_shape[1:])

        circuit_input = circuit_input.view(search_shape[0], search_shape[1], -1, circuit_input.shape[2], circuit_input.shape[3]).permute(0, 2, 1, 3, 4)
        proc_labels = F.interpolate(labels, circuit_input.shape[-2:])
        proc_label_shape = [int(x) for x in proc_labels.shape]
        proc_labels = proc_labels.view([search_shape[0]] + proc_label_shape[1:]).to(src_search.device)  # .mean(2)
        inh_1 = self.cnl(self.circuit_inh_1_init(1 - proc_labels))
        exc_1 = self.cnl(self.circuit_exc_1_init(proc_labels))

        nl_src_search = self.cnl(self.circuit_proj(circuit_input))
        pos_search_shape = [int(x) for x in pos_search[-1].shape]
        pos_search = pos_search[-1].view(search_shape[:2] + pos_search_shape[1:])
        mask_search = mask_search.view(search_shape[:2] + post_mask_search_shape[1:])
        excs, rnn_gates = [], []
        for t in range(nl_src_search.shape[2]):

            # Step 1, saliency
            pre_exc_1, inh_1 = self.circuit_1(nl_src_search[:, :, t], excitation=exc_1, inhibition=inh_1, activ=self.cnl)

            # Split off a pair of TransT features and then use the circuit to gate the Qs in the decoder.
            dec_rnn = self.circuit_rnn_decode_1(pre_exc_1)
            prev_hs = self.circuit_rnn_decode_2(self.cnl(dec_rnn))  # * torch.sigmoid(self.rnn_gate(dec_rnn))

            # Pass activities through transformer
            # prev_hs = 1 + (prev_hs.tanh())  # in [0, 2]
            if t > 0:
                # cost_vol = torch.exp(-(res_hs - pre_exc_1) ** 2)  # Correspondence between FB and the predicted modulation
                # This cost vol is all-to-all spatial maps. Then Reshaped so that all features are aggregated.
                cost_vol = torch.einsum('bik,bjk->bijk', res_hs.view(pre_exc_1.shape[0], pre_exc_1.shape[1], -1), pre_exc_1.view(pre_exc_1.shape[0], pre_exc_1.shape[1], -1)).view(pre_exc_1.shape[0], pre_exc_1.shape[1] ** 2, pre_exc_1.shape[2], pre_exc_1.shape[3])
                # cost_vol = torch.cat([res_hs, prev_hs], 1)
                rnn_gate = self.circuit_gate(cost_vol).sigmoid()
                # rnn_gates.append((rnn_gate ** 2).mean((1, 2, 3)))
                src_input = self.input_proj(src_search[:, t]) + prev_hs * rnn_gate
            else:
                rnn_gate = 0.  # Dont allow the RNN on the first frame
                src_input = self.input_proj(src_search[:, t])
            hs = self.featurefusion_network(src_template, mask_template, src_input, mask_search[:, t], pos_template[-1], pos_search[:, t])

            # Step 3, TD-FB incorporating hs
            pre_res_hs = hs.squeeze().view(search_shape[0], self.height, self.height, self.hidden_dim).permute(0, 3, 1, 2)

            # TD from Trans to Circuit
            res_hs = self.cnl(self.circuit_tf_codeswitch(pre_res_hs))
            ##### This pooling is for technical reasons. Ideally dont have to do this but whatever...
            if t == 0:
                td_inh_1 = self.cnl(self.circuit_td_inh_1_init(res_hs))
            exc_1, td_inh_1 = self.circuit_td_1(pre_exc_1, excitation=res_hs, inhibition=td_inh_1, activ=self.cnl)
            # from matplotlib import pyplot as plt
            # plt.subplot(141);plt.imshow(search[0, t].squeeze().permute(1, 2, 0).cpu());plt.subplot(142);
            # plt.imshow((self.input_proj(src_search[:, t])[0].squeeze() ** 2).mean(0).detach().cpu());
            # plt.subplot(143);plt.imshow((prev_hs.sigmoid().squeeze() ** 2)[0].mean(0).detach().cpu());
            # plt.subplot(144);plt.imshow(((self.input_proj(src_search[:, t]) * prev_hs.sigmoid())[0].squeeze() ** 2).mean(0).detach().cpu());plt.show()
            # if t > src_search.shape[1] - 5:
            #     plt.subplot(141);plt.title("Resnet");plt.imshow(search[0, t].squeeze().permute(1, 2, 0).cpu());plt.title("L1");plt.subplot(142);plt.imshow((pre_exc_1[0] ** 2).squeeze().mean(0).detach().cpu());plt.subplot(143);plt.title("L2");plt.imshow((exc_2[0] ** 2).squeeze().mean(0).detach().cpu());plt.subplot(144);plt.title("L1 after TD");plt.imshow((exc_1[0] ** 2).squeeze().mean(0).detach().cpu());
            #     plt.show()
            excs.append(exc_1)

        # Concat exc to hs too
        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        excs = torch.stack(excs, 2)
        proc_rnn = self.circuit_exc_bbox_3(self.cnl(self.circuit_exc_bbox_1(excs)).view(excs.shape[0], excs.shape[2], -1)).sigmoid()

        rnn_gate = self.circuit_gate_readout(self.cnl(cost_vol)).mean((1, 2, 3))  # [:, None]
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 'hgru_boxes': proc_rnn, 'rnn_gate': rnn_gate}
        return out
示例#5
0
    def forward(self, search, template, samp_idx, labels, settings, boxes):
        """ The forward expects a NestedTensor, which consists of:
               - search.tensors: batched images, of shape [batch_size x 3 x H_search x W_search]
               - search.mask: a binary mask of shape [batch_size x H_search x W_search], containing 1 on padded pixels
               - template.tensors: batched images, of shape [batch_size x 3 x H_template x W_template]
               - template.mask: a binary mask of shape [batch_size x H_template x W_template], containing 1 on padded pixels

            It returns a dict with the following elements:
               - "pred_logits": the classification logits for all feature vectors.
                                Shape= [batch_size x num_vectors x (num_classes + 1)]
               - "pred_boxes": The normalized boxes coordinates for all feature vectors, represented as
                               (center_x, center_y, height, width). These values are normalized in [0, 1],
                               relative to the size of each individual image.

        """
        # Reshape search into a 4D tensor
        search_shape = [int(x) for x in search.shape]
        searcht = search.view([search_shape[0] * search_shape[1]] +
                              search_shape[2:])
        if not isinstance(searcht, NestedTensor):
            searcht = nested_tensor_from_tensor(searcht)
        if not isinstance(template, NestedTensor):
            templatet = nested_tensor_from_tensor(template)
        with torch.no_grad():
            feature_search, pos_search = self.backbone(searcht)
            feature_template, pos_template = self.backbone(templatet)
        src_search, mask_search = feature_search[-1].decompose()
        assert mask_search is not None
        src_template, mask_template = feature_template[-1].decompose()
        assert mask_template is not None

        post_src_search_shape = [int(x) for x in src_search.shape]
        post_mask_search_shape = [int(x) for x in mask_search.shape]
        src_search = src_search.view(search_shape[:2] +
                                     post_src_search_shape[1:])
        mask_search = mask_search.view(search_shape[:2] +
                                       post_mask_search_shape[1:])

        # Reshape pos_search
        pos_search_shape = [int(x) for x in pos_search[-1].shape]
        pos_search = pos_search[-1].view(search_shape[:2] +
                                         pos_search_shape[1:])

        # Split off a pair of TransT features and then use the circuit to gate the Qs in the decoder.
        src_search = self.input_proj(src_search[:, -1])  # samp_idx
        src_template = self.input_proj(src_template)
        mask_search = mask_search[:, -1]  # samp_idx]
        pos_search = pos_search[:, -1]  # samp_idx]
        hs = self.featurefusion_network(
            src_template, mask_template, src_search, mask_search,
            pos_template[-1],
            pos_search)  # , exc=proc_rnn[:, :, -1])  # .sigmoid())

        # Concat exc to hs too
        # hs = torch.cat([hs, self.rnn_embed(exc)], -1)
        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {
            'pred_logits': outputs_class[-1],
            'pred_boxes': outputs_coord[-1],
            'hgru_boxes': src_search
        }
        return out
示例#6
0
    def forward(self, search, template, samp_idx, labels, settings, boxes):
        """ The forward expects a NestedTensor, which consists of:
               - search.tensors: batched images, of shape [batch_size x 3 x H_search x W_search]
               - search.mask: a binary mask of shape [batch_size x H_search x W_search], containing 1 on padded pixels
               - template.tensors: batched images, of shape [batch_size x 3 x H_template x W_template]
               - template.mask: a binary mask of shape [batch_size x H_template x W_template], containing 1 on padded pixels

            It returns a dict with the following elements:
               - "pred_logits": the classification logits for all feature vectors.
                                Shape= [batch_size x num_vectors x (num_classes + 1)]
               - "pred_boxes": The normalized boxes coordinates for all feature vectors, represented as
                               (center_x, center_y, height, width). These values are normalized in [0, 1],
                               relative to the size of each individual image.

        """
        # Reshape search into a 4D tensor
        search_shape = [int(x) for x in search.shape]
        searcht = search.view([search_shape[0] * search_shape[1]] +
                              search_shape[2:])
        if not isinstance(searcht, NestedTensor):
            searcht = nested_tensor_from_tensor(searcht)
        if not isinstance(template, NestedTensor):
            templatet = nested_tensor_from_tensor(template)
        with torch.no_grad():
            feature_search, pos_search = self.backbone(searcht)
            feature_template, pos_template = self.backbone(templatet)
        src_search, mask_search = feature_search[-1].decompose()
        assert mask_search is not None
        src_template, mask_template = feature_template[-1].decompose()
        assert mask_template is not None

        post_src_search_shape = [int(x) for x in src_search.shape]
        post_mask_search_shape = [int(x) for x in mask_search.shape]
        src_search = src_search.view(search_shape[:2] +
                                     post_src_search_shape[1:])
        mask_search = mask_search.view(search_shape[:2] +
                                       post_mask_search_shape[1:])
        proc_labels = F.interpolate(labels, src_search.shape[-2:])
        # Duble check that the below is correct
        # rnn_src_search = self.nl(self.rnn_proj(src_search.permute(0, 2, 1, 3, 4)))
        # psrc = src_search.permute(0, 2, 1, 3, 4)
        # rnn_channels = self.rnn_proj_channel_2(self.nl(self.rnn_proj_channel_1(psrc.mean(dim=(3, 4), keepdim=True))))
        # rnn_space = self.rnn_proj_spatial_2(self.nl(self.rnn_proj_spatial_1(psrc)))
        # rnn_src_search = F.softplus(rnn_space + rnn_channels)  # Used to be multiplicative

        proc_label_shape = [int(x) for x in proc_labels.shape]
        proc_labels = proc_labels.view([search_shape[0]] +
                                       proc_label_shape[1:]).to(
                                           src_search.device)  # .mean(2)
        exc_1, inh_1 = None, None
        exc_2, inh_2 = None, None
        td_inh = None
        excs = []
        # for t in range(rnn_src_search.shape[2]):
        # inh_1 = self.nl(self.inh_1_proj(src_search[:, 0]))
        inh_1 = self.nl(self.inh_1_init(src_search[:, 0]))
        exc_1 = self.nl(self.exc_1_init(src_search[:, 0]))

        for t in range(src_search.shape[1]):

            # Step 1, saliency
            pre_exc_1, inh_1 = self.circuit_1(self.nl(src_search[:, t]),
                                              excitation=exc_1,
                                              inhibition=inh_1)

            # Step 2, tracking
            exc_2, inh_2 = self.circuit_2(self.step2_trans(pre_exc_1),
                                          excitation=exc_2,
                                          inhibition=inh_2,
                                          label=proc_labels)

            # Step 3, TD-FB
            if t == 0:
                # td_inh = self.nl(self.td_inh_proj(exc_2))
                td_inh = self.nl(torch.zeros_like(exc_2))
            td_exc = self.nl(self.step3_trans(exc_2))

            # exc_1, td_inh = self.circuit_td(pre_exc_1, excitation=self.nl(self.step3_trans(exc_2)), inhibition=td_inh)
            exc_1, td_inh = self.circuit_td(pre_exc_1,
                                            excitation=td_exc,
                                            inhibition=td_inh)

            # plt.subplot(141);plt.imshow(search[0, t].squeeze().permute(1, 2, 0).cpu());plt.subplot(142);plt.imshow((rnn_src_search[0, :, t].squeeze() ** 2).mean(0).detach().cpu());plt.subplot(143);plt.imshow((exc[0].squeeze() ** 2).mean(0).detach().cpu());plt.subplot(144);plt.imshow((inh[0].squeeze() ** 2).mean(0).detach().cpu());plt.show()
            plt.subplot(141)
            plt.imshow(search[0, t].squeeze().permute(1, 2, 0).cpu())
            plt.subplot(142)
            plt.imshow((pre_exc_1[0]**2).squeeze().mean(0).detach().cpu())
            plt.subplot(143)
            plt.imshow((exc_2[0]**2).squeeze().mean(0).detach().cpu())
            plt.subplot(144)
            plt.imshow((exc_1[0]**2).squeeze().mean(0).detach().cpu())
            # plt.show()
            excs.append(exc_1)  # Also try exc_1?
            # ros.append(self.rnn_embed(exc_2.flatten(2).permute(2, 0, 1)).sigmoid())
            # if t == samp_idx - 2 and self.vj_pen:
            #     penultimate = exc.clone()
        # ros = self.rnn_embed(exc_2.flatten(2).permute(2, 0, 1).unsqueeze(0).transpose(1, 2)).sigmoid()

        if self.vj_pen:
            norm_1_vect = torch.ones_like(exc)
            norm_1_vect.requires_grad = False
            vj_prod = torch.autograd.grad(exc,
                                          penultimate,
                                          grad_outputs=[norm_1_vect],
                                          retain_graph=True,
                                          create_graph=True,
                                          allow_unused=True)[0]
            vj_penalty = (vj_prod -
                          0.95).clamp(0)**2  # Squared to emphasize outliers
            vj_penalty = vj_penalty.sum()  # Save memory with the mean

        # from matplotlib import pyplot as plt
        # im=0;ti=0;plt.subplot(141);plt.imshow(proc_labels[im].squeeze().cpu());plt.subplot(142);plt.imshow((rnn_src_search[im, :, ti] ** 2).mean(0).detach().cpu());plt.subplot(143);plt.imshow((exc[im] ** 2).squeeze().mean(0).detach().cpu());plt.subplot(144);plt.imshow(search[im, ti].squeeze().permute(1, 2, 0).cpu());plt.show()

        # Reshape pos_search
        pos_search_shape = [int(x) for x in pos_search[-1].shape]
        pos_search = pos_search[-1].view(search_shape[:2] +
                                         pos_search_shape[1:])

        # Split off a pair of TransT features and then use the circuit to gate the Qs in the decoder.
        src_search = self.input_proj(src_search[:, -1])  # samp_idx
        src_template = self.input_proj(src_template)
        mask_search = mask_search[:, -1]  # samp_idx]
        pos_search = pos_search[:, -1]  # samp_idx]
        dec_rnn = self.rnn_decode_1(self.bn(torch.stack(excs, 2)))
        proc_rnn = self.rnn_decode_2(
            self.nl(dec_rnn))  # * torch.sigmoid(self.rnn_gate(dec_rnn))
        hs, _ = self.featurefusion_network(src_template,
                                           mask_template,
                                           src_search,
                                           mask_search,
                                           pos_template[-1],
                                           pos_search,
                                           exc=proc_rnn[:, :, -1])

        # Concat exc to hs too
        # hs = torch.cat([hs, self.rnn_embed(exc)], -1)
        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        # exc_bbox = self.exc_bbox(exc).sigmoid()
        # coords = torch.stack(torch.meshgrid(torch.arange(excs.shape[-2]), torch.arange(excs.shape[-1])), 0)
        # coords = coords / coords.max()
        # coords = coords.float().to(excs.device)[None, :, None].repeat(excs.shape[0], 1, excs.shape[2], 1, 1)
        # excs = torch.cat([coords, excs], 1)e
        # exc = exc.view(exc.shape[1], self.height, self.height, -1).permute(0, 3, 1, 2)
        # excs = self.exc_bbox_2(F.relu(self.exc_bbox_1(excs)))
        # exc_bbox = self.exc_bbox_3(excs.flatten(3).squeeze(1)).sigmoid()
        proc_rnn = self.exc_bbox_3(
            self.exc_bbox_1(proc_rnn).view(proc_rnn.shape[0],
                                           proc_rnn.shape[2], -1)).sigmoid()
        if self.vj_pen:
            out = {
                'pred_logits': outputs_class[-1],
                'pred_boxes': outputs_coord[-1],
                'hgru_boxes': exc_bbox,
                "vj_penalty": vj_penalty
            }
        else:
            out = {
                'pred_logits': outputs_class[-1],
                'pred_boxes': outputs_coord[-1],
                'hgru_boxes': proc_rnn
            }
        return out