Example #1
0
def test(loss_weights, dataloader, output_vis, num_to_vis):
    model.eval()

    num_vis = 0
    num_batches = len(dataloader)
    with torch.no_grad():
        gc.collect()
        for t, sample in enumerate(dataloader):
            inputs = sample['input']
            #print(len(inputs), sample.keys())
            #print(inputs[0].shape, inputs[1].shape)
            input_dim = np.array(sample['sdf'].shape[2:])
            sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d)    ' %
                             (num_vis, num_to_vis, sample['name'],
                              input_dim[0], input_dim[1], input_dim[2]))
            sys.stdout.flush()
            hierarchy_factor = pow(2, args.num_hierarchy_levels - 1)
            model.update_sizes(input_dim, input_dim // hierarchy_factor)
            try:
                if not args.cpu:
                    inputs[1] = inputs[1].cuda()
                output_sdf, output_occs = model(inputs, loss_weights)
            except:
                print('exception at %s' % sample['name'])
                gc.collect()
                continue

            # remove padding
            dims = sample['orig_dims'][0]
            mask = (output_sdf[0][:, 0] <
                    dims[0]) & (output_sdf[0][:, 1] <
                                dims[1]) & (output_sdf[0][:, 2] < dims[2])
            output_sdf[0] = output_sdf[0][mask]
            output_sdf[1] = output_sdf[1][mask]
            mask = (inputs[0][:, 0] < dims[0]) & (
                inputs[0][:, 1] < dims[1]) & (inputs[0][:, 2] < dims[2])
            inputs[0] = inputs[0][mask]
            inputs[1] = inputs[1][mask]
            vis_pred_sdf = [None]
            if len(output_sdf[0]) > 0:
                vis_pred_sdf[0] = [
                    output_sdf[0].cpu().numpy(),
                    output_sdf[1].squeeze().cpu().numpy()
                ]
            inputs = [inputs[0].numpy(), inputs[1].cpu().numpy()]
            target_sdf_save = sample['sdf'].numpy()
            data_util.save_predictions(output_vis, sample['name'], inputs,
                                       target_sdf_save, None, vis_pred_sdf,
                                       None, sample['world2grid'],
                                       args.truncation)
            num_vis += 1
            if num_vis >= num_to_vis:
                break
    sys.stdout.write('\n')
Example #2
0
def test(dataloader, output_vis, num_to_vis):
    model.eval()

    chunk_dim = args.input_dim
    args.max_input_height = chunk_dim[0]
    if args.stride == 0:
        args.stride = chunk_dim[1]
    pad = 2

    num_proc = 0
    num_vis = 0
    num_batches = len(dataloader)
    print('starting testing...')
    with torch.no_grad():
        for t, sample in enumerate(dataloader):
            inputs = sample['input']
            sdfs = sample['sdf']
            mask = sample['mask']
            colors = sample['colors']

            max_input_dim = np.array(sdfs.shape[2:])
            if args.max_input_height > 0 and max_input_dim[
                    UP_AXIS] > args.max_input_height:
                max_input_dim[UP_AXIS] = args.max_input_height
                inputs = inputs[:, :, :args.max_input_height]
                if mask is not None:
                    mask = mask[:, :, :args.max_input_height]
                if sdfs is not None:
                    sdfs = sdfs[:, :, :args.max_input_height]
                if colors is not None:
                    colors = colors[:, :args.max_input_height]
            sys.stdout.write(
                '\r[ %d | %d ] %s (%d, %d, %d)    ' %
                (num_proc, args.max_to_process, sample['name'],
                 max_input_dim[0], max_input_dim[1], max_input_dim[2]))

            output_colors = torch.zeros(colors.shape)
            output_sdfs = torch.zeros(sdfs.shape)
            output_norms = torch.zeros(sdfs.shape)
            output_occs = torch.zeros(sdfs.shape, dtype=torch.uint8)

            # chunk up the scene
            chunk_input = torch.ones(1, args.input_nf, chunk_dim[0],
                                     chunk_dim[1], chunk_dim[2]).cuda()
            chunk_mask = torch.ones(1, 1, chunk_dim[0], chunk_dim[1],
                                    chunk_dim[2]).cuda()
            chunk_target_sdf = torch.ones(1, 1, chunk_dim[0], chunk_dim[1],
                                          chunk_dim[2]).cuda()
            chunk_target_colors = torch.zeros(1,
                                              chunk_dim[0],
                                              chunk_dim[1],
                                              chunk_dim[2],
                                              3,
                                              dtype=torch.uint8).cuda()

            for y in range(0, max_input_dim[1], args.stride):
                for x in range(0, max_input_dim[2], args.stride):
                    chunk_input_mask = torch.abs(
                        inputs[:, :, :chunk_dim[0], y:y + chunk_dim[1],
                               x:x + chunk_dim[2]]) < args.truncation
                    if torch.sum(chunk_input_mask).item() == 0:
                        continue
                    sys.stdout.write(
                        '\r[ %d | %d ] %s (%d, %d, %d) (%d, %d)    ' %
                        (num_proc, args.max_to_process, sample['name'],
                         max_input_dim[0], max_input_dim[1], max_input_dim[2],
                         y, x))

                    fill_dim = [
                        min(sdfs.shape[2], chunk_dim[0]),
                        min(sdfs.shape[3] - y, chunk_dim[1]),
                        min(sdfs.shape[4] - x, chunk_dim[2])
                    ]
                    chunk_target_sdf.fill_(float('inf'))
                    chunk_target_colors.fill_(0)
                    chunk_input[:, 0].fill_(-args.truncation)
                    chunk_input[:, 1:].fill_(0)
                    chunk_mask.fill_(0)
                    chunk_input[:, :, :fill_dim[0], :fill_dim[1], :
                                fill_dim[2]] = inputs[:, :, :chunk_dim[0],
                                                      y:y + chunk_dim[1],
                                                      x:x + chunk_dim[2]]
                    chunk_mask[:, :, :fill_dim[0], :fill_dim[1], :
                               fill_dim[2]] = mask[:, :, :chunk_dim[0],
                                                   y:y + chunk_dim[1],
                                                   x:x + chunk_dim[2]]
                    chunk_target_sdf[:, :, :fill_dim[0], :fill_dim[1], :
                                     fill_dim[2]] = sdfs[:, :, :chunk_dim[0],
                                                         y:y + chunk_dim[1],
                                                         x:x + chunk_dim[2]]
                    chunk_target_colors[:, :fill_dim[0], :fill_dim[
                        1], :fill_dim[2], :] = colors[:, :chunk_dim[0],
                                                      y:y + chunk_dim[1],
                                                      x:x + chunk_dim[2]]

                    target_for_sdf, target_for_colors = loss_util.compute_targets(
                        chunk_target_sdf.cuda(), args.truncation, False, None,
                        chunk_target_colors.cuda())

                    output_occ = None
                    output_occ, output_sdf, output_color = model(
                        chunk_input,
                        chunk_mask,
                        pred_sdf=[True, True],
                        pred_color=args.weight_color_loss > 0)

                    if output_occ is not None:
                        occ = torch.nn.Sigmoid()(output_occ.detach()) > 0.5
                        locs = torch.nonzero((torch.abs(
                            output_sdf.detach()[:, 0]) < args.truncation)
                                             & occ[:, 0]).cpu()
                    else:
                        locs = torch.nonzero(
                            torch.abs(output_sdf[:,
                                                 0]) < args.truncation).cpu()
                    locs = torch.cat([locs[:, 1:], locs[:, :1]], 1)
                    output_sdf = [
                        locs, output_sdf[locs[:, -1], :, locs[:, 0],
                                         locs[:, 1], locs[:,
                                                          2]].detach().cpu()
                    ]
                    if args.weight_color_loss == 0:
                        output_color = None
                    else:
                        output_color = [
                            locs, output_color[locs[:, -1], :, locs[:, 0],
                                               locs[:, 1], locs[:, 2]]
                        ]

                    output_locs = output_sdf[0] + torch.LongTensor(
                        [0, y, x, 0])
                    if args.stride < chunk_dim[1]:
                        min_dim = [0, y, x]
                        max_dim = [
                            0 + chunk_dim[0], y + chunk_dim[1],
                            x + chunk_dim[2]
                        ]
                        if y > 0:
                            min_dim[1] += pad
                        if x > 0:
                            min_dim[2] += pad
                        if y + chunk_dim[1] < max_input_dim[1]:
                            max_dim[1] -= pad
                        if x + chunk_dim[2] < max_input_dim[2]:
                            max_dim[2] -= pad
                        for k in range(3):
                            max_dim[k] = min(max_dim[k], sdfs.shape[k + 2])
                        outmask = (output_locs[:, 0] >= min_dim[0]) & (
                            output_locs[:, 1] >=
                            min_dim[1]) & (output_locs[:, 2] >= min_dim[2]) & (
                                output_locs[:, 0] < max_dim[0]) & (
                                    output_locs[:, 1] < max_dim[1]) & (
                                        output_locs[:, 2] < max_dim[2])
                    else:
                        outmask = (
                            output_locs[:, 0] < output_sdfs.shape[2]) & (
                                output_locs[:, 1] < output_sdfs.shape[3]) & (
                                    output_locs[:, 2] < output_sdfs.shape[4])
                    output_locs = output_locs[outmask]
                    output_sdf = [
                        output_sdf[0][outmask], output_sdf[1][outmask]
                    ]
                    if output_color is not None:
                        output_color = [
                            output_color[0][outmask], output_color[1][outmask]
                        ]
                        output_color = (output_color[1] + 1) * 0.5

                        output_colors[
                            0, output_locs[:, 0], output_locs[:, 1],
                            output_locs[:,
                                        2], :] += output_color.detach().cpu()
                    if output_occ is not None:
                        output_occs[:, :, :chunk_dim[0], y:y + chunk_dim[1],
                                    x:x + chunk_dim[2]] = occ[:, :, :fill_dim[
                                        0], :fill_dim[1], :fill_dim[2]]

                    output_sdfs[
                        0, 0, output_locs[:, 0], output_locs[:, 1],
                        output_locs[:, 2]] += output_sdf[1][:,
                                                            0].detach().cpu()
                    output_norms[0, 0, output_locs[:, 0], output_locs[:, 1],
                                 output_locs[:, 2]] += 1

            # normalize
            mask = output_norms > 0
            output_norms = output_norms[mask]
            output_sdfs[mask] = output_sdfs[mask] / output_norms
            output_sdfs[~mask] = -float('inf')
            mask = mask.view(1, mask.shape[2], mask.shape[3], mask.shape[4])
            output_colors[
                mask, :] = output_colors[mask, :] / output_norms.view(-1, 1)
            output_colors = torch.clamp(output_colors * 255, 0, 255)

            sdfs = torch.clamp(sdfs, -args.truncation, args.truncation)
            output_sdfs = torch.clamp(output_sdfs, -args.truncation,
                                      args.truncation)

            if num_vis < num_to_vis:
                inputs = inputs.cpu().numpy()
                locs = torch.nonzero(
                    torch.abs(output_sdfs[0, 0]) < args.truncation)
                vis_pred_sdf = [None]
                vis_pred_color = [None]
                sdf_vals = output_sdfs[0, 0, locs[:, 0], locs[:, 1],
                                       locs[:, 2]].view(-1)
                vis_pred_sdf[0] = [locs.cpu().numpy(), sdf_vals.cpu().numpy()]
                if args.weight_color_loss > 0:
                    vals = output_colors[0, locs[:, 0], locs[:, 1], locs[:, 2]]
                    vis_pred_color[0] = vals.cpu().numpy()
                if output_occs is not None:
                    pred_occ = output_occs.cpu().numpy().astype(np.float32)
                data_util.save_predictions(output_vis,
                                           np.arange(1),
                                           sample['name'],
                                           inputs,
                                           sdfs.cpu().numpy(),
                                           colors.cpu().numpy(),
                                           None,
                                           None,
                                           vis_pred_sdf,
                                           vis_pred_color,
                                           None,
                                           None,
                                           sample['world2grid'],
                                           args.truncation,
                                           args.color_space,
                                           pred_occ=pred_occ)
                num_vis += 1
            num_proc += 1
            gc.collect()

    sys.stdout.write('\n')
Example #3
0
def test(epoch, iter, loss_weights, dataloader, log_file, output_save):
    val_losses = [[] for i in range(args.num_hierarchy_levels + 2)]
    val_l1preds = []
    val_l1tgts = []
    val_ious = [[] for i in range(args.num_hierarchy_levels)]
    model.eval()
    #start = time.time()

    num_batches = len(dataloader)
    with torch.no_grad():
        for t, sample in enumerate(dataloader):
            sdfs = sample['sdf']
            if sdfs.shape[0] < args.batch_size:
                continue  # maintain same batch size
            inputs = sample['input']
            known = sample['known']
            hierarchy = sample['hierarchy']
            for h in range(len(hierarchy)):
                hierarchy[h] = hierarchy[h].cuda()
            if args.use_loss_masking:
                known = known.cuda()
            inputs[0] = inputs[0].cuda()
            inputs[1] = inputs[1].cuda()
            target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
                sdfs.cuda(), hierarchy, args.num_hierarchy_levels,
                args.truncation, args.use_loss_masking, known)

            output_sdf, output_occs = model(inputs, loss_weights)
            loss, losses = loss_util.compute_loss(
                output_sdf, output_occs, target_for_sdf, target_for_occs,
                target_for_hier, loss_weights, args.truncation,
                args.logweight_target_sdf, args.weight_missing_geo, inputs[0],
                args.use_loss_masking, known)

            output_visual = output_save and t + 2 == num_batches
            compute_pred_occs = (t % 20 == 0) or output_visual
            if compute_pred_occs:
                pred_occs = [None] * args.num_hierarchy_levels
                for h in range(args.num_hierarchy_levels):
                    factor = 2**(args.num_hierarchy_levels - h - 1)
                    pred_occs[h] = [None] * args.batch_size
                    if len(output_occs[h][0]) == 0:
                        continue
                    for b in range(args.batch_size):
                        batchmask = output_occs[h][0][:, -1] == b
                        locs = output_occs[h][0][batchmask][:, :-1]
                        vals = torch.nn.Sigmoid()(
                            output_occs[h][1][:, 0].detach()[batchmask]) > 0.5
                        pred_occs[h][b] = locs[vals.view(-1)]
            val_losses[0].append(loss.item())
            for h in range(args.num_hierarchy_levels):
                val_losses[h + 1].append(losses[h])
                target = target_for_occs[h].byte()
                if compute_pred_occs:
                    iou = loss_util.compute_iou_sparse_dense(
                        pred_occs[h], target, args.use_loss_masking)
                    val_ious[h].append(iou)
            val_losses[args.num_hierarchy_levels + 1].append(losses[-1])
            if len(output_sdf[0]) > 0:
                output_sdf = [output_sdf[0].detach(), output_sdf[1].detach()]
            if loss_weights[-1] > 0 and t % 20 == 0:
                val_l1preds.append(
                    loss_util.compute_l1_predsurf_sparse_dense(
                        output_sdf[0], output_sdf[1], target_for_sdf, None,
                        False, args.use_loss_masking, known).item())
                val_l1tgts.append(
                    loss_util.compute_l1_tgtsurf_sparse_dense(
                        output_sdf[0], output_sdf[1], target_for_sdf,
                        args.truncation, args.use_loss_masking, known))
            if output_visual:
                vis_pred_sdf = [None] * args.batch_size
                if len(output_sdf[0]) > 0:
                    for b in range(args.batch_size):
                        mask = output_sdf[0][:, -1] == b
                        if len(mask) > 0:
                            vis_pred_sdf[b] = [
                                output_sdf[0][mask].cpu().numpy(),
                                output_sdf[1][mask].squeeze().cpu().numpy()
                            ]
                inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()]
                for h in range(args.num_hierarchy_levels):
                    for b in range(args.batch_size):
                        if pred_occs[h][b] is not None:
                            pred_occs[h][b] = pred_occs[h][b].cpu().numpy()
                data_util.save_predictions(
                    os.path.join(args.save, 'iter%d-epoch%d' % (iter, epoch),
                                 'val'), sample['name'], inputs,
                    target_for_sdf.cpu().numpy(),
                    [x.cpu().numpy()
                     for x in target_for_occs], vis_pred_sdf, pred_occs,
                    sample['world2grid'], args.vis_dfs, args.truncation)

    #took = time.time() - start
    return val_losses, val_l1preds, val_l1tgts, val_ious
Example #4
0
def train(epoch, iter, dataloader, log_file, output_save):
    train_losses = [[] for i in range(args.num_hierarchy_levels + 2)]
    train_l1preds = []
    train_l1tgts = []
    train_ious = [[] for i in range(args.num_hierarchy_levels)]
    model.train()
    start = time.time()

    if args.scheduler_step_size == 0:
        scheduler.step()

    num_batches = len(dataloader)
    for t, sample in enumerate(dataloader):
        loss_weights = get_loss_weights(iter, args.num_hierarchy_levels,
                                        args.num_iters_per_level,
                                        args.weight_sdf_loss)
        if epoch == args.start_epoch and t == 0:
            print('[iter %d/epoch %d] loss_weights' % (iter, epoch),
                  loss_weights)

        sdfs = sample['sdf']
        if sdfs.shape[0] < args.batch_size:
            continue  # maintain same batch size for training
        inputs = sample['input']
        known = sample['known']
        hierarchy = sample['hierarchy']
        for h in range(len(hierarchy)):
            hierarchy[h] = hierarchy[h].cuda()
        if args.use_loss_masking:
            known = known.cuda()
        inputs[0] = inputs[0].cuda()
        inputs[1] = inputs[1].cuda()
        target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
            sdfs.cuda(), hierarchy, args.num_hierarchy_levels, args.truncation,
            args.use_loss_masking, known)

        optimizer.zero_grad()
        output_sdf, output_occs = model(inputs, loss_weights)
        loss, losses = loss_util.compute_loss(
            output_sdf, output_occs, target_for_sdf, target_for_occs,
            target_for_hier, loss_weights, args.truncation,
            args.logweight_target_sdf, args.weight_missing_geo, inputs[0],
            args.use_loss_masking, known)
        loss.backward()
        optimizer.step()

        output_visual = output_save and t + 2 == num_batches
        compute_pred_occs = (iter % 20 == 0) or output_visual
        if compute_pred_occs:
            pred_occs = [None] * args.num_hierarchy_levels
            for h in range(args.num_hierarchy_levels):
                factor = 2**(args.num_hierarchy_levels - h - 1)
                pred_occs[h] = [None] * args.batch_size
                if len(output_occs[h][0]) == 0:
                    continue
                output_occs[h][1] = torch.nn.Sigmoid()(
                    output_occs[h][1][:, 0].detach()) > 0.5
                for b in range(args.batch_size):
                    batchmask = output_occs[h][0][:, -1] == b
                    locs = output_occs[h][0][batchmask][:, :-1]
                    vals = output_occs[h][1][batchmask]
                    pred_occs[h][b] = locs[vals.view(-1)]
        train_losses[0].append(loss.item())
        for h in range(args.num_hierarchy_levels):
            train_losses[h + 1].append(losses[h])
            target = target_for_occs[h].byte()
            if compute_pred_occs:
                iou = loss_util.compute_iou_sparse_dense(
                    pred_occs[h], target, args.use_loss_masking)
                train_ious[h].append(iou)
        train_losses[args.num_hierarchy_levels + 1].append(losses[-1])
        if len(output_sdf[0]) > 0:
            output_sdf = [output_sdf[0].detach(), output_sdf[1].detach()]
        if loss_weights[-1] > 0 and iter % 20 == 0:
            train_l1preds.append(
                loss_util.compute_l1_predsurf_sparse_dense(
                    output_sdf[0], output_sdf[1], target_for_sdf, None, False,
                    args.use_loss_masking, known).item())
            train_l1tgts.append(
                loss_util.compute_l1_tgtsurf_sparse_dense(
                    output_sdf[0], output_sdf[1], target_for_sdf,
                    args.truncation, args.use_loss_masking, known))

        iter += 1
        if args.scheduler_step_size > 0 and iter % args.scheduler_step_size == 0:
            scheduler.step()
        if iter % 20 == 0:
            took = time.time() - start
            print_log(log_file, epoch, iter, train_losses, train_l1preds,
                      train_l1tgts, train_ious, None, None, None, None, took)
        if iter % 2000 == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                os.path.join(args.save,
                             'model-iter%s-epoch%s.pth' % (iter, epoch)))
        if output_visual:
            vis_pred_sdf = [None] * args.batch_size
            if len(output_sdf[0]) > 0:
                for b in range(args.batch_size):
                    mask = output_sdf[0][:, -1] == b
                    if len(mask) > 0:
                        vis_pred_sdf[b] = [
                            output_sdf[0][mask].cpu().numpy(),
                            output_sdf[1][mask].squeeze().cpu().numpy()
                        ]
            inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()]
            for h in range(args.num_hierarchy_levels):
                for b in range(args.batch_size):
                    if pred_occs[h][b] is not None:
                        pred_occs[h][b] = pred_occs[h][b].cpu().numpy()
            data_util.save_predictions(
                os.path.join(args.save, 'iter%d-epoch%d' % (iter, epoch),
                             'train'), sample['name'], inputs,
                target_for_sdf.cpu().numpy(),
                [x.cpu().numpy()
                 for x in target_for_occs], vis_pred_sdf, pred_occs,
                sample['world2grid'].numpy(), args.vis_dfs, args.truncation)

    return train_losses, train_l1preds, train_l1tgts, train_ious, iter, loss_weights
Example #5
0
def test(loss_weights, dataloader, output_vis, num_to_vis):
    model.eval()

    num_vis = 0
    num_batches = len(dataloader)
    print("num_batches: ", num_batches)
    with torch.no_grad():
        for t, sample in enumerate(dataloader):
            inputs = sample['input']
            print("input_dim: ", np.array(sample['sdf'].shape))
            input_dim = np.array(sample['sdf'].shape[2:])

            print(input_dim)
            sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d)    ' %
                             (num_vis, num_to_vis, sample['name'],
                              input_dim[0], input_dim[1], input_dim[2]))
            sys.stdout.flush()

            # hierarchy_factor = pow(2, args.num_hierarchy_levels-1)
            # model.update_sizes(input_dim, input_dim // hierarchy_factor)

            sdfs = sample['sdf']
            if sdfs.shape[0] < args.batch_size:
                continue  # maintain same batch size for training
            inputs = sample['input']
            known = sample['known']
            hierarchy = sample['hierarchy']
            for h in range(len(hierarchy)):
                hierarchy[h] = hierarchy[h].cuda()
                hierarchy[h].unsqueeze_(1)
            if args.use_loss_masking:
                known = known.cuda()
            inputs = inputs.cuda()
            sdfs = sdfs.cuda()
            target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
                sdfs,
                hierarchy,
                args.num_hierarchy_levels,
                args.truncation,
                args.use_loss_masking,
                known,
                flipped=args.flipped)

            start_time = time.time()
            try:
                output_sdf, output_occs = model(inputs, loss_weights)
            except:
                print('exception at %s' % sample['name'])
                gc.collect()
                continue

            end_time = time.time()
            print('TIME: %.4fs' % (end_time - start_time))

            # remove padding
            # dims = sample['orig_dims'][0]
            # mask = (output_sdf[0][:,0] < dims[0]) & (output_sdf[0][:,1] < dims[1]) & (output_sdf[0][:,2] < dims[2])
            # output_sdf[0] = output_sdf[0][mask]
            # output_sdf[1] = output_sdf[1][mask]
            # mask = (inputs[0][:,0] < dims[0]) & (inputs[0][:,1] < dims[1]) & (inputs[0][:,2] < dims[2])
            # inputs[0] = inputs[0][mask]
            # inputs[1] = inputs[1][mask]
            # # vis_pred_sdf = [None]
            # # if len(output_sdf[0]) > 0:
            # #     vis_pred_sdf[0] = [output_sdf[0].cpu().numpy(), output_sdf[1].squeeze().cpu().numpy()]
            # inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()]

            # vis occ & sdf
            pred_occs = [None] * args.num_hierarchy_levels
            for h in range(args.num_hierarchy_levels):
                factor = 2**(args.num_hierarchy_levels - h - 1)
                pred_occs[h] = [None] * args.batch_size
                if len(output_occs[h][0]) == 0:
                    continue
                # filter: occ > 0
                for b in range(args.batch_size):
                    batchmask = output_occs[h][0][:, -1] == b
                    locs = output_occs[h][0][batchmask][:, :-1]
                    vals = output_occs[h][1][batchmask]
                    occ_mask = torch.nn.Sigmoid()(vals[:, 0].detach()) > 0.5
                    if args.flipped:
                        pred_occs[h][b] = [
                            locs[occ_mask.view(-1)].cpu().numpy(),
                            vals[occ_mask.view(-1)].cpu().numpy()
                        ]
                    else:
                        pred_occs[h][b] = locs[occ_mask.view(-1)].cpu().numpy()

            vis_pred_sdf = [None] * args.batch_size
            if len(output_sdf[0]) > 0:
                for b in range(args.batch_size):
                    mask = output_sdf[0][:, -1] == b
                    if len(mask) > 0:
                        vis_pred_sdf[b] = [
                            output_sdf[0][mask].cpu().numpy(),
                            output_sdf[1][mask].squeeze().cpu().numpy()
                        ]

            data_util.save_predictions(output_vis,
                                       sample['name'],
                                       inputs,
                                       target_for_sdf.cpu().numpy(),
                                       None,
                                       vis_pred_sdf,
                                       pred_occs,
                                       sample['world2grid'],
                                       args.truncation,
                                       flipped=args.flipped)
            num_vis += 1
            if num_vis >= num_to_vis:
                break
    sys.stdout.write('\n')
Example #6
0
def test(dataloader, output_vis, num_to_vis):
    model.eval()

    hierarchy_factor = 4
    num_proc = 0
    num_vis = 0
    num_batches = len(dataloader)
    with torch.no_grad():
        for t, sample in enumerate(dataloader):            
            inputs = sample['input']
            mask = sample['mask']
            max_input_dim = np.array(inputs.shape[2:])
            if args.max_input_height > 0 and max_input_dim[UP_AXIS] > args.max_input_height:
                max_input_dim[UP_AXIS] = args.max_input_height
                mask_input = inputs[0][:,UP_AXIS] < args.max_input_height
                inputs = inputs[:,:,:args.max_input_height]
                if mask is not None:
                    mask = mask[:,:,:args.max_input_height]
            max_input_dim = ((max_input_dim + hierarchy_factor - 1) // hierarchy_factor) * hierarchy_factor
            sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d)    ' % (num_proc, args.max_to_process, sample['name'], max_input_dim[0], max_input_dim[1], max_input_dim[2]))
            # pad target to max_input_dim
            padded = torch.zeros(1, inputs.shape[1], max_input_dim[0], max_input_dim[1], max_input_dim[2])
            padded[:,0].fill_(-args.truncation)
            padded[:, :, :min(args.max_input_height,inputs.shape[2]), :inputs.shape[3], :inputs.shape[4]] = inputs[:, :, :args.max_input_height, :, :]
            inputs = padded
            padded_mask = torch.zeros(1, 1, max_input_dim[0], max_input_dim[1], max_input_dim[2])
            padded_mask[:, :, :min(args.max_input_height,mask.shape[2]), :mask.shape[3], :mask.shape[4]] = mask[:, :, :args.max_input_height, :, :]
            mask = padded_mask

            model.update_sizes(max_input_dim)  
            output_occ = None          
            try:
                if not args.cpu:
                    inputs = inputs.cuda()
                    mask = mask.cuda()
                output_occ, output_sdf, output_color = model(inputs, mask, pred_sdf=[True, True], pred_color=args.weight_color_loss > 0)
                
                if output_occ is not None:
                    occ = torch.nn.Sigmoid()(output_occ.detach()) > 0.5
                    locs = torch.nonzero((torch.abs(output_sdf.detach()[:,0]) < args.truncation) & occ[:,0]).cpu()
                else:
                    locs = torch.nonzero(torch.abs(output_sdf[:,0]) < args.truncation).cpu()
                locs = torch.cat([locs[:,1:], locs[:,:1]],1)
                output_sdf = [locs, output_sdf[locs[:,-1],:,locs[:,0],locs[:,1],locs[:,2]].detach().cpu()]
                if args.weight_color_loss == 0:
                    output_color = None
                else:
                    output_color = [locs, output_color[locs[:,-1],:,locs[:,0],locs[:,1],locs[:,2]]]
                if output_color is not None:
                    output_color = (output_color[1] + 1) * 0.5
            except:
                print('exception')
                gc.collect()
                continue

            num_proc += 1
            if num_vis < num_to_vis:
                vis_pred_sdf = [None]
                vis_pred_color = [None]
                if len(output_sdf[0]) > 0:
                    if output_color is not None: # convert colors to vec3uc
                        output_color = torch.clamp(output_color.detach() * 255, 0, 255)
                    vis_pred_sdf[0] = [output_sdf[0].cpu().numpy(), output_sdf[1].squeeze().cpu().numpy()]
                    vis_pred_color[0] = output_color.cpu().numpy()
                vis_pred_images_color = None
                vis_tgt_images_color = None
                data_util.save_predictions(output_vis, np.arange(1), sample['name'], inputs.cpu().numpy(), None, None, None, vis_tgt_images_color, vis_pred_sdf, vis_pred_color, None, vis_pred_images_color, sample['world2grid'], args.truncation, args.color_space)
                num_vis += 1
            gc.collect()
    sys.stdout.write('\n')
Example #7
0
def test(loss_weights, dataloader, output_pred, output_vis, num_to_vis):
    model.eval()
    missing = []

    num_proc = 0
    num_vis = 0
    num_batches = len(dataloader)
    with torch.no_grad():
        for t, sample in enumerate(dataloader):
            inputs = sample['input']
            sdfs = sample['sdf']
            known = sample['known']
            hierarchy = sample['hierarchy']
            input_dim = np.array(sdfs.shape[2:])
            sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d)    ' %
                             (num_proc, args.max_to_process, sample['name'],
                              input_dim[0], input_dim[1], input_dim[2]))
            sys.stdout.flush()
            target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
                sdfs, hierarchy, args.num_hierarchy_levels, args.truncation,
                args.use_loss_masking, known)
            hierarchy_factor = pow(2, args.num_hierarchy_levels - 1)
            model.update_sizes(input_dim, input_dim // hierarchy_factor)
            try:
                if not args.cpu:
                    inputs[1] = inputs[1].cuda()
                    target_for_sdf = target_for_sdf.cuda()
                    for h in range(len(target_for_occs)):
                        target_for_occs[h] = target_for_occs[h].cuda()
                        target_for_hier[h] = target_for_hier[h].cuda()
                output_sdf, output_occs = model(inputs, loss_weights)
            except:
                print('exception at %s' % sample['name'])
                gc.collect()
                missing.extend(sample['name'])
                continue
            # remove padding
            dims = sample['orig_dims'][0]
            mask = (output_sdf[0][:, 0] <
                    dims[0]) & (output_sdf[0][:, 1] <
                                dims[1]) & (output_sdf[0][:, 2] < dims[2])
            output_sdf[0] = output_sdf[0][mask]
            output_sdf[1] = output_sdf[1][mask]
            for h in range(len(output_occs)):
                dims = target_for_occs[h].shape[2:]
                mask = (output_occs[h][0][:, 0] <
                        dims[0]) & (output_occs[h][0][:, 1] < dims[1]) & (
                            output_occs[h][0][:, 2] < dims[2])
                output_occs[h][0] = output_occs[h][0][mask]
                output_occs[h][1] = output_occs[h][1][mask]

            # save prediction files
            data_util.save_predictions_to_file(
                output_sdf[0][:, :3].cpu().numpy(),
                output_sdf[1].cpu().numpy(),
                os.path.join(output_pred, sample['name'][0] + '.pred'))

            try:
                pred_occs = [None] * args.num_hierarchy_levels
                for h in range(args.num_hierarchy_levels):
                    pred_occs[h] = [None]
                    if len(output_occs[h][0]) == 0:
                        continue
                    locs = output_occs[h][0][:, :-1]
                    vals = torch.nn.Sigmoid()(
                        output_occs[h][1][:, 0].detach()) > 0.5
                    pred_occs[h][0] = locs[vals.view(-1)]
            except:
                print('exception at %s' % sample['name'])
                gc.collect()
                missing.extend(sample['name'])
                continue

            num_proc += 1
            if num_vis < num_to_vis:
                num = min(num_to_vis - num_vis, 1)
                vis_pred_sdf = [None] * num
                if len(output_sdf[0]) > 0:
                    for b in range(num):
                        mask = output_sdf[0][:, -1] == b
                        if len(mask) > 0:
                            vis_pred_sdf[b] = [
                                output_sdf[0][mask].cpu().numpy(),
                                output_sdf[1][mask].squeeze().cpu().numpy()
                            ]
                inputs = [inputs[0].numpy(), inputs[1].cpu().numpy()]
                data_util.save_predictions(
                    output_vis, np.arange(num), inputs,
                    target_for_sdf.cpu().numpy(),
                    [x.cpu().numpy()
                     for x in target_for_occs], vis_pred_sdf, pred_occs,
                    sample['world2grid'], args.vis_dfs, args.truncation)
                num_vis += 1
            gc.collect()

    sys.stdout.write('\n')
    print('missing', missing)