def test(loss_weights, dataloader, output_vis, num_to_vis): model.eval() num_vis = 0 num_batches = len(dataloader) with torch.no_grad(): gc.collect() for t, sample in enumerate(dataloader): inputs = sample['input'] #print(len(inputs), sample.keys()) #print(inputs[0].shape, inputs[1].shape) input_dim = np.array(sample['sdf'].shape[2:]) sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d) ' % (num_vis, num_to_vis, sample['name'], input_dim[0], input_dim[1], input_dim[2])) sys.stdout.flush() hierarchy_factor = pow(2, args.num_hierarchy_levels - 1) model.update_sizes(input_dim, input_dim // hierarchy_factor) try: if not args.cpu: inputs[1] = inputs[1].cuda() output_sdf, output_occs = model(inputs, loss_weights) except: print('exception at %s' % sample['name']) gc.collect() continue # remove padding dims = sample['orig_dims'][0] mask = (output_sdf[0][:, 0] < dims[0]) & (output_sdf[0][:, 1] < dims[1]) & (output_sdf[0][:, 2] < dims[2]) output_sdf[0] = output_sdf[0][mask] output_sdf[1] = output_sdf[1][mask] mask = (inputs[0][:, 0] < dims[0]) & ( inputs[0][:, 1] < dims[1]) & (inputs[0][:, 2] < dims[2]) inputs[0] = inputs[0][mask] inputs[1] = inputs[1][mask] vis_pred_sdf = [None] if len(output_sdf[0]) > 0: vis_pred_sdf[0] = [ output_sdf[0].cpu().numpy(), output_sdf[1].squeeze().cpu().numpy() ] inputs = [inputs[0].numpy(), inputs[1].cpu().numpy()] target_sdf_save = sample['sdf'].numpy() data_util.save_predictions(output_vis, sample['name'], inputs, target_sdf_save, None, vis_pred_sdf, None, sample['world2grid'], args.truncation) num_vis += 1 if num_vis >= num_to_vis: break sys.stdout.write('\n')
def test(dataloader, output_vis, num_to_vis): model.eval() chunk_dim = args.input_dim args.max_input_height = chunk_dim[0] if args.stride == 0: args.stride = chunk_dim[1] pad = 2 num_proc = 0 num_vis = 0 num_batches = len(dataloader) print('starting testing...') with torch.no_grad(): for t, sample in enumerate(dataloader): inputs = sample['input'] sdfs = sample['sdf'] mask = sample['mask'] colors = sample['colors'] max_input_dim = np.array(sdfs.shape[2:]) if args.max_input_height > 0 and max_input_dim[ UP_AXIS] > args.max_input_height: max_input_dim[UP_AXIS] = args.max_input_height inputs = inputs[:, :, :args.max_input_height] if mask is not None: mask = mask[:, :, :args.max_input_height] if sdfs is not None: sdfs = sdfs[:, :, :args.max_input_height] if colors is not None: colors = colors[:, :args.max_input_height] sys.stdout.write( '\r[ %d | %d ] %s (%d, %d, %d) ' % (num_proc, args.max_to_process, sample['name'], max_input_dim[0], max_input_dim[1], max_input_dim[2])) output_colors = torch.zeros(colors.shape) output_sdfs = torch.zeros(sdfs.shape) output_norms = torch.zeros(sdfs.shape) output_occs = torch.zeros(sdfs.shape, dtype=torch.uint8) # chunk up the scene chunk_input = torch.ones(1, args.input_nf, chunk_dim[0], chunk_dim[1], chunk_dim[2]).cuda() chunk_mask = torch.ones(1, 1, chunk_dim[0], chunk_dim[1], chunk_dim[2]).cuda() chunk_target_sdf = torch.ones(1, 1, chunk_dim[0], chunk_dim[1], chunk_dim[2]).cuda() chunk_target_colors = torch.zeros(1, chunk_dim[0], chunk_dim[1], chunk_dim[2], 3, dtype=torch.uint8).cuda() for y in range(0, max_input_dim[1], args.stride): for x in range(0, max_input_dim[2], args.stride): chunk_input_mask = torch.abs( inputs[:, :, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]]) < args.truncation if torch.sum(chunk_input_mask).item() == 0: continue sys.stdout.write( '\r[ %d | %d ] %s (%d, %d, %d) (%d, %d) ' % (num_proc, args.max_to_process, sample['name'], max_input_dim[0], max_input_dim[1], max_input_dim[2], y, x)) fill_dim = [ min(sdfs.shape[2], chunk_dim[0]), min(sdfs.shape[3] - y, chunk_dim[1]), min(sdfs.shape[4] - x, chunk_dim[2]) ] chunk_target_sdf.fill_(float('inf')) chunk_target_colors.fill_(0) chunk_input[:, 0].fill_(-args.truncation) chunk_input[:, 1:].fill_(0) chunk_mask.fill_(0) chunk_input[:, :, :fill_dim[0], :fill_dim[1], : fill_dim[2]] = inputs[:, :, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]] chunk_mask[:, :, :fill_dim[0], :fill_dim[1], : fill_dim[2]] = mask[:, :, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]] chunk_target_sdf[:, :, :fill_dim[0], :fill_dim[1], : fill_dim[2]] = sdfs[:, :, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]] chunk_target_colors[:, :fill_dim[0], :fill_dim[ 1], :fill_dim[2], :] = colors[:, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]] target_for_sdf, target_for_colors = loss_util.compute_targets( chunk_target_sdf.cuda(), args.truncation, False, None, chunk_target_colors.cuda()) output_occ = None output_occ, output_sdf, output_color = model( chunk_input, chunk_mask, pred_sdf=[True, True], pred_color=args.weight_color_loss > 0) if output_occ is not None: occ = torch.nn.Sigmoid()(output_occ.detach()) > 0.5 locs = torch.nonzero((torch.abs( output_sdf.detach()[:, 0]) < args.truncation) & occ[:, 0]).cpu() else: locs = torch.nonzero( torch.abs(output_sdf[:, 0]) < args.truncation).cpu() locs = torch.cat([locs[:, 1:], locs[:, :1]], 1) output_sdf = [ locs, output_sdf[locs[:, -1], :, locs[:, 0], locs[:, 1], locs[:, 2]].detach().cpu() ] if args.weight_color_loss == 0: output_color = None else: output_color = [ locs, output_color[locs[:, -1], :, locs[:, 0], locs[:, 1], locs[:, 2]] ] output_locs = output_sdf[0] + torch.LongTensor( [0, y, x, 0]) if args.stride < chunk_dim[1]: min_dim = [0, y, x] max_dim = [ 0 + chunk_dim[0], y + chunk_dim[1], x + chunk_dim[2] ] if y > 0: min_dim[1] += pad if x > 0: min_dim[2] += pad if y + chunk_dim[1] < max_input_dim[1]: max_dim[1] -= pad if x + chunk_dim[2] < max_input_dim[2]: max_dim[2] -= pad for k in range(3): max_dim[k] = min(max_dim[k], sdfs.shape[k + 2]) outmask = (output_locs[:, 0] >= min_dim[0]) & ( output_locs[:, 1] >= min_dim[1]) & (output_locs[:, 2] >= min_dim[2]) & ( output_locs[:, 0] < max_dim[0]) & ( output_locs[:, 1] < max_dim[1]) & ( output_locs[:, 2] < max_dim[2]) else: outmask = ( output_locs[:, 0] < output_sdfs.shape[2]) & ( output_locs[:, 1] < output_sdfs.shape[3]) & ( output_locs[:, 2] < output_sdfs.shape[4]) output_locs = output_locs[outmask] output_sdf = [ output_sdf[0][outmask], output_sdf[1][outmask] ] if output_color is not None: output_color = [ output_color[0][outmask], output_color[1][outmask] ] output_color = (output_color[1] + 1) * 0.5 output_colors[ 0, output_locs[:, 0], output_locs[:, 1], output_locs[:, 2], :] += output_color.detach().cpu() if output_occ is not None: output_occs[:, :, :chunk_dim[0], y:y + chunk_dim[1], x:x + chunk_dim[2]] = occ[:, :, :fill_dim[ 0], :fill_dim[1], :fill_dim[2]] output_sdfs[ 0, 0, output_locs[:, 0], output_locs[:, 1], output_locs[:, 2]] += output_sdf[1][:, 0].detach().cpu() output_norms[0, 0, output_locs[:, 0], output_locs[:, 1], output_locs[:, 2]] += 1 # normalize mask = output_norms > 0 output_norms = output_norms[mask] output_sdfs[mask] = output_sdfs[mask] / output_norms output_sdfs[~mask] = -float('inf') mask = mask.view(1, mask.shape[2], mask.shape[3], mask.shape[4]) output_colors[ mask, :] = output_colors[mask, :] / output_norms.view(-1, 1) output_colors = torch.clamp(output_colors * 255, 0, 255) sdfs = torch.clamp(sdfs, -args.truncation, args.truncation) output_sdfs = torch.clamp(output_sdfs, -args.truncation, args.truncation) if num_vis < num_to_vis: inputs = inputs.cpu().numpy() locs = torch.nonzero( torch.abs(output_sdfs[0, 0]) < args.truncation) vis_pred_sdf = [None] vis_pred_color = [None] sdf_vals = output_sdfs[0, 0, locs[:, 0], locs[:, 1], locs[:, 2]].view(-1) vis_pred_sdf[0] = [locs.cpu().numpy(), sdf_vals.cpu().numpy()] if args.weight_color_loss > 0: vals = output_colors[0, locs[:, 0], locs[:, 1], locs[:, 2]] vis_pred_color[0] = vals.cpu().numpy() if output_occs is not None: pred_occ = output_occs.cpu().numpy().astype(np.float32) data_util.save_predictions(output_vis, np.arange(1), sample['name'], inputs, sdfs.cpu().numpy(), colors.cpu().numpy(), None, None, vis_pred_sdf, vis_pred_color, None, None, sample['world2grid'], args.truncation, args.color_space, pred_occ=pred_occ) num_vis += 1 num_proc += 1 gc.collect() sys.stdout.write('\n')
def test(epoch, iter, loss_weights, dataloader, log_file, output_save): val_losses = [[] for i in range(args.num_hierarchy_levels + 2)] val_l1preds = [] val_l1tgts = [] val_ious = [[] for i in range(args.num_hierarchy_levels)] model.eval() #start = time.time() num_batches = len(dataloader) with torch.no_grad(): for t, sample in enumerate(dataloader): sdfs = sample['sdf'] if sdfs.shape[0] < args.batch_size: continue # maintain same batch size inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() if args.use_loss_masking: known = known.cuda() inputs[0] = inputs[0].cuda() inputs[1] = inputs[1].cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs.cuda(), hierarchy, args.num_hierarchy_levels, args.truncation, args.use_loss_masking, known) output_sdf, output_occs = model(inputs, loss_weights) loss, losses = loss_util.compute_loss( output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known) output_visual = output_save and t + 2 == num_batches compute_pred_occs = (t % 20 == 0) or output_visual if compute_pred_occs: pred_occs = [None] * args.num_hierarchy_levels for h in range(args.num_hierarchy_levels): factor = 2**(args.num_hierarchy_levels - h - 1) pred_occs[h] = [None] * args.batch_size if len(output_occs[h][0]) == 0: continue for b in range(args.batch_size): batchmask = output_occs[h][0][:, -1] == b locs = output_occs[h][0][batchmask][:, :-1] vals = torch.nn.Sigmoid()( output_occs[h][1][:, 0].detach()[batchmask]) > 0.5 pred_occs[h][b] = locs[vals.view(-1)] val_losses[0].append(loss.item()) for h in range(args.num_hierarchy_levels): val_losses[h + 1].append(losses[h]) target = target_for_occs[h].byte() if compute_pred_occs: iou = loss_util.compute_iou_sparse_dense( pred_occs[h], target, args.use_loss_masking) val_ious[h].append(iou) val_losses[args.num_hierarchy_levels + 1].append(losses[-1]) if len(output_sdf[0]) > 0: output_sdf = [output_sdf[0].detach(), output_sdf[1].detach()] if loss_weights[-1] > 0 and t % 20 == 0: val_l1preds.append( loss_util.compute_l1_predsurf_sparse_dense( output_sdf[0], output_sdf[1], target_for_sdf, None, False, args.use_loss_masking, known).item()) val_l1tgts.append( loss_util.compute_l1_tgtsurf_sparse_dense( output_sdf[0], output_sdf[1], target_for_sdf, args.truncation, args.use_loss_masking, known)) if output_visual: vis_pred_sdf = [None] * args.batch_size if len(output_sdf[0]) > 0: for b in range(args.batch_size): mask = output_sdf[0][:, -1] == b if len(mask) > 0: vis_pred_sdf[b] = [ output_sdf[0][mask].cpu().numpy(), output_sdf[1][mask].squeeze().cpu().numpy() ] inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()] for h in range(args.num_hierarchy_levels): for b in range(args.batch_size): if pred_occs[h][b] is not None: pred_occs[h][b] = pred_occs[h][b].cpu().numpy() data_util.save_predictions( os.path.join(args.save, 'iter%d-epoch%d' % (iter, epoch), 'val'), sample['name'], inputs, target_for_sdf.cpu().numpy(), [x.cpu().numpy() for x in target_for_occs], vis_pred_sdf, pred_occs, sample['world2grid'], args.vis_dfs, args.truncation) #took = time.time() - start return val_losses, val_l1preds, val_l1tgts, val_ious
def train(epoch, iter, dataloader, log_file, output_save): train_losses = [[] for i in range(args.num_hierarchy_levels + 2)] train_l1preds = [] train_l1tgts = [] train_ious = [[] for i in range(args.num_hierarchy_levels)] model.train() start = time.time() if args.scheduler_step_size == 0: scheduler.step() num_batches = len(dataloader) for t, sample in enumerate(dataloader): loss_weights = get_loss_weights(iter, args.num_hierarchy_levels, args.num_iters_per_level, args.weight_sdf_loss) if epoch == args.start_epoch and t == 0: print('[iter %d/epoch %d] loss_weights' % (iter, epoch), loss_weights) sdfs = sample['sdf'] if sdfs.shape[0] < args.batch_size: continue # maintain same batch size for training inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() if args.use_loss_masking: known = known.cuda() inputs[0] = inputs[0].cuda() inputs[1] = inputs[1].cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs.cuda(), hierarchy, args.num_hierarchy_levels, args.truncation, args.use_loss_masking, known) optimizer.zero_grad() output_sdf, output_occs = model(inputs, loss_weights) loss, losses = loss_util.compute_loss( output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known) loss.backward() optimizer.step() output_visual = output_save and t + 2 == num_batches compute_pred_occs = (iter % 20 == 0) or output_visual if compute_pred_occs: pred_occs = [None] * args.num_hierarchy_levels for h in range(args.num_hierarchy_levels): factor = 2**(args.num_hierarchy_levels - h - 1) pred_occs[h] = [None] * args.batch_size if len(output_occs[h][0]) == 0: continue output_occs[h][1] = torch.nn.Sigmoid()( output_occs[h][1][:, 0].detach()) > 0.5 for b in range(args.batch_size): batchmask = output_occs[h][0][:, -1] == b locs = output_occs[h][0][batchmask][:, :-1] vals = output_occs[h][1][batchmask] pred_occs[h][b] = locs[vals.view(-1)] train_losses[0].append(loss.item()) for h in range(args.num_hierarchy_levels): train_losses[h + 1].append(losses[h]) target = target_for_occs[h].byte() if compute_pred_occs: iou = loss_util.compute_iou_sparse_dense( pred_occs[h], target, args.use_loss_masking) train_ious[h].append(iou) train_losses[args.num_hierarchy_levels + 1].append(losses[-1]) if len(output_sdf[0]) > 0: output_sdf = [output_sdf[0].detach(), output_sdf[1].detach()] if loss_weights[-1] > 0 and iter % 20 == 0: train_l1preds.append( loss_util.compute_l1_predsurf_sparse_dense( output_sdf[0], output_sdf[1], target_for_sdf, None, False, args.use_loss_masking, known).item()) train_l1tgts.append( loss_util.compute_l1_tgtsurf_sparse_dense( output_sdf[0], output_sdf[1], target_for_sdf, args.truncation, args.use_loss_masking, known)) iter += 1 if args.scheduler_step_size > 0 and iter % args.scheduler_step_size == 0: scheduler.step() if iter % 20 == 0: took = time.time() - start print_log(log_file, epoch, iter, train_losses, train_l1preds, train_l1tgts, train_ious, None, None, None, None, took) if iter % 2000 == 0: torch.save( { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join(args.save, 'model-iter%s-epoch%s.pth' % (iter, epoch))) if output_visual: vis_pred_sdf = [None] * args.batch_size if len(output_sdf[0]) > 0: for b in range(args.batch_size): mask = output_sdf[0][:, -1] == b if len(mask) > 0: vis_pred_sdf[b] = [ output_sdf[0][mask].cpu().numpy(), output_sdf[1][mask].squeeze().cpu().numpy() ] inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()] for h in range(args.num_hierarchy_levels): for b in range(args.batch_size): if pred_occs[h][b] is not None: pred_occs[h][b] = pred_occs[h][b].cpu().numpy() data_util.save_predictions( os.path.join(args.save, 'iter%d-epoch%d' % (iter, epoch), 'train'), sample['name'], inputs, target_for_sdf.cpu().numpy(), [x.cpu().numpy() for x in target_for_occs], vis_pred_sdf, pred_occs, sample['world2grid'].numpy(), args.vis_dfs, args.truncation) return train_losses, train_l1preds, train_l1tgts, train_ious, iter, loss_weights
def test(loss_weights, dataloader, output_vis, num_to_vis): model.eval() num_vis = 0 num_batches = len(dataloader) print("num_batches: ", num_batches) with torch.no_grad(): for t, sample in enumerate(dataloader): inputs = sample['input'] print("input_dim: ", np.array(sample['sdf'].shape)) input_dim = np.array(sample['sdf'].shape[2:]) print(input_dim) sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d) ' % (num_vis, num_to_vis, sample['name'], input_dim[0], input_dim[1], input_dim[2])) sys.stdout.flush() # hierarchy_factor = pow(2, args.num_hierarchy_levels-1) # model.update_sizes(input_dim, input_dim // hierarchy_factor) sdfs = sample['sdf'] if sdfs.shape[0] < args.batch_size: continue # maintain same batch size for training inputs = sample['input'] known = sample['known'] hierarchy = sample['hierarchy'] for h in range(len(hierarchy)): hierarchy[h] = hierarchy[h].cuda() hierarchy[h].unsqueeze_(1) if args.use_loss_masking: known = known.cuda() inputs = inputs.cuda() sdfs = sdfs.cuda() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs, hierarchy, args.num_hierarchy_levels, args.truncation, args.use_loss_masking, known, flipped=args.flipped) start_time = time.time() try: output_sdf, output_occs = model(inputs, loss_weights) except: print('exception at %s' % sample['name']) gc.collect() continue end_time = time.time() print('TIME: %.4fs' % (end_time - start_time)) # remove padding # dims = sample['orig_dims'][0] # mask = (output_sdf[0][:,0] < dims[0]) & (output_sdf[0][:,1] < dims[1]) & (output_sdf[0][:,2] < dims[2]) # output_sdf[0] = output_sdf[0][mask] # output_sdf[1] = output_sdf[1][mask] # mask = (inputs[0][:,0] < dims[0]) & (inputs[0][:,1] < dims[1]) & (inputs[0][:,2] < dims[2]) # inputs[0] = inputs[0][mask] # inputs[1] = inputs[1][mask] # # vis_pred_sdf = [None] # # if len(output_sdf[0]) > 0: # # vis_pred_sdf[0] = [output_sdf[0].cpu().numpy(), output_sdf[1].squeeze().cpu().numpy()] # inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()] # vis occ & sdf pred_occs = [None] * args.num_hierarchy_levels for h in range(args.num_hierarchy_levels): factor = 2**(args.num_hierarchy_levels - h - 1) pred_occs[h] = [None] * args.batch_size if len(output_occs[h][0]) == 0: continue # filter: occ > 0 for b in range(args.batch_size): batchmask = output_occs[h][0][:, -1] == b locs = output_occs[h][0][batchmask][:, :-1] vals = output_occs[h][1][batchmask] occ_mask = torch.nn.Sigmoid()(vals[:, 0].detach()) > 0.5 if args.flipped: pred_occs[h][b] = [ locs[occ_mask.view(-1)].cpu().numpy(), vals[occ_mask.view(-1)].cpu().numpy() ] else: pred_occs[h][b] = locs[occ_mask.view(-1)].cpu().numpy() vis_pred_sdf = [None] * args.batch_size if len(output_sdf[0]) > 0: for b in range(args.batch_size): mask = output_sdf[0][:, -1] == b if len(mask) > 0: vis_pred_sdf[b] = [ output_sdf[0][mask].cpu().numpy(), output_sdf[1][mask].squeeze().cpu().numpy() ] data_util.save_predictions(output_vis, sample['name'], inputs, target_for_sdf.cpu().numpy(), None, vis_pred_sdf, pred_occs, sample['world2grid'], args.truncation, flipped=args.flipped) num_vis += 1 if num_vis >= num_to_vis: break sys.stdout.write('\n')
def test(dataloader, output_vis, num_to_vis): model.eval() hierarchy_factor = 4 num_proc = 0 num_vis = 0 num_batches = len(dataloader) with torch.no_grad(): for t, sample in enumerate(dataloader): inputs = sample['input'] mask = sample['mask'] max_input_dim = np.array(inputs.shape[2:]) if args.max_input_height > 0 and max_input_dim[UP_AXIS] > args.max_input_height: max_input_dim[UP_AXIS] = args.max_input_height mask_input = inputs[0][:,UP_AXIS] < args.max_input_height inputs = inputs[:,:,:args.max_input_height] if mask is not None: mask = mask[:,:,:args.max_input_height] max_input_dim = ((max_input_dim + hierarchy_factor - 1) // hierarchy_factor) * hierarchy_factor sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d) ' % (num_proc, args.max_to_process, sample['name'], max_input_dim[0], max_input_dim[1], max_input_dim[2])) # pad target to max_input_dim padded = torch.zeros(1, inputs.shape[1], max_input_dim[0], max_input_dim[1], max_input_dim[2]) padded[:,0].fill_(-args.truncation) padded[:, :, :min(args.max_input_height,inputs.shape[2]), :inputs.shape[3], :inputs.shape[4]] = inputs[:, :, :args.max_input_height, :, :] inputs = padded padded_mask = torch.zeros(1, 1, max_input_dim[0], max_input_dim[1], max_input_dim[2]) padded_mask[:, :, :min(args.max_input_height,mask.shape[2]), :mask.shape[3], :mask.shape[4]] = mask[:, :, :args.max_input_height, :, :] mask = padded_mask model.update_sizes(max_input_dim) output_occ = None try: if not args.cpu: inputs = inputs.cuda() mask = mask.cuda() output_occ, output_sdf, output_color = model(inputs, mask, pred_sdf=[True, True], pred_color=args.weight_color_loss > 0) if output_occ is not None: occ = torch.nn.Sigmoid()(output_occ.detach()) > 0.5 locs = torch.nonzero((torch.abs(output_sdf.detach()[:,0]) < args.truncation) & occ[:,0]).cpu() else: locs = torch.nonzero(torch.abs(output_sdf[:,0]) < args.truncation).cpu() locs = torch.cat([locs[:,1:], locs[:,:1]],1) output_sdf = [locs, output_sdf[locs[:,-1],:,locs[:,0],locs[:,1],locs[:,2]].detach().cpu()] if args.weight_color_loss == 0: output_color = None else: output_color = [locs, output_color[locs[:,-1],:,locs[:,0],locs[:,1],locs[:,2]]] if output_color is not None: output_color = (output_color[1] + 1) * 0.5 except: print('exception') gc.collect() continue num_proc += 1 if num_vis < num_to_vis: vis_pred_sdf = [None] vis_pred_color = [None] if len(output_sdf[0]) > 0: if output_color is not None: # convert colors to vec3uc output_color = torch.clamp(output_color.detach() * 255, 0, 255) vis_pred_sdf[0] = [output_sdf[0].cpu().numpy(), output_sdf[1].squeeze().cpu().numpy()] vis_pred_color[0] = output_color.cpu().numpy() vis_pred_images_color = None vis_tgt_images_color = None data_util.save_predictions(output_vis, np.arange(1), sample['name'], inputs.cpu().numpy(), None, None, None, vis_tgt_images_color, vis_pred_sdf, vis_pred_color, None, vis_pred_images_color, sample['world2grid'], args.truncation, args.color_space) num_vis += 1 gc.collect() sys.stdout.write('\n')
def test(loss_weights, dataloader, output_pred, output_vis, num_to_vis): model.eval() missing = [] num_proc = 0 num_vis = 0 num_batches = len(dataloader) with torch.no_grad(): for t, sample in enumerate(dataloader): inputs = sample['input'] sdfs = sample['sdf'] known = sample['known'] hierarchy = sample['hierarchy'] input_dim = np.array(sdfs.shape[2:]) sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d) ' % (num_proc, args.max_to_process, sample['name'], input_dim[0], input_dim[1], input_dim[2])) sys.stdout.flush() target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets( sdfs, hierarchy, args.num_hierarchy_levels, args.truncation, args.use_loss_masking, known) hierarchy_factor = pow(2, args.num_hierarchy_levels - 1) model.update_sizes(input_dim, input_dim // hierarchy_factor) try: if not args.cpu: inputs[1] = inputs[1].cuda() target_for_sdf = target_for_sdf.cuda() for h in range(len(target_for_occs)): target_for_occs[h] = target_for_occs[h].cuda() target_for_hier[h] = target_for_hier[h].cuda() output_sdf, output_occs = model(inputs, loss_weights) except: print('exception at %s' % sample['name']) gc.collect() missing.extend(sample['name']) continue # remove padding dims = sample['orig_dims'][0] mask = (output_sdf[0][:, 0] < dims[0]) & (output_sdf[0][:, 1] < dims[1]) & (output_sdf[0][:, 2] < dims[2]) output_sdf[0] = output_sdf[0][mask] output_sdf[1] = output_sdf[1][mask] for h in range(len(output_occs)): dims = target_for_occs[h].shape[2:] mask = (output_occs[h][0][:, 0] < dims[0]) & (output_occs[h][0][:, 1] < dims[1]) & ( output_occs[h][0][:, 2] < dims[2]) output_occs[h][0] = output_occs[h][0][mask] output_occs[h][1] = output_occs[h][1][mask] # save prediction files data_util.save_predictions_to_file( output_sdf[0][:, :3].cpu().numpy(), output_sdf[1].cpu().numpy(), os.path.join(output_pred, sample['name'][0] + '.pred')) try: pred_occs = [None] * args.num_hierarchy_levels for h in range(args.num_hierarchy_levels): pred_occs[h] = [None] if len(output_occs[h][0]) == 0: continue locs = output_occs[h][0][:, :-1] vals = torch.nn.Sigmoid()( output_occs[h][1][:, 0].detach()) > 0.5 pred_occs[h][0] = locs[vals.view(-1)] except: print('exception at %s' % sample['name']) gc.collect() missing.extend(sample['name']) continue num_proc += 1 if num_vis < num_to_vis: num = min(num_to_vis - num_vis, 1) vis_pred_sdf = [None] * num if len(output_sdf[0]) > 0: for b in range(num): mask = output_sdf[0][:, -1] == b if len(mask) > 0: vis_pred_sdf[b] = [ output_sdf[0][mask].cpu().numpy(), output_sdf[1][mask].squeeze().cpu().numpy() ] inputs = [inputs[0].numpy(), inputs[1].cpu().numpy()] data_util.save_predictions( output_vis, np.arange(num), inputs, target_for_sdf.cpu().numpy(), [x.cpu().numpy() for x in target_for_occs], vis_pred_sdf, pred_occs, sample['world2grid'], args.vis_dfs, args.truncation) num_vis += 1 gc.collect() sys.stdout.write('\n') print('missing', missing)