def finetune_detector(self,
                          box_roi_pool,
                          fpn_features,
                          gt_box,
                          bbox_pred_decoder,
                          image,
                          finetuning_config,
                          box_head,
                          box_predictor,
                          plot=False):
        self.box_head = box_head
        self.box_predictor = box_predictor

        self.box_predictor.train()
        self.box_head.train()
        optimizer = torch.optim.Adam(list(self.box_predictor.parameters()),
                                     lr=float(
                                         finetuning_config["learning_rate"]))
        criterion = torch.nn.SmoothL1Loss()

        if isinstance(fpn_features, torch.Tensor):
            fpn_features = OrderedDict([(0, fpn_features)])

        if finetuning_config["validate"]:
            if not self.plotter:
                self.plotter = VisdomLinePlotter()
            validation_boxes = self.generate_training_set(
                float(finetuning_config["max_displacement"]),
                batch_size=int(finetuning_config["batch_size_val"]),
                plot=plot,
                plot_args=(image, "val", self.id)).to(device)
            validation_boxes_resized = resize_boxes(
                validation_boxes, self.im_info, self.transformed_image_size[0])
            proposals_val = [validation_boxes_resized]
            roi_pool_feat_val = box_roi_pool(fpn_features, proposals_val,
                                             self.im_info)
            plotter = VisdomLinePlotter()

        save_state_box_predictor = FastRCNNPredictor(1024, 2).to(device)
        save_state_box_predictor.load_state_dict(
            self.box_predictor.state_dict())

        self.checkpoints[0] = [box_head, save_state_box_predictor]

        for i in range(int(finetuning_config["iterations"])):

            if finetuning_config["validation_over_time"]:
                if not self.plotter:
                    self.plotter = VisdomLinePlotter()
                    print("Making Plotter")
                if np.mod(i + 1,
                          finetuning_config["checkpoint_interval"]) == 0:
                    self.box_predictor.eval()
                    save_state_box_predictor = FastRCNNPredictor(1024,
                                                                 2).to(device)
                    save_state_box_predictor.load_state_dict(
                        self.box_predictor.state_dict())
                    self.checkpoints[i +
                                     1] = [box_head, save_state_box_predictor]
                    #input('Checkpoints are the same: {} {}'.format(i+1, Tracker.compare_weights(self.box_predictor, self.checkpoints[0][1])))
                    self.box_predictor.train()

            optimizer.zero_grad()
            training_boxes = self.generate_training_set(
                float(finetuning_config["max_displacement"]),
                batch_size=int(finetuning_config["batch_size"]),
                plot=plot,
                plot_args=(image, i, self.id)).to(device)

            boxes = resize_boxes(training_boxes, self.im_info,
                                 self.transformed_image_size[0])
            scaled_gt_box = resize_boxes(
                gt_box.unsqueeze(0), self.im_info,
                self.transformed_image_size[0]).squeeze(0)
            proposals = [boxes]

            with torch.no_grad():
                roi_pool_feat = box_roi_pool(fpn_features, proposals,
                                             self.im_info)
                # feed pooled features to top model
                pooled_feat = self.box_head(roi_pool_feat)

            # compute bbox offset
            _, bbox_pred = self.box_predictor(pooled_feat)

            pred_boxes = bbox_pred_decoder(bbox_pred, proposals)
            pred_boxes = pred_boxes[:, 1:].squeeze(dim=1)

            if np.mod(i, int(finetuning_config["iterations_per_validation"])
                      ) == 0 and finetuning_config["validate"]:
                pooled_feat_val = self.box_head(roi_pool_feat_val)
                _, bbox_pred_val = self.box_predictor(pooled_feat_val)
                pred_boxes_val = bbox_pred_decoder(bbox_pred_val,
                                                   proposals_val)
                pred_boxes_val = pred_boxes_val[:, 1:].squeeze(dim=1)
                #plot_bounding_boxes(self.im_info,
                #                    gt_box.unsqueeze(0),
                #                    image,
                #                    resize_boxes(pred_boxes_val, self.transformed_image_size[0], self.im_info),
                #                    i,
                #                    self.id,
                #                    validate=True)
                val_loss = criterion(
                    pred_boxes_val,
                    scaled_gt_box.repeat(
                        int(finetuning_config["batch_size_val"]), 1))
                plotter.plot('loss', 'val',
                             "Bbox Loss Track {}".format(self.id), i,
                             val_loss.item())

            loss = criterion(
                pred_boxes,
                scaled_gt_box.repeat(int(finetuning_config["batch_size"]), 1))
            print('Finished iteration {} --- Loss {}'.format(i, loss.item()))

            loss.backward()
            optimizer.step()

        self.box_predictor.eval()
        self.box_head.eval()
 def get_box_predictor(self):
     box_predictor = FastRCNNPredictor(1024, 2).to(device)
     box_predictor.load_state_dict(self.bbox_predictor_weights)
     return box_predictor