def nscale_forward(self, inputs, scales): """ Hierarchical attention, primarily used for getting best inference results. We use attention at multiple scales, giving priority to the lower resolutions. For example, if we have 4 scales {0.5, 1.0, 1.5, 2.0}, then evaluation is done as follows: p_joint = attn_1.5 * p_1.5 + (1 - attn_1.5) * down(p_2.0) p_joint = attn_1.0 * p_1.0 + (1 - attn_1.0) * down(p_joint) p_joint = up(attn_0.5 * p_0.5) * (1 - up(attn_0.5)) * p_joint The target scale is always 1.0, and 1.0 is expected to be part of the list of scales. When predictions are done at greater than 1.0 scale, the predictions are downsampled before combining with the next lower scale. Inputs: scales - a list of scales to evaluate inputs - dict containing 'images', the input, and 'gts', the ground truth mask Output: If training, return loss, else return prediction + attention """ x_1x = inputs['images'] assert 1.0 in scales, 'expected 1.0 to be the target scale' # Lower resolution provides attention for higher rez predictions, # so we evaluate in order: high to low scales = sorted(scales, reverse=True) pred = None output_dict = {} for s in scales: x = ResizeX(x_1x, s) bs = x.shape[0] scale_float = torch.Tensor(bs).fill_(s) p, attn, _aspp_attn, _aspp = self._fwd(x, scale_float=scale_float) output_dict[fmt_scale('pred', s)] = p if s != 2.0: output_dict[fmt_scale('attn', s)] = attn if pred is None: pred = p elif s >= 1.0: # downscale previous pred = scale_as(pred, p) pred = attn * p + (1 - attn) * pred else: # upscale current p = attn * p p = scale_as(p, pred) attn = scale_as(attn, pred) pred = p + (1 - attn) * pred if self.training: assert 'gts' in inputs gts = inputs['gts'] loss = self.criterion(pred, gts) return loss else: output_dict['pred'] = pred return output_dict
def nscale_forward(self, inputs, scales): """ Hierarchical attention, primarily used for getting best inference results. We use attention at multiple scales, giving priority to the lower resolutions. For example, if we have 4 scales {0.5, 1.0, 1.5, 2.0}, then evaluation is done as follows: p_joint = attn_1.5 * p_1.5 + (1 - attn_1.5) * down(p_2.0) p_joint = attn_1.0 * p_1.0 + (1 - attn_1.0) * down(p_joint) p_joint = up(attn_0.5 * p_0.5) * (1 - up(attn_0.5)) * p_joint The target scale is always 1.0, and 1.0 is expected to be part of the list of scales. When predictions are done at greater than 1.0 scale, the predictions are downsampled before combining with the next lower scale. Inputs: scales - a list of scales to evaluate inputs - dict containing 'images', the input, and 'gts', the ground truth mask Output: If training, return loss, else return prediction + attention """ x_1x = inputs['images'] assert 1.0 in scales, 'expected 1.0 to be the target scale' # Lower resolution provides attention for higher rez predictions, # so we evaluate in order: high to low scales = sorted(scales, reverse=True) pred = None aux = None output_dict = {} print("scales in forward") print(scales) for s in scales: x = ResizeX(x_1x, s) outs = self._fwd(x) cls_out = outs['cls_out'] attn_out = outs['logit_attn'] aux_out = outs['aux_out'] output_dict[fmt_scale('pred', s)] = cls_out if s != 2.0: output_dict[fmt_scale('attn', s)] = attn_out if pred is None: pred = cls_out aux = aux_out elif s >= 1.0: # downscale previous pred = scale_as(pred, cls_out) pred = attn_out * cls_out + (1 - attn_out) * pred aux = scale_as(aux, cls_out) aux = attn_out * aux_out + (1 - attn_out) * aux else: # s < 1.0: upscale current cls_out = attn_out * cls_out aux_out = attn_out * aux_out cls_out = scale_as(cls_out, pred) aux_out = scale_as(aux_out, pred) attn_out = scale_as(attn_out, pred) pred = cls_out + (1 - attn_out) * pred aux = aux_out + (1 - attn_out) * aux if self.training: assert 'gts' in inputs gts = inputs['gts'] loss = cfg.LOSS.OCR_ALPHA * self.criterion(aux, gts) + \ self.criterion(pred, gts) return loss else: output_dict['pred'] = pred return output_dict
def eval_minibatch(data, net, criterion, val_loss, calc_metrics, args, val_idx): """ Evaluate a single minibatch of images. * calculate metrics * dump images There are two primary multi-scale inference types: 1. 'MSCALE', or in-model multi-scale: where the multi-scale iteration loop is handled within the model itself (see networks/mscale.py -> nscale_forward()) 2. 'multi_scale_inference', where we use Averaging to combine scales """ torch.cuda.empty_cache() scales = [args.default_scale] if args.multi_scale_inference: scales.extend([float(x) for x in args.extra_scales.split(',')]) if val_idx == 0: logx.msg( f'Using multi-scale inference (AVGPOOL) with scales {scales}') # input = torch.Size([1, 3, h, w]) # gt_image = torch.Size([1, h, w]) images, gt_image, img_names, scale_float = data assert len(images.size()) == 4 and len(gt_image.size()) == 3 assert images.size()[2:] == gt_image.size()[1:] batch_pixel_size = images.size(0) * images.size(2) * images.size(3) input_size = images.size(2), images.size(3) if args.do_flip: # By ending with flip=0, we insure that the images that are dumped # out correspond to the unflipped versions. A bit hacky. flips = [1, 0] else: flips = [0] with torch.no_grad(): output = 0.0 for flip in flips: for scale in scales: if flip == 1: inputs = flip_tensor(images, 3) else: inputs = images infer_size = [round(sz * scale) for sz in input_size] if scale != 1.0: inputs = resize_tensor(inputs, infer_size) inputs = {'images': inputs, 'gts': gt_image} inputs = {k: v.cuda() for k, v in inputs.items()} # Expected Model outputs: # required: # 'pred' the network prediction, shape (1, 19, h, w) # # optional: # 'pred_*' - multi-scale predictions from mscale model # 'attn_*' - multi-scale attentions from mscale model output_dict = net(inputs) _pred = output_dict['pred'] # save AVGPOOL style multi-scale output for visualizing if not cfg.MODEL.MSCALE: scale_name = fmt_scale('pred', scale) output_dict[scale_name] = _pred # resize tensor down to 1.0x scale in order to combine # with other scales of prediction if scale != 1.0: _pred = resize_tensor(_pred, input_size) if flip == 1: output = output + flip_tensor(_pred, 3) else: output = output + _pred output = output / len(scales) / len(flips) assert_msg = 'output_size {} gt_cuda size {}' gt_cuda = gt_image.cuda() assert_msg = assert_msg.format(output.size()[2:], gt_cuda.size()[1:]) assert output.size()[2:] == gt_cuda.size()[1:], assert_msg assert output.size()[1] == cfg.DATASET.NUM_CLASSES, assert_msg # Update loss and scoring datastructure if calc_metrics: val_loss.update( criterion(output, gt_image.cuda()).item(), batch_pixel_size) output_data = torch.nn.functional.softmax(output, dim=1).cpu().data max_probs, predictions = output_data.max(1) # Assemble assets to visualize assets = {} for item in output_dict: if 'attn_' in item: assets[item] = output_dict[item] if 'pred_' in item: smax = torch.nn.functional.softmax(output_dict[item], dim=1) _, pred = smax.data.max(1) assets[item] = pred.cpu().numpy() predictions = predictions.numpy() assets['predictions'] = predictions assets['prob_mask'] = max_probs if calc_metrics: assets['err_mask'] = calc_err_mask_all(predictions, gt_image.numpy(), cfg.DATASET.NUM_CLASSES) _iou_acc = fast_hist(predictions.flatten(), gt_image.numpy().flatten(), cfg.DATASET.NUM_CLASSES) return assets, _iou_acc
def eval_minibatch(data, net, criterion, val_loss, calc_metrics, args, val_idx): """ Evaluate a single minibatch of images. * calculate metrics * dump images There are two primary multi-scale inference types: 1. 'MSCALE', or in-model multi-scale: where the multi-scale iteration loop is handled within the model itself (see networks/mscale.py -> nscale_forward()) 2. 'multi_scale_inference', where we use Averaging to combine scales """ torch.cuda.empty_cache() scales = [args.default_scale] if args.multi_scale_inference: scales.extend([float(x) for x in args.extra_scales.split(',')]) if val_idx == 0: logx.msg( f'Using multi-scale inference (AVGPOOL) with scales {scales}') # input = torch.Size([1, 3, h, w]) # gt_image = torch.Size([1, h, w]) ori_images, gt_image, img_names, scale_float = data if len(gt_image.size()) == 4: # if input is the image, we construct zero gt for the image. (This should only happen for test mode, where there is no gt.) gt_image = gt_image.new_zeros(gt_image.size()[:-1]) assert len(ori_images.size()) == 4 and len(gt_image.size()) == 3 assert ori_images.size()[2:] == gt_image.size()[1:] batch_pixel_size = ori_images.size(0) * ori_images.size( 2) * ori_images.size(3) input_size = ori_images.size(2), ori_images.size(3) if args.do_flip: # By ending with flip=0, we insure that the images that are dumped # out correspond to the unflipped versions. A bit hacky. flips = [1, 0] else: flips = [0] #TODO add to config. max_crop_size = (args.crop_size[0], args.crop_size[1]) m_h, m_w = max_crop_size crop_overlaps = (args.crop_overlap[0], args.crop_overlap[1]) h_sp, w_sp = max_crop_size[0] - crop_overlaps[0], max_crop_size[ 1] - crop_overlaps[1] assert h_sp > 0 and w_sp > 0, "crop size should be larger than crop overlaps." output = 0.0 with torch.no_grad(): for flip in flips: for scale in scales: if flip == 1: inputs = flip_tensor(ori_images, 3) else: inputs = ori_images infer_size = [round(sz * scale) for sz in input_size] if scale != 1.0: inputs = resize_tensor(inputs, infer_size) n, c, h, w = inputs.size(0), inputs.size(1), inputs.size( 2), inputs.size(3) if crop_overlaps[0] > 0: h_n = (h - max_crop_size[0] - 1) // h_sp + 2 else: h_n = (h - 1) // max_crop_size[0] + 1 if crop_overlaps[1] > 0: w_n = (w - max_crop_size[1] - 1) // w_sp + 2 else: w_n = (w - 1) // max_crop_size[1] + 1 if h_n > 1 and crop_overlaps[0] == 0: h_sp = (h - max_crop_size[0]) // (h_n - 1) if w_n > 1 and crop_overlaps[1] == 0: w_sp = (w - max_crop_size[1]) // (w_n - 1) full_output_dict = None weights = None for i in range(h_n): for j in range(w_n): if i != h_n - 1 and j != w_n - 1: h0, h1 = [i * h_sp, i * h_sp + m_h] w0, w1 = [j * w_sp, j * w_sp + m_w] elif i != h_n - 1 and j == w_n - 1: h0, h1 = [i * h_sp, i * h_sp + m_h] w0, w1 = [w - m_w, w] elif i == h_n - 1 and j != w_n - 1: h0, h1 = [h - m_h, h] w0, w1 = [j * w_sp, j * w_sp + m_w] else: h0, h1 = [h - m_h, h] w0, w1 = [w - m_w, w] temp_inputs = inputs[:, :, h0:h1, w0:w1] temp_gt_image = gt_image[:, h0:h1, w0:w1] temp_inputs = { 'images': temp_inputs, 'gts': temp_gt_image } temp_inputs = { k: v.cuda() for k, v in temp_inputs.items() } # Expected Model outputs: # required: # 'pred' the network prediction, shape (1, 19, h, w) # # optional: # 'pred_*' - multi-scale predictions from mscale model # 'attn_*' - multi-scale attentions from mscale model output_dict = net(temp_inputs) _pred = output_dict['pred'] if full_output_dict is None: full_output_dict = {} weights = _pred.new_zeros((n, 1, h, w)) for k, v in output_dict.items(): #TODO need to finish the rest of the keys? if k != 'pred': continue full_output_dict[k] = v.new_zeros( *(v.shape[:-2]), h, w) weights[:, :, h0:h1, w0:w1] += _pred.new_ones(1) for k, v in output_dict.items(): # TODO need to finish the rest of the keys? if k != 'pred': continue full_output_dict[k][:, :, h0:h1, w0:w1] += output_dict[k] for k, v in full_output_dict.items(): full_output_dict[k] = v / weights _pred = full_output_dict['pred'] # save AVGPOOL style multi-scale output for visualizing if not cfg.MODEL.MSCALE: scale_name = fmt_scale('pred', scale) output_dict[scale_name] = _pred # resize tensor down to 1.0x scale in order to combine # with other scales of prediction if scale != 1.0: _pred = resize_tensor(_pred, input_size) if flip == 1: output = output + flip_tensor(_pred, 3) else: output = output + _pred output = output / len(scales) / len(flips) assert_msg = 'output_size {} gt_cuda size {}' gt_cuda = gt_image.cuda() assert_msg = assert_msg.format(output.size()[2:], gt_cuda.size()[1:]) assert output.size()[2:] == gt_cuda.size()[1:], assert_msg assert output.size()[1] == cfg.DATASET.NUM_CLASSES, assert_msg # Update loss and scoring datastructure if calc_metrics: val_loss.update( criterion(output, gt_image.cuda()).item(), batch_pixel_size) output_data = torch.nn.functional.softmax(output, dim=1).cpu().data max_probs, predictions = output_data.max(1) # Assemble assets to visualize assets = {} for item in output_dict: if 'attn_' in item: assets[item] = output_dict[item] if 'pred_' in item: smax = torch.nn.functional.softmax(output_dict[item], dim=1) _, pred = smax.data.max(1) assets[item] = pred.cpu().numpy() predictions = predictions.numpy() assets['predictions'] = predictions assets['prob_mask'] = max_probs if calc_metrics: assets['err_mask'] = calc_err_mask_all(predictions, gt_image.numpy(), cfg.DATASET.NUM_CLASSES) _iou_acc = fast_hist(predictions.flatten(), gt_image.numpy().flatten(), cfg.DATASET.NUM_CLASSES) return assets, _iou_acc