def finetune_detector(self, box_roi_pool, fpn_features, gt_box, bbox_pred_decoder, image, finetuning_config, box_head, box_predictor, plot=False): self.box_head = box_head self.box_predictor = box_predictor self.box_predictor.train() self.box_head.train() optimizer = torch.optim.Adam(list(self.box_predictor.parameters()), lr=float( finetuning_config["learning_rate"])) criterion = torch.nn.SmoothL1Loss() if isinstance(fpn_features, torch.Tensor): fpn_features = OrderedDict([(0, fpn_features)]) if finetuning_config["validate"]: if not self.plotter: self.plotter = VisdomLinePlotter() validation_boxes = self.generate_training_set( float(finetuning_config["max_displacement"]), batch_size=int(finetuning_config["batch_size_val"]), plot=plot, plot_args=(image, "val", self.id)).to(device) validation_boxes_resized = resize_boxes( validation_boxes, self.im_info, self.transformed_image_size[0]) proposals_val = [validation_boxes_resized] roi_pool_feat_val = box_roi_pool(fpn_features, proposals_val, self.im_info) plotter = VisdomLinePlotter() save_state_box_predictor = FastRCNNPredictor(1024, 2).to(device) save_state_box_predictor.load_state_dict( self.box_predictor.state_dict()) self.checkpoints[0] = [box_head, save_state_box_predictor] for i in range(int(finetuning_config["iterations"])): if finetuning_config["validation_over_time"]: if not self.plotter: self.plotter = VisdomLinePlotter() print("Making Plotter") if np.mod(i + 1, finetuning_config["checkpoint_interval"]) == 0: self.box_predictor.eval() save_state_box_predictor = FastRCNNPredictor(1024, 2).to(device) save_state_box_predictor.load_state_dict( self.box_predictor.state_dict()) self.checkpoints[i + 1] = [box_head, save_state_box_predictor] #input('Checkpoints are the same: {} {}'.format(i+1, Tracker.compare_weights(self.box_predictor, self.checkpoints[0][1]))) self.box_predictor.train() optimizer.zero_grad() training_boxes = self.generate_training_set( float(finetuning_config["max_displacement"]), batch_size=int(finetuning_config["batch_size"]), plot=plot, plot_args=(image, i, self.id)).to(device) boxes = resize_boxes(training_boxes, self.im_info, self.transformed_image_size[0]) scaled_gt_box = resize_boxes( gt_box.unsqueeze(0), self.im_info, self.transformed_image_size[0]).squeeze(0) proposals = [boxes] with torch.no_grad(): roi_pool_feat = box_roi_pool(fpn_features, proposals, self.im_info) # feed pooled features to top model pooled_feat = self.box_head(roi_pool_feat) # compute bbox offset _, bbox_pred = self.box_predictor(pooled_feat) pred_boxes = bbox_pred_decoder(bbox_pred, proposals) pred_boxes = pred_boxes[:, 1:].squeeze(dim=1) if np.mod(i, int(finetuning_config["iterations_per_validation"]) ) == 0 and finetuning_config["validate"]: pooled_feat_val = self.box_head(roi_pool_feat_val) _, bbox_pred_val = self.box_predictor(pooled_feat_val) pred_boxes_val = bbox_pred_decoder(bbox_pred_val, proposals_val) pred_boxes_val = pred_boxes_val[:, 1:].squeeze(dim=1) #plot_bounding_boxes(self.im_info, # gt_box.unsqueeze(0), # image, # resize_boxes(pred_boxes_val, self.transformed_image_size[0], self.im_info), # i, # self.id, # validate=True) val_loss = criterion( pred_boxes_val, scaled_gt_box.repeat( int(finetuning_config["batch_size_val"]), 1)) plotter.plot('loss', 'val', "Bbox Loss Track {}".format(self.id), i, val_loss.item()) loss = criterion( pred_boxes, scaled_gt_box.repeat(int(finetuning_config["batch_size"]), 1)) print('Finished iteration {} --- Loss {}'.format(i, loss.item())) loss.backward() optimizer.step() self.box_predictor.eval() self.box_head.eval()
def get_box_predictor(self): box_predictor = FastRCNNPredictor(1024, 2).to(device) box_predictor.load_state_dict(self.bbox_predictor_weights) return box_predictor