def forward(self, emb0, emb1, valid, summ_writer, suffix=''): total_loss = torch.tensor(0.0).cuda() B, C, H, W = list(emb0.shape) emb0_vec = emb0.permute(0, 2, 3, 1).reshape(B, H * W, C) emb1_vec = emb1.permute(0, 2, 3, 1).reshape(B, H * W, C) valid_vec = valid.permute(0, 2, 3, 1).reshape(B, H * W, 1) assert (self.num_samples < (B * H * W)) # we will take num_samples from each one margin_loss = self.compute_margin_loss(B, C, H, W, emb0_vec, emb1_vec, valid_vec, 'all', True, summ_writer) total_loss = utils_misc.add_loss('emb2D/emb_2D_ml_loss%s' % suffix, total_loss, margin_loss, hyp.emb_2D_ml_coeff, summ_writer) l2_loss_im = sql2_on_axis(emb0 - emb1, 1, keepdim=True) summ_writer.summ_oned('emb2D/emb_2D_l2_loss%s' % suffix, l2_loss_im) emb_l2_loss = reduce_masked_mean(l2_loss_im, valid) total_loss = utils_misc.add_loss('emb2D/emb_2D_l2_loss%s' % suffix, total_loss, emb_l2_loss, hyp.emb_2D_l2_coeff, summ_writer) summ_writer.summ_feats('emb2D/embs_2D%s' % suffix, [emb0, emb1], pca=True) return total_loss, emb1
def forward(self, feat, obj_g=None, bkg_g=None, valid_g=None, summ_writer=None): total_loss = torch.tensor(0.0).cuda() mot_e_ = self.net(feat) mot_e = F.sigmoid(mot_e_) # smooth loss dz, dy, dx = gradient3D(mot_e_, absolute=True) smooth_vox = torch.mean(dx+dy+dz, dim=1, keepdims=True) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('mot/smooth_loss', total_loss, smooth_loss, hyp.mot_smooth_coeff, summ_writer) if summ_writer is not None: summ_writer.summ_occ('mot/mot_e', mot_e) summ_writer.summ_oned('mot/smooth_loss', torch.mean(smooth_vox, dim=3)) # mot_e_vis = torch.max(torch.max(mot_e, dim=1)[1], dim=2)[0] # summ_writer.summ_seg('mot/seg_mot_e', mot_e_vis) # for cls in list(range(self.num_classes)): # cls_mask_e = mot_e[:,cls:cls+1] # cls_mask_e = F.sigmoid(cls_mask_e) # summ_writer.summ_oned('mot/mot_e_cls%d' % cls, cls_mask_e, bev=True, norm=False) # cls_mask_g = (valid_g*(mot_g==cls).float()).unsqueeze(1) # summ_writer.summ_oned('mot/mot_g_cls%d' % cls, cls_mask_g, bev=True, norm=False) if obj_g is not None: prob_loss = self.compute_loss(mot_e_, obj_g, bkg_g, valid_g, summ_writer) total_loss = utils_misc.add_loss('mot/prob_loss', total_loss, prob_loss, hyp.mot_prob_coeff, summ_writer) # mot_g_vis = torch.max(mot_g, dim=2)[0] # if summ_writer is not None: # summ_writer.summ_seg('mot/seg_mot_g', mot_g_vis) # summ_writer.summ_oned('mot/valid_g', valid_g.unsqueeze(1), bev=True, norm=False) return total_loss, mot_e
def forward(self, emb_g, emb_e, valid, summ_writer): total_loss = torch.tensor(0.0).cuda() valid = torch.round(utils_basic.downsample(valid, 2)) B, C, H, W = list(emb_e.shape) emb_e_vec = emb_e.permute(0, 2, 3, 1).reshape(B, H * W, C) emb_g_vec = emb_g.permute(0, 2, 3, 1).reshape(B, H * W, C) valid_vec = valid.permute(0, 2, 3, 1).reshape(B, H * W, 1) assert (self.num_samples < (B * H * W)) # we will take num_samples from each one margin_loss = self.compute_margin_loss(B, C, H, W, emb_e_vec, emb_g_vec, valid_vec, 'all', True, summ_writer) total_loss = utils_misc.add_loss('emb2D/emb_2D_ml_loss', total_loss, margin_loss, hyp.emb_2D_ml_coeff, summ_writer) l2_loss_im = sql2_on_axis(emb_e - emb_g, 1, keepdim=True) summ_writer.summ_oned('emb2D/emb_2D_l2_loss', l2_loss_im) emb_l2_loss = reduce_masked_mean(l2_loss_im, valid) total_loss = utils_misc.add_loss('emb2D/emb_2D_l2_loss', total_loss, emb_l2_loss, hyp.emb_2D_l2_coeff, summ_writer) dy, dx = gradient2D(emb_g, absolute=True) smooth_loss_im = torch.sum(dy + dx, dim=1, keepdim=True) summ_writer.summ_oned('emb2D/emb_2D_smooth_loss', smooth_loss_im) emb_smooth_loss = torch.mean(smooth_loss_im) total_loss = utils_misc.add_loss('emb2D/emb_2D_smooth_loss', total_loss, emb_smooth_loss, hyp.emb_2D_smooth_coeff, summ_writer) summ_writer.summ_feats('emb2D/embs_2D', [emb_e, emb_g], pca=True) return total_loss, emb_g
def forward(self, corrs, occ, iou=None, summ_writer=None): # corrs is the set of corr heatmaps output by the matcher total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(corrs.shape) corrs = corrs.detach() # do not backprop into the corrs # print('N', N) # print('corrs', corrs.shape) occ = torch.mean(occ, dim=[2, 3, 4], keepdims=True).repeat(1, 1, Z, Y, X) corrs = torch.cat([corrs, occ], dim=1) feat = self.bottle(corrs) # this is B x C*num_replicas x 4 x 4 x 4 # print('feat', feat.shape) feats = feat.reshape(B, self.bottle_chans, self.num_replicas, -1) # print('feats', feats.shape) # feat_vec = feat.reshape(B, -1) # # print('feat_vec', feat_vec.shape) # if no bottle: # feat_vec = corrs.reshape(B, -1) ious = [] for ind, pred in enumerate(self.preds): feat_vec = feats[:, :, ind].reshape(B, -1) # print('feat_vec', feat_vec.shape) iou_e = pred(feat_vec).reshape(B) if iou is not None: # print('iou', iou.shape) l2_loss = torch.mean(self.mse(iou_e, iou)) l2_norm = torch.mean(torch.abs(iou_e - iou)) # print('l2_loss', l2_loss.detach().cpu().numpy()) # print('l2_norm', l2_norm.detach().cpu().numpy()) total_loss = utils_misc.add_loss('conf/l2_loss_%d' % ind, total_loss, l2_loss, hyp.conf_coeff, summ_writer) utils_misc.add_loss('conf/l2_norm_%d' % ind, 0, l2_norm, 0, summ_writer) iou_e = iou_e.clamp( 0.0, 0.999 ) # not 1.0, to distinguish this from hardcoded true/false gt scores ious.append(iou_e) ious = torch.stack(ious, dim=1) # this is B x num_replicas4 return ious, total_loss
def forward(self, feat_mem, clist_cam, summ_writer, suffix=''): total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(feat_mem.shape) B2, S, D = list(clist_cam.shape) assert (B == B2) assert (D == 3) clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X) # this is (still) B x S x 3 feat_ = feat_mem.permute(0, 1, 3, 2, 4).reshape(B, C * Y, Z, X) mask_ = 1.0 - (feat_ == 0).all(dim=1, keepdim=True).float().cuda() grid_ = utils_basic.meshgrid2D(B, Z, X, stack=True, norm=True).permute(0, 3, 1, 2) halfgrid_ = utils_basic.meshgrid2D(B, int(Z / 2), int(X / 2), stack=True, norm=True).permute(0, 3, 1, 2) feat_ = torch.cat([feat_, grid_], dim=1) energy_map, mask = self.net(feat_, mask_, halfgrid_) # energy_map = self.net(feat_) # energy_map is B x 1 x Z x X # don't do this: # energy_map = energy_map + (1.0-mask) * (torch.min(torch.min(energy_map, dim=2)[0], dim=2)[0]).reshape(B, 1, 1, 1) summ_writer.summ_feat('pri/energy_input', feat_) summ_writer.summ_oned('pri/energy_map', energy_map) summ_writer.summ_oned('pri/mask', mask, norm=False) summ_writer.summ_histogram('pri/energy_map_hist', energy_map) loglike_per_traj = utils_misc.get_traj_loglike( clist_mem * 0.5, energy_map) # 0.5 since it's half res # loglike_per_traj = self.get_traj_loglike(clist_mem*0.25, energy_map) # 0.25 since it's quarter res # this is B x K ce_loss = -1.0 * torch.mean(loglike_per_traj) # this is [] total_loss = utils_misc.add_loss('pri/ce_loss', total_loss, ce_loss, hyp.pri2D_ce_coeff, summ_writer) reg_loss = torch.sum(torch.abs(energy_map)) total_loss = utils_misc.add_loss('pri/reg_loss', total_loss, reg_loss, hyp.pri2D_reg_coeff, summ_writer) # smooth loss dz, dx = utils_basic.gradient2D(energy_map, absolute=True) smooth_vox = torch.mean(dz + dx, dim=1, keepdims=True) summ_writer.summ_oned('pri/smooth_loss', smooth_vox) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('pri/smooth_loss', total_loss, smooth_loss, hyp.pri2D_smooth_coeff, summ_writer) return total_loss, energy_map
def forward(self, sensor_imgs, sampled_embeddings, do_ml, summ_writer): total_loss = torch.tensor(0.0).cuda() sensor_feats = self.touch_net(sensor_imgs) sensor_embeddings = l2_normalize(sensor_feats, dim=1) # Now I have the sensor embeddings and the sampled_embedding compute simple l2 loss on them simple_l2_loss = F.mse_loss(sensor_embeddings, sampled_embeddings) total_loss = utils_misc.add_loss('embtouch/emb_touch_l2_loss', total_loss, simple_l2_loss, hyp.emb_2D_touch_l2_coeff, summ_writer) if len(list(sensor_embeddings.shape)) == 2: prev_B, prev_C = list(sensor_embeddings.shape) assert len(list(sampled_embeddings)) == len( list(sensor_embeddings)) sampled_embeddings = sampled_embeddings.view(prev_B, prev_C, 1, 1) sensor_embeddings = sensor_embeddings.view(prev_B, prev_C, 1, 1) if do_ml: print('doing ml loss, so this iteration you should not be printed') import ipdb ipdb.set_trace() B, C, H, W = list(sensor_embeddings.shape) # to make it compatible with the api of compute margin loss need to reshape emb_e_vec = sensor_embeddings.permute(0, 2, 3, 1).view(B, H * W, C) emb_g_vec = sampled_embeddings.permute(0, 2, 3, 1).view(B, H * W, C) assert B == hyp.B, "batch should be same" assert C == hyp.touch_emb_dim, "this is the network output I specified" assert self.num_samples < (B * H * W), "num samples in hyp is problem" margin_loss = self.compute_margin_loss(B, C, H, W, emb_e_vec, emb_g_vec, 'all', True, summ_writer) # this adds a two plots to tensorboard, raw and scaled loss # add the curr loss to total loss and returns total_loss = utils_misc.add_loss('embtouch/emb_touch_ml_loss', total_loss, margin_loss, hyp.emb_2D_touch_ml_coeff, summ_writer) print('L2Loss: {}\t MarginLoss: {}\t total_loss: {}'.format( simple_l2_loss.item(), margin_loss.item(), total_loss.item())) # summarize the pred_embeddings summ_writer.summ_feats('embtouch/embs_touch', [sensor_embeddings, sampled_embeddings], pca=True) return total_loss, sensor_embeddings
def forward(self, feat0, feat1, valid0, valid1, summ_writer=None): total_loss = torch.tensor(0.0).cuda() B, C, D, H, W = list(feat0.shape) neg_input = torch.cat([feat1-feat0], dim=1) feat1_flat = feat1.reshape(B, C, -1) valid1_flat = valid1.reshape(B, 1, -1) perm = np.random.permutation(D*H*W) feat1_flat_shuf = feat1_flat[:,:,perm] feat1_shuf = feat1_flat_shuf.reshape(B, C, D, H, W) valid1_flat_shuf = valid1_flat[:,:,perm] valid1_shuf = valid1_flat_shuf.reshape(B, 1, D, H, W) # pos_input = torch.cat([feat0, feat1_shuf-feat0], dim=1) pos_input = torch.cat([feat1_shuf-feat0], dim=1) # noncorresps should be rejected pos_output = self.net(pos_input) # corresps should NOT be rejected neg_output = self.net(neg_input) pos_sig = F.sigmoid(pos_output) neg_sig = F.sigmoid(neg_output) if summ_writer is not None: summ_writer.summ_feat('reject/pos_input', pos_input, pca=True) summ_writer.summ_feat('reject/neg_input', neg_input, pca=True) summ_writer.summ_oned('reject/pos_sig', pos_sig, bev=True, norm=False) summ_writer.summ_oned('reject/neg_sig', neg_sig, bev=True, norm=False) pos_output_vec = pos_output.reshape(B, D*H*W) neg_output_vec = neg_output.reshape(B, D*H*W) pos_target_vec = torch.ones([B, D*H*W]).float().cuda() neg_target_vec = torch.zeros([B, D*H*W]).float().cuda() # if feat1_shuf is valid, then it is practically guranateed to mismatch feat0 pos_valid_vec = valid1_shuf.reshape(B, D*H*W) # both have to be valid to not reject neg_valid_vec = (valid0*valid1).reshape(B, D*H*W) pos_loss_vec = self.criterion(pos_output_vec, pos_target_vec) neg_loss_vec = self.criterion(neg_output_vec, neg_target_vec) pos_loss = utils_basic.reduce_masked_mean(pos_loss_vec, pos_valid_vec) neg_loss = utils_basic.reduce_masked_mean(neg_loss_vec, pos_valid_vec) ce_loss = pos_loss + neg_loss utils_misc.add_loss('reject3D/ce_loss_pos', 0, pos_loss, 0, summ_writer) utils_misc.add_loss('reject3D/ce_loss_neg', 0, neg_loss, 0, summ_writer) total_loss = utils_misc.add_loss('reject3D/ce_loss', total_loss, ce_loss, hyp.reject3D_ce_coeff, summ_writer) return total_loss, neg_sig, pos_sig
def forward(self, feat, obj_g=None, bkg_g=None, valid_g=None, summ_writer=None): total_loss = torch.tensor(0.0).cuda() sub_e_ = self.net(feat) sub_e = F.sigmoid(sub_e_) sub_e_binary = torch.round(sub_e) # smooth loss dz, dy, dx = utils_basic.gradient3D(sub_e_, absolute=True) smooth_vox = torch.mean(dx + dy + dz, dim=1, keepdims=True) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('sub/smooth_loss', total_loss, smooth_loss, hyp.sub_smooth_coeff, summ_writer) if obj_g is not None: # # collect some accuracy stats # pos_match = sub_g*torch.eq(sub_e_binary, sub_g).float() # neg_match = (1.0 - sub_g)*torch.eq(1.0-sub_e_binary, 1.0 - sub_g).float() # either_match = torch.clamp(pos_match+neg_match, 0.0, 1.0) # either_have = sub_g.clone() # acc_pos = utils_basic.reduce_masked_mean(pos_match, sub_g*valid) # acc_neg = utils_basic.reduce_masked_mean(neg_match, (1.0-sub_g)*valid) # acc_total = utils_basic.reduce_masked_mean(either_match, either_have*valid) # acc_bal = (acc_pos + acc_neg)*0.5 # summ_writer.summ_scalar('unscaled_sub/acc_pos', acc_pos.cpu().item()) # summ_writer.summ_scalar('unscaled_sub/acc_neg', acc_neg.cpu().item()) # summ_writer.summ_scalar('unscaled_sub/acc_total', acc_total.cpu().item()) # summ_writer.summ_scalar('unscaled_sub/acc_bal', acc_bal.cpu().item()) prob_loss = self.compute_loss(sub_e_, obj_g, bkg_g, valid_g, summ_writer) # prob_loss = self.compute_loss(sub_e_, sub_g, (1.0 - sub_g), valid, summ_writer) total_loss = utils_misc.add_loss('sub/prob_loss', total_loss, prob_loss, hyp.sub_coeff, summ_writer) # if summ_writer is not None: # if sub_g is not None: # summ_writer.summ_occ('sub/sub_g', sub_g) # summ_writer.summ_oned('sub/sub_g_', sub_g, bev=True, norm=False) # summ_writer.summ_occ('sub/sub_e', sub_e) # summ_writer.summ_oned('sub/sub_e', sub_e, bev=True, norm=False) return total_loss, sub_e
def compute_samp_loss(self, obj_lrt_e, obj_lrt_g, summ_writer=None): total_loss = torch.tensor(0.0).cuda() coords_e = utils_vox.convert_lrt_to_sampling_coords(obj_lrt_e, self.sce_Z, self.sce_Y, self.sce_X, self.obj_Z, self.obj_Y, self.obj_X, additive_pad=0.0) coords_g = utils_vox.convert_lrt_to_sampling_coords(obj_lrt_g, self.sce_Z, self.sce_Y, self.sce_X, self.obj_Z, self.obj_Y, self.obj_X, additive_pad=0.0) # normalize these by resolution coords_e = coords_e / float(self.sce_Z) coords_g = coords_g / float(self.sce_Z) samp_loss = self.mse(coords_e, coords_g) total_loss = utils_misc.add_loss('loc/samp_loss', total_loss, samp_loss, hyp.loc_samp_coeff, summ_writer) if summ_writer is not None: summ_writer.summ_histogram('coords_e', coords_e) summ_writer.summ_histogram('coords_g', coords_g) return total_loss
def forward(self, feat, occ, rgb_g, valid, summ_writer): total_loss = torch.tensor(0.0).cuda() rgb_feat = self.rgb_layer(feat) rgb_e = self.accu_render(rgb_feat, occ) emb_e = self.accu_render(feat, occ) # postproc emb_e = l2_normalize(emb_e, dim=1) rgb_e = torch.nn.functional.tanh(rgb_e) * 0.5 loss_im = l1_on_axis(rgb_e - rgb_g, 1, keepdim=True) summ_writer.summ_oned('render/rgb_loss', loss_im) summ_writer.summ_occs('render/occ', occ.unsqueeze(0), reduce_axes=[2]) summ_writer.summ_occs('render/occ', occ.unsqueeze(0), reduce_axes=[3]) summ_writer.summ_occs('render/occ', occ.unsqueeze(0), reduce_axes=[4]) rgb_loss = utils_basic.reduce_masked_mean(loss_im, valid) total_loss = utils_misc.add_loss('render/rgb_l1_loss', total_loss, rgb_loss, hyp.render_l1_coeff, summ_writer) # vis summ_writer.summ_rgbs('render/rgb', [rgb_e, rgb_g]) return total_loss, rgb_e, emb_e
def forward(self, image_input, summ_writer=None): B, H, W = list(image_input.shape) total_loss = torch.tensor(0.0).cuda() summ_writer.summ_oned('gen2dvq/image_input', image_input.unsqueeze(1) / 512.0, norm=False) emb = self.embed(image_input) image_output_logits = self.net(emb) # print('logits', image_output_logits.shape) image_output = torch.argmax(image_output_logits, dim=1, keepdim=True) # print('output', image_output.shape) summ_writer.summ_oned('gen2dvq/image_output', image_output.float() / 512.0, norm=False) ce_loss_image = F.cross_entropy(image_output_logits, image_input, reduction='none') summ_writer.summ_oned('gen2dvq/ce_loss', ce_loss_image.unsqueeze(1)) ce_loss = torch.mean(ce_loss_image) total_loss = utils_misc.add_loss('gen2dvq/ce_loss', total_loss, ce_loss, hyp.gen2dvq_coeff, summ_writer) return total_loss
def forward(self, feat, summ_writer=None): total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(feat.shape) latent_loss, quantized_vox, perplexity, _, inds = self._vq_vae(feat) total_loss = utils_misc.add_loss('vq3dnet/latent', total_loss, latent_loss, hyp.vq3d_latent_coeff, summ_writer) # count the number of unique inds being used unique_inds_here = np.unique(inds.detach().cpu().numpy()) self.ind_pool.update(unique_inds_here) all_used_inds = self.ind_pool.fetch() unique_used_inds = np.unique(all_used_inds) ind_vox = inds.reshape(B, Z, Y, X) if summ_writer is not None: summ_writer.summ_scalar('unscaled_vq3dnet/perplexity', perplexity.cpu().item()) summ_writer.summ_scalar('unscaled_vq3dnet/num_used_inds', float(len(unique_used_inds))) summ_writer.summ_feat('vq3dnet/quantized', quantized_vox, pca=True) return total_loss, quantized_vox, ind_vox
def forward(self, vox_input, summ_writer=None): B, Z, Y, X = list(vox_input.shape) total_loss = torch.tensor(0.0).cuda() summ_writer.summ_oned('gen3d/vox_input', vox_input.unsqueeze(1) / 512.0, bev=True, norm=False) emb = self.embed(vox_input) vox_output_logits = self.net(emb) # print('logits', vox_output_logits.shape) vox_output = torch.argmax(vox_output_logits, dim=1) # print('output', vox_output.shape) summ_writer.summ_oned('gen3d/vox_output', vox_output.unsqueeze(1).float() / 512.0, bev=True, norm=False) ce_loss_vox = F.cross_entropy(vox_output_logits, vox_input, reduction='none') # ce_loss_vox = self.focal(vox_output_logits, vox_input, reduction='none') summ_writer.summ_oned('gen3d/ce_loss', ce_loss_vox.unsqueeze(1), bev=True) ce_loss = torch.mean(ce_loss_vox) total_loss = utils_misc.add_loss('gen3d/ce_loss', total_loss, ce_loss, hyp.gen3d_coeff, summ_writer) return total_loss, vox_output
def forward(self, rgb, summ_writer=None): total_loss = torch.tensor(0.0).cuda() B, C, H, W = list(rgb.shape) if summ_writer is not None: summ_writer.summ_rgb('feat2D/rgb', rgb) feat = self.net(rgb) # smooth loss dy, dx = utils_basic.gradient2D(feat, absolute=True) smooth_im = torch.mean(dy + dx, dim=1, keepdims=True) if summ_writer is not None: summ_writer.summ_oned('feat2D/smooth_loss', smooth_im) smooth_loss = torch.mean(smooth_im) total_loss = utils_misc.add_loss('feat2D/smooth_loss', total_loss, smooth_loss, hyp.feat2D_smooth_coeff, summ_writer) feat = utils_basic.l2_normalize(feat, dim=1) if summ_writer is not None: summ_writer.summ_feat('feat2D/feat_output', feat, pca=True) return total_loss, feat
def forward(self, image_input, summ_writer=None): B, C, H, W = list(image_input.shape) total_loss = torch.tensor(0.0).cuda() summ_writer.summ_oned('gengray/image_input', image_input) # feat = feat.long() # discrete input # logit = self.run_net(feat) # print(image_input.shape) # target = torch.autograd.Variable((image_input.data[:,0] * 255).long()).cuda() # target = (image_input.data[:,0] * 255).long().cuda() target = ((image_input.data[:, 0] + 0.5) * 255).long().cuda() # print(target.shape) image_output_logits = self.net(image_input) image_output = torch.argmax(image_output_logits, dim=1, keepdim=True) summ_writer.summ_oned('gengray/image_output', image_output.float() / 255.0 - 0.5) ce_loss_image = F.cross_entropy(image_output_logits, target, reduction='none') summ_writer.summ_oned('gengray/ce_loss', ce_loss_image.unsqueeze(1)) # # smooth loss # dz, dy, dx = gradient3D(logit, absolute=True) # smooth_vox = torch.mean(dx+dy+dx, dim=1, keepdims=True) # summ_writer.summ_oned('gengray/smooth_loss', torch.mean(smooth_vox, dim=3)) # # smooth_loss = utils_basic.reduce_masked_mean(smooth_vox, valid) # smooth_loss = torch.mean(smooth_vox) # total_loss = utils_misc.add_loss('gengray/smooth_loss', total_loss, # smooth_loss, hyp.genocc_smooth_coeff, summ_writer) # summ_writer.summ_feat('gengray/feat_output', logit, pca=False) # occ_e = torch.argmax(logit, dim=1, keepdim=True) # loss_pos = self.criterion(logit, (occ_g[:,0]).long()) # loss_neg = self.criterion(logit, (1-free_g[:,0]).long()) # summ_writer.summ_oned('gengray/ce_loss', # torch.mean((loss_pos+loss_neg), dim=1, keepdim=True) * \ # torch.clamp(occ_g+(1-free_g), 0, 1), # bev=True) ce_loss = torch.mean(ce_loss_image) total_loss = utils_misc.add_loss('gengray/ce_loss', total_loss, ce_loss, hyp.gengray_coeff, summ_writer) # sample = self.generate_sample(1, int(Z/2), int(Y/2), int(X/2)) # occ_sample = torch.argmax(sample, dim=1, keepdim=True) # summ_writer.summ_occ('gengray/occ_sample', occ_sample) return total_loss
def forward(self, feat, occ_g, free_g, summ_writer=None): B, C, Z, Y, X = list(feat.shape) total_loss = torch.tensor(0.0).cuda() summ_writer.summ_feat('genoccnet/feat_input', feat, pca=False) feat = feat.long() # discrete input logit = self.run_net(feat) # smooth loss dz, dy, dx = gradient3D(logit, absolute=True) smooth_vox = torch.mean(dx+dy+dx, dim=1, keepdims=True) summ_writer.summ_oned('genoccnet/smooth_loss', torch.mean(smooth_vox, dim=3)) # smooth_loss = utils_basic.reduce_masked_mean(smooth_vox, valid) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('genoccnet/smooth_loss', total_loss, smooth_loss, hyp.genocc_smooth_coeff, summ_writer) summ_writer.summ_feat('genoccnet/feat_output', logit, pca=False) occ_e = torch.argmax(logit, dim=1, keepdim=True) summ_writer.summ_occ('genoccnet/occ_e', occ_e) summ_writer.summ_occ('genoccnet/occ_g', occ_g) loss_pos = self.criterion(logit, (occ_g[:,0]).long()) loss_neg = self.criterion(logit, (1-free_g[:,0]).long()) summ_writer.summ_oned('genoccnet/loss', torch.mean((loss_pos+loss_neg), dim=1, keepdim=True) * \ torch.clamp(occ_g+(1-free_g), 0, 1), bev=True) loss_pos = utils_basic.reduce_masked_mean(loss_pos.unsqueeze(1), occ_g) loss_neg = utils_basic.reduce_masked_mean(loss_neg.unsqueeze(1), free_g) loss_bal = loss_pos + loss_neg total_loss = utils_misc.add_loss('genoccnet/loss_bal', total_loss, loss_bal, hyp.genocc_coeff, summ_writer) # sample = self.generate_sample(1, int(Z/2), int(Y/2), int(X/2)) # occ_sample = torch.argmax(sample, dim=1, keepdim=True) # summ_writer.summ_occ('genoccnet/occ_sample', occ_sample) return logit, total_loss
def forward(self, feat, seg_g=None, occ_g=None, free_g=None, summ_writer=None): total_loss = torch.tensor(0.0).cuda() seg_e = self.conv3d(feat) # smooth loss dz, dy, dx = gradient3D(seg_e, absolute=True) smooth_vox = torch.mean(dx + dy + dz, dim=1, keepdims=True) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('seg/smooth_loss', total_loss, smooth_loss, hyp.seg_smooth_coeff, summ_writer) if summ_writer is not None: summ_writer.summ_oned('seg/smooth_loss', torch.mean(smooth_vox, dim=3)) seg_e_vis = torch.max(torch.max(seg_e, dim=1)[1], dim=2)[0] summ_writer.summ_seg('seg/seg_e', seg_e_vis) if seg_g is not None: prob_loss = self.compute_loss(seg_e, seg_g, free_g.long().squeeze(1), summ_writer) total_loss = utils_misc.add_loss('seg/prob_loss', total_loss, prob_loss, hyp.seg_prob_coeff, summ_writer) seg_g_vis = torch.max(seg_g, dim=2)[0] if summ_writer is not None: summ_writer.summ_seg('seg/seg_g', seg_g_vis) return total_loss, seg_e
def forward(self, feat, summ_writer=None): total_loss = torch.tensor(0.0).cuda() # B, C, Z, Y, X = list(feat.shape) feat = self.net(feat) # smooth loss dz, dy, dx = utils_basic.gradient3D(feat, absolute=True) smooth_vox = torch.mean(dx+dy+dz, dim=1, keepdims=True) summ_writer.summ_oned('up3D/smooth_loss', torch.mean(smooth_vox, dim=3)) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('up3D/smooth_loss', total_loss, smooth_loss, hyp.up3D_smooth_coeff, summ_writer) # feat = utils_basic.l2_normalize(feat, dim=1) # print('feat', feat.shape) if summ_writer is not None: summ_writer.summ_feat('up3D/feat_output', feat, pca=True) return total_loss, feat
def forward(self, feat, summ_writer=None): total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(feat.shape) mask = (feat[:,0:1] > 0.0).float() # if summ_writer is not None: # summ_writer.summ_feat('feat3D/feat_mask', mask, pca=False) if summ_writer is not None: summ_writer.summ_feat('feat3D/feat_input', feat, pca=(C>3)) feat = self.net(feat) mask = torch.ones_like(feat[:,0:1]) # smooth loss dz, dy, dx = utils_basic.gradient3D(feat, absolute=True) smooth_vox = torch.mean(dz+dy+dx, dim=1, keepdims=True) if summ_writer is not None: summ_writer.summ_oned('feat3D/smooth_loss', torch.mean(smooth_vox, dim=3)) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('feat3D/smooth_loss', total_loss, smooth_loss, hyp.feat3D_smooth_coeff, summ_writer) feat = utils_basic.l2_normalize(feat, dim=1) if hyp.feat3D_sparse: feat = feat * mask if summ_writer is not None: summ_writer.summ_feat('feat3D/feat_output', feat, pca=True) # summ_writer.summ_feat('feat3D/feat_mask', mask, pca=False) # if hyp.feat3D_skip: # feat = feat[:,:, # self.crop[0]:-self.crop[0], # self.crop[1]:-self.crop[1], # self.crop[2]:-self.crop[2]] # mask = mask[:,:, # self.crop[0]:-self.crop[0], # self.crop[1]:-self.crop[1], # self.crop[2]:-self.crop[2]] return total_loss, feat, mask
def forward(self, emb0, emb1, summ_writer, mod=''): total_loss = torch.tensor(0.0).cuda() if torch.isnan(emb0).any() or torch.isnan(emb1).any(): assert (False) B, C, H, W = list(emb0.shape) # we will take num_samples across the batch assert (self.num_samples < (B * H * W)) emb0_vec, emb1_vec = self.sample_embs(emb0, emb1) moc_loss = self.moc_trainer.forward(emb0_vec, emb1_vec.detach()) self.moc_trainer.enqueue(emb1_vec) total_loss = utils_misc.add_loss('moc2D/moc2D_loss%s' % mod, total_loss, moc_loss, hyp.moc2D_coeff, summ_writer) summ_writer.summ_feats('moc2D/embs%s' % mod, [emb0, emb1], pca=True) return total_loss
def forward(self, codes, obj_inds, bkg_inds, summ_writer=None): total_loss = torch.tensor(0.0).cuda() obj_codes = codes[obj_inds] bkg_codes = codes[bkg_inds] obj_logits = self.net(obj_codes) obj_targets = torch.ones_like(obj_logits) obj_loss = self.crit(obj_logits, obj_targets) bkg_logits = self.net(bkg_codes) bkg_targets = torch.zeros_like(bkg_logits) bkg_loss = self.crit(bkg_logits, bkg_targets) bal_loss = obj_loss + bkg_loss total_loss = utils_misc.add_loss('linclass/bce_loss', total_loss, bal_loss, hyp.linclass_coeff, summ_writer) obj_sig = F.sigmoid(obj_logits) obj_bin = torch.round(obj_sig) bkg_sig = F.sigmoid(bkg_logits) bkg_bin = torch.round(bkg_sig) # collect some accuracy stats obj_match = torch.eq(obj_bin, 1).float() bkg_match = torch.eq(bkg_bin, 0).float() obj_acc = torch.mean(obj_match) bkg_acc = torch.mean(bkg_match) bal_acc = (obj_acc + bkg_acc) * 0.5 summ_writer.summ_scalar('unscaled_linclass/acc_obj', obj_acc.cpu().item()) summ_writer.summ_scalar('unscaled_linclass/acc_bkg', bkg_acc.cpu().item()) summ_writer.summ_scalar('unscaled_linclass/acc_bal', bal_acc.cpu().item()) return total_loss
def forward(self, emb0, emb1, valid0, valid1, summ_writer): total_loss = torch.tensor(0.0).cuda() if torch.isnan(emb0).any() or torch.isnan(emb1).any(): assert (False) B, C, Z, Y, X = list(emb0.shape) # we will take num_samples across the batch assert (self.num_samples < (B * Z * Y * X)) emb0_vec, emb1_vec, _ = self.sample_embs(emb0, emb1, valid0 * valid1) moc_loss = self.moc_trainer.forward(emb0_vec, emb1_vec.detach()) self.moc_trainer.enqueue(emb1_vec) total_loss = utils_misc.add_loss('moc3D/moc3D_loss', total_loss, moc_loss, hyp.moc3D_coeff, summ_writer) summ_writer.summ_feats('moc3D/embs', [emb0, emb1], valids=[valid0, valid1], pca=True) return total_loss
def forward(self, feat, rgb_g, valid, summ_writer, name): total_loss = torch.tensor(0.0).cuda() if hyp.dataset_name == "clevr": valid = torch.ones_like(valid) feat = self.net(feat) emb_e = self.emb_layer(feat) rgb_e = self.rgb_layer(feat) # postproc emb_e = l2_normalize(emb_e, dim=1) rgb_e = torch.nn.functional.tanh(rgb_e) * 0.5 loss_im = l1_on_axis(rgb_e - rgb_g, 1, keepdim=True) summ_writer.summ_oned('view/rgb_loss', loss_im * valid) rgb_loss = utils_basic.reduce_masked_mean(loss_im, valid) total_loss = utils_misc.add_loss('view/rgb_l1_loss', total_loss, rgb_loss, hyp.view_l1_coeff, summ_writer) # vis summ_writer.summ_rgbs(f'view/{name}', [rgb_e, rgb_g]) return total_loss, rgb_e, emb_e
def compute_feat_loss(self, obj_feat, sce_feat, obj_lrt_e, summ_writer=None, suffix=''): total_loss = torch.tensor(0.0).cuda() obj_crop = utils_vox.crop_zoom_from_mem(sce_feat, obj_lrt_e, self.obj_Z, self.obj_Y, self.obj_X, additive_pad=0.0) # whatever we match to, we want the feat distance to be small # we'll use the negative frobenius inner prod, normalized for resolution feat_loss = -torch.mean(obj_feat * obj_crop) # summ_writer.summ_feat('loc/obj_feat%s' % suffix, obj_feat, pca=True) # summ_writer.summ_feat('loc/obj_crop%s' % suffix, obj_crop, pca=True) total_loss = utils_misc.add_loss('loc/feat_loss%s' % suffix, total_loss, feat_loss, hyp.loc_feat_coeff, summ_writer) return total_loss
def forward(self, rgb_g, summ_writer=None): total_loss = torch.tensor(0.0).cuda() B, C, H, W = list(rgb_g.shape) z = self._encoder(rgb_g) z = self._pre_vq_conv(z) # print('encoded z', z.shape) latent_loss, quantized, perplexity, _, inds = self._vq_vae(z) rgb_e = self._decoder(quantized) recon_loss = F.mse_loss(rgb_g, rgb_e) total_loss = utils_misc.add_loss('vqrgbnet/recon', total_loss, recon_loss, hyp.vqrgb_recon_coeff, summ_writer) utils_misc.add_loss('vqrgbnet/perplexity', 0.0, perplexity, 0.0, summ_writer) # utils_py.print_stats('rgb_e', rgb_e.detach().cpu().numpy()) # utils_py.print_stats('rgb_g', rgb_g.detach().cpu().numpy()) if summ_writer is not None: summ_writer.summ_rgb('vqrgbnet/rgb_e', rgb_e.clamp(-0.5, 0.5)) summ_writer.summ_rgb('vqrgbnet/rgb_g', rgb_g) total_loss = utils_misc.add_loss('vqrgbnet/latent', total_loss, latent_loss, hyp.vqrgb_latent_coeff, summ_writer) # count the number of unique inds being used unique_inds_here = np.unique(inds.detach().cpu().numpy()) self.ind_pool.update(unique_inds_here) all_used_inds = self.ind_pool.fetch() unique_used_inds = np.unique(all_used_inds) utils_misc.add_loss('vqrgbnet/num_used_inds', 0.0, len(unique_used_inds), 0.0, summ_writer) ind_image = inds.reshape(B, int(H / 8), int(W / 8)) return total_loss, rgb_e, ind_image
def forward(self, template, template_mask, search_region, xyz, r_delta, sampled_corners, sampled_centers, vox_util=None, summ_writer=None): # template is the thing we are searching for; it is B x C x ZZ x ZY x ZX # template_mask marks the voxels of the object we care about, within the template; it is B x C x ZZ x ZY x ZX # search_region is the featuremap where we are searching; it is B x C x Z x Y x X # xyz is the location of the answer in the search region; it is B x 3 total_loss = torch.tensor(0.0).cuda() B, C, ZZ, ZY, ZX = list(template.shape) _, _, Z, Y, X = list(search_region.shape) _, D = list(xyz.shape) assert(D==3) if hyp.rigid_use_cubes: R = hyp.rigid_repeats template = template.repeat(R, 1, 1, 1, 1) template_mask = template_mask.repeat(R, 1, 1, 1, 1) search_region = search_region.repeat(R, 1, 1, 1, 1) if summ_writer is not None: box_vox = vox_util.voxelize_xyz(torch.cat([sampled_corners, sampled_centers], dim=1), ZZ, ZY, ZX, already_mem=True) # summ_writer.summ_scalar('rigid/num_tries', num_tries) summ_writer.summ_occ('rigid/box', box_vox, reduce_axes=[2,3,4]) summ_writer.summ_oned('rigid/mask', template_mask, bev=True, norm=False) ## NOTE: YOU GOTTA HANDLE THE CENTROID OFFSET, FOR THE TRANSLATION TASK TO BE WELL FORMED # > to simplify this problem at the start, let's use a fixed centroid # >> ok done # now i need to sample features from these locations # i think this is fast enough: template_vec = torch.zeros([B*R, C, 8]).float().cuda() for b in list(range(B*R)): corners_b = sampled_corners[b].long() for ci, corner in enumerate(corners_b): template_vec[b,:,ci] = template[b,:,corner[2],corner[1],corner[0]] search_vec = search_region.view(B*R, C, -1) # this is B x C x huge search_vec = search_vec.permute(0, 2, 1) # this is B x huge x C corr_vec = torch.matmul(search_vec, template_vec) # this is B x huge x med # print('corr_vec', corr_vec.shape) corr = corr_vec.reshape(B*R, Z, Y, X, 8) corr = corr.permute(0, 4, 1, 2, 3) # corr is B x 8 x Z x Y x X # print('corr', corr.shape) # next step is: # a network should do quick work of this and turn it into an output # corr = corr.reshape(B, -1) # rigid = self.predictor(corr) # rigid = self.predictor2(feat) feat = self.predictor1(corr) # # rule19: # # print('feat', feat.shape) # feat = torch.mean(feat, dim=[2,3,4]) # rule17: feat = feat.reshape(B*R, -1) # print('feat', feat.shape) rigid = self.predictor2(feat) # rigid is B*R x 9 rigid = rigid.reshape(B, R, 9) normal_center = np.reshape(np.array([ZX/2, ZY/2, ZZ/2]), [1, 1, 3]) normal_centers = torch.from_numpy(normal_center).float().cuda().repeat(B*R, 1, 1) rigid[:,:,:3] = rigid[:,:,:3] - (sampled_centers.reshape(B, R, 3) - normal_centers.reshape(B, R, 3)) # rigid[:,:,:3] = rigid[:,:,:3] + (normal_centers.reshape(B, R, 3) - sampled_centers.reshape(B, R, 3)) # rigid[:,:,:3] = rigid[:,:,:3] + (sampled_centers.reshape(B, R, 3) - normal_centers.reshape(B, R, 3)) rigid = torch.mean(rigid, dim=1) # # xyz_e is the location of the object in the search region, assuming we used normal center # but we didn't # xyz_e = rigid[:,:3] else: # ok, i want to corr each voxel of the template with the search region. # this is a giant matmul search_vec = search_region.view(B, C, -1) # this is B x C x huge search_vec = search_vec.permute(0, 2, 1) # this is B x huge x C template_vec = template.view(B, C, -1) # this is B x C x med corr_vec = torch.matmul(search_vec, template_vec) # this is B x huge x med # print('corr_vec', corr_vec.shape) corr = corr_vec.reshape(B, Z, Y, X, ZZ*ZY*ZX) corr = corr.permute(0, 4, 1, 2, 3) # corr is B x med x Z x Y x X # next step is: # a network should do quick work of this and turn it into an output # rigid = self.predictor(corr) # # this is B x 3 x 1 x 1 x 1 # rigid = rigid.view(B, 3) # # now, i basically want this to be the answer feat = self.predictor1(corr) # print('feat', feat.shape) feat = feat.reshape(B, -1) # print('feat', feat.shape) rigid = self.predictor2(feat) xyz_e = rigid[:,:3] # center = np.reshape(np.array([ZX/2, ZY/2, ZZ/2]), [1, 3]) # xyz_e = xyz_e - # center = np.reshape(np.array([ZX/2, ZY/2, ZZ/2]), [1, 3]) sin_e = rigid[:,3:6] cos_e = 1.0+rigid[:,6:9] sin_e, cos_e = utils_geom.sincos_norm(sin_e, cos_e) # let's say the sines and cosines are in xyz order rot_e = utils_geom.sincos2rotm( sin_e[:,2], # z sin_e[:,1], # y sin_e[:,0], # x cos_e[:,2], # z cos_e[:,1], # y cos_e[:,0]) # x rx_e, ry_e, rz_e = utils_geom.rotm2eul(rot_e) rad_e = torch.stack([rx_e, ry_e, rz_e], dim=1) deg_e = utils_geom.rad2deg(rad_e) deg_g = utils_geom.rad2deg(r_delta) # rad_g = torch.stack([rx_e, ry_e, rz_e], dim=1) if summ_writer is not None: rx_e, ry_e, rz_e = torch.unbind(deg_e, dim=1) summ_writer.summ_histogram('rx_e', rx_e) summ_writer.summ_histogram('ry_e', ry_e) summ_writer.summ_histogram('rz_e', rz_e) rx_g, ry_g, rz_g = torch.unbind(deg_e, dim=1) summ_writer.summ_histogram('rx_g', rx_g) summ_writer.summ_histogram('ry_g', ry_g) summ_writer.summ_histogram('rz_g', rz_g) # # let's be in degrees, for the loss # rx_e = utils_geom.rad2deg(rx_e) # ry_e = utils_geom.rad2deg(ry_e) # rz_e = utils_geom.rad2deg(rz_e) # rx_g = utils_geom.rad2deg(rx_g) # ry_g = utils_geom.rad2deg(ry_g) # rz_g = utils_geom.rad2deg(rz_g) # r_loss = torch.mean(torch.norm(deg_e - deg_g, dim=1)) # t_loss = torch.mean(torch.norm(xyz_e - xyz, dim=1)) r_loss = self.smoothl1(deg_e, deg_g) t_loss = self.smoothl1(xyz_e, xyz) # now, i basically want this to be the answer # rigid_loss = torch.mean(torch.norm(rigid - xyz, dim=1)) total_loss = utils_misc.add_loss('rigid/r_loss', total_loss, r_loss, hyp.rigid_r_coeff, summ_writer) total_loss = utils_misc.add_loss('rigid/t_loss', total_loss, t_loss, hyp.rigid_t_coeff, summ_writer) # print('r_loss', r_loss.detach().cpu().numpy()) # print('t_loss', t_loss.detach().cpu().numpy()) if summ_writer is not None: # inputs summ_writer.summ_feat('rigid/input_template', template, pca=False) summ_writer.summ_feat('rigid/input_search_region', search_region, pca=False) return xyz_e, rad_e, total_loss
def forward(self, template_feat, search_feat, template_mask, template_lrt, search_lrt, vox_util, lrt_cam0s, summ_writer=None): # template_feat is the thing we are searching for; it is B x C x ZZ x ZY x ZX # search_feat is the featuremap where we are searching; it is B x C x Z x Y x X total_loss = torch.tensor(0.0).cuda() B, C, ZZ, ZY, ZX = list(template_feat.shape) _, _, Z, Y, X = list(search_feat.shape) xyz0_template = utils_basic.gridcloud3D(B, ZZ, ZY, ZX) # this is B x med x 3 xyz0_cam = vox_util.Zoom2Ref(xyz0_template, template_lrt, ZZ, ZY, ZX) # ok, next, after i relocate the object in search coords, # i need to transform those coords into cam, and then do svd on that # print('template_feat', template_feat.shape) # print('search_feat', search_feat.shape) search_feat = search_feat.view(B, C, -1) # this is B x C x huge template_feat = template_feat.view(B, C, -1) # this is B x C x med template_mask = template_mask.view(B, -1) # this is B x med # next i need to sample # i would like to take N random samples within the mask cam1_T_cam0_e = utils_geom.eye_4x4(B) # to simplify the impl, we will iterate over the batch dim for b in list(range(B)): template_feat_b = template_feat[b] template_mask_b = template_mask[b] search_feat_b = search_feat[b] xyz0_cam_b = xyz0_cam[b] # print('xyz0_cam_b', xyz0_cam_b.shape) # print('template_mask_b', template_mask_b.shape) # print('template_mask_b sum', torch.sum(template_mask_b).cpu().numpy()) # take any points within the mask inds = torch.where(template_mask_b > 0) # gather up template_feat_b = template_feat_b.permute(1, 0) # this is C x med template_feat_b = template_feat_b[inds] xyz0_cam_b = xyz0_cam_b[inds] # these are self.num_pts x C # print('inds', inds) # not sure why this is a tuple # inds = inds[0] # trim down to self.num_pts # inds = inds.squeeze() assert (len(xyz0_cam_b) > 8) # otw we should have returned early # i want to have self.num_pts pts every time if len(xyz0_cam_b) < self.num_pts: reps = int(self.num_pts / len(xyz0_cam_b)) + 1 print('only have %d pts; repeating %d times...' % (len(xyz0_cam_b), reps)) xyz0_cam_b = xyz0_cam_b.repeat(reps, 1) template_feat_b = template_feat_b.repeat(reps, 1) assert (len(xyz0_cam_b) >= self.num_pts) # now trim down perm = np.random.permutation(len(xyz0_cam_b)) # print('perm', perm[:10]) xyz0_cam_b = xyz0_cam_b[perm[:self.num_pts]] template_feat_b = template_feat_b[perm[:self.num_pts]] heat_b = torch.matmul(template_feat_b, search_feat_b) # this is self.num_pts x huge # it represents each point's heatmap in the search region # make the min zero heat_b = heat_b - (torch.min(heat_b, dim=1).values).unsqueeze(1) # scale up, for numerical stability heat_b = heat_b * float(len(heat_b[0].reshape(-1))) heat_b = heat_b.reshape(self.num_pts, 1, Z, Y, X) xyz1_search_b = utils_basic.argmax3D(heat_b, hard=False, stack=True) # this is self.num_pts x 3 # i need to get to cam coords xyz1_cam_b = vox_util.Zoom2Ref(xyz1_search_b.unsqueeze(0), search_lrt[b:b + 1], Z, Y, X).squeeze(0) # print('xyz0, xyz1', xyz0_cam_b.shape, xyz1_cam_b.shape) # cam1_T_cam0_e[b] = utils_track.rigid_transform_3D(xyz0_cam_b, xyz1_cam_b) # cam1_T_cam0_e[b] = utils_track.differentiable_rigid_transform_3D(xyz0_cam_b, xyz1_cam_b) cam1_T_cam0_e[b] = utils_track.rigid_transform_3D( xyz0_cam_b, xyz1_cam_b) _, rt_cam0_g = utils_geom.split_lrt(lrt_cam0s[:, 0]) _, rt_cam1_g = utils_geom.split_lrt(lrt_cam0s[:, 1]) # these represent ref_T_obj cam1_T_cam0_g = torch.matmul(rt_cam1_g, rt_cam0_g.inverse()) # cam1_T_cam0_e = cam1_T_cam0_g lrt_cam1_e = utils_geom.apply_4x4_to_lrtlist(cam1_T_cam0_e, lrt_cam0s[:, 0:1]).squeeze(1) # lrt_cam1_g = lrt_cam0s[:,1] # _, rt_cam1_e = utils_geom.split_lrt(lrt_cam1_e) # _, rt_cam1_g = utils_geom.split_lrt(lrt_cam1_g) # let's try the cube loss lx, ly, lz = 1.0, 1.0, 1.0 x = np.array([ lx / 2., lx / 2., -lx / 2., -lx / 2., lx / 2., lx / 2., -lx / 2., -lx / 2. ]) y = np.array([ ly / 2., ly / 2., ly / 2., ly / 2., -ly / 2., -ly / 2., -ly / 2., -ly / 2. ]) z = np.array([ lz / 2., -lz / 2., -lz / 2., lz / 2., lz / 2., -lz / 2., -lz / 2., lz / 2. ]) xyz = np.stack([x, y, z], axis=1) # this is 8 x 3 xyz = torch.from_numpy(xyz).float().cuda() xyz = xyz.reshape(1, 8, 3) # this is B x 8 x 3 # xyz_e = utils_geom.apply_4x4(rt_cam1_e, xyz) # xyz_g = utils_geom.apply_4x4(rt_cam1_g, xyz) xyz_e = utils_geom.apply_4x4(cam1_T_cam0_e, xyz) xyz_g = utils_geom.apply_4x4(cam1_T_cam0_g, xyz) # print('xyz_e', xyz_e.detach().cpu().numpy()) # print('xyz_g', xyz_g.detach().cpu().numpy()) corner_loss = self.smoothl1(xyz_e, xyz_g) total_loss = utils_misc.add_loss('robust/corner_loss', total_loss, corner_loss, hyp.robust_corner_coeff, summ_writer) # rot_e, t_e = utils_geom.split_rt(rt_cam1_e) # rot_g, t_g = utils_geom.split_rt(rt_cam1_g) rot_e, t_e = utils_geom.split_rt(cam1_T_cam0_e) rot_g, t_g = utils_geom.split_rt(cam1_T_cam0_g) rx_e, ry_e, rz_e = utils_geom.rotm2eul(rot_e) rx_g, ry_g, rz_g = utils_geom.rotm2eul(rot_g) rad_e = torch.stack([rx_e, ry_e, rz_e], dim=1) rad_g = torch.stack([rx_g, ry_g, rz_g], dim=1) deg_e = utils_geom.rad2deg(rad_e) deg_g = utils_geom.rad2deg(rad_g) r_loss = self.smoothl1(deg_e, deg_g) t_loss = self.smoothl1(t_e, t_g) total_loss = utils_misc.add_loss('robust/r_loss', total_loss, r_loss, hyp.robust_r_coeff, summ_writer) total_loss = utils_misc.add_loss('robust/t_loss', total_loss, t_loss, hyp.robust_t_coeff, summ_writer) # print('r_loss', r_loss.detach().cpu().numpy()) # print('t_loss', t_loss.detach().cpu().numpy()) return lrt_cam1_e, total_loss
def forward(self, feat, occ_g, free_g, summ_writer): total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(feat.shape) # feat is B x C x Z x Y x X # occ_g is B x 1 x Z x Y x X if hyp.preocc_do_flip: # randomly flip the input flip0 = torch.rand(1) flip1 = torch.rand(1) flip2 = torch.rand(1) if flip0 > 0.5: # transpose width/depth (rotate 90deg) feat = feat.permute(0,1,4,3,2) if flip1 > 0.5: # flip depth feat = feat.flip(2) if flip2 > 0.5: # flip width feat = feat.flip(4) feat = self.net(feat) if hyp.preocc_do_flip: if flip2 > 0.5: # unflip width feat = feat.flip(4) if flip1 > 0.5: # unflip depth feat = feat.flip(2) if flip0 > 0.5: # untranspose width/depth feat = feat.permute(0,1,4,3,2) # this is half res, so let's bring it up occ_e_ = F.interpolate(feat, scale_factor=2) # occ_e_ is B x 1 x Z x Y x X # smooth loss dz, dy, dx = gradient3D(occ_e_, absolute=True) smooth_vox = torch.mean(dx+dy+dx, dim=1, keepdims=True) smooth_loss = torch.mean(smooth_vox) summ_writer.summ_oned('preocc/smooth_loss', torch.mean(smooth_vox, dim=3)) total_loss = utils_misc.add_loss('preocc/smooth_loss', total_loss, smooth_loss, hyp.preocc_smooth_coeff, summ_writer) occ_e = F.sigmoid(occ_e_) occ_e_binary = torch.round(occ_e) summ_writer.summ_oned('preocc/reg_loss', torch.mean(occ_e, dim=3)) total_loss = utils_misc.add_loss('preocc/regularizer_loss', total_loss, torch.mean(occ_e), hyp.preocc_reg_coeff, summ_writer) # collect some accuracy stats occ_match = occ_g*torch.eq(occ_e_binary, occ_g).float() free_match = free_g*torch.eq(1.0-occ_e_binary, free_g).float() either_match = torch.clamp(occ_match+free_match, 0.0, 1.0) either_have = torch.clamp(occ_g+free_g, 0.0, 1.0) acc_occ = reduce_masked_mean(occ_match, occ_g) acc_free = reduce_masked_mean(free_match, free_g) acc_total = reduce_masked_mean(either_match, either_have) summ_writer.summ_scalar('preocc/acc_occ', acc_occ.cpu().item()) summ_writer.summ_scalar('preocc/acc_free', acc_free.cpu().item()) summ_writer.summ_scalar('preocc/acc_total', acc_total.cpu().item()) amount_occ = torch.mean(occ_e_binary) summ_writer.summ_scalar('preocc/amount_occ', amount_occ.cpu().item()) # vis summ_writer.summ_occ('preocc/occ_g', occ_g, reduce_axes=[2,3]) summ_writer.summ_occ('preocc/free_g', free_g, reduce_axes=[2,3]) summ_writer.summ_occ('preocc/occ_e', occ_e, reduce_axes=[2,3]) summ_writer.summ_occ('preocc/occ_e_binary', occ_e_binary, reduce_axes=[2,3]) prob_loss = self.compute_loss(occ_e_, occ_g, free_g, summ_writer) total_loss = utils_misc.add_loss('preocc/prob_loss', total_loss, prob_loss, hyp.preocc_coeff, summ_writer) # compute final computation mask (for later nets) # first fatten the gt; we will include all this weights = torch.ones(1, 1, 3, 3, 3, device=torch.device('cuda')) occ_g_fat = F.conv3d(occ_g, weights, padding=1) occ_g_fat = torch.clamp(occ_g_fat, 0, 1) # to save us in the case that occ_g_fat is already beyond our target density, # let's add some uncertainty to it, so that we have a chance to drop some of it # (in practice, i never see final density drop to 0, which means this is not a big risk) occ_g_mask = torch.FloatTensor(B, 1, Z, Y, X).uniform_(0.8, 1.0).cuda() occ_g_fat *= occ_g_mask summ_writer.summ_occ('preocc/occ_g_fat', occ_g_fat) # definitely exclude the known free voxels comp_mask = torch.clamp(occ_e.detach()-free_g, 0, 1) # definitely include the known occ voxels comp_mask = torch.clamp(comp_mask+occ_g_fat, 0, 1) summ_writer.summ_occ('preocc/comp_mask', comp_mask.round()) # print('trimming comp_mask to have at most %.2f density' % hyp.preocc_density_coeff) comp_mask[comp_mask < 0.5] = 0.0 if hyp.preocc_density_coeff > 0: while torch.mean(comp_mask.round()) > hyp.preocc_density_coeff: comp_vec = comp_mask.reshape(-1) nonzero_min = torch.min(comp_vec[comp_vec > 0]) comp_mask[comp_mask < (nonzero_min + 0.05)] = 0.0 # print('setting values under %.2f+0.05 to zero; now density is %.2f' % ( # nonzero_min.cpu().numpy(), torch.mean(comp_mask.round()).cpu().numpy())) comp_mask = torch.round(comp_mask) summ_writer.summ_occ('preocc/comp_mask_trimmed', comp_mask) amount_comp = torch.mean(comp_mask) summ_writer.summ_scalar('preocc/amount_comp', amount_comp.cpu().item()) return total_loss, comp_mask
def forward(self, image_input, summ_writer=None, is_train=True): B, H, W = list(image_input.shape) total_loss = torch.tensor(0.0).cuda() summ_writer.summ_oned('sigen2d/image_input', image_input.unsqueeze(1) / 512.0, norm=False) emb = self.embed(image_input) y = torch.randint(low=0, high=H, size=[B, self.num_choices, 1]) x = torch.randint(low=0, high=W, size=[B, self.num_choices, 1]) choice_mask = utils_improc.xy2mask(torch.cat([x, y], dim=2), H, W, norm=False) summ_writer.summ_oned('sigen2d/choice_mask', choice_mask, norm=False) # cover up the 3x3 region surrounding each choice xy = torch.cat([ torch.cat([x - 1, y - 1], dim=2), torch.cat([x + 0, y - 1], dim=2), torch.cat([x + 1, y - 1], dim=2), torch.cat([x - 1, y], dim=2), torch.cat([x + 0, y], dim=2), torch.cat([x + 1, y], dim=2), torch.cat([x - 1, y + 1], dim=2), torch.cat([x + 0, y + 1], dim=2), torch.cat([x + 1, y + 1], dim=2) ], dim=1) input_mask = 1.0 - utils_improc.xy2mask(xy, H, W, norm=False) # if is_train: # input_mask = (torch.rand((B, 1, H, W)).cuda() > 0.5).float() # else: # input_mask = torch.ones((B, 1, H, W)).cuda().float() # input_mask = input_mask * (1.0 - choice_mask) # input_mask = 1.0 - choice_mask emb = emb * input_mask summ_writer.summ_oned('sigen2d/input_mask', input_mask, norm=False) image_output_logits, _ = self.net(emb, input_mask) image_output = torch.argmax(image_output_logits, dim=1, keepdim=True) summ_writer.summ_feat('sigen2d/emb', emb, pca=True) summ_writer.summ_oned('sigen2d/image_output', image_output.float() / 512.0, norm=False) ce_loss_image = F.cross_entropy(image_output_logits, image_input, reduction='none').unsqueeze(1) # summ_writer.summ_oned('sigen2d/ce_loss', ce_loss_image) # ce_loss = torch.mean(ce_loss_image) # only apply loss at the choice pixels ce_loss = utils_basic.reduce_masked_mean(ce_loss_image, choice_mask) total_loss = utils_misc.add_loss('sigen2d/ce_loss', total_loss, ce_loss, hyp.sigen2d_coeff, summ_writer) return total_loss
def forward(self, feat, occ_g, free_g, valid, summ_writer, prefix="", log_summ=True, only_pred=False): total_loss = torch.tensor(0.0).cuda() occ_e_ = self.conv3d(feat) # occ_e_ is B x 1 x Z x Y x X # smooth loss dz, dy, dx = gradient3D(occ_e_, absolute=True) smooth_vox = torch.mean(dx + dy + dx, dim=1, keepdims=True) summ_writer.summ_oned(f'occ/{prefix}smooth_loss', torch.mean(smooth_vox, dim=3)) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss(f'occ/{prefix}smooth_loss', total_loss, smooth_loss, hyp.occ_smooth_coeff, summ_writer) occ_e = F.sigmoid(occ_e_) occ_e_binary = torch.round(occ_e) # collect some accuracy stats occ_match = occ_g * torch.eq(occ_e_binary, occ_g).float() free_match = free_g * torch.eq(1.0 - occ_e_binary, free_g).float() either_match = torch.clamp(occ_match + free_match, 0.0, 1.0) either_have = torch.clamp(occ_g + free_g, 0.0, 1.0) acc_occ = reduce_masked_mean(occ_match, occ_g * valid) acc_free = reduce_masked_mean(free_match, free_g * valid) acc_total = reduce_masked_mean(either_match, either_have * valid) if log_summ: summ_writer.summ_scalar(f'occ/{prefix}acc_occ', acc_occ.cpu().item()) summ_writer.summ_scalar(f'occ/{prefix}acc_free', acc_free.cpu().item()) summ_writer.summ_scalar(f'occ/{prefix}acc_total', acc_total.cpu().item()) # vis summ_writer.summ_occ(f'occ/{prefix}occ_g', occ_g) summ_writer.summ_occ(f'occ/{prefix}free_g', free_g) summ_writer.summ_occ(f'occ/{prefix}occ_e', occ_e) summ_writer.summ_occ(f'occ/{prefix}valid', valid) prob_loss = self.compute_loss(occ_e_, occ_g, free_g, valid, summ_writer, prefix=prefix) total_loss = utils_misc.add_loss(f'occ/{prefix}prob_loss', total_loss, prob_loss, hyp.occ_coeff, summ_writer) return total_loss, occ_e