Ejemplo n.º 1
0
def compute_triplet_loss(model,
                         batch,
                         random_proposal=False,
                         similarity_function=""):
    n, c, h, w = batch["images"].shape

    model.train()
    O_dict = model(batch["images"].cuda())
    O = O_dict["embedding_mask"]

    n, c, h, w = O.shape

    points = batch["points"]
    batch["maskObjects"] = None
    batch['maskClasses'] = None
    batch["maskVoid"] = None

    pointList = au.mask2pointList(points)["pointList"]

    loss = torch.tensor(0.).cuda()
    if len(pointList) == 0:
        return loss

    propDict = au.pointList2propDict(pointList,
                                     batch,
                                     single_point=True,
                                     thresh=0.5)

    Embeddings, Labels = propDict2EmbedLabels(O,
                                              propDict,
                                              random_proposal=True)

    loss = triplet(Embeddings, torch.LongTensor(Labels))
    return loss
Ejemplo n.º 2
0
    def predict(self, batch, predict_method="counts", **options):
        self.eval()

        n,c,h,w = batch["images"].shape
        x_input = batch["images"].cuda()
        O = self(x_input)
        if predict_method == "counts":
            return {"counts":ms.t2n(torch.sigmoid(O)>0.5)}

        if predict_method == "points":
            pred_dict =  self.forward_test(x_input, class_threshold=0, peak_threshold = 30)
            return pred_dict["points"]
            
        elif predict_method == "BestDice":
            
            points = self.get_points(batch)
            pointList = au.mask2pointList(points[None])["pointList"]
            if len(pointList) == 0:
                return {"annList":[]}
            pred_dict = au.pointList2BestObjectness(pointList, batch)

            return {"annList":pred_dict["annList"]}

        else:
            return self.forward_test(x_input, class_threshold=0, peak_threshold = 30)
Ejemplo n.º 3
0
    def predict(self, batch, predict_method="counts"):
        self.eval()
        img = batch["images"]
        
        padded_size = (int(np.ceil(img.shape[2]/8)*8), int(np.ceil(img.shape[3]/8)*8))

        p2d = (0, padded_size[1] - img.shape[3], 0, padded_size[0] - img.shape[2])
        img = F.pad(img, p2d)

        dheight = int(np.ceil(img.shape[2]/8))
        dwidth = int(np.ceil(img.shape[3]/8))
        n, c, h, w = img.shape

        lcfcn_pointList = au.mask2pointList(batch["points"])["pointList"]
        counts = np.zeros(self.n_classes-1)
        if len(lcfcn_pointList) == 0:
            return {"blobs": np.zeros((h,w), int), "annList":[], "counts":counts}


        propDict = au.pointList2propDict(lcfcn_pointList, batch, 
                                        proposal_type="sharp",
                                             thresh=0.5)

        aff_mat = torch.pow(self.aff.forward(img.cuda(), True), self.beta)
        trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)

        for _ in range(self.logt):
            trans_mat = torch.matmul(trans_mat, trans_mat)

        import ipdb; ipdb.set_trace()  # breakpoint ac0c04d2 //

        for prop in propDict["propDict"]:
            mask = prop["annList"][0]["mask"]
            mask = torch.FloatTensor(mask)[None]
            mask = F.pad(mask, p2d)
            mask_arr =  F.avg_pool2d(mask, 8, 8)

            mask_vec = mask_arr.view(1, -1)

            mask_rw = torch.matmul(mask_vec.cuda(), trans_mat)
            mask_rw = mask_rw.view(1, dheight, dwidth)
            mask_rw = torch.nn.Upsample((img.shape[2], img.shape[3]), mode='bilinear')(mask_rw[None])
        
        import ipdb; ipdb.set_trace()  # breakpoint 89e7f819 //

        cam_rw = torch.nn.Upsample((img.shape[2], img.shape[3]), mode='bilinear')(cam_rw)
        _, cam_rw_pred = torch.max(cam_rw, 1)

        res = np.uint8(cam_rw_pred.cpu().data[0])[:h, :w]

        if predict_method == "annList":
            pass
        else:
            return img, res 
Ejemplo n.º 4
0
def compute_lc_loss(counts, points, probs, probs_log):
    n, k, h, w = probs.size()

    with torch.no_grad():
        annList = au.probs2GtAnnList(probs, points)

    # IMAGE LOSS
    probs_max = probs.view(n, k, h * w).max(2)[0].view(-1)

    loss = F.binary_cross_entropy(probs_max[1:],
                                  (counts.squeeze() != 0).float(),
                                  reduction="sum")
    loss += F.binary_cross_entropy(probs_max[:1],
                                   torch.ones(1).cuda(),
                                   reduction="sum")

    # Point Loss
    loss += F.nll_loss(probs_log, points, ignore_index=0, reduction="sum")

    for ann in annList:

        if ann["status"] == "SP":
            scale = len(ann["gt_pointList"])
            T = 1 - au.probs2splitMask_all(probs,
                                           ann["gt_pointList"])["background"]
            T = 1 - T * au.ann2mask(ann)["mask"]
            loss += scale * F.nll_loss(probs_log,
                                       torch.LongTensor(T).cuda(),
                                       ignore_index=1,
                                       reduction="elementwise_mean")

        if ann["status"] == "FP":
            T = 1 - au.ann2mask(ann)["mask"]
            loss += F.nll_loss(probs_log,
                               torch.LongTensor(T).cuda()[None],
                               ignore_index=1,
                               reduction="elementwise_mean")

    # Global loss
    pointList = au.mask2pointList(points)["pointList"]
    if len(pointList) > 1:
        T = au.probs2splitMask_all(probs, pointList)["background"]
        loss += F.nll_loss(probs_log,
                           torch.LongTensor(T).cuda(),
                           ignore_index=1,
                           reduction="elementwise_mean")
    # print("n_points", len(annList), loss.item())
    return loss
Ejemplo n.º 5
0
def metric_base(O, batch, pointList=None):
    n, c, h, w = O.shape

    if pointList is None:
        points = batch["points"]
        batch["maskObjects"] = None
        batch['maskClasses'] = None
        batch["maskVoid"] = None

        pointList = au.mask2pointList(points)["pointList"]

    if len(pointList) == 0:
        return None

    if "single_point" in batch:
        single_point = True
    else:
        single_point = False

    propDict = au.pointList2propDict(pointList,
                                     batch,
                                     single_point=single_point,
                                     thresh=0.5)
    background = propDict["background"]

    propDict = propDict["propDict"]

    yList = []
    xList = []
    for p in pointList:
        yList += [p["y"]]
        xList += [p["x"]]

    return {
        "xList": xList,
        "yList": yList,
        "background": background,
        "propDict": propDict
    }
Ejemplo n.º 6
0
def compute_metric_loss_mean(O, batch, random_proposal=False):

    n, c, h, w = O.shape

    similarity_function = au.log_pairwise_mean

    points = batch["points"]
    batch["maskObjects"] = None
    batch['maskClasses'] = None
    batch["maskVoid"] = None

    pointList = au.mask2pointList(points)["pointList"]

    loss = torch.tensor(0.).cuda()
    if len(pointList) == 0:
        return loss

    if "single_point" in batch:
        single_point = True
    else:
        single_point = False

    propDict = au.pointList2propDict(pointList,
                                     batch,
                                     single_point=single_point,
                                     thresh=0.5)
    # img = ms.pretty_vis(batch["images"], propDict["propDict"][0]["annList"],dpi=100)

    # ms.images(img)
    # import ipdb; ipdb.set_trace()  # breakpoint 6758c0a1 //

    background = propDict["background"]

    propDict = propDict["propDict"]

    yList = []
    xList = []
    for p in pointList:
        yList += [p["y"]]
        xList += [p["x"]]

    fg_seeds = O[:, :, yList, xList]
    n_seeds = fg_seeds.shape[-1]
    prop_mask = np.zeros((h, w))

    for i in range(n_seeds):
        annList = propDict[i]["annList"]

        if len(annList) == 0:
            mask = np.zeros(points.squeeze().shape)
            mask[propDict[i]["point"]["y"], propDict[i]["point"]["x"]] = 1
        else:

            if random_proposal:
                ann_i = np.random.randint(0, len(annList))
                mask = annList[ann_i]["mask"]
            else:
                mask = annList[0]["mask"]

        mask_ind = np.where(mask)
        prop_mask[mask != 0] = (i + 1)

        f_A = fg_seeds[:, :, [i]]

        # Positive Embeddings
        n_pixels = mask_ind[0].shape[0]
        P_ind = np.random.randint(0, n_pixels, 100)
        yList = mask_ind[0][P_ind]
        xList = mask_ind[1][P_ind]
        fg_P = O[:, :, yList, xList]

        ap = -torch.log(similarity_function(f_A, fg_P))
        loss += ap.mean()

        # Get Negatives
        if n_seeds > 1:
            N_ind = [j for j in range(n_seeds) if j != i]
            f_N = fg_seeds[:, :, N_ind]
            an = -torch.log(1. - similarity_function(f_A, f_N))
            loss += an.mean()

    # # Extract background seeds
    bg = np.where(background.squeeze())

    n_pixels = bg[0].shape[0]
    bg_ind = np.random.randint(0, n_pixels, n_seeds)
    yList = bg[0][bg_ind]
    xList = bg[1][bg_ind]
    f_A = O[:, :, yList, xList]

    bg_ind = np.random.randint(0, n_pixels, 100)
    yList = bg[0][bg_ind]
    xList = bg[1][bg_ind]
    f_P = O[:, :, yList, xList]

    # BG seeds towards BG pixels, BG seeds away from FG seeds
    ap = -torch.log(similarity_function(f_A[:, :, None], f_P[:, :, :, None]))
    an = -torch.log(1. - similarity_function(f_A[:, :, None], fg_seeds[:, :, :,
                                                                       None]))

    loss += ap.mean()
    loss += an.mean()

    if batch["dataset"][0] == "cityscapes" or batch["dataset"][0] == "coco2014":
        n_max = 6
    else:
        n_max = 12

    if f_A.shape[2] < n_max:
        with torch.no_grad():
            diff = similarity_function(
                O.view(1, c, -1)[:, :, :, None],
                torch.cat([fg_seeds, f_A], 2)[:, :, None])
            labels = diff.max(2)[1] + 1
            labels = labels <= n_seeds
            labels = labels.squeeze().reshape(h, w)
            bg = labels.cpu().long() * torch.from_numpy(background)
            # ms.images(labels.cpu().long()*torch.from_numpy(background))

        # Extract false positive pixels
        bg_ind = np.where(bg.squeeze())
        n_P = bg_ind[0].shape[0]
        if n_P != 0:
            A_ind = np.random.randint(0, n_P, n_seeds)
            f_P = O[:, :, bg_ind[0][A_ind], bg_ind[1][A_ind]]

            ap = -torch.log(
                similarity_function(f_A[:, :, None], f_P[:, :, :, None]))
            an = -torch.log(1. - similarity_function(f_P[:, :, None],
                                                     fg_seeds[:, :, :, None]))

            # if i < 3:
            loss += ap.mean()
            loss += an.mean()

    # if visualize:
    #     diff = log_func(O.view(1,64,-1)[:,:,:,None], torch.cat([se, f_A], 2)[:,:,None])
    #     labels = diff.max(2)[1] + 1
    #     labels[labels > n_se] = 0
    #     labels = labels.squeeze().reshape(h,w)

    #     ms.images(batch["images"], ms.t2n(labels),denorm=1, win="labels")
    #     ms.images(batch["images"], prop_mask.astype(int), denorm=1, win="true")
    #     ms.images(batch["images"], background.astype(int), denorm=1, win="bg")

    return loss / max(n_seeds, 1)
Ejemplo n.º 7
0
def OneHeadLoss_prototypes(model, batch, visualize=False):
    n, c, h, w = batch["images"].shape

    model.eval()
    O_dict = model(batch["images"].cuda())
    O = O_dict["embedding_mask"]

    loss = torch.tensor(0.).cuda()

    base_dict = helpers.metric_base(O, batch)
    if base_dict is None:
        return loss

    points = batch["points"]
    yList = base_dict["yList"]
    xList = base_dict["xList"]
    propDict = base_dict["propDict"]
    background = base_dict["background"]
    yList = base_dict["yList"]

    # foreground = distance_transform_cdt(1 - background)
    ###################################
    n, c, h, w = O.shape

    fg_seeds = O[:, :, yList, xList]
    n_seeds = fg_seeds.shape[-1]

    for i in range(n_seeds):
        annList = propDict[i]["annList"]

        if len(annList) == 0:
            mask = np.zeros(points.squeeze().shape)
            mask[propDict[i]["point"]["y"], propDict[i]["point"]["x"]] = 1
        else:
            mask = annList[0]["mask"]

        mask_ind = np.where(mask)

        f_A = fg_seeds[:, :, [i]]

        # Positive Embeddings
        fg_P = O[:, :, mask_ind[0], mask_ind[1]]
        ap = -torch.log(au.log_pairwise_sum(f_A, fg_P))
        loss += ap.mean()

        # Get Negatives
        mask_ind = np.where(1 - mask)
        f_N = O[:, :, mask_ind[0], mask_ind[1]]
        an = -torch.log(1. - au.log_pairwise_sum(f_A, f_N))
        loss += an.mean()

    n_sp = 0
    background = base_dict["background"]
    blobs, categoryDict, propDict = models.pairwise.prototype_predict(
        model,
        batch,
        pointList=au.mask2pointList(points)["pointList"],
        visualize=False)
    if background.mean() != 1:
        bg_dict = helpers.get_bg_dict(base_dict["background"])
        test_mask = np.zeros(base_dict["background"].shape).squeeze()

        f_P = O[:, :, bg_dict["mask_pos"][0], bg_dict["mask_pos"][1]]
        f_N = O[:, :, bg_dict["mask_neg"][0], bg_dict["mask_neg"][1]]

        n_sp = 0
        for y, x in zip(bg_dict["yList"], bg_dict["xList"]):
            # print(y, x)
            n_sp += 1

            test_mask[y, x] = 1
            f_A = O[:, :, [y], [x]]

            # Positive Embeddings

            ap = -torch.log(au.log_pairwise_sum(f_A, f_P))
            loss += ap.mean()

            # Get Negatives
            an = -torch.log(1. - au.log_pairwise_sum(f_A, f_N))
            loss += an.mean()

        # try:
        #     assert test_mask.sum() == (test_mask*base_dict["background"]).sum()
        # except Exception as exc:
        # Refinement

        ind = np.where((blobs != 0).squeeze() & (background == 1).squeeze())
        f_P = O[:, :, ind[0], ind[1]]
        if ind[0].size != 0:
            for y, x in zip(bg_dict["yList"], bg_dict["xList"]):
                test_mask[y, x] = 1
                f_A = O[:, :, [y], [x]]

                # Positive Embeddings
                ap = -torch.log(au.log_pairwise_sum(f_A, f_P))
                loss += ap.mean()

        # Connected components
        # ind = np.where((blobs!=0).squeeze()&(background==1).squeeze())
        # if ind[0].size != 0:
        # f_P = O[:,:,ind[0], ind[1]]
        # for y, x in zip(bg_dict["yList"], bg_dict["xList"]):
        #     # print(y, x)
        #     n_sp += 1

        #     test_mask[y, x] = 1
        #     f_A = O[:, :, [y], [x]]

        #     # Positive Embeddings
        #     ap = - torch.log(au.log_pairwise_sum(f_A, f_P))
        #     loss += ap.mean()

    for l in np.unique(blobs):
        if l == 0:
            continue
        cc = label(blobs == l)

        point = propDict["propDict"][l - 1]["point"]
        y, x = point["y"], point["x"]
        true = cc[y, x]
        f_A = O[:, :, [y], [x]]
        print(np.unique(cc).size)
        for lc in np.unique(cc):
            if lc == 0 or lc == true:
                continue
            else:
                # Get Negatives
                n_sp += 1
                ind = np.where((cc == lc).squeeze())
                f_N = O[:, :, ind[0], ind[1]]
                an = -torch.log(1. - au.log_pairwise_sum(f_A, f_N))
                loss += an.max()

    return loss / max(n_seeds + n_sp, 1)
Ejemplo n.º 8
0
def wiseaffinity_loss(model, batch):
    # model.lcfcn
    images = batch["images"].cuda()
    n, c, h, w = images.shape
    pointList = au.mask2pointList(batch["points"])["pointList"]

    loss = torch.tensor(0.).cuda()
    loss_bg = torch.tensor(0.).cuda()
    loss_fg = torch.tensor(0.).cuda()

    if len(pointList) == 0:
        return loss

    propDict = au.pointList2propDict(pointList,
                                     batch,
                                     single_point=True,
                                     thresh=0.5)
    propList = propDict["propDict"]

    pred_dict = model.lcfcn.predict(batch, predict_method="pointList")
    blobs = pred_dict["blobs"]
    probs = pred_dict["probs"]
    blobs_components = morph.label(blobs != 0)
    image_pad = ms.pad_image(images)
    _, _, dheight, dwidth = image_pad.shape
    trans_mat = model.aff.forward_trans(image_pad)

    # bg_probs = probs[:,[0]]
    for i in range(len(propList)):
        prop = propList[i]
        if not len(prop["annList"]):
            continue
        proposal_mask = torch.LongTensor(prop["annList"][0]["mask"]).cuda()
        # proposal_mask = F.interpolate(proposal_mask, size=(dheight//8, dwidth//8))
        y, x = prop["point"]["y"], prop["point"]["x"]

        category = blobs[:, y, x][0]
        instance = blobs_components[:, y, x][0]
        blob_mask_ind = blobs_components == instance

        O = torch.FloatTensor(probs[:, [0, category]]).detach()
        O[:, 1] = O[:, 1] * torch.FloatTensor(blob_mask_ind.astype(float))
        O[:, 1] = O[:, 1].clamp(1e-10)

        O_scale = F.interpolate(O, size=(dheight // 8, dwidth // 8))
        O_scale = O_scale.view(1, 2, -1).cuda()

        O_rw = torch.matmul(O_scale, trans_mat)
        O_rw = O_rw.view(1, 2, dheight // 8, dwidth // 8)
        O_final = F.interpolate(O_rw, size=(h, w))

        S_log = F.log_softmax(O_final, 1)

        loss_bg += F.nll_loss(S_log,
                              proposal_mask[None],
                              ignore_index=1,
                              size_average=True)
        loss_fg += F.nll_loss(S_log,
                              proposal_mask[None],
                              ignore_index=0,
                              size_average=True)

    return (loss_bg + loss_fg) / len(propList)
Ejemplo n.º 9
0
def MaskRCNNLoss(model,
                 batch,
                 prm_points=False,
                 true_annList=False,
                 visualize=False):
    """
    - images: [batch, H, W, C]
    - image_metas: [batch, size of image meta]
    - rpn_match: [batch, N] Integer (1=positive anchor, -1=negative, 0=neutral)
    - rpn_bbox: [batch, N, (dy, dx, log(dh), log(dw))] Anchor bbox deltas.
    - gt_class_ids: [batch, MAX_GT_INSTANCES] Integer class IDs
    - gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)]
    - gt_masks: [batch, height, width, MAX_GT_INSTANCES]. The height and width
                are those of the image unless use_mini_mask is True, in which
                case they are defined in MINI_MASK_SHAPE.
    """
    # 1. Data Generator
    ###################
    images = batch["images"].cuda()
    if true_annList:

        gt_annList = base_dataset.batch2annList(batch)
    else:
        if prm_points:
            fname = "{}/prm{}.pkl".format(batch["path"][0], batch["name"][0])
            points = torch.LongTensor(ms.load_pkl(fname))[None].cuda()

        else:
            points = batch["points"].cuda()

        pointList = au.mask2pointList(points)["pointList"]

        if len(pointList) == 0:
            return 0

        gt_annList = au.pointList2BestObjectness(
            pointList, batch, proposal_type="sharp")["annList"]

    if len(gt_annList) == 0:
        return 0

    # ms.images(batch["images"], au.annList2mask(gt_annList)["mask"], denorm=1)
    # image_id = batch["name"][0].replace("_","")
    image_id = 1

    image, image_metas, gt_class_ids, gt_boxes, gt_masks = model.load_image_gt(
        images, image_id, gt_annList, augment=False)
    h, w = images.shape[-2:]
    rpn_match, rpn_bbox = build_rpn_targets(
        (h, w, 3), model.anchors, gt_class_ids, gt_boxes, model.config)

    # n_anns = len(gt_annList)
    # h, w = images.shape[-2:]
    # gt_class_ids = np.zeros((n_anns))
    # gt_boxes = np.zeros((n_anns, 4))
    # gt_masks = np.zeros((h, w, n_anns))

    # for i, ann in enumerate(gt_annList):
    #     gt_class_ids[i] = ann["category_id"]
    #     x_y_xe_ye = au.ann2bbox(ann)["shape"]
    #     # (x, y, xe, ye)
    #     y1,x1,y2,x2 = x_y_xe_ye[1], x_y_xe_ye[0], x_y_xe_ye[-1], x_y_xe_ye[-2]
    #     gt_boxes[i] = np.array((y1,x1,y2,x2))
    #     gt_masks[:,:,i] = au.ann2mask(ann)["mask"]

    # rpn_match, rpn_bbox = helpers.build_rpn_targets((h,w,3), model.anchors,
    #                                                 gt_class_ids, gt_boxes, model.config)

    # 2. RPN Stuff
    ##############

    # If more instances than fits in the array, sub-sample from them.
    if gt_boxes.shape[0] > model.config.MAX_GT_INSTANCES:
        ids = np.random.choice(np.arange(gt_boxes.shape[0]),
                               model.config.MAX_GT_INSTANCES,
                               replace=False)
        gt_class_ids = gt_class_ids[ids]
        gt_boxes = gt_boxes[ids]
        gt_masks = gt_masks[:, :, ids]

    # Add to batch
    rpn_match = rpn_match[:, np.newaxis]
    images = model.mold_image(image.astype(np.float32), model.config)

    # Convert
    images = torch.from_numpy(images.transpose(2, 0, 1)).float()[None].cuda()
    image_metas = torch.from_numpy(image_metas)[None].cuda()
    rpn_match = torch.from_numpy(rpn_match)[None].cuda()
    rpn_bbox = torch.from_numpy(rpn_bbox).float()[None].cuda()
    gt_class_ids = torch.from_numpy(gt_class_ids)[None].cuda().int()
    gt_boxes = torch.from_numpy(gt_boxes).float()[None].cuda()
    gt_masks = torch.from_numpy(gt_masks.astype(int).transpose(
        2, 0, 1)).float()[None].cuda()


    rpn_class_logits, rpn_pred_bbox, target_class_ids, mrcnn_class_logits, target_deltas, mrcnn_bbox, target_mask, mrcnn_mask = \
                model.forward_train({"images":images,
                                     "gt_class_ids": gt_class_ids,
                                     "gt_boxes":gt_boxes, "gt_masks":gt_masks})

    # 3. Compute Losses
    ###################

    rpn_class_loss = compute_rpn_class_loss(rpn_match, rpn_class_logits)
    rpn_bbox_loss = compute_rpn_bbox_loss(rpn_bbox, rpn_match, rpn_pred_bbox)

    mrcnn_class_loss = compute_mrcnn_class_loss(target_class_ids,
                                                mrcnn_class_logits)
    mrcnn_bbox_loss = compute_mrcnn_bbox_loss(target_deltas, target_class_ids,
                                              mrcnn_bbox)
    mrcnn_mask_loss = compute_mrcnn_mask_loss(target_mask, target_class_ids,
                                              mrcnn_mask)

    loss = rpn_class_loss + rpn_bbox_loss + mrcnn_class_loss + mrcnn_bbox_loss + mrcnn_mask_loss

    return loss