Exemple #1
0
    def loss_masks(self, outputs, targets, indices, num_boxes):
        assert "pred_masks" in outputs
        # print('---- loss masks ----')

        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)

        src_masks = outputs["pred_masks"]
        # print('---- src masks ----')
        # print(src_masks[0][0])
        # print('---- targets ----')
        # print(len(targets))
        # print(targets[0]['masks'].shape)
        # print(targets[0]['labels'].shape)
        # TODO use valid to mask invalid areas due to padding in loss
        target_masks, valid = NestedTensor.from_tensor_list(
            [t["masks"] for t in targets]).decompose()
        target_masks = target_masks.to(src_masks)

        src_masks = src_masks[src_idx]
        src_masks = misc_ops.interpolate(src_masks[:, None],
                                         size=target_masks.shape[-3:],
                                         mode="trilinear",
                                         align_corners=False)
        src_masks = src_masks[:, 0].flatten(1)

        target_masks = target_masks[tgt_idx].flatten(1)

        losses = {
            "loss_mask": sigmoid_focal_loss(src_masks, target_masks,
                                            num_boxes),
            "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
        }
        return losses
Exemple #2
0
def train(args, model, criterion, data_loader,optimizer, device, epoch, max_norm, scheduler, data_loader_eval):
    model.train()
    criterion.train()
    epoch_loss = 0.0
    total = len(data_loader)
    i = 0
    with tqdm.tqdm(total=total) as pbar:
        for images, masks, caps, cap_masks in data_loader:
            samples = NestedTensor(images, masks).to(device)
            caps = caps.to(device)
            cap_masks = cap_masks.to(device)
            outputs = model(samples, caps[:, :-1], cap_masks[:, :-1])
            loss = criterion(outputs.permute(0, 2, 1), caps[:, 1:])
            loss_value = loss.item()
            if i % args.log_interval == 0:
                wandb.log({"loss":loss_value})
            epoch_loss += loss_value
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            optimizer.step()
            pbar.update(1)
            if i % args.scheduler_updates == 0:
                scheduler.step()
            if i % args.eval_interval == 0 and args.do_eval:
                eval_loss = evaluate(model, criterion, data_loader_eval, device)
                print("Eval loss after {} batches is {}".format(i, eval_loss))
            i += 1
    return epoch_loss / total, scheduler
Exemple #3
0
 def forward(self, tensor_list: NestedTensor):
     xs = self.body(tensor_list.tensors)
     out: Dict[str, NestedTensor] = {}
     for name, x in xs.items():
         m = tensor_list.mask
         assert m is not None
         mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
         out[name] = NestedTensor(x, mask)
     return out
Exemple #4
0
 def forward(self, tensor_list):
     xs = self.body(tensor_list.tensors)
     out = OrderedDict()
     for name, x in xs.items():
         mask = F.interpolate(tensor_list.mask[None].float(),
                              size=x.shape[-2:]).bool()[0]
         out[name] = NestedTensor(x, mask)
     print("backbone out")
     print(out)
     return out
Exemple #5
0
 def forward(self, tensor_list):
     xs = self.backbone(tensor_list.tensors.long())
     out = OrderedDict()
     x = xs
     name = "default"
     mask = F.interpolate(tensor_list.mask[None].float(),
                          size=x.shape[-3:]).bool()[0]
     out[name] = NestedTensor(x, mask)
     # for name, x in xs.items():
     #     mask = F.interpolate(tensor_list.mask[None].float(), size=x.shape[-2:]).bool()[0]
     #     out[name] = NestedTensor(x, mask)
     return out
Exemple #6
0
def evaluate(model, criterion, data_loader, device):
    model.eval()
    criterion.eval()
    validation_loss = 0.0
    total = len(data_loader)
    with tqdm.tqdm(total=total) as pbar:
        for images, masks, caps, cap_masks in data_loader:
            samples = NestedTensor(images, masks).to(device)
            caps = caps.to(device)
            cap_masks = cap_masks.to(device)
            outputs = model(samples, caps[:, :-1], cap_masks[:, :-1])
            loss = criterion(outputs.permute(0, 2, 1), caps[:, 1:])
            validation_loss += loss.item()
            pbar.update(1)
    wandb.log({"eval_loss":validation_loss / total})
    return validation_loss / total
    def forward(self, samples: NestedTensor):
        if not isinstance(samples, NestedTensor):
            samples = NestedTensor.from_tensor_list(samples)
        features, pos = self.detr.backbone(samples)

        bs = features[-1].tensors.shape[0]

        src, mask = features[-1].decompose()
        src_proj = self.detr.input_proj(src)
        hs, memory = self.detr.transformer(src_proj, mask,
                                           self.detr.query_embed.weight,
                                           pos[-1])

        outputs_class = self.detr.class_embed(hs)
        outputs_coord = self.detr.bbox_embed(hs).sigmoid()
        out = {
            "pred_logits": outputs_class[-1],
            "pred_boxes": outputs_coord[-1]
        }
        if self.detr.aux_loss:
            out["aux_outputs"] = [{
                "pred_logits": a,
                "pred_boxes": b
            } for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

        # FIXME h_boxes takes the last one computed, keep this in mind
        bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
        seg_masks = self.mask_head(src_proj, bbox_mask, [features[-1].tensors])
        outputs_seg_masks = seg_masks.view(
            bs,
            self.detr.num_queries,
            seg_masks.shape[-3],
            seg_masks.shape[-2],
            seg_masks.shape[-1],
        )

        out["pred_masks"] = outputs_seg_masks
        return out
Exemple #8
0
 def forward(self, samples: NestedTensor):
     print("... DTER Forwarding ... ")
     print(samples.tensors.shape)
     if not isinstance(samples, NestedTensor):
         samples = NestedTensor.from_tensor_list(samples)
     features, pos = self.backbone(samples)
     src, mask = features[-1].decompose()
     # (6, bs, num_queries, hidden_dim)
     hs = self.transformer(self.input_proj(src), mask,
                           self.query_embed.weight, pos[-1])[0]
     print("---- hs size ----")
     print(hs.shape)
     outputs_class = self.class_embed(hs)
     outputs_coord = self.bbox_embed(hs).sigmoid()
     out = {
         "pred_logits": outputs_class[-1],
         "pred_boxes": outputs_coord[-1]
     }
     if self.aux_loss:
         out["aux_outputs"] = [{
             "pred_logits": a,
             "pred_boxes": b
         } for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
     return out