def loss_masks(self, outputs, targets, indices, num_boxes): assert "pred_masks" in outputs # print('---- loss masks ----') src_idx = self._get_src_permutation_idx(indices) tgt_idx = self._get_tgt_permutation_idx(indices) src_masks = outputs["pred_masks"] # print('---- src masks ----') # print(src_masks[0][0]) # print('---- targets ----') # print(len(targets)) # print(targets[0]['masks'].shape) # print(targets[0]['labels'].shape) # TODO use valid to mask invalid areas due to padding in loss target_masks, valid = NestedTensor.from_tensor_list( [t["masks"] for t in targets]).decompose() target_masks = target_masks.to(src_masks) src_masks = src_masks[src_idx] src_masks = misc_ops.interpolate(src_masks[:, None], size=target_masks.shape[-3:], mode="trilinear", align_corners=False) src_masks = src_masks[:, 0].flatten(1) target_masks = target_masks[tgt_idx].flatten(1) losses = { "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), "loss_dice": dice_loss(src_masks, target_masks, num_boxes), } return losses
def train(args, model, criterion, data_loader,optimizer, device, epoch, max_norm, scheduler, data_loader_eval): model.train() criterion.train() epoch_loss = 0.0 total = len(data_loader) i = 0 with tqdm.tqdm(total=total) as pbar: for images, masks, caps, cap_masks in data_loader: samples = NestedTensor(images, masks).to(device) caps = caps.to(device) cap_masks = cap_masks.to(device) outputs = model(samples, caps[:, :-1], cap_masks[:, :-1]) loss = criterion(outputs.permute(0, 2, 1), caps[:, 1:]) loss_value = loss.item() if i % args.log_interval == 0: wandb.log({"loss":loss_value}) epoch_loss += loss_value optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() pbar.update(1) if i % args.scheduler_updates == 0: scheduler.step() if i % args.eval_interval == 0 and args.do_eval: eval_loss = evaluate(model, criterion, data_loader_eval, device) print("Eval loss after {} batches is {}".format(i, eval_loss)) i += 1 return epoch_loss / total, scheduler
def forward(self, tensor_list: NestedTensor): xs = self.body(tensor_list.tensors) out: Dict[str, NestedTensor] = {} for name, x in xs.items(): m = tensor_list.mask assert m is not None mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] out[name] = NestedTensor(x, mask) return out
def forward(self, tensor_list): xs = self.body(tensor_list.tensors) out = OrderedDict() for name, x in xs.items(): mask = F.interpolate(tensor_list.mask[None].float(), size=x.shape[-2:]).bool()[0] out[name] = NestedTensor(x, mask) print("backbone out") print(out) return out
def forward(self, tensor_list): xs = self.backbone(tensor_list.tensors.long()) out = OrderedDict() x = xs name = "default" mask = F.interpolate(tensor_list.mask[None].float(), size=x.shape[-3:]).bool()[0] out[name] = NestedTensor(x, mask) # for name, x in xs.items(): # mask = F.interpolate(tensor_list.mask[None].float(), size=x.shape[-2:]).bool()[0] # out[name] = NestedTensor(x, mask) return out
def evaluate(model, criterion, data_loader, device): model.eval() criterion.eval() validation_loss = 0.0 total = len(data_loader) with tqdm.tqdm(total=total) as pbar: for images, masks, caps, cap_masks in data_loader: samples = NestedTensor(images, masks).to(device) caps = caps.to(device) cap_masks = cap_masks.to(device) outputs = model(samples, caps[:, :-1], cap_masks[:, :-1]) loss = criterion(outputs.permute(0, 2, 1), caps[:, 1:]) validation_loss += loss.item() pbar.update(1) wandb.log({"eval_loss":validation_loss / total}) return validation_loss / total
def forward(self, samples: NestedTensor): if not isinstance(samples, NestedTensor): samples = NestedTensor.from_tensor_list(samples) features, pos = self.detr.backbone(samples) bs = features[-1].tensors.shape[0] src, mask = features[-1].decompose() src_proj = self.detr.input_proj(src) hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) outputs_class = self.detr.class_embed(hs) outputs_coord = self.detr.bbox_embed(hs).sigmoid() out = { "pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1] } if self.detr.aux_loss: out["aux_outputs"] = [{ "pred_logits": a, "pred_boxes": b } for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] # FIXME h_boxes takes the last one computed, keep this in mind bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) seg_masks = self.mask_head(src_proj, bbox_mask, [features[-1].tensors]) outputs_seg_masks = seg_masks.view( bs, self.detr.num_queries, seg_masks.shape[-3], seg_masks.shape[-2], seg_masks.shape[-1], ) out["pred_masks"] = outputs_seg_masks return out
def forward(self, samples: NestedTensor): print("... DTER Forwarding ... ") print(samples.tensors.shape) if not isinstance(samples, NestedTensor): samples = NestedTensor.from_tensor_list(samples) features, pos = self.backbone(samples) src, mask = features[-1].decompose() # (6, bs, num_queries, hidden_dim) hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] print("---- hs size ----") print(hs.shape) outputs_class = self.class_embed(hs) outputs_coord = self.bbox_embed(hs).sigmoid() out = { "pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1] } if self.aux_loss: out["aux_outputs"] = [{ "pred_logits": a, "pred_boxes": b } for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] return out