def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ # batch_size = self.hparams.train.batch_size num_hierarchy_levels = self.hparams.train.num_hierarchy_levels truncation = self.hparams.train.truncation use_loss_masking = self.hparams.train.use_loss_masking logweight_target_sdf = self.hparams.model.logweight_target_sdf weight_missing_geo = self.hparams.train.weight_missing_geo sample = batch sdfs = sample['sdf'] # TODO: fix it # if sdfs.shape[0] < batch_size: # continue # maintain same batch size for training inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() if use_loss_masking: known = known.cuda() inputs[0] = inputs[0].cuda() inputs[1] = inputs[1].cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs.cuda(), hierarchy, num_hierarchy_levels, truncation, use_loss_masking, known) # update loss weights _iter = self._iter_counter loss_weights = get_loss_weights( _iter, self.hparams.train.num_hierarchy_levels, self.hparams.train.num_iters_per_level, self.hparams.train.weight_sdf_loss) output_sdf, output_occs = self(inputs, loss_weights) loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, truncation, logweight_target_sdf, weight_missing_geo, inputs[0], use_loss_masking, known) output = OrderedDict({ 'val_loss': loss, }) losses_dict = dict([(f'val_loss_{i}', l) for (i, l) in enumerate(losses)]) output.update(losses_dict) return output
def check(dataloader, output_save): num_batches = len(dataloader) print("num_batches: ", num_batches) for t, sample in enumerate(dataloader): print("START") sdfs = sample['sdf'] if sdfs.shape[0] < args.batch_size: print("empty sdf: ", sample['name']) continue # maintain same batch size for training inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs, hierarchy, args.num_hierarchy_levels, args.truncation, True, known, flipped=FLIP) # print("target_for_occs: ", target_for_occs[-1].shape, (target_for_occs[-1]>0).sum(), (target_for_occs[-1]==0).sum()) dims = sample['orig_dims'][0] inputs = inputs.cpu().numpy() # print(inputs[0].shape, inputs[0][0], target_for_sdf.shape) data_util.save_input_target_pred(args.save, sample['name'], inputs, target_for_sdf, None, truncation=2) print("END") return
def test(dataloader, output_vis, num_to_vis): model.eval() chunk_dim = args.input_dim args.max_input_height = chunk_dim[0] if args.stride == 0: args.stride = chunk_dim[1] pad = 2 num_proc = 0 num_vis = 0 num_batches = len(dataloader) print('starting testing...') with torch.no_grad(): for t, sample in enumerate(dataloader): inputs = sample['input'] sdfs = sample['sdf'] mask = sample['mask'] colors = sample['colors'] max_input_dim = np.array(sdfs.shape[2:]) if args.max_input_height > 0 and max_input_dim[ UP_AXIS] > args.max_input_height: max_input_dim[UP_AXIS] = args.max_input_height inputs = inputs[:, :, :args.max_input_height] if mask is not None: mask = mask[:, :, :args.max_input_height] if sdfs is not None: sdfs = sdfs[:, :, :args.max_input_height] if colors is not None: colors = colors[:, :args.max_input_height] sys.stdout.write( '\r[ %d | %d ] %s (%d, %d, %d) ' % (num_proc, args.max_to_process, sample['name'], max_input_dim[0], max_input_dim[1], max_input_dim[2])) output_colors = torch.zeros(colors.shape) output_sdfs = torch.zeros(sdfs.shape) output_norms = torch.zeros(sdfs.shape) output_occs = torch.zeros(sdfs.shape, dtype=torch.uint8) # chunk up the scene chunk_input = torch.ones(1, args.input_nf, chunk_dim[0], chunk_dim[1], chunk_dim[2]).cuda() chunk_mask = torch.ones(1, 1, chunk_dim[0], chunk_dim[1], chunk_dim[2]).cuda() chunk_target_sdf = torch.ones(1, 1, chunk_dim[0], chunk_dim[1], chunk_dim[2]).cuda() chunk_target_colors = torch.zeros(1, chunk_dim[0], chunk_dim[1], chunk_dim[2], 3, dtype=torch.uint8).cuda() for y in range(0, max_input_dim[1], args.stride): for x in range(0, max_input_dim[2], args.stride): chunk_input_mask = torch.abs( inputs[:, :, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]]) < args.truncation if torch.sum(chunk_input_mask).item() == 0: continue sys.stdout.write( '\r[ %d | %d ] %s (%d, %d, %d) (%d, %d) ' % (num_proc, args.max_to_process, sample['name'], max_input_dim[0], max_input_dim[1], max_input_dim[2], y, x)) fill_dim = [ min(sdfs.shape[2], chunk_dim[0]), min(sdfs.shape[3] - y, chunk_dim[1]), min(sdfs.shape[4] - x, chunk_dim[2]) ] chunk_target_sdf.fill_(float('inf')) chunk_target_colors.fill_(0) chunk_input[:, 0].fill_(-args.truncation) chunk_input[:, 1:].fill_(0) chunk_mask.fill_(0) chunk_input[:, :, :fill_dim[0], :fill_dim[1], : fill_dim[2]] = inputs[:, :, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]] chunk_mask[:, :, :fill_dim[0], :fill_dim[1], : fill_dim[2]] = mask[:, :, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]] chunk_target_sdf[:, :, :fill_dim[0], :fill_dim[1], : fill_dim[2]] = sdfs[:, :, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]] chunk_target_colors[:, :fill_dim[0], :fill_dim[ 1], :fill_dim[2], :] = colors[:, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]] target_for_sdf, target_for_colors = loss_util.compute_targets( chunk_target_sdf.cuda(), args.truncation, False, None, chunk_target_colors.cuda()) output_occ = None output_occ, output_sdf, output_color = model( chunk_input, chunk_mask, pred_sdf=[True, True], pred_color=args.weight_color_loss > 0) if output_occ is not None: occ = torch.nn.Sigmoid()(output_occ.detach()) > 0.5 locs = torch.nonzero((torch.abs( output_sdf.detach()[:, 0]) < args.truncation) & occ[:, 0]).cpu() else: locs = torch.nonzero( torch.abs(output_sdf[:, 0]) < args.truncation).cpu() locs = torch.cat([locs[:, 1:], locs[:, :1]], 1) output_sdf = [ locs, output_sdf[locs[:, -1], :, locs[:, 0], locs[:, 1], locs[:, 2]].detach().cpu() ] if args.weight_color_loss == 0: output_color = None else: output_color = [ locs, output_color[locs[:, -1], :, locs[:, 0], locs[:, 1], locs[:, 2]] ] output_locs = output_sdf[0] + torch.LongTensor( [0, y, x, 0]) if args.stride < chunk_dim[1]: min_dim = [0, y, x] max_dim = [ 0 + chunk_dim[0], y + chunk_dim[1], x + chunk_dim[2] ] if y > 0: min_dim[1] += pad if x > 0: min_dim[2] += pad if y + chunk_dim[1] < max_input_dim[1]: max_dim[1] -= pad if x + chunk_dim[2] < max_input_dim[2]: max_dim[2] -= pad for k in range(3): max_dim[k] = min(max_dim[k], sdfs.shape[k + 2]) outmask = (output_locs[:, 0] >= min_dim[0]) & ( output_locs[:, 1] >= min_dim[1]) & (output_locs[:, 2] >= min_dim[2]) & ( output_locs[:, 0] < max_dim[0]) & ( output_locs[:, 1] < max_dim[1]) & ( output_locs[:, 2] < max_dim[2]) else: outmask = ( output_locs[:, 0] < output_sdfs.shape[2]) & ( output_locs[:, 1] < output_sdfs.shape[3]) & ( output_locs[:, 2] < output_sdfs.shape[4]) output_locs = output_locs[outmask] output_sdf = [ output_sdf[0][outmask], output_sdf[1][outmask] ] if output_color is not None: output_color = [ output_color[0][outmask], output_color[1][outmask] ] output_color = (output_color[1] + 1) * 0.5 output_colors[ 0, output_locs[:, 0], output_locs[:, 1], output_locs[:, 2], :] += output_color.detach().cpu() if output_occ is not None: output_occs[:, :, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]] = occ[:, :, :fill_dim[ 0], :fill_dim[1], :fill_dim[2]] output_sdfs[ 0, 0, output_locs[:, 0], output_locs[:, 1], output_locs[:, 2]] += output_sdf[1][:, 0].detach().cpu() output_norms[0, 0, output_locs[:, 0], output_locs[:, 1], output_locs[:, 2]] += 1 # normalize mask = output_norms > 0 output_norms = output_norms[mask] output_sdfs[mask] = output_sdfs[mask] / output_norms output_sdfs[~mask] = -float('inf') mask = mask.view(1, mask.shape[2], mask.shape[3], mask.shape[4]) output_colors[ mask, :] = output_colors[mask, :] / output_norms.view(-1, 1) output_colors = torch.clamp(output_colors * 255, 0, 255) sdfs = torch.clamp(sdfs, -args.truncation, args.truncation) output_sdfs = torch.clamp(output_sdfs, -args.truncation, args.truncation) if num_vis < num_to_vis: inputs = inputs.cpu().numpy() locs = torch.nonzero( torch.abs(output_sdfs[0, 0]) < args.truncation) vis_pred_sdf = [None] vis_pred_color = [None] sdf_vals = output_sdfs[0, 0, locs[:, 0], locs[:, 1], locs[:, 2]].view(-1) vis_pred_sdf[0] = [locs.cpu().numpy(), sdf_vals.cpu().numpy()] if args.weight_color_loss > 0: vals = output_colors[0, locs[:, 0], locs[:, 1], locs[:, 2]] vis_pred_color[0] = vals.cpu().numpy() if output_occs is not None: pred_occ = output_occs.cpu().numpy().astype(np.float32) data_util.save_predictions(output_vis, np.arange(1), sample['name'], inputs, sdfs.cpu().numpy(), colors.cpu().numpy(), None, None, vis_pred_sdf, vis_pred_color, None, None, sample['world2grid'], args.truncation, args.color_space, pred_occ=pred_occ) num_vis += 1 num_proc += 1 gc.collect() sys.stdout.write('\n')
def test(epoch, iter, loss_weights, dataloader, log_file, output_save): val_losses = [[] for i in range(args.num_hierarchy_levels + 2)] val_l1preds = [] val_l1tgts = [] val_ious = [[] for i in range(args.num_hierarchy_levels)] model.eval() #start = time.time() num_batches = len(dataloader) with torch.no_grad(): for t, sample in enumerate(dataloader): sdfs = sample['sdf'] if sdfs.shape[0] < args.batch_size: continue # maintain same batch size inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() if args.use_loss_masking: known = known.cuda() inputs[0] = inputs[0].cuda() inputs[1] = inputs[1].cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs.cuda(), hierarchy, args.num_hierarchy_levels, args.truncation, args.use_loss_masking, known) output_sdf, output_occs = model(inputs, loss_weights) loss, losses = loss_util.compute_loss( output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known) output_visual = output_save and t + 2 == num_batches compute_pred_occs = (t % 20 == 0) or output_visual if compute_pred_occs: pred_occs = [None] * args.num_hierarchy_levels for h in range(args.num_hierarchy_levels): factor = 2**(args.num_hierarchy_levels - h - 1) pred_occs[h] = [None] * args.batch_size if len(output_occs[h][0]) == 0: continue for b in range(args.batch_size): batchmask = output_occs[h][0][:, -1] == b locs = output_occs[h][0][batchmask][:, :-1] vals = torch.nn.Sigmoid()( output_occs[h][1][:, 0].detach()[batchmask]) > 0.5 pred_occs[h][b] = locs[vals.view(-1)] val_losses[0].append(loss.item()) for h in range(args.num_hierarchy_levels): val_losses[h + 1].append(losses[h]) target = target_for_occs[h].byte() if compute_pred_occs: iou = loss_util.compute_iou_sparse_dense( pred_occs[h], target, args.use_loss_masking) val_ious[h].append(iou) val_losses[args.num_hierarchy_levels + 1].append(losses[-1]) if len(output_sdf[0]) > 0: output_sdf = [output_sdf[0].detach(), output_sdf[1].detach()] if loss_weights[-1] > 0 and t % 20 == 0: val_l1preds.append( loss_util.compute_l1_predsurf_sparse_dense( output_sdf[0], output_sdf[1], target_for_sdf, None, False, args.use_loss_masking, known).item()) val_l1tgts.append( loss_util.compute_l1_tgtsurf_sparse_dense( output_sdf[0], output_sdf[1], target_for_sdf, args.truncation, args.use_loss_masking, known)) if output_visual: vis_pred_sdf = [None] * args.batch_size if len(output_sdf[0]) > 0: for b in range(args.batch_size): mask = output_sdf[0][:, -1] == b if len(mask) > 0: vis_pred_sdf[b] = [ output_sdf[0][mask].cpu().numpy(), output_sdf[1][mask].squeeze().cpu().numpy() ] inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()] for h in range(args.num_hierarchy_levels): for b in range(args.batch_size): if pred_occs[h][b] is not None: pred_occs[h][b] = pred_occs[h][b].cpu().numpy() data_util.save_predictions( os.path.join(args.save, 'iter%d-epoch%d' % (iter, epoch), 'val'), sample['name'], inputs, target_for_sdf.cpu().numpy(), [x.cpu().numpy() for x in target_for_occs], vis_pred_sdf, pred_occs, sample['world2grid'], args.vis_dfs, args.truncation) #took = time.time() - start return val_losses, val_l1preds, val_l1tgts, val_ious
def train(epoch, iter, dataloader, log_file, output_save): train_losses = [[] for i in range(args.num_hierarchy_levels + 2)] train_l1preds = [] train_l1tgts = [] train_ious = [[] for i in range(args.num_hierarchy_levels)] model.train() start = time.time() if args.scheduler_step_size == 0: scheduler.step() num_batches = len(dataloader) for t, sample in enumerate(dataloader): loss_weights = get_loss_weights(iter, args.num_hierarchy_levels, args.num_iters_per_level, args.weight_sdf_loss) if epoch == args.start_epoch and t == 0: print('[iter %d/epoch %d] loss_weights' % (iter, epoch), loss_weights) sdfs = sample['sdf'] if sdfs.shape[0] < args.batch_size: continue # maintain same batch size for training inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() if args.use_loss_masking: known = known.cuda() inputs[0] = inputs[0].cuda() inputs[1] = inputs[1].cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs.cuda(), hierarchy, args.num_hierarchy_levels, args.truncation, args.use_loss_masking, known) optimizer.zero_grad() output_sdf, output_occs = model(inputs, loss_weights) loss, losses = loss_util.compute_loss( output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known) loss.backward() optimizer.step() output_visual = output_save and t + 2 == num_batches compute_pred_occs = (iter % 20 == 0) or output_visual if compute_pred_occs: pred_occs = [None] * args.num_hierarchy_levels for h in range(args.num_hierarchy_levels): factor = 2**(args.num_hierarchy_levels - h - 1) pred_occs[h] = [None] * args.batch_size if len(output_occs[h][0]) == 0: continue output_occs[h][1] = torch.nn.Sigmoid()( output_occs[h][1][:, 0].detach()) > 0.5 for b in range(args.batch_size): batchmask = output_occs[h][0][:, -1] == b locs = output_occs[h][0][batchmask][:, :-1] vals = output_occs[h][1][batchmask] pred_occs[h][b] = locs[vals.view(-1)] train_losses[0].append(loss.item()) for h in range(args.num_hierarchy_levels): train_losses[h + 1].append(losses[h]) target = target_for_occs[h].byte() if compute_pred_occs: iou = loss_util.compute_iou_sparse_dense( pred_occs[h], target, args.use_loss_masking) train_ious[h].append(iou) train_losses[args.num_hierarchy_levels + 1].append(losses[-1]) if len(output_sdf[0]) > 0: output_sdf = [output_sdf[0].detach(), output_sdf[1].detach()] if loss_weights[-1] > 0 and iter % 20 == 0: train_l1preds.append( loss_util.compute_l1_predsurf_sparse_dense( output_sdf[0], output_sdf[1], target_for_sdf, None, False, args.use_loss_masking, known).item()) train_l1tgts.append( loss_util.compute_l1_tgtsurf_sparse_dense( output_sdf[0], output_sdf[1], target_for_sdf, args.truncation, args.use_loss_masking, known)) iter += 1 if args.scheduler_step_size > 0 and iter % args.scheduler_step_size == 0: scheduler.step() if iter % 20 == 0: took = time.time() - start print_log(log_file, epoch, iter, train_losses, train_l1preds, train_l1tgts, train_ious, None, None, None, None, took) if iter % 2000 == 0: torch.save( { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join(args.save, 'model-iter%s-epoch%s.pth' % (iter, epoch))) if output_visual: vis_pred_sdf = [None] * args.batch_size if len(output_sdf[0]) > 0: for b in range(args.batch_size): mask = output_sdf[0][:, -1] == b if len(mask) > 0: vis_pred_sdf[b] = [ output_sdf[0][mask].cpu().numpy(), output_sdf[1][mask].squeeze().cpu().numpy() ] inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()] for h in range(args.num_hierarchy_levels): for b in range(args.batch_size): if pred_occs[h][b] is not None: pred_occs[h][b] = pred_occs[h][b].cpu().numpy() data_util.save_predictions( os.path.join(args.save, 'iter%d-epoch%d' % (iter, epoch), 'train'), sample['name'], inputs, target_for_sdf.cpu().numpy(), [x.cpu().numpy() for x in target_for_occs], vis_pred_sdf, pred_occs, sample['world2grid'].numpy(), args.vis_dfs, args.truncation) return train_losses, train_l1preds, train_l1tgts, train_ious, iter, loss_weights
def test(loss_weights, dataloader, output_vis, num_to_vis): model.eval() num_vis = 0 num_batches = len(dataloader) print("num_batches: ", num_batches) with torch.no_grad(): for t, sample in enumerate(dataloader): inputs = sample['input'] print("input_dim: ", np.array(sample['sdf'].shape)) input_dim = np.array(sample['sdf'].shape[2:]) print(input_dim) sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d) ' % (num_vis, num_to_vis, sample['name'], input_dim[0], input_dim[1], input_dim[2])) sys.stdout.flush() # hierarchy_factor = pow(2, args.num_hierarchy_levels-1) # model.update_sizes(input_dim, input_dim // hierarchy_factor) sdfs = sample['sdf'] if sdfs.shape[0] < args.batch_size: continue # maintain same batch size for training inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() hierarchy[h].unsqueeze_(1) if args.use_loss_masking: known = known.cuda() inputs = inputs.cuda() sdfs = sdfs.cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs, hierarchy, args.num_hierarchy_levels, args.truncation, args.use_loss_masking, known, flipped=args.flipped) start_time = time.time() try: output_sdf, output_occs = model(inputs, loss_weights) except: print('exception at %s' % sample['name']) gc.collect() continue end_time = time.time() print('TIME: %.4fs' % (end_time - start_time)) # remove padding # dims = sample['orig_dims'][0] # mask = (output_sdf[0][:,0] < dims[0]) & (output_sdf[0][:,1] < dims[1]) & (output_sdf[0][:,2] < dims[2]) # output_sdf[0] = output_sdf[0][mask] # output_sdf[1] = output_sdf[1][mask] # mask = (inputs[0][:,0] < dims[0]) & (inputs[0][:,1] < dims[1]) & (inputs[0][:,2] < dims[2]) # inputs[0] = inputs[0][mask] # inputs[1] = inputs[1][mask] # # vis_pred_sdf = [None] # # if len(output_sdf[0]) > 0: # # vis_pred_sdf[0] = [output_sdf[0].cpu().numpy(), output_sdf[1].squeeze().cpu().numpy()] # inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()] # vis occ & sdf pred_occs = [None] * args.num_hierarchy_levels for h in range(args.num_hierarchy_levels): factor = 2**(args.num_hierarchy_levels - h - 1) pred_occs[h] = [None] * args.batch_size if len(output_occs[h][0]) == 0: continue # filter: occ > 0 for b in range(args.batch_size): batchmask = output_occs[h][0][:, -1] == b locs = output_occs[h][0][batchmask][:, :-1] vals = output_occs[h][1][batchmask] occ_mask = torch.nn.Sigmoid()(vals[:, 0].detach()) > 0.5 if args.flipped: pred_occs[h][b] = [ locs[occ_mask.view(-1)].cpu().numpy(), vals[occ_mask.view(-1)].cpu().numpy() ] else: pred_occs[h][b] = locs[occ_mask.view(-1)].cpu().numpy() vis_pred_sdf = [None] * args.batch_size if len(output_sdf[0]) > 0: for b in range(args.batch_size): mask = output_sdf[0][:, -1] == b if len(mask) > 0: vis_pred_sdf[b] = [ output_sdf[0][mask].cpu().numpy(), output_sdf[1][mask].squeeze().cpu().numpy() ] data_util.save_predictions(output_vis, sample['name'], inputs, target_for_sdf.cpu().numpy(), None, vis_pred_sdf, pred_occs, sample['world2grid'], args.truncation, flipped=args.flipped) num_vis += 1 if num_vis >= num_to_vis: break sys.stdout.write('\n')
def vis(self): count = 0 with torch.no_grad(): for t, sample in enumerate(self.val_dataloader): if rospy.is_shutdown(): exit() if count > self.max_num: return else: count += 1 print("name: ", sample['name']) sdfs = sample['sdf'] inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() hierarchy[h].unsqueeze_(1) known = known.cuda() inputs = inputs.cuda() sdfs = sdfs.cuda() predz = int(np.floor(inputs.shape[-1] / 8) * 8) # embed() inputs = inputs[:, :, :, :predz] sdfs = sdfs[:, :, :, :, :predz] # embed() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs, hierarchy, self.num_hierarchy_levels, 2.0, True, known, True) # embed() output_sdf, output_occs = self.model(inputs, self.weights) # bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_dense(output_sdf, output_occs, target_for_sdf, target_for_occs, # target_for_hier, self.weights, 3.0, True, 3.0, inputs, # True, known, flipped=True) # loss = bce_loss + sdf_loss # val_bceloss.append(last_loss) # val_sdfloss.append(sdf_loss) # val_losses[0].append(loss) # output_sdf[known_mask] = inputs[known_mask] if output_occs is not None: start_time = time.time() # pred_occ = output_sdf[:,0].squeeze() pred_occ = output_occs[-1][:, 0].squeeze() pred_occ = self.sigmoid_func(pred_occ) > 0.45 # pred_occ = output_sdf[:,1].squeeze() # pred_occ = (pred_occ).abs() < 2 target_occ = target_for_occs[-1].squeeze() > 0 input_occ = inputs.squeeze() > 0 # embed() # precision, recall, f1score = loss_util.compute_dense_occ_accuracy(input_occ, target_occ, pred_occ, truncation=3) dimx = pred_occ.shape[-3] dimy = pred_occ.shape[-2] dimz = pred_occ.shape[-1] # if not self.size_init: self.init_size(dimx, dimy, dimz) known_mask = (input_occ >= 0) # fix conflicts with input in known grids # pred_occ[known_mask] = input_occ[known_mask] # num = (input_occ | pred_occ | target_occ).float().sum().cpu().numpy().tolist() # print("total size: ", num) # b - r - y # vis pub # embed() input_occ_coords = self.pointers[input_occ] input_colors = torch.FloatTensor([255, 50, 50]).repeat( [input_occ_coords.shape[0], 1]) input_points = torch.cat([input_occ_coords, input_colors], dim=1) input_points[:, :3] *= self.voxel_size # print("input size: ", input_points.shape) # red pred_occ_coords = self.pointers[pred_occ & (~input_occ)] pred_colors = torch.FloatTensor([50, 50, 255]).repeat( [pred_occ_coords.shape[0], 1]) pred_points = torch.cat([pred_occ_coords, pred_colors], dim=1) pred_points[:, :3] *= self.voxel_size # print("pred size: ", pred_points.shape) target_occ_coords = self.pointers[ target_occ & (~input_occ)] # &(~pred_occ) target_colors = torch.FloatTensor([50, 255, 50]).repeat( [target_occ_coords.shape[0], 1]) target_points = torch.cat( [target_occ_coords, target_colors], dim=1) target_points[:, :3] *= self.voxel_size target_colors_blue = torch.FloatTensor( [255, 50, 50]).repeat([target_occ_coords.shape[0], 1]) target_points_blue = torch.cat( [target_occ_coords, target_colors_blue], dim=1) target_points_blue[:, :3] *= self.voxel_size # print("target size: ", target_points.shape) # points = torch.cat([input_points, pred_points, target_points], dim=0) input_points = input_points inout_points = torch.cat([input_points, target_points], dim=0) inout_points_blue = torch.cat( [input_points, target_points_blue], dim=0) inpred_points = torch.cat([input_points, pred_points], dim=0) # msg1 = PointCloud2() msg1.header.frame_id = "map" msg1.height = 1 msg1.width = input_points.shape[0] msg1.fields = [ PointField('x', 0, PointField.FLOAT32, 1), PointField('y', 4, PointField.FLOAT32, 1), PointField('z', 8, PointField.FLOAT32, 1), PointField('r', 12, PointField.FLOAT32, 1), PointField('g', 16, PointField.FLOAT32, 1), PointField('b', 20, PointField.FLOAT32, 1) ] msg1.is_bigendian = False msg1.point_step = 24 #12 msg1.row_step = 24 * msg1.width msg1.is_dense = False msg1.data = input_points.reshape( [-1]).cpu().numpy().tostring() # msg2 = deepcopy(msg1) msg2.width = inout_points.shape[0] msg2.row_step = 24 * msg2.width msg2.data = inout_points.reshape( [-1]).cpu().numpy().tostring() # msg0 = deepcopy(msg1) msg0.width = inout_points_blue.shape[0] msg0.row_step = 24 * msg2.width msg0.data = inout_points_blue.reshape( [-1]).cpu().numpy().tostring() # msg3 = deepcopy(msg1) msg3.width = inpred_points.shape[0] msg3.row_step = 24 * msg2.width msg3.data = inpred_points.reshape( [-1]).cpu().numpy().tostring() print("post process time: %fs" % (time.time() - start_time)) self.output_pub.publish(msg0) # rospy.sleep(2) raw_input('press key to continue') self.output_pub.publish(msg1) # rospy.sleep(2) raw_input('press key to continue') self.output_pub.publish(msg2) # rospy.sleep(2) raw_input('press key to continue') self.output_pub.publish(msg3) # rospy.sleep(2) raw_input('press key to continue')
def test(loss_weights, dataloader, output_vis, num_to_vis): model.eval() val_losses = [[] for i in range(args.num_hierarchy_levels + 2)] val_bceloss = [] val_sdfloss = [] val_precision = [] val_recall = [] val_f1score = [] val_ious = [] val_time = [] num_vis = 0 num_batches = len(dataloader) print("num_batches: ", num_batches) sigmoid_func = torch.nn.Sigmoid() with torch.no_grad(): for t, sample in enumerate(dataloader): # print("sample = ",sample) sdfs = sample['sdf'] inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] # print("check hie=",hierarchy) for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() hierarchy[h].unsqueeze_(1) # if args.use_loss_masking: if True: known = known.cuda() inputs = inputs.cuda() sdfs = sdfs.cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs, hierarchy, args.num_hierarchy_levels, args.truncation, True, known, flipped=args.flipped) # target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(sdfs, hierarchy, # 1, args.truncation, True, known, flipped=args.flipped) start_time = time.time() output_sdf, output_occs = model(inputs, loss_weights) known_mask = inputs > -1 val_time.append(time.time() - start_time) # loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known) bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_nosdf( output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs, args.use_loss_masking, known, flipped=args.flipped) loss = bce_loss * 2 # + sdf_loss val_bceloss.append(last_loss) # val_sdfloss.append(sdf_loss) val_losses[0].append(loss) # output_sdf[known_mask] = inputs[known_mask] if output_sdf is not None: pred_occ = output_occs[-1][:, 0].squeeze() pred_occ = sigmoid_func(pred_occ) > 0.0 # pred_occ = output_sdf[:,1].squeeze() # pred_occ = (pred_occ).abs() < 2 target_occ = target_for_occs[-1].squeeze() input_occ = inputs.squeeze() precision, recall, f1score = loss_util.compute_dense_occ_accuracy( input_occ, target_occ, pred_occ, truncation=args.truncation) else: precision, recall, f1score = 0, 0, 0 val_precision.append(precision) val_recall.append(recall) val_f1score.append(f1score) val_losses[args.num_hierarchy_levels + 1].append(losses[-1]) if True: data_util.save_dense_predictions(args.output, sample['name'], input_occ.cpu().numpy(), target_occ.cpu().numpy(), pred_occ.cpu().numpy(), args.truncation, flipped=args.flipped) print("saved: ", sample['name']) print("epoch_loss/total: ", np.mean(val_losses[0])) print("epoch_loss/bce: ", np.mean(val_bceloss)) print("epoch_loss/precision: ", np.mean(val_precision)) print("epoch_loss/recall: ", np.mean(val_recall)) print("epoch_loss/f1: ", np.mean(val_f1score)) print("average_time: ", np.mean(val_time)) return
def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ #x, y = batch #x = x.view(x.size(0), -1) #y_hat = self(x) #loss_val = self.loss(y, y_hat) ## acc #labels_hat = torch.argmax(y_hat, dim=1) #val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) #val_acc = torch.tensor(val_acc) #if self.on_gpu: #val_acc = val_acc.cuda(loss_val.device.index) #output = OrderedDict({ #'val_loss': loss_val, #'val_acc': val_acc, #}) ## can also return just a scalar instead of a dict (return loss_val) #return output batch_size = self.hparams.batch_size num_hierarchy_levels = self.hparams.num_hierarchy_levels truncation = self.hparams.truncation use_loss_masking = self.hparams.use_loss_masking logweight_target_sdf = self.hparams.logweight_target_sdf weight_missing_geo = self.hparams.weight_missing_geo sample = batch sdfs = sample['sdf'] # TODO: fix it #if sdfs.shape[0] < batch_size: # continue # maintain same batch size for training inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() if use_loss_masking: known = known.cuda() inputs[0] = inputs[0].cuda() inputs[1] = inputs[1].cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(sdfs.cuda(), hierarchy, num_hierarchy_levels, truncation, use_loss_masking, known) # TODO: update _iter = self._iter_counter loss_weights = get_loss_weights(_iter, self.hparams.num_hierarchy_levels, self.hparams.num_iters_per_level, self.hparams.weight_sdf_loss) output_sdf, output_occs = self(inputs, loss_weights) loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, truncation, logweight_target_sdf, weight_missing_geo, inputs[0], use_loss_masking, known) output = OrderedDict({ 'val_loss': loss, }) return output
def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ ## forward pass #x, y = batch #x = x.view(x.size(0), -1) #y_hat = self(x) ## calculate loss #loss_val = self.loss(y, y_hat) #tqdm_dict = {'train_loss': loss_val} #output = OrderedDict({ #'loss': loss_val, #'progress_bar': tqdm_dict, #'log': tqdm_dict #}) ## can also return just a scalar instead of a dict (return loss_val) #return output batch_size = self.hparams.batch_size num_hierarchy_levels = self.hparams.num_hierarchy_levels truncation = self.hparams.truncation use_loss_masking = self.hparams.use_loss_masking logweight_target_sdf = self.hparams.logweight_target_sdf weight_missing_geo = self.hparams.weight_missing_geo sample = batch sdfs = sample['sdf'] # TODO: fix it #if sdfs.shape[0] < batch_size: # continue # maintain same batch size for training inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() if use_loss_masking: known = known.cuda() inputs[0] = inputs[0].cuda() inputs[1] = inputs[1].cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(sdfs.cuda(), hierarchy, num_hierarchy_levels, truncation, use_loss_masking, known) # TODO: update #loss_weights = self.model._loss_weights _iter = self._iter_counter loss_weights = get_loss_weights(_iter, self.hparams.num_hierarchy_levels, self.hparams.num_iters_per_level, self.hparams.weight_sdf_loss) output_sdf, output_occs = self(inputs, loss_weights) loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, truncation, logweight_target_sdf, weight_missing_geo, inputs[0], use_loss_masking, known) tqdm_dict = {'train_loss': loss} output = OrderedDict({ 'loss': loss, 'progress_bar': tqdm_dict, 'log': tqdm_dict }) self._iter_counter += 1 return output
def test(epoch, iter, loss_weights, dataloader, log_file, output_save): start_time = time.time() val_losses = [[] for i in range(args.num_hierarchy_levels + 2)] val_bceloss = [] val_sdfloss = [] val_precision = [] val_recall = [] val_f1score = [] val_ious = [[] for i in range(args.num_hierarchy_levels)] model.eval() num_batches = len(dataloader) with torch.no_grad(): for t, sample in enumerate(dataloader): sdfs = sample['sdf'] if sdfs.shape[0] < args.batch_size: continue # maintain same batch size inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() hierarchy[h].unsqueeze_(1) if args.use_loss_masking: known = known.cuda() inputs = inputs.cuda() sdfs = sdfs.cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs, hierarchy, args.num_hierarchy_levels, args.truncation, args.use_loss_masking, known, flipped=args.flipped) output_sdf, output_occs = model(inputs, loss_weights) # loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known) bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_dense( output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs, args.use_loss_masking, known, flipped=args.flipped) loss = bce_loss * 10 + sdf_loss val_losses[0].append(loss.item()) if last_loss > 0: val_bceloss.append(last_loss) WRITER.add_scalar("loss/val", loss.item(), iter + t) output_visual = (output_save and t + 1 == num_batches and output_sdf is not None) compute_pred_occs = (t % 20 == 0) or output_visual # if compute_pred_occs: pred_occs = [None] * args.num_hierarchy_levels vis_occs = [None] * args.num_hierarchy_levels for h in range(args.num_hierarchy_levels): factor = 2**(args.num_hierarchy_levels - h - 1) if isinstance(output_occs[h], list): continue # output_occs[h][1] = torch.nn.Sigmoid()(output_occs[h][1][:,0].detach()) > 0.5 output_occs[h] = output_occs[h].detach() vis_occs[h] = output_occs[h][:, 1, :, :, :].unsqueeze(1) for h in range(args.num_hierarchy_levels): val_losses[h + 1].append(losses[h]) # if len(output_occs[-1]) is not 0: if output_sdf is not None: pred_occ = output_sdf[:, 0].squeeze() # pred_occ = pred_occ.abs() < (args.truncation / 2) pred_occ = pred_occ > 0.5 target_occ = target_for_occs[-1].squeeze() input_occ = inputs.squeeze() precision, recall, f1score = loss_util.compute_dense_occ_accuracy( input_occ, target_occ, pred_occ, truncation=args.truncation) else: precision, recall, f1score = 0, 0, 0 val_precision.append(precision) val_recall.append(recall) val_f1score.append(f1score) val_losses[args.num_hierarchy_levels + 1].append(losses[-1]) if output_visual: data_util.save_dense_predictions(os.path.join( args.save, 'iter%d-epoch%d' % (iter, epoch), 'val'), sample['name'], input_occ.cpu().numpy(), target_occ.cpu().numpy(), pred_occ.cpu().numpy(), args.truncation, flipped=args.flipped) WRITER.add_scalar("epoch_loss/val/total", np.mean(val_losses[0]), epoch) if len(val_bceloss) > 0: WRITER.add_scalar("epoch_loss/val/bce", np.mean(val_bceloss), epoch) WRITER.add_scalar("epoch_loss/val/precision", np.mean(val_precision), epoch) WRITER.add_scalar("epoch_loss/val/recall", np.mean(val_recall), epoch) WRITER.add_scalar("epoch_loss/val/f1", np.mean(val_f1score), epoch) print(" test epoch {} used {} seconds".format(epoch, time.time() - start_time)) return val_losses, val_precision, val_recall, val_f1score,
def train(epoch, iter, dataloader, log_file, output_save): start_time = time.time() train_losses = [[] for i in range(args.num_hierarchy_levels + 2)] train_bceloss = [] train_sdfloss = [] train_precision = [] train_recall = [] train_f1score = [] model.train() start = time.time() num_batches = len(dataloader) for t, sample in enumerate(dataloader): loss_weights = get_loss_weights(iter, args.num_hierarchy_levels, args.num_iters_per_level, args.weight_sdf_loss) if epoch == args.start_epoch and t == 0: print('[iter %d/epoch %d] loss_weights' % (iter, epoch), loss_weights) sdfs = sample['sdf'] if sdfs.shape[0] < args.batch_size: continue # maintain same batch size for training inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() hierarchy[h].unsqueeze_(1) if args.use_loss_masking: known = known.cuda() inputs = inputs.cuda() sdfs = sdfs.cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs, hierarchy, args.num_hierarchy_levels, args.truncation, args.use_loss_masking, known, flipped=args.flipped) optimizer.zero_grad() output_sdf, output_occs = model(inputs, loss_weights) # loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known) bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_dense( output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs, args.use_loss_masking, known, flipped=args.flipped) loss = bce_loss * 10 + sdf_loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) optimizer.step() train_losses[0].append(loss.item()) if last_loss > 0: train_bceloss.append(last_loss) train_sdfloss.append(bce_loss.item()) output_visual = (output_save and t + 1 == num_batches and output_sdf is not None) compute_pred_occs = (iter % 20 == 0) or output_visual # if compute_pred_occs: pred_occs = [None] * args.num_hierarchy_levels vis_occs = [None] * args.num_hierarchy_levels for h in range(args.num_hierarchy_levels): factor = 2**(args.num_hierarchy_levels - h - 1) if isinstance(output_occs[h], list): continue # output_occs[h][1] = torch.nn.Sigmoid()(output_occs[h][1][:,0].detach()) > 0.5 output_occs[h] = output_occs[h].detach() vis_occs[h] = output_occs[h][:, 1, :, :, :].unsqueeze(1) for h in range(args.num_hierarchy_levels): train_losses[h + 1].append(losses[h]) # if len(output_occs[-1]) is not 0: if output_sdf is not None: output_sdf.detach() # occ pred_occ = output_sdf[:, 0].squeeze() ## sdf # pred_occ = pred_occ.abs() < (args.truncation / 2) # occ pred_occ = pred_occ > 0.1 target_occ = target_for_occs[-1].squeeze() input_occ = inputs.detach().squeeze() precision, recall, f1score = loss_util.compute_dense_occ_accuracy( input_occ, target_occ, pred_occ, truncation=args.truncation) else: precision, recall, f1score = 0, 0, 0 train_precision.append(precision) train_recall.append(recall) train_f1score.append(f1score) train_losses[args.num_hierarchy_levels + 1].append(losses[-1]) if output_sdf is not None: output_sdf = output_sdf.detach() iter += 1 WRITER.add_scalar("loss/train", loss.item(), iter) if iter % 20 == 0: took = time.time() - start print_log(log_file, epoch, iter, train_losses, None, None, train_precision, None, None, None, None, took) if iter % 2000 == 0: torch.save( { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join(args.save, 'model-iter%s-epoch%s.pth' % (iter, epoch))) if args.scheduler_step_size == 0: scheduler.step() else: args.scheduler_step_size -= 1 WRITER.add_scalar("epoch_loss/train/total", np.mean(train_losses[0]), epoch) if len(train_bceloss) > 0: WRITER.add_scalar("epoch_loss/train/bce", np.mean(train_bceloss), epoch) WRITER.add_scalar("epoch_loss/train/precision", np.mean(train_precision), epoch) WRITER.add_scalar("epoch_loss/train/recall", np.mean(train_recall), epoch) WRITER.add_scalar("epoch_loss/train/f1", np.mean(train_f1score), epoch) print(" test epoch {} used {} seconds".format(epoch, time.time() - start_time)) return train_losses, train_precision, train_recall, train_f1score, iter, loss_weights
def test(loss_weights, dataloader, output_pred, output_vis, num_to_vis): model.eval() missing = [] num_proc = 0 num_vis = 0 num_batches = len(dataloader) with torch.no_grad(): for t, sample in enumerate(dataloader): inputs = sample['input'] sdfs = sample['sdf'] known = sample['known'] hierarchy = sample['hierarchy'] input_dim = np.array(sdfs.shape[2:]) sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d) ' % (num_proc, args.max_to_process, sample['name'], input_dim[0], input_dim[1], input_dim[2])) sys.stdout.flush() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs, hierarchy, args.num_hierarchy_levels, args.truncation, args.use_loss_masking, known) hierarchy_factor = pow(2, args.num_hierarchy_levels - 1) model.update_sizes(input_dim, input_dim // hierarchy_factor) try: if not args.cpu: inputs[1] = inputs[1].cuda() target_for_sdf = target_for_sdf.cuda() for h in range(len(target_for_occs)): target_for_occs[h] = target_for_occs[h].cuda() target_for_hier[h] = target_for_hier[h].cuda() output_sdf, output_occs = model(inputs, loss_weights) except: print('exception at %s' % sample['name']) gc.collect() missing.extend(sample['name']) continue # remove padding dims = sample['orig_dims'][0] mask = (output_sdf[0][:, 0] < dims[0]) & (output_sdf[0][:, 1] < dims[1]) & (output_sdf[0][:, 2] < dims[2]) output_sdf[0] = output_sdf[0][mask] output_sdf[1] = output_sdf[1][mask] for h in range(len(output_occs)): dims = target_for_occs[h].shape[2:] mask = (output_occs[h][0][:, 0] < dims[0]) & (output_occs[h][0][:, 1] < dims[1]) & ( output_occs[h][0][:, 2] < dims[2]) output_occs[h][0] = output_occs[h][0][mask] output_occs[h][1] = output_occs[h][1][mask] # save prediction files data_util.save_predictions_to_file( output_sdf[0][:, :3].cpu().numpy(), output_sdf[1].cpu().numpy(), os.path.join(output_pred, sample['name'][0] + '.pred')) try: pred_occs = [None] * args.num_hierarchy_levels for h in range(args.num_hierarchy_levels): pred_occs[h] = [None] if len(output_occs[h][0]) == 0: continue locs = output_occs[h][0][:, :-1] vals = torch.nn.Sigmoid()( output_occs[h][1][:, 0].detach()) > 0.5 pred_occs[h][0] = locs[vals.view(-1)] except: print('exception at %s' % sample['name']) gc.collect() missing.extend(sample['name']) continue num_proc += 1 if num_vis < num_to_vis: num = min(num_to_vis - num_vis, 1) vis_pred_sdf = [None] * num if len(output_sdf[0]) > 0: for b in range(num): mask = output_sdf[0][:, -1] == b if len(mask) > 0: vis_pred_sdf[b] = [ output_sdf[0][mask].cpu().numpy(), output_sdf[1][mask].squeeze().cpu().numpy() ] inputs = [inputs[0].numpy(), inputs[1].cpu().numpy()] data_util.save_predictions( output_vis, np.arange(num), inputs, target_for_sdf.cpu().numpy(), [x.cpu().numpy() for x in target_for_occs], vis_pred_sdf, pred_occs, sample['world2grid'], args.vis_dfs, args.truncation) num_vis += 1 gc.collect() sys.stdout.write('\n') print('missing', missing)
def train(epoch, iter, dataloader, output_save): global STEP_COUNTER start_time = time.time() train_losses = [[] for i in range(args.num_hierarchy_levels + 2)] train_bceloss = [] train_sdfloss = [] train_precision = [] train_recall = [] train_f1score = [] model.train() start = time.time() num_batches = len(dataloader) for t, sample in enumerate(dataloader): loss_weights = get_loss_weights(iter, args.num_hierarchy_levels, args.num_iters_per_level, args.weight_sdf_loss) if epoch == args.start_epoch and t == 0: print('-------------- epoch %d -------------]' % (epoch)) sdfs = sample['sdf'] if sdfs.shape[0] < args.batch_size: continue # maintain same batch size for training inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() hierarchy[h].unsqueeze_(1) if args.use_loss_masking: known = known.cuda() inputs = inputs.cuda() sdfs = sdfs.cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs, hierarchy, args.num_hierarchy_levels, args.truncation, args.use_loss_masking, known, flipped=args.flipped) optimizer.zero_grad() output_sdf, output_occs = model(inputs, loss_weights) # loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known) bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_nosdf( output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs, args.use_loss_masking, known, flipped=args.flipped) # bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_nosurf(output_sdf, output_occs, target_for_sdf, target_for_occs, # target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs, # args.use_loss_masking, known, flipped=args.flipped) loss = bce_loss * 2 # + sdf_loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) optimizer.step() train_losses[0].append(loss.item()) if last_loss > 0: train_bceloss.append(last_loss) train_sdfloss.append(bce_loss.item()) output_visual = (output_save and t + 1 == num_batches and loss_weights[-2] >= 1) compute_pred_occs = (iter % 20 == 0) or output_visual # if compute_pred_occs: pred_occs = [None] * args.num_hierarchy_levels vis_occs = [None] * args.num_hierarchy_levels for h in range(args.num_hierarchy_levels): train_losses[h + 1].append(losses[h]) # if len(output_occs[-1]) is not 0: if loss_weights[-2] >= 1 is not None: output_occs[-1].detach() # occ pred_occ = output_occs[-1][:, 0].squeeze() ## sdf # pred_occ = pred_occ.abs() < (args.truncation / 2) # occ pred_occ = pred_occ > 0 target_occ = target_for_occs[-1].squeeze() input_occ = inputs.detach().squeeze() precision, recall, f1score = loss_util.compute_dense_occ_accuracy( input_occ, target_occ, pred_occ, truncation=args.truncation) else: precision, recall, f1score = 0, 0, 0 train_precision.append(precision) train_recall.append(recall) train_f1score.append(f1score) train_losses[args.num_hierarchy_levels + 1].append(losses[-1]) STEP_COUNTER += 1 iter += 1 WRITER.add_scalar("loss/train", loss.item(), STEP_COUNTER) if args.scheduler_step_size == 0: scheduler.step() else: args.scheduler_step_size -= 1 WRITER.add_scalar("epoch_loss/train/total", np.mean(train_losses[0]), epoch) if len(train_bceloss) > 0: WRITER.add_scalar("epoch_loss/train/bce", np.mean(train_bceloss), epoch) WRITER.add_scalar("epoch_loss/train/precision", np.mean(train_precision), epoch) WRITER.add_scalar("epoch_loss/train/recall", np.mean(train_recall), epoch) WRITER.add_scalar("epoch_loss/train/f1", np.mean(train_f1score), epoch) print(" test epoch {} used {} seconds".format(epoch, time.time() - start_time)) return loss_weights