def loss_batch(model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None, cb_handler:Optional[CallbackHandler]=None)->Tuple[Union[Tensor,int,float,str]]: "Calculate loss and metrics for a batch, call out to callbacks as necessary." cb_handler = ifnone(cb_handler, CallbackHandler()) device = xb.device # Translate from fastai box format to torchvision. batch_sz = len(xb) images = xb targets = [] for i in range(batch_sz): boxes = yb[0][i] labels = yb[1][i] boxes = to_box_pixel(boxes, *images[0].shape[1:3]) targets.append(BoxList(boxes, labels=labels)) out = None loss = torch.Tensor([0.0]).to(device=device) if model.training: loss_dict = model(images, targets) loss = loss_dict['total_loss'] cb_handler.state_dict['loss_dict'] = loss_dict else: out = model(images) out = cb_handler.on_loss_begin(out) if opt is not None: loss,skip_bwd = cb_handler.on_backward_begin(loss) if not skip_bwd: loss.backward() if not cb_handler.on_backward_end(): opt.step() if not cb_handler.on_step_end(): opt.zero_grad() return loss.detach().cpu()
def test_encode_decode(self): height, width = 256, 256 stride = 4 num_labels = 2 boxes = torch.tensor([ [0., 0., 128., 128.], [128., 128., 192., 192.], [64., 64., 128., 128.], ]) # shift by two since only numbers that fall in the middle of the stride # can be represented exactly boxes += 2 labels = torch.tensor([0, 0, 1]) bl = BoxList(boxes, labels=labels) boxlists = [bl] positions = get_positions(height, width, stride, boxes.device) keypoint, reg = encode(boxlists, positions, stride, num_labels) path = '/opt/data/pascal2007/encoded.png' plot_encoded(boxlists[0], stride, keypoint[0], reg[0], path) decoded_boxlists = decode( keypoint, reg, positions, stride, prob_thresh=0.05) path = '/opt/data/pascal2007/decoded.png' plot_encoded(decoded_boxlists[0], stride, keypoint[0], reg[0], path) del decoded_boxlists[0].extras['scores'] self.assertTrue(boxlists[0].equal(decoded_boxlists[0]))
def forward(self, input, targets=None): """Forward pass Args: input: tensor<n, 3, h, w> with batch of images targets: None or list<BoxList> of length n with boxes and labels Returns: if targets is None, returns list<BoxList> of length n, containing boxes, labels, and scores for boxes with score > 0.05. Further filtering based on score should be done before considering the prediction "final". if targets is a list, returns the losses as dict of form { } """ if targets: _targets = [bl.xyxy() for bl in targets] _targets = [{ 'boxes': bl.boxes, 'labels': bl.get_field('labels') } for bl in _targets] loss_dict = self.model(input, _targets) loss_dict['total_loss'] = sum(list(loss_dict.values())) return loss_dict out = self.model(input) return [ BoxList(_out['boxes'], labels=_out['labels'], scores=_out['scores']).yxyx() for _out in out ]
def __getitem__(self, ind): img_fn = self.imgs[ind] img_id = self.img2id[img_fn] img = np.array(Image.open(join(self.img_dir, img_fn))) boxes, labels = self.id2boxes[img_id], self.id2labels[img_id] if self.transforms: out = self.transforms(image=img, bboxes=boxes, labels=labels) img = out['image'] boxes = torch.tensor(out['bboxes']) labels = torch.tensor(out['labels']) if len(boxes) > 0: x, y, w, h = boxes[:, 0:1], boxes[:, 1:2], boxes[:, 2:3], boxes[:, 3:4] boxes = torch.cat([y, x, y + h, x + w], dim=1) boxlist = BoxList(boxes, labels=labels) else: boxlist = BoxList(torch.empty((0, 4)), labels=torch.empty((0, ))) return (img, boxlist)
def test_fcos_with_targets(self): h, w = 64, 64 num_labels = 3 x = torch.empty((1, 3, h, w)) model = FCOS('resnet18', num_labels, pretrained=False) boxes = torch.tensor([[0, 0, 16, 16], [8, 8, 12, 12]]) labels = torch.tensor([0, 1]) targets = [BoxList(boxes, labels=labels)] loss_dict = model(x, targets) self.assertTrue('label_loss' in loss_dict) self.assertTrue('reg_loss' in loss_dict) self.assertTrue('center_loss' in loss_dict)
def test_backwards(self): h, w = 64, 64 num_labels = 3 x = 2.0 * torch.rand((1, 3, h, w)) - 1.0 model = FCOS('resnet18', num_labels, pretrained=False) boxes = torch.tensor([[0, 0, 16, 16], [16, 16, 32, 32]]) labels = torch.tensor([0, 1]) targets = [BoxList(boxes, labels=labels)] model.train() model.zero_grad() loss_dict = model(x, targets) loss = sum(list(loss_dict.values())) loss.backward() for param in model.parameters(): self.assertTrue(len(torch.nonzero(param.grad)) > 0)
def decode_batch_output(output, pyramid_shape, img_height, img_width, iou_thresh=0.5): """Decode output for batch of images. Args: output: list of tuples where each tuple corresponds to a pyramid level tuple is of form (reg_arr, label_arr, center_arr) where - reg_arr is tensor<n, 4, h, w>, - label_arr is tensor<n, num_labels, h, w> - center_arr is tensor<n, 1, h, w> and label_arr and center_arr are logits pyramid_shape: img_height: img_width: iou_thresh: (float) iou threshold passed to NMS Returns: list of n BoxLists """ boxlists = [] batch_sz = output[0][0].shape[0] for i in range(batch_sz): single_head_out = [] for level, (reg_arr, label_arr, center_arr) in enumerate(output): # Convert logits in label_arr and center_arr to probabilities. single_head_out.append(( reg_arr[i], torch.sigmoid(label_arr[i]), torch.sigmoid(center_arr[i]))) boxlist = decode_single_output(single_head_out, pyramid_shape) boxlist = BoxList( boxlist.boxes, labels=boxlist.get_field('labels'), scores=boxlist.get_field('scores') * boxlist.get_field('centerness'), centerness=boxlist.get_field('centerness')) boxlist = boxlist.clamp(img_height, img_width) boxlist = boxlist.nms(iou_thresh=iou_thresh) boxlists.append(boxlist) return boxlists
def decode_single_output(output, pyramid_shape, score_thresh=0.05): """Decode output of heads for all levels of pyramid for one image. Args: output: list of tuples where each tuple corresponds to a pyramid level tuple is of form (reg_arr, label_arr, center_arr) where - reg_arr is tensor<4, h, w>, - label_arr is tensor<num_labels, h, w> - center_arr is tensor<1, h, w> and label_arr and center_arr are probabilities score_thresh: (float) probability score threshold used to determine if a box is present at a cell Returns: BoxList """ boxlists = [] for level, level_out in enumerate(output): stride = pyramid_shape[level][0] boxlist = decode_level_output( *level_out, stride, score_thresh=score_thresh) boxlists.append(boxlist) return BoxList.cat(boxlists)
def decode_level_output(reg_arr, label_arr, center_arr, stride, score_thresh=0.05): """Decode output of head for one level of the pyramid for one image. Args: reg_arr: (tensor) with shape (4, h, w). The first dimension ranges over t, l, b, r (ie. top, left, bottom, right). label_arr: (tensor) with shape (num_labels, h, w) containing probabilities center_arr: (tensor) with shape (1, h, w) containing values between 0 and 1 stride: (int) the stride of the level of the pyramid score_thresh: (float) probability score threshold used to determine if a box is present at a cell Returns: BoxList """ device = reg_arr.device h, w = reg_arr.shape[1:] pos_arr = torch.empty((2, h, w), device=device) pos_arr[0, :, :] = torch.arange( stride//2, stride * h, stride, device=device)[:, None] pos_arr[1, :, :] = torch.arange( stride//2, stride * w, stride, device=device)[None, :] boxes = torch.empty((4, h, w), device=device) boxes[0:2, :, :] = pos_arr - reg_arr[0:2, :, :] boxes[2:, :, :] = pos_arr + reg_arr[2:, :, :] scores, labels = torch.max(label_arr, dim=0) boxes = boxes.reshape(4, -1).transpose(1, 0) labels = labels.reshape(-1) scores = scores.reshape(-1) centerness = center_arr.reshape(-1) return BoxList(boxes, labels=labels, scores=scores, centerness=centerness).score_filter(score_thresh)
def decode(keypoint, reg, positions, stride, cfg, prob_thresh=0.05): N = keypoint.shape[0] boxlists = [] flat_positions = positions.permute((1, 2, 0)).reshape((-1, 2)) img_height = positions.shape[1] * stride img_width = positions.shape[2] * stride for n in range(N): per_keypoint = keypoint[n] per_reg = reg[n] num_labels = per_keypoint.shape[0] is_over_thresh = per_keypoint > prob_thresh is_local_max = torch.ones_like(is_over_thresh) if cfg.model.centernet.max_pool_nms: is_local_max = per_keypoint == F.max_pool2d( per_keypoint, kernel_size=3, stride=1, padding=1) is_pos = is_local_max * is_over_thresh num_pos = is_pos.sum() if num_pos == 0: bl = BoxList( torch.empty((0, 4)), labels=torch.empty((0,)), scores=torch.empty((0,))) else: flat_is_pos, _ = is_pos.permute((1, 2, 0)).reshape((-1, num_labels)).max(1) flat_per_reg = per_reg.permute((1, 2, 0)).reshape((-1, 2)) flat_per_reg = flat_per_reg[flat_is_pos, :] sizes = flat_per_reg centers = flat_positions[flat_is_pos] boxes = torch.cat([centers - sizes / 2, centers + sizes / 2], dim=1) flat_per_keypoint = per_keypoint.permute((1, 2, 0)).reshape((-1, num_labels)) flat_per_keypoint = flat_per_keypoint[flat_is_pos, :] scores, labels = flat_per_keypoint.max(1) bl = BoxList(boxes, labels=labels, scores=scores) bl.clamp(img_height, img_width) boxlists.append(bl) return boxlists