def forward(self, outputs, targets): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = {k: v for k, v in outputs.items() if k !=\ 'aux_outputs'} indices = self.matcher(outputs_without_aux, targets) num_boxes = sum(len(t['labels']) for t in targets) num_boxes = paddle.to_tensor([num_boxes], dtype=torch.float, device =next(iter(outputs.values())).device) if is_dist_avail_and_initialized(): torch2paddle.all_reduce(num_boxes) num_boxes = paddle.clip(num_boxes / get_world_size(), min=1).item() losses = {} for loss in self.losses: losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: if loss == 'masks': continue kwargs = {} if loss == 'labels': kwargs = {'log': False} l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) l_dict = {(k + f'_{i}'): v for k, v in l_dict.items()} losses.update(l_dict) return losses
def forward(self, outputs, targets): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ # Compute the average number of target boxes accross all nodes, for normalization purposes # TODO: this is a reserved function fro a negative sample training to improve the robustness like DasiamRPN num_boxes = sum(t['valid'].item() for t in targets) # print("num of valid boxes: {}".format(num_boxes)) # debug num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: losses.update(loss(outputs, targets, num_boxes)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): for loss in self.losses: l_dict = loss(aux_outputs, targets, num_boxes) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) return losses
def forward(self, outputs, targets): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'} # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: kwargs = {} losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs['log'] = False l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) if 'enc_outputs' in outputs: enc_outputs = outputs['enc_outputs'] bin_targets = copy.deepcopy(targets) for bt in bin_targets: bt['labels'] = torch.zeros_like(bt['labels']) indices = self.matcher(enc_outputs, bin_targets) for loss in self.losses: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs['log'] = False l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) l_dict = {k + f'_enc': v for k, v in l_dict.items()} losses.update(l_dict) return losses
def get_num_boxes(self, num_samples): num_boxes = torch.as_tensor(num_samples, dtype=torch.float, device=self.sample_device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() return num_boxes
def forward(self, outputs): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format """ # Since we are doing meta-learning over our constructed meta-tasks, the targets for these meta-tasks are # stored in outputs['meta_targets']. We dont use original targets. targets = outputs['meta_targets'] outputs_without_aux = { k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs' } # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: kwargs = {} losses.update( self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: if loss == 'category_codes_cls': # meta-attention cls loss not for aux_outputs continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs['log'] = False l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) return losses
def forward(self, outputs, targets): """Loss computation. Args: outputs (dict): Dict of RTD outputs, which are tensors. targets (dict): List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied. Returns: losses (dict): Dict of losses. """ outputs_without_aux = { k: v for k, v in outputs.items() if k != 'aux_outputs' } # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t['labels']) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: losses.update( self.get_loss(loss, outputs, targets, indices, num_boxes)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs and 'iou' not in self.losses: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs = {'log': False} l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) return losses
def forward(self, outputs, targets, origin_indices=None): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} origin_indices = self.matcher(outputs_without_aux, targets) num_items = sum(len(t["labels"]) for t in targets) num_items = torch.as_tensor([num_items], dtype=torch.float, device=next(iter(outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_items) num_items = torch.clamp(num_items / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: losses.update(self.get_loss(loss, outputs, targets, num_items, origin_indices=origin_indices)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. aux_name = 'aux_outputs' if aux_name in outputs: for i, aux_outputs in enumerate(outputs[aux_name]): origin_indices = self.matcher(aux_outputs, targets) for loss in self.losses: kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs = {'log': False} l_dict = self.get_loss(loss, aux_outputs, targets, num_items, origin_indices=origin_indices, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) return losses
def gather(self, input_list): input = torch.cat(input_list) size = utils.get_world_size() if size == 1: if len(input.shape) > 1: input = input[input.sum(-1) != -1234 * input.shape[1]] else: input = input[input != -1234] return input input_list = [torch.zeros_like(input) for _ in range(size)] torch.cuda.synchronize() dist.all_gather(input_list, input) input = torch.cat(input_list) if len(input.shape) > 1: input = input[input.sum(-1) != -1234 * input.shape[1]] else: input = input[input != -1234] return input
def forward(self, outputs, targets): outputs_without_aux = { k: v for k, v in outputs.items() if k != 'aux_outputs' } # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) num_interactions = sum(len(t['obj_labels']) for t in targets) num_interactions = torch.as_tensor([num_interactions], dtype=torch.float, device=next(iter( outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_interactions) num_interactions = torch.clamp(num_interactions / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: losses.update( self.get_loss(loss, outputs, targets, indices, num_interactions)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: kwargs = {} if loss == 'obj_labels': # Logging is enabled only for the last layer kwargs = {'log': False} l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_interactions, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) return losses
def clustering(self, image_path=None): # sync data self.sync_pseudo_gt() feature = self.gather(self.feature_memory) obj_score = self.gather(self.obj_score_memory) paths = self.gather(self.path_memory) bbox = self.gather(self.bbox_memory) self.feature_memory = [] self.obj_score_memory = [] self.path_memory = [] self.bbox_memory = [] if utils.get_rank() == 0 and self.cls_weight.weight.sum() < len( self.cls_weight.weight): ids, centroid, var = clustering(feature, K=self.num_centroid, step=self.step, device=feature.device, tol=1e-3, Niter=150) count = torch.bincount(ids) mean_obj_score = torch.bincount( ids, weights=obj_score.to(ids.device)) / (count + 1e-6) # top 10 % dense clusters. dist_topk_bound = -torch.topk( -var.view(-1), k=min(len(mean_obj_score), 13)).values[-1] mask = var < dist_topk_bound # number of found unknown classes cls_weight = sum(self.cls_weight.weight) - self.num_classes # high objectness clusters. cluster_obj_thresh = min( self.cluster_obj_thresh * (1 + cls_weight / len(self.cls_weight.weight)), 0.99) obj_mask = mean_obj_score.to(mask.device) > cluster_obj_thresh mask = torch.logical_and(mask, obj_mask.to(mask.device)) mask = mask.bool().view(-1) ids = ids.long().view(-1) paths = paths[mask[ids]] bbox = bbox[mask[ids]] feature = feature[mask[ids]] obj_score = obj_score[mask[ids]] ids = ids[mask[ids]] centroid = centroid[mask] if len(obj_score) > 0: obj_thresh = min(self.coupled_obj_thresh, max(obj_score)) else: obj_thresh = self.coupled_obj_thresh obj_thresh = obj_thresh + (self.n_pseudo_gt * 0.01 / 100) obj_thresh = min(obj_thresh, 0.99) idx = obj_score >= obj_thresh bbox = bbox[idx] feature = feature[idx] paths = paths[idx] obj_score = obj_score[idx] ids = ids[idx] feats = [] boxes = [] ps = [] obj_scores = [] new_ids = [] cls_weight = sum(self.cls_weight.weight) - self.num_classes coupled_cos_thresh = self.coupled_cos_thresh * ( 1 - cls_weight / len(self.cls_weight.weight)) coupled_cos_thresh = max(coupled_cos_thresh, 0.01) for i, l in enumerate(sorted(ids.unique())): idx = ids == l feat = feature[idx] bb = bbox[idx] path = paths[idx] obj = obj_score[idx] cos_sim = get_cos_sim(feat, feat).view(-1) cos_dist = 1 - cos_sim idx = cos_dist.argsort() used = [] used_path = [] printer = cos_sim[idx] printer = printer[printer < 0.99999] # eliminate same element pairs for v in idx: x, y = v // len(feat), v % len(feat) if cos_dist[v] > coupled_cos_thresh: break if path[x] != path[y] and path[ x] not in used_path and path[y] not in used_path: used.append(x) used.append(y) used_path.append(path[x]) used_path.append(path[y]) if len(used) > 0: idx = torch.as_tensor(used, device=feat.device) temp_ids = torch.ones( (len(used), ), device=feat.device) * l feats.append(feat[idx]) boxes.append(bb[idx]) ps.append(path[idx]) obj_scores.append(obj[idx]) new_ids.append(temp_ids) if len(feats) > 0: feature = torch.cat(feats) bbox = torch.cat(boxes) paths = torch.cat(ps) obj_score = torch.cat(obj_scores) ids = torch.cat(new_ids) cls_weight = self.cls_weight.weight start_l = int(cls_weight.sum() ) + self.original_num_classes - self.num_classes labels = -ids - 1 unique_label = labels.unique() unique_label = unique_label[:cls_weight.shape[0] - int(cls_weight.sum())] for i, p in enumerate(unique_label): if i + start_l - self.original_num_classes == self.num_centroid: break labels[labels == p] = i + start_l idx = labels > 0 obj_score = obj_score[idx] labels = labels[idx] paths = paths[idx] feature = feature[idx] bbox = bbox[idx] data = torch.cat( (paths.unsqueeze(1), labels.unsqueeze(1).float(), bbox), dim=-1) else: data = torch.zeros((0, 6), device=feature.device) if image_path is not None and len(data) > 0: utils.save_boxes(data, feature.detach(), obj_score.detach(), image_path, self.pal, self.step, self.num_classes, self.output_dir) size = torch.as_tensor([len(data), len(centroid)], device=feature.device).float() storage = get_event_storage() storage.put_scalar("exemplar/obj_th", float(obj_thresh)) storage.put_scalar("exemplar/cluster_obj_th", float(cluster_obj_thresh)) storage.put_scalar("exemplar/sel_cluster", int(mask.sum())) storage.put_scalar("exemplar/coupled_cos_th", float(coupled_cos_thresh)) storage.put_scalar("exemplar/new", len(data)) else: size = torch.empty(size=(1, 2), device=feature.device) # gather if utils.get_world_size() > 1: torch.cuda.synchronize() dist.broadcast(size, 0) if utils.get_rank() > 0: data = torch.empty(size=(int(size[0, 0]), 6), device=feature.device) torch.cuda.synchronize() dist.broadcast(data, 0) l_cls = self.original_num_classes - 1 l_new = int(data[:, 1].max() - l_cls) if len(data) > 0 else 0 cls_weight = self.cls_weight.weight.data cls_weight[:self.num_classes + l_new] = 1 self.cls_weight.weight.data = cls_weight if self.pseudo_gt is None: self.pseudo_gt = data else: self.pseudo_gt = torch.cat((self.pseudo_gt, data)) self.n_pseudo_gt = len(self.pseudo_gt) # flush if utils.get_rank() == 0: try: torch.save( self.pseudo_gt.cpu(), os.path.join(self.output_dir, 'pseudo_gts/{}.pth'.format(self.step))) except: pass
def sync_pseudo_gt(self, templete=None, dir_name='pseudo_gts'): size = utils.get_world_size() if self.pseudo_gt is None or size == 1: return try: data = self.pseudo_gt[self.n_pseudo_gt:].view(-1, 6) path = data[:, 0].long() label = data[:, 1] boxes = data[:, 2:] dir_name = os.path.join(self.output_dir, dir_name) if not os.path.exists(dir_name): os.mkdir(dir_name) for p in path.unique(): img = cv2.imread(templete.format(p)) img_h, img_w, _ = img.shape multiplier = torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32, device=data.device) idx = path == p bbox = boxes[idx] bbox = bbox * multiplier bbox = bbox.int().cpu().numpy() lbl = label[idx] if not os.path.exists(dir_name + '/{:05}'.format(self.step)): os.mkdir(dir_name + '/{:05}'.format(self.step)) for i, box in enumerate(bbox): cropped_image = img[box[1]:box[3] + 1, box[0]:box[2] + 1] framed_image = cropped_image.astype(np.uint8) cropped_img_path = os.path.join( dir_name, '{:05}/{:03}_{:012}_{:03}.jpg'.format( self.step, int(lbl[i]), int(p), i)) out = cv2.imwrite(cropped_img_path, framed_image) if not out: print("FAIL TO SAVE") except: print("FAIL TO SAVE") rank = utils.get_rank() array = torch.zeros((size, 1), device=self.pseudo_gt.device) array[rank] = len(self.pseudo_gt) - self.n_pseudo_gt dist.all_reduce(array, dist.ReduceOp.SUM) data = self.pseudo_gt[self.n_pseudo_gt:] max_size = int(array.max()) data = torch.cat((data, torch.zeros((max_size - len(data), data.shape[1]), device=data.device))) input_list = [ torch.empty(size=(max_size, 6), device=self.pseudo_gt.device) for i in array ] dist.all_gather(input_list, data) input_list = [e[:int(array[i])] for i, e in enumerate(input_list)] data = torch.cat(input_list) print("{} data sync".format(len(data))) self.pseudo_gt = torch.cat((self.pseudo_gt, data)) self.n_pseudo_gt = len(self.pseudo_gt) if utils.get_rank() == 0: print(array) torch.save( self.pseudo_gt.cpu(), os.path.join(self.output_dir, 'pseudo_gts/{}.pth'.format(self.step)))
def forward(self, outputs, targets): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ # 接下来看下前向过程,了解下loss的计算过程。这里一定要先搞清楚模型输出(outputs)和GT(targets)的形式,对于outputs可参考下列的注释; # 而targets是一个包含多个dict的list,长度与batch size相等,其中每个dict的形式如同COCO数据集的标注, # outputs是DETR模型的输出,是一个dict,形式如下: # {'pred_logits':(b, num_queries=100, num_classes), # 'pred_boxes':(b, num_queries=100, 4), # 'aux_outputs':[{'pred_logits':..,'pred_boxes':...}, {...}, ...]} # 过滤掉中间层的输出,只保留最后一层的预测结果 outputs_without_aux = { k: v for k, v in outputs.items() if k != 'aux_outputs' } # Retrieve the matching between the outputs of the last layer and the targets # 计算loss的一个关键前置步骤就是将模型输出的预测结果与GT进行匹配,对应下面self.matcher()部分,返回的indices的形式在下面注释中说明。 # 将预测结果与GT匹配,indices是一个包含多个元组的list,长度与batch size相等,每个元组为(index_i,index_j),前者是匹配的预测索引, # 后者是GT索引,并且len(index_i)=len(index_j)=min(num_queries,num_targets_in_image) indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes # 计算这个batch的图像中目标物体的数量,在所有分布式节点之间同步 num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses # 接下来是计算各种类型的loss,并将对应结果存到一个dict中(下面的losses变量),self.get_loss()方法返回loss计算结果。 losses = {} for loss in self.losses: # 计算特定类型的loss(这里的loss变量是字符串:'labels','boxes','cardinality','masks',表示loss类型), # get_loss方法中并不涉及具体loss的计算,其仅仅是将不同类型的loss计算映射到对应的方法,最后将计算结果返回。 losses.update( self.get_loss(loss, outputs, targets, indices, num_boxes)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. # 若模型输出包含了中间层输出,则一并计算对应的loss if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs = {'log': False} l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) return losses
def forward(self, outputs, targets, indices_track=None, track_on=False): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = { k: v for k, v in outputs.items() if k != 'aux_outputs' } if track_on: track_exists = "pred_tracks" in outputs_without_aux.keys() assert track_exists is True # Track Match. indices_track, targets = self.track_matcher( outputs_without_aux, targets, indices_track=indices_track) # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter( outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses losses = {} losses_type = self.losses + ['tracks'] for loss in losses_type: losses.update( self.get_loss(loss, outputs, targets, indices_track, num_boxes)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): assert track_exists is True for loss in losses_type: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs = {'log': False} # we use the default matcher in tracking target. l_dict = self.get_loss(loss, aux_outputs, targets, indices_track, num_boxes, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) else: track_exists = "pred_tracks" in outputs_without_aux.keys() if track_exists: outputs_without_aux.pop("pred_tracks") # Retrieve the matching between the outputs of the last layer and the targets indices_track = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter( outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: losses.update( self.get_loss(loss, outputs, targets, indices_track, num_boxes)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): if track_exists: aux_outputs.pop("pred_tracks") indices = self.matcher(aux_outputs, targets) for loss in self.losses: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs = {'log': False} l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) return losses, indices_track
def forward(self, outputs, bbox_tgts, clas_tgts): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ device = next(iter(outputs.values())).device targets = [] # for each image for bbox_gt, class_gt in zip(bbox_tgts, clas_tgts): # extract non zero boxes and labels bbox_gt = bbox_gt[np.nonzero(class_gt)].squeeze(dim=1).cpu() class_gt = class_gt[class_gt > 0] - 1 # change gt from y,x,y2,x2 -> x,y,w,h bbox_gt[:, 2:] = bbox_gt[:, 2:] - bbox_gt[:, :2] bbox_gt = bbox_gt[:, [1, 0, 3, 2]] # change gt from x,y,w,h -> cxcywh bbox_gt[:, :2] = bbox_gt[:, :2] + 0.5 * bbox_gt[:, 2:] # scale form input(-1, 1) to expected (0, 1) bbox_gt[:, 2:] = bbox_gt[:, 2:] / 2. bbox_gt[:, :2] = (bbox_gt[:, :2] + 1) / 2. targets.append({ "boxes": bbox_gt.to(device), "labels": class_gt.to(device), }) outputs_without_aux = { k: v for k, v in outputs.items() if k != 'aux_outputs' } # Retrieve the matching between the outputs of the last layer and the targets indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # Compute all the requested losses losses = {} for loss in self.losses: losses.update( self.get_loss(loss, outputs, targets, indices, num_boxes)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs = {'log': False} l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) self.metrics = {} for name in losses: if name in self.metric_names: self.metrics[name] = losses[ name] if name not in self.weight_dict else losses[ name] * self.weight_dict[name] losses = sum(losses[k] * self.weight_dict[k] for k in losses.keys() if k in self.weight_dict) return losses