def forward_pass_teacher(self, net, data, run_box_head, run_cls_head): feat_dict_list = [] # process the templates for i in range(self.settings.num_template): template_img_i = data['template_images'][i].view( -1, *data['template_images'].shape[2:]) # (batch, 3, 128, 128) template_att_i = data['template_att'][i].view( -1, *data['template_att'].shape[2:]) # (batch, 128, 128) feat_dict_list.append( net(img=template_img_i, mask=template_att_i, mode='backbone', zx="template%d" % i)) # process the search regions (t-th frame) search_img = data['search_images'].view( -1, *data['search_images'].shape[2:]) # (batch, 3, 320, 320) search_att = data['search_att'].view( -1, *data['search_att'].shape[2:]) # (batch, 320, 320) feat_dict_list.append( net(img=search_img, mask=search_att, mode='backbone', zx="search")) # run the transformer and compute losses seq_dict = merge_template_search(feat_dict_list) out_dict, _, _ = net(seq_dict=seq_dict, mode="transformer", run_box_head=run_box_head, run_cls_head=run_cls_head) # out_dict: (B, N, C), outputs_coord: (1, B, N, C), target_query: (1, B, N, C) return out_dict
def track(self, image, info: dict = None): H, W, _ = image.shape self.frame_id += 1 x_patch_arr, resize_factor, x_amask_arr = sample_target( image, self.state, self.params.search_factor, output_sz=self.params.search_size) # (x1, y1, w, h) search = self.preprocessor.process(x_patch_arr, x_amask_arr) with torch.no_grad(): x_dict = self.network.forward_backbone(search) # merge the template and the search feat_dict_list = [self.z_dict1, x_dict] seq_dict = merge_template_search(feat_dict_list) # run the transformer out_dict, _, _ = self.network.forward_transformer( seq_dict=seq_dict, run_box_head=True) pred_boxes = out_dict['pred_boxes'].view(-1, 4) # Baseline: Take the mean of all pred boxes as the final result pred_box = (pred_boxes.mean(dim=0) * self.params.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1] # get the final box result self.state = clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10) # Clipping helps to improve robustness. Experiments shows that it doesn't influence performance # self.state = self.map_box_back(pred_box, resize_factor) # for debug if self.debug: x1, y1, w, h = self.state image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) cv2.rectangle(image_BGR, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color=(0, 0, 255), thickness=2) save_path = os.path.join(self.save_dir, "%04d.jpg" % self.frame_id) cv2.imwrite(save_path, image_BGR) if self.save_all_boxes: '''save all 10 predictions''' all_boxes = self.map_box_back_batch( pred_boxes * self.params.search_size / resize_factor, resize_factor) all_boxes_save = all_boxes.view(-1).tolist() # (4N, ) return {"target_bbox": self.state, "all_boxes": all_boxes_save} else: return {"target_bbox": self.state}
'''import stark network module''' model_module = importlib.import_module('lib.models.stark') if args.script == "stark_s": model_constructor = model_module.build_starks model = model_constructor(cfg) # get the template and search template = get_data(bs, z_sz) search = get_data(bs, x_sz) # transfer to device model = model.to(device) template = template.to(device) search = search.to(device) # forward template and search oup_t = model.forward_backbone(template) oup_s = model.forward_backbone(search) seq_dict = merge_template_search([oup_t, oup_s]) # evaluate the model properties evaluate(model, search, seq_dict, run_box_head=True, run_cls_head=False) elif args.script == "stark_st2": model_constructor = model_module.build_starkst model = model_constructor(cfg) # get the template and search template1 = get_data(bs, z_sz) template2 = get_data(bs, z_sz) search = get_data(bs, x_sz) # transfer to device model = model.to(device)
def track(self, image, info: dict = None): H, W, _ = image.shape self.frame_id += 1 # get the t-th search region x_patch_arr, resize_factor, x_amask_arr = sample_target( image, self.state, self.params.search_factor, output_sz=self.params.search_size) # (x1, y1, w, h) search = self.preprocessor.process(x_patch_arr, x_amask_arr) with torch.no_grad(): x_dict = self.network.forward_backbone(search) # merge the template and the search feat_dict_list = self.z_dict_list + [x_dict] seq_dict = merge_template_search(feat_dict_list) # run the transformer out_dict, _, _ = self.network.forward_transformer( seq_dict=seq_dict, run_box_head=True, run_cls_head=True) # get the final result pred_boxes = out_dict['pred_boxes'].view(-1, 4) # Baseline: Take the mean of all pred boxes as the final result pred_box = (pred_boxes.mean(dim=0) * self.params.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1] # get the final box result self.state = clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10) # Clipping helps to improve robustness. Experiments shows that it doesn't influence performance # self.state = self.map_box_back(pred_box, resize_factor) # get confidence score (whether the search region is reliable) conf_score = out_dict["pred_logits"].view(-1).sigmoid().item() # update template for idx, update_i in enumerate(self.update_intervals): if self.frame_id % update_i == 0 and conf_score > 0.5: z_patch_arr, _, z_amask_arr = sample_target( image, self.state, self.params.template_factor, output_sz=self.params.template_size) # (x1, y1, w, h) template_t = self.preprocessor.process(z_patch_arr, z_amask_arr) with torch.no_grad(): z_dict_t = self.network.forward_backbone(template_t) self.z_dict_list[ idx + 1] = z_dict_t # the 1st element of z_dict_list is template from the 1st frame # for debug if self.debug: x1, y1, w, h = self.state image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) cv2.rectangle(image_BGR, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color=(0, 0, 255), thickness=2) save_path = os.path.join(self.save_dir, "%04d.jpg" % self.frame_id) cv2.imwrite(save_path, image_BGR) if self.save_all_boxes: '''save all 10 predictions''' all_boxes = self.map_box_back_batch( pred_boxes * self.params.search_size / resize_factor, resize_factor) all_boxes_save = all_boxes.view(-1).tolist() # (4N, ) return { "target_bbox": self.state, "all_boxes": all_boxes_save, "conf_score": conf_score } else: return {"target_bbox": self.state, "conf_score": conf_score}