Esempio n. 1
0
    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """

        # batch_size = self.hparams.train.batch_size
        num_hierarchy_levels = self.hparams.train.num_hierarchy_levels
        truncation = self.hparams.train.truncation
        use_loss_masking = self.hparams.train.use_loss_masking
        logweight_target_sdf = self.hparams.model.logweight_target_sdf
        weight_missing_geo = self.hparams.train.weight_missing_geo

        sample = batch

        sdfs = sample['sdf']
        # TODO: fix it
        # if sdfs.shape[0] < batch_size:
        #     continue  # maintain same batch size for training
        inputs = sample['input']
        known = sample['known']
        hierarchy = sample['hierarchy']
        for h in range(len(hierarchy)):
            hierarchy[h] = hierarchy[h].cuda()
        if use_loss_masking:
            known = known.cuda()
        inputs[0] = inputs[0].cuda()
        inputs[1] = inputs[1].cuda()
        target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
            sdfs.cuda(), hierarchy, num_hierarchy_levels, truncation,
            use_loss_masking, known)

        # update loss weights
        _iter = self._iter_counter
        loss_weights = get_loss_weights(
            _iter, self.hparams.train.num_hierarchy_levels,
            self.hparams.train.num_iters_per_level,
            self.hparams.train.weight_sdf_loss)

        output_sdf, output_occs = self(inputs, loss_weights)
        loss, losses = loss_util.compute_loss(output_sdf, output_occs,
                                              target_for_sdf, target_for_occs,
                                              target_for_hier, loss_weights,
                                              truncation, logweight_target_sdf,
                                              weight_missing_geo, inputs[0],
                                              use_loss_masking, known)

        output = OrderedDict({
            'val_loss': loss,
        })

        losses_dict = dict([(f'val_loss_{i}', l)
                            for (i, l) in enumerate(losses)])
        output.update(losses_dict)

        return output
Esempio n. 2
0
def check(dataloader, output_save):

    num_batches = len(dataloader)
    print("num_batches: ", num_batches)
    for t, sample in enumerate(dataloader):
        print("START")
        sdfs = sample['sdf']
        if sdfs.shape[0] < args.batch_size:
            print("empty sdf: ", sample['name'])
            continue  # maintain same batch size for training
        inputs = sample['input']
        known = sample['known']
        hierarchy = sample['hierarchy']

        target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
            sdfs,
            hierarchy,
            args.num_hierarchy_levels,
            args.truncation,
            True,
            known,
            flipped=FLIP)
        # print("target_for_occs: ", target_for_occs[-1].shape, (target_for_occs[-1]>0).sum(), (target_for_occs[-1]==0).sum())
        dims = sample['orig_dims'][0]
        inputs = inputs.cpu().numpy()

        # print(inputs[0].shape, inputs[0][0], target_for_sdf.shape)
        data_util.save_input_target_pred(args.save,
                                         sample['name'],
                                         inputs,
                                         target_for_sdf,
                                         None,
                                         truncation=2)

        print("END")
    return
Esempio n. 3
0
def test(dataloader, output_vis, num_to_vis):
    model.eval()

    chunk_dim = args.input_dim
    args.max_input_height = chunk_dim[0]
    if args.stride == 0:
        args.stride = chunk_dim[1]
    pad = 2

    num_proc = 0
    num_vis = 0
    num_batches = len(dataloader)
    print('starting testing...')
    with torch.no_grad():
        for t, sample in enumerate(dataloader):
            inputs = sample['input']
            sdfs = sample['sdf']
            mask = sample['mask']
            colors = sample['colors']

            max_input_dim = np.array(sdfs.shape[2:])
            if args.max_input_height > 0 and max_input_dim[
                    UP_AXIS] > args.max_input_height:
                max_input_dim[UP_AXIS] = args.max_input_height
                inputs = inputs[:, :, :args.max_input_height]
                if mask is not None:
                    mask = mask[:, :, :args.max_input_height]
                if sdfs is not None:
                    sdfs = sdfs[:, :, :args.max_input_height]
                if colors is not None:
                    colors = colors[:, :args.max_input_height]
            sys.stdout.write(
                '\r[ %d | %d ] %s (%d, %d, %d)    ' %
                (num_proc, args.max_to_process, sample['name'],
                 max_input_dim[0], max_input_dim[1], max_input_dim[2]))

            output_colors = torch.zeros(colors.shape)
            output_sdfs = torch.zeros(sdfs.shape)
            output_norms = torch.zeros(sdfs.shape)
            output_occs = torch.zeros(sdfs.shape, dtype=torch.uint8)

            # chunk up the scene
            chunk_input = torch.ones(1, args.input_nf, chunk_dim[0],
                                     chunk_dim[1], chunk_dim[2]).cuda()
            chunk_mask = torch.ones(1, 1, chunk_dim[0], chunk_dim[1],
                                    chunk_dim[2]).cuda()
            chunk_target_sdf = torch.ones(1, 1, chunk_dim[0], chunk_dim[1],
                                          chunk_dim[2]).cuda()
            chunk_target_colors = torch.zeros(1,
                                              chunk_dim[0],
                                              chunk_dim[1],
                                              chunk_dim[2],
                                              3,
                                              dtype=torch.uint8).cuda()

            for y in range(0, max_input_dim[1], args.stride):
                for x in range(0, max_input_dim[2], args.stride):
                    chunk_input_mask = torch.abs(
                        inputs[:, :, :chunk_dim[0], y:y + chunk_dim[1],
                               x:x + chunk_dim[2]]) < args.truncation
                    if torch.sum(chunk_input_mask).item() == 0:
                        continue
                    sys.stdout.write(
                        '\r[ %d | %d ] %s (%d, %d, %d) (%d, %d)    ' %
                        (num_proc, args.max_to_process, sample['name'],
                         max_input_dim[0], max_input_dim[1], max_input_dim[2],
                         y, x))

                    fill_dim = [
                        min(sdfs.shape[2], chunk_dim[0]),
                        min(sdfs.shape[3] - y, chunk_dim[1]),
                        min(sdfs.shape[4] - x, chunk_dim[2])
                    ]
                    chunk_target_sdf.fill_(float('inf'))
                    chunk_target_colors.fill_(0)
                    chunk_input[:, 0].fill_(-args.truncation)
                    chunk_input[:, 1:].fill_(0)
                    chunk_mask.fill_(0)
                    chunk_input[:, :, :fill_dim[0], :fill_dim[1], :
                                fill_dim[2]] = inputs[:, :, :chunk_dim[0],
                                                      y:y + chunk_dim[1],
                                                      x:x + chunk_dim[2]]
                    chunk_mask[:, :, :fill_dim[0], :fill_dim[1], :
                               fill_dim[2]] = mask[:, :, :chunk_dim[0],
                                                   y:y + chunk_dim[1],
                                                   x:x + chunk_dim[2]]
                    chunk_target_sdf[:, :, :fill_dim[0], :fill_dim[1], :
                                     fill_dim[2]] = sdfs[:, :, :chunk_dim[0],
                                                         y:y + chunk_dim[1],
                                                         x:x + chunk_dim[2]]
                    chunk_target_colors[:, :fill_dim[0], :fill_dim[
                        1], :fill_dim[2], :] = colors[:, :chunk_dim[0],
                                                      y:y + chunk_dim[1],
                                                      x:x + chunk_dim[2]]

                    target_for_sdf, target_for_colors = loss_util.compute_targets(
                        chunk_target_sdf.cuda(), args.truncation, False, None,
                        chunk_target_colors.cuda())

                    output_occ = None
                    output_occ, output_sdf, output_color = model(
                        chunk_input,
                        chunk_mask,
                        pred_sdf=[True, True],
                        pred_color=args.weight_color_loss > 0)

                    if output_occ is not None:
                        occ = torch.nn.Sigmoid()(output_occ.detach()) > 0.5
                        locs = torch.nonzero((torch.abs(
                            output_sdf.detach()[:, 0]) < args.truncation)
                                             & occ[:, 0]).cpu()
                    else:
                        locs = torch.nonzero(
                            torch.abs(output_sdf[:,
                                                 0]) < args.truncation).cpu()
                    locs = torch.cat([locs[:, 1:], locs[:, :1]], 1)
                    output_sdf = [
                        locs, output_sdf[locs[:, -1], :, locs[:, 0],
                                         locs[:, 1], locs[:,
                                                          2]].detach().cpu()
                    ]
                    if args.weight_color_loss == 0:
                        output_color = None
                    else:
                        output_color = [
                            locs, output_color[locs[:, -1], :, locs[:, 0],
                                               locs[:, 1], locs[:, 2]]
                        ]

                    output_locs = output_sdf[0] + torch.LongTensor(
                        [0, y, x, 0])
                    if args.stride < chunk_dim[1]:
                        min_dim = [0, y, x]
                        max_dim = [
                            0 + chunk_dim[0], y + chunk_dim[1],
                            x + chunk_dim[2]
                        ]
                        if y > 0:
                            min_dim[1] += pad
                        if x > 0:
                            min_dim[2] += pad
                        if y + chunk_dim[1] < max_input_dim[1]:
                            max_dim[1] -= pad
                        if x + chunk_dim[2] < max_input_dim[2]:
                            max_dim[2] -= pad
                        for k in range(3):
                            max_dim[k] = min(max_dim[k], sdfs.shape[k + 2])
                        outmask = (output_locs[:, 0] >= min_dim[0]) & (
                            output_locs[:, 1] >=
                            min_dim[1]) & (output_locs[:, 2] >= min_dim[2]) & (
                                output_locs[:, 0] < max_dim[0]) & (
                                    output_locs[:, 1] < max_dim[1]) & (
                                        output_locs[:, 2] < max_dim[2])
                    else:
                        outmask = (
                            output_locs[:, 0] < output_sdfs.shape[2]) & (
                                output_locs[:, 1] < output_sdfs.shape[3]) & (
                                    output_locs[:, 2] < output_sdfs.shape[4])
                    output_locs = output_locs[outmask]
                    output_sdf = [
                        output_sdf[0][outmask], output_sdf[1][outmask]
                    ]
                    if output_color is not None:
                        output_color = [
                            output_color[0][outmask], output_color[1][outmask]
                        ]
                        output_color = (output_color[1] + 1) * 0.5

                        output_colors[
                            0, output_locs[:, 0], output_locs[:, 1],
                            output_locs[:,
                                        2], :] += output_color.detach().cpu()
                    if output_occ is not None:
                        output_occs[:, :, :chunk_dim[0], y:y + chunk_dim[1],
                                    x:x + chunk_dim[2]] = occ[:, :, :fill_dim[
                                        0], :fill_dim[1], :fill_dim[2]]

                    output_sdfs[
                        0, 0, output_locs[:, 0], output_locs[:, 1],
                        output_locs[:, 2]] += output_sdf[1][:,
                                                            0].detach().cpu()
                    output_norms[0, 0, output_locs[:, 0], output_locs[:, 1],
                                 output_locs[:, 2]] += 1

            # normalize
            mask = output_norms > 0
            output_norms = output_norms[mask]
            output_sdfs[mask] = output_sdfs[mask] / output_norms
            output_sdfs[~mask] = -float('inf')
            mask = mask.view(1, mask.shape[2], mask.shape[3], mask.shape[4])
            output_colors[
                mask, :] = output_colors[mask, :] / output_norms.view(-1, 1)
            output_colors = torch.clamp(output_colors * 255, 0, 255)

            sdfs = torch.clamp(sdfs, -args.truncation, args.truncation)
            output_sdfs = torch.clamp(output_sdfs, -args.truncation,
                                      args.truncation)

            if num_vis < num_to_vis:
                inputs = inputs.cpu().numpy()
                locs = torch.nonzero(
                    torch.abs(output_sdfs[0, 0]) < args.truncation)
                vis_pred_sdf = [None]
                vis_pred_color = [None]
                sdf_vals = output_sdfs[0, 0, locs[:, 0], locs[:, 1],
                                       locs[:, 2]].view(-1)
                vis_pred_sdf[0] = [locs.cpu().numpy(), sdf_vals.cpu().numpy()]
                if args.weight_color_loss > 0:
                    vals = output_colors[0, locs[:, 0], locs[:, 1], locs[:, 2]]
                    vis_pred_color[0] = vals.cpu().numpy()
                if output_occs is not None:
                    pred_occ = output_occs.cpu().numpy().astype(np.float32)
                data_util.save_predictions(output_vis,
                                           np.arange(1),
                                           sample['name'],
                                           inputs,
                                           sdfs.cpu().numpy(),
                                           colors.cpu().numpy(),
                                           None,
                                           None,
                                           vis_pred_sdf,
                                           vis_pred_color,
                                           None,
                                           None,
                                           sample['world2grid'],
                                           args.truncation,
                                           args.color_space,
                                           pred_occ=pred_occ)
                num_vis += 1
            num_proc += 1
            gc.collect()

    sys.stdout.write('\n')
Esempio n. 4
0
def test(epoch, iter, loss_weights, dataloader, log_file, output_save):
    val_losses = [[] for i in range(args.num_hierarchy_levels + 2)]
    val_l1preds = []
    val_l1tgts = []
    val_ious = [[] for i in range(args.num_hierarchy_levels)]
    model.eval()
    #start = time.time()

    num_batches = len(dataloader)
    with torch.no_grad():
        for t, sample in enumerate(dataloader):
            sdfs = sample['sdf']
            if sdfs.shape[0] < args.batch_size:
                continue  # maintain same batch size
            inputs = sample['input']
            known = sample['known']
            hierarchy = sample['hierarchy']
            for h in range(len(hierarchy)):
                hierarchy[h] = hierarchy[h].cuda()
            if args.use_loss_masking:
                known = known.cuda()
            inputs[0] = inputs[0].cuda()
            inputs[1] = inputs[1].cuda()
            target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
                sdfs.cuda(), hierarchy, args.num_hierarchy_levels,
                args.truncation, args.use_loss_masking, known)

            output_sdf, output_occs = model(inputs, loss_weights)
            loss, losses = loss_util.compute_loss(
                output_sdf, output_occs, target_for_sdf, target_for_occs,
                target_for_hier, loss_weights, args.truncation,
                args.logweight_target_sdf, args.weight_missing_geo, inputs[0],
                args.use_loss_masking, known)

            output_visual = output_save and t + 2 == num_batches
            compute_pred_occs = (t % 20 == 0) or output_visual
            if compute_pred_occs:
                pred_occs = [None] * args.num_hierarchy_levels
                for h in range(args.num_hierarchy_levels):
                    factor = 2**(args.num_hierarchy_levels - h - 1)
                    pred_occs[h] = [None] * args.batch_size
                    if len(output_occs[h][0]) == 0:
                        continue
                    for b in range(args.batch_size):
                        batchmask = output_occs[h][0][:, -1] == b
                        locs = output_occs[h][0][batchmask][:, :-1]
                        vals = torch.nn.Sigmoid()(
                            output_occs[h][1][:, 0].detach()[batchmask]) > 0.5
                        pred_occs[h][b] = locs[vals.view(-1)]
            val_losses[0].append(loss.item())
            for h in range(args.num_hierarchy_levels):
                val_losses[h + 1].append(losses[h])
                target = target_for_occs[h].byte()
                if compute_pred_occs:
                    iou = loss_util.compute_iou_sparse_dense(
                        pred_occs[h], target, args.use_loss_masking)
                    val_ious[h].append(iou)
            val_losses[args.num_hierarchy_levels + 1].append(losses[-1])
            if len(output_sdf[0]) > 0:
                output_sdf = [output_sdf[0].detach(), output_sdf[1].detach()]
            if loss_weights[-1] > 0 and t % 20 == 0:
                val_l1preds.append(
                    loss_util.compute_l1_predsurf_sparse_dense(
                        output_sdf[0], output_sdf[1], target_for_sdf, None,
                        False, args.use_loss_masking, known).item())
                val_l1tgts.append(
                    loss_util.compute_l1_tgtsurf_sparse_dense(
                        output_sdf[0], output_sdf[1], target_for_sdf,
                        args.truncation, args.use_loss_masking, known))
            if output_visual:
                vis_pred_sdf = [None] * args.batch_size
                if len(output_sdf[0]) > 0:
                    for b in range(args.batch_size):
                        mask = output_sdf[0][:, -1] == b
                        if len(mask) > 0:
                            vis_pred_sdf[b] = [
                                output_sdf[0][mask].cpu().numpy(),
                                output_sdf[1][mask].squeeze().cpu().numpy()
                            ]
                inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()]
                for h in range(args.num_hierarchy_levels):
                    for b in range(args.batch_size):
                        if pred_occs[h][b] is not None:
                            pred_occs[h][b] = pred_occs[h][b].cpu().numpy()
                data_util.save_predictions(
                    os.path.join(args.save, 'iter%d-epoch%d' % (iter, epoch),
                                 'val'), sample['name'], inputs,
                    target_for_sdf.cpu().numpy(),
                    [x.cpu().numpy()
                     for x in target_for_occs], vis_pred_sdf, pred_occs,
                    sample['world2grid'], args.vis_dfs, args.truncation)

    #took = time.time() - start
    return val_losses, val_l1preds, val_l1tgts, val_ious
Esempio n. 5
0
def train(epoch, iter, dataloader, log_file, output_save):
    train_losses = [[] for i in range(args.num_hierarchy_levels + 2)]
    train_l1preds = []
    train_l1tgts = []
    train_ious = [[] for i in range(args.num_hierarchy_levels)]
    model.train()
    start = time.time()

    if args.scheduler_step_size == 0:
        scheduler.step()

    num_batches = len(dataloader)
    for t, sample in enumerate(dataloader):
        loss_weights = get_loss_weights(iter, args.num_hierarchy_levels,
                                        args.num_iters_per_level,
                                        args.weight_sdf_loss)
        if epoch == args.start_epoch and t == 0:
            print('[iter %d/epoch %d] loss_weights' % (iter, epoch),
                  loss_weights)

        sdfs = sample['sdf']
        if sdfs.shape[0] < args.batch_size:
            continue  # maintain same batch size for training
        inputs = sample['input']
        known = sample['known']
        hierarchy = sample['hierarchy']
        for h in range(len(hierarchy)):
            hierarchy[h] = hierarchy[h].cuda()
        if args.use_loss_masking:
            known = known.cuda()
        inputs[0] = inputs[0].cuda()
        inputs[1] = inputs[1].cuda()
        target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
            sdfs.cuda(), hierarchy, args.num_hierarchy_levels, args.truncation,
            args.use_loss_masking, known)

        optimizer.zero_grad()
        output_sdf, output_occs = model(inputs, loss_weights)
        loss, losses = loss_util.compute_loss(
            output_sdf, output_occs, target_for_sdf, target_for_occs,
            target_for_hier, loss_weights, args.truncation,
            args.logweight_target_sdf, args.weight_missing_geo, inputs[0],
            args.use_loss_masking, known)
        loss.backward()
        optimizer.step()

        output_visual = output_save and t + 2 == num_batches
        compute_pred_occs = (iter % 20 == 0) or output_visual
        if compute_pred_occs:
            pred_occs = [None] * args.num_hierarchy_levels
            for h in range(args.num_hierarchy_levels):
                factor = 2**(args.num_hierarchy_levels - h - 1)
                pred_occs[h] = [None] * args.batch_size
                if len(output_occs[h][0]) == 0:
                    continue
                output_occs[h][1] = torch.nn.Sigmoid()(
                    output_occs[h][1][:, 0].detach()) > 0.5
                for b in range(args.batch_size):
                    batchmask = output_occs[h][0][:, -1] == b
                    locs = output_occs[h][0][batchmask][:, :-1]
                    vals = output_occs[h][1][batchmask]
                    pred_occs[h][b] = locs[vals.view(-1)]
        train_losses[0].append(loss.item())
        for h in range(args.num_hierarchy_levels):
            train_losses[h + 1].append(losses[h])
            target = target_for_occs[h].byte()
            if compute_pred_occs:
                iou = loss_util.compute_iou_sparse_dense(
                    pred_occs[h], target, args.use_loss_masking)
                train_ious[h].append(iou)
        train_losses[args.num_hierarchy_levels + 1].append(losses[-1])
        if len(output_sdf[0]) > 0:
            output_sdf = [output_sdf[0].detach(), output_sdf[1].detach()]
        if loss_weights[-1] > 0 and iter % 20 == 0:
            train_l1preds.append(
                loss_util.compute_l1_predsurf_sparse_dense(
                    output_sdf[0], output_sdf[1], target_for_sdf, None, False,
                    args.use_loss_masking, known).item())
            train_l1tgts.append(
                loss_util.compute_l1_tgtsurf_sparse_dense(
                    output_sdf[0], output_sdf[1], target_for_sdf,
                    args.truncation, args.use_loss_masking, known))

        iter += 1
        if args.scheduler_step_size > 0 and iter % args.scheduler_step_size == 0:
            scheduler.step()
        if iter % 20 == 0:
            took = time.time() - start
            print_log(log_file, epoch, iter, train_losses, train_l1preds,
                      train_l1tgts, train_ious, None, None, None, None, took)
        if iter % 2000 == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                os.path.join(args.save,
                             'model-iter%s-epoch%s.pth' % (iter, epoch)))
        if output_visual:
            vis_pred_sdf = [None] * args.batch_size
            if len(output_sdf[0]) > 0:
                for b in range(args.batch_size):
                    mask = output_sdf[0][:, -1] == b
                    if len(mask) > 0:
                        vis_pred_sdf[b] = [
                            output_sdf[0][mask].cpu().numpy(),
                            output_sdf[1][mask].squeeze().cpu().numpy()
                        ]
            inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()]
            for h in range(args.num_hierarchy_levels):
                for b in range(args.batch_size):
                    if pred_occs[h][b] is not None:
                        pred_occs[h][b] = pred_occs[h][b].cpu().numpy()
            data_util.save_predictions(
                os.path.join(args.save, 'iter%d-epoch%d' % (iter, epoch),
                             'train'), sample['name'], inputs,
                target_for_sdf.cpu().numpy(),
                [x.cpu().numpy()
                 for x in target_for_occs], vis_pred_sdf, pred_occs,
                sample['world2grid'].numpy(), args.vis_dfs, args.truncation)

    return train_losses, train_l1preds, train_l1tgts, train_ious, iter, loss_weights
Esempio n. 6
0
def test(loss_weights, dataloader, output_vis, num_to_vis):
    model.eval()

    num_vis = 0
    num_batches = len(dataloader)
    print("num_batches: ", num_batches)
    with torch.no_grad():
        for t, sample in enumerate(dataloader):
            inputs = sample['input']
            print("input_dim: ", np.array(sample['sdf'].shape))
            input_dim = np.array(sample['sdf'].shape[2:])

            print(input_dim)
            sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d)    ' %
                             (num_vis, num_to_vis, sample['name'],
                              input_dim[0], input_dim[1], input_dim[2]))
            sys.stdout.flush()

            # hierarchy_factor = pow(2, args.num_hierarchy_levels-1)
            # model.update_sizes(input_dim, input_dim // hierarchy_factor)

            sdfs = sample['sdf']
            if sdfs.shape[0] < args.batch_size:
                continue  # maintain same batch size for training
            inputs = sample['input']
            known = sample['known']
            hierarchy = sample['hierarchy']
            for h in range(len(hierarchy)):
                hierarchy[h] = hierarchy[h].cuda()
                hierarchy[h].unsqueeze_(1)
            if args.use_loss_masking:
                known = known.cuda()
            inputs = inputs.cuda()
            sdfs = sdfs.cuda()
            target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
                sdfs,
                hierarchy,
                args.num_hierarchy_levels,
                args.truncation,
                args.use_loss_masking,
                known,
                flipped=args.flipped)

            start_time = time.time()
            try:
                output_sdf, output_occs = model(inputs, loss_weights)
            except:
                print('exception at %s' % sample['name'])
                gc.collect()
                continue

            end_time = time.time()
            print('TIME: %.4fs' % (end_time - start_time))

            # remove padding
            # dims = sample['orig_dims'][0]
            # mask = (output_sdf[0][:,0] < dims[0]) & (output_sdf[0][:,1] < dims[1]) & (output_sdf[0][:,2] < dims[2])
            # output_sdf[0] = output_sdf[0][mask]
            # output_sdf[1] = output_sdf[1][mask]
            # mask = (inputs[0][:,0] < dims[0]) & (inputs[0][:,1] < dims[1]) & (inputs[0][:,2] < dims[2])
            # inputs[0] = inputs[0][mask]
            # inputs[1] = inputs[1][mask]
            # # vis_pred_sdf = [None]
            # # if len(output_sdf[0]) > 0:
            # #     vis_pred_sdf[0] = [output_sdf[0].cpu().numpy(), output_sdf[1].squeeze().cpu().numpy()]
            # inputs = [inputs[0].cpu().numpy(), inputs[1].cpu().numpy()]

            # vis occ & sdf
            pred_occs = [None] * args.num_hierarchy_levels
            for h in range(args.num_hierarchy_levels):
                factor = 2**(args.num_hierarchy_levels - h - 1)
                pred_occs[h] = [None] * args.batch_size
                if len(output_occs[h][0]) == 0:
                    continue
                # filter: occ > 0
                for b in range(args.batch_size):
                    batchmask = output_occs[h][0][:, -1] == b
                    locs = output_occs[h][0][batchmask][:, :-1]
                    vals = output_occs[h][1][batchmask]
                    occ_mask = torch.nn.Sigmoid()(vals[:, 0].detach()) > 0.5
                    if args.flipped:
                        pred_occs[h][b] = [
                            locs[occ_mask.view(-1)].cpu().numpy(),
                            vals[occ_mask.view(-1)].cpu().numpy()
                        ]
                    else:
                        pred_occs[h][b] = locs[occ_mask.view(-1)].cpu().numpy()

            vis_pred_sdf = [None] * args.batch_size
            if len(output_sdf[0]) > 0:
                for b in range(args.batch_size):
                    mask = output_sdf[0][:, -1] == b
                    if len(mask) > 0:
                        vis_pred_sdf[b] = [
                            output_sdf[0][mask].cpu().numpy(),
                            output_sdf[1][mask].squeeze().cpu().numpy()
                        ]

            data_util.save_predictions(output_vis,
                                       sample['name'],
                                       inputs,
                                       target_for_sdf.cpu().numpy(),
                                       None,
                                       vis_pred_sdf,
                                       pred_occs,
                                       sample['world2grid'],
                                       args.truncation,
                                       flipped=args.flipped)
            num_vis += 1
            if num_vis >= num_to_vis:
                break
    sys.stdout.write('\n')
Esempio n. 7
0
    def vis(self):
        count = 0
        with torch.no_grad():
            for t, sample in enumerate(self.val_dataloader):
                if rospy.is_shutdown():
                    exit()
                if count > self.max_num:
                    return
                else:
                    count += 1

                print("name: ", sample['name'])
                sdfs = sample['sdf']
                inputs = sample['input']
                known = sample['known']
                hierarchy = sample['hierarchy']
                for h in range(len(hierarchy)):
                    hierarchy[h] = hierarchy[h].cuda()
                    hierarchy[h].unsqueeze_(1)
                known = known.cuda()
                inputs = inputs.cuda()
                sdfs = sdfs.cuda()
                predz = int(np.floor(inputs.shape[-1] / 8) * 8)
                # embed()
                inputs = inputs[:, :, :, :predz]
                sdfs = sdfs[:, :, :, :, :predz]
                # embed()
                target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
                    sdfs, hierarchy, self.num_hierarchy_levels, 2.0, True,
                    known, True)
                # embed()

                output_sdf, output_occs = self.model(inputs, self.weights)
                # bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_dense(output_sdf, output_occs, target_for_sdf, target_for_occs,
                #     target_for_hier, self.weights, 3.0, True, 3.0, inputs,
                #     True, known, flipped=True)
                # loss = bce_loss + sdf_loss

                # val_bceloss.append(last_loss)
                # val_sdfloss.append(sdf_loss)
                # val_losses[0].append(loss)

                # output_sdf[known_mask] = inputs[known_mask]

                if output_occs is not None:
                    start_time = time.time()
                    # pred_occ = output_sdf[:,0].squeeze()
                    pred_occ = output_occs[-1][:, 0].squeeze()
                    pred_occ = self.sigmoid_func(pred_occ) > 0.45
                    # pred_occ = output_sdf[:,1].squeeze()
                    # pred_occ = (pred_occ).abs() < 2
                    target_occ = target_for_occs[-1].squeeze() > 0
                    input_occ = inputs.squeeze() > 0
                    # embed()
                    # precision, recall, f1score = loss_util.compute_dense_occ_accuracy(input_occ, target_occ, pred_occ, truncation=3)

                    dimx = pred_occ.shape[-3]
                    dimy = pred_occ.shape[-2]
                    dimz = pred_occ.shape[-1]
                    # if not self.size_init:
                    self.init_size(dimx, dimy, dimz)

                    known_mask = (input_occ >= 0)
                    # fix conflicts with input in known grids
                    # pred_occ[known_mask] = input_occ[known_mask]

                    # num = (input_occ | pred_occ | target_occ).float().sum().cpu().numpy().tolist()
                    # print("total size: ", num)

                    # b - r - y
                    # vis pub
                    # embed()
                    input_occ_coords = self.pointers[input_occ]
                    input_colors = torch.FloatTensor([255, 50, 50]).repeat(
                        [input_occ_coords.shape[0], 1])
                    input_points = torch.cat([input_occ_coords, input_colors],
                                             dim=1)
                    input_points[:, :3] *= self.voxel_size
                    # print("input size: ", input_points.shape)

                    # red
                    pred_occ_coords = self.pointers[pred_occ & (~input_occ)]
                    pred_colors = torch.FloatTensor([50, 50, 255]).repeat(
                        [pred_occ_coords.shape[0], 1])
                    pred_points = torch.cat([pred_occ_coords, pred_colors],
                                            dim=1)
                    pred_points[:, :3] *= self.voxel_size
                    # print("pred size: ", pred_points.shape)

                    target_occ_coords = self.pointers[
                        target_occ & (~input_occ)]  # &(~pred_occ)
                    target_colors = torch.FloatTensor([50, 255, 50]).repeat(
                        [target_occ_coords.shape[0], 1])
                    target_points = torch.cat(
                        [target_occ_coords, target_colors], dim=1)
                    target_points[:, :3] *= self.voxel_size

                    target_colors_blue = torch.FloatTensor(
                        [255, 50, 50]).repeat([target_occ_coords.shape[0], 1])
                    target_points_blue = torch.cat(
                        [target_occ_coords, target_colors_blue], dim=1)
                    target_points_blue[:, :3] *= self.voxel_size
                    # print("target size: ", target_points.shape)

                    # points = torch.cat([input_points, pred_points, target_points], dim=0)
                    input_points = input_points
                    inout_points = torch.cat([input_points, target_points],
                                             dim=0)
                    inout_points_blue = torch.cat(
                        [input_points, target_points_blue], dim=0)
                    inpred_points = torch.cat([input_points, pred_points],
                                              dim=0)

                    #
                    msg1 = PointCloud2()
                    msg1.header.frame_id = "map"
                    msg1.height = 1
                    msg1.width = input_points.shape[0]
                    msg1.fields = [
                        PointField('x', 0, PointField.FLOAT32, 1),
                        PointField('y', 4, PointField.FLOAT32, 1),
                        PointField('z', 8, PointField.FLOAT32, 1),
                        PointField('r', 12, PointField.FLOAT32, 1),
                        PointField('g', 16, PointField.FLOAT32, 1),
                        PointField('b', 20, PointField.FLOAT32, 1)
                    ]
                    msg1.is_bigendian = False
                    msg1.point_step = 24  #12
                    msg1.row_step = 24 * msg1.width
                    msg1.is_dense = False
                    msg1.data = input_points.reshape(
                        [-1]).cpu().numpy().tostring()
                    #
                    msg2 = deepcopy(msg1)
                    msg2.width = inout_points.shape[0]
                    msg2.row_step = 24 * msg2.width
                    msg2.data = inout_points.reshape(
                        [-1]).cpu().numpy().tostring()

                    #
                    msg0 = deepcopy(msg1)
                    msg0.width = inout_points_blue.shape[0]
                    msg0.row_step = 24 * msg2.width
                    msg0.data = inout_points_blue.reshape(
                        [-1]).cpu().numpy().tostring()

                    #
                    msg3 = deepcopy(msg1)
                    msg3.width = inpred_points.shape[0]
                    msg3.row_step = 24 * msg2.width
                    msg3.data = inpred_points.reshape(
                        [-1]).cpu().numpy().tostring()

                    print("post process time: %fs" %
                          (time.time() - start_time))

                    self.output_pub.publish(msg0)
                    # rospy.sleep(2)
                    raw_input('press key to continue')

                    self.output_pub.publish(msg1)
                    # rospy.sleep(2)
                    raw_input('press key to continue')
                    self.output_pub.publish(msg2)
                    # rospy.sleep(2)
                    raw_input('press key to continue')

                    self.output_pub.publish(msg3)
                    # rospy.sleep(2)
                    raw_input('press key to continue')
Esempio n. 8
0
def test(loss_weights, dataloader, output_vis, num_to_vis):
    model.eval()
    val_losses = [[] for i in range(args.num_hierarchy_levels + 2)]
    val_bceloss = []
    val_sdfloss = []
    val_precision = []
    val_recall = []
    val_f1score = []
    val_ious = []
    val_time = []
    num_vis = 0
    num_batches = len(dataloader)
    print("num_batches: ", num_batches)

    sigmoid_func = torch.nn.Sigmoid()
    with torch.no_grad():
        for t, sample in enumerate(dataloader):
            # print("sample = ",sample)
            sdfs = sample['sdf']
            inputs = sample['input']
            known = sample['known']
            hierarchy = sample['hierarchy']
            # print("check  hie=",hierarchy)
            for h in range(len(hierarchy)):
                hierarchy[h] = hierarchy[h].cuda()
                hierarchy[h].unsqueeze_(1)
            # if args.use_loss_masking:
            if True:
                known = known.cuda()
            inputs = inputs.cuda()
            sdfs = sdfs.cuda()
            target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
                sdfs,
                hierarchy,
                args.num_hierarchy_levels,
                args.truncation,
                True,
                known,
                flipped=args.flipped)
            # target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(sdfs, hierarchy,
            #     1, args.truncation, True, known, flipped=args.flipped)

            start_time = time.time()
            output_sdf, output_occs = model(inputs, loss_weights)
            known_mask = inputs > -1

            val_time.append(time.time() - start_time)
            # loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known)
            bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_nosdf(
                output_sdf,
                output_occs,
                target_for_sdf,
                target_for_occs,
                target_for_hier,
                loss_weights,
                args.truncation,
                args.logweight_target_sdf,
                args.weight_missing_geo,
                inputs,
                args.use_loss_masking,
                known,
                flipped=args.flipped)
            loss = bce_loss * 2  # + sdf_loss

            val_bceloss.append(last_loss)
            # val_sdfloss.append(sdf_loss)
            val_losses[0].append(loss)

            # output_sdf[known_mask] = inputs[known_mask]

            if output_sdf is not None:
                pred_occ = output_occs[-1][:, 0].squeeze()
                pred_occ = sigmoid_func(pred_occ) > 0.0
                # pred_occ = output_sdf[:,1].squeeze()
                # pred_occ = (pred_occ).abs() < 2
                target_occ = target_for_occs[-1].squeeze()
                input_occ = inputs.squeeze()
                precision, recall, f1score = loss_util.compute_dense_occ_accuracy(
                    input_occ,
                    target_occ,
                    pred_occ,
                    truncation=args.truncation)
            else:
                precision, recall, f1score = 0, 0, 0

            val_precision.append(precision)
            val_recall.append(recall)
            val_f1score.append(f1score)
            val_losses[args.num_hierarchy_levels + 1].append(losses[-1])

            if True:
                data_util.save_dense_predictions(args.output,
                                                 sample['name'],
                                                 input_occ.cpu().numpy(),
                                                 target_occ.cpu().numpy(),
                                                 pred_occ.cpu().numpy(),
                                                 args.truncation,
                                                 flipped=args.flipped)
                print("saved: ", sample['name'])

    print("epoch_loss/total: ", np.mean(val_losses[0]))
    print("epoch_loss/bce: ", np.mean(val_bceloss))
    print("epoch_loss/precision: ", np.mean(val_precision))
    print("epoch_loss/recall: ", np.mean(val_recall))
    print("epoch_loss/f1: ", np.mean(val_f1score))
    print("average_time: ", np.mean(val_time))

    return
Esempio n. 9
0
    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        #x, y = batch
        #x = x.view(x.size(0), -1)
        #y_hat = self(x)

        #loss_val = self.loss(y, y_hat)

        ## acc
        #labels_hat = torch.argmax(y_hat, dim=1)
        #val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        #val_acc = torch.tensor(val_acc)

        #if self.on_gpu:
            #val_acc = val_acc.cuda(loss_val.device.index)

        #output = OrderedDict({
            #'val_loss': loss_val,
            #'val_acc': val_acc,
        #})

        ## can also return just a scalar instead of a dict (return loss_val)
        #return output

        batch_size = self.hparams.batch_size
        num_hierarchy_levels = self.hparams.num_hierarchy_levels
        truncation = self.hparams.truncation
        use_loss_masking = self.hparams.use_loss_masking
        logweight_target_sdf = self.hparams.logweight_target_sdf
        weight_missing_geo = self.hparams.weight_missing_geo

        sample = batch

        sdfs = sample['sdf']
        # TODO: fix it
        #if sdfs.shape[0] < batch_size:
        #    continue  # maintain same batch size for training
        inputs = sample['input']
        known = sample['known']
        hierarchy = sample['hierarchy']
        for h in range(len(hierarchy)):
            hierarchy[h] = hierarchy[h].cuda()
        if use_loss_masking:
            known = known.cuda()
        inputs[0] = inputs[0].cuda()
        inputs[1] = inputs[1].cuda()
        target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(sdfs.cuda(), hierarchy, num_hierarchy_levels, truncation, use_loss_masking, known)

        # TODO: update
        _iter = self._iter_counter
        loss_weights = get_loss_weights(_iter,
                                        self.hparams.num_hierarchy_levels,
                                        self.hparams.num_iters_per_level,
                                        self.hparams.weight_sdf_loss)

        output_sdf, output_occs = self(inputs, loss_weights)
        loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, truncation,
                                              logweight_target_sdf, weight_missing_geo, inputs[0], use_loss_masking, known)


        output = OrderedDict({
            'val_loss': loss,
        })
        return output
Esempio n. 10
0
    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        ## forward pass
        #x, y = batch
        #x = x.view(x.size(0), -1)

        #y_hat = self(x)

        ## calculate loss
        #loss_val = self.loss(y, y_hat)

        #tqdm_dict = {'train_loss': loss_val}
        #output = OrderedDict({
            #'loss': loss_val,
            #'progress_bar': tqdm_dict,
            #'log': tqdm_dict
        #})

        ## can also return just a scalar instead of a dict (return loss_val)
        #return output

        batch_size = self.hparams.batch_size
        num_hierarchy_levels = self.hparams.num_hierarchy_levels
        truncation = self.hparams.truncation
        use_loss_masking = self.hparams.use_loss_masking
        logweight_target_sdf = self.hparams.logweight_target_sdf
        weight_missing_geo = self.hparams.weight_missing_geo

        sample = batch

        sdfs = sample['sdf']
        # TODO: fix it
        #if sdfs.shape[0] < batch_size:
        #    continue  # maintain same batch size for training
        inputs = sample['input']
        known = sample['known']
        hierarchy = sample['hierarchy']
        for h in range(len(hierarchy)):
            hierarchy[h] = hierarchy[h].cuda()
        if use_loss_masking:
            known = known.cuda()
        inputs[0] = inputs[0].cuda()
        inputs[1] = inputs[1].cuda()
        target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(sdfs.cuda(), hierarchy, num_hierarchy_levels, truncation, use_loss_masking, known)

        # TODO: update
        #loss_weights = self.model._loss_weights
        _iter = self._iter_counter
        loss_weights = get_loss_weights(_iter,
                                        self.hparams.num_hierarchy_levels,
                                        self.hparams.num_iters_per_level,
                                        self.hparams.weight_sdf_loss)


        output_sdf, output_occs = self(inputs, loss_weights)
        loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, truncation,
                                              logweight_target_sdf, weight_missing_geo, inputs[0], use_loss_masking, known)

        tqdm_dict = {'train_loss': loss}
        output = OrderedDict({
            'loss': loss,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })

        self._iter_counter += 1

        return output
Esempio n. 11
0
def test(epoch, iter, loss_weights, dataloader, log_file, output_save):
    start_time = time.time()
    val_losses = [[] for i in range(args.num_hierarchy_levels + 2)]
    val_bceloss = []
    val_sdfloss = []
    val_precision = []
    val_recall = []
    val_f1score = []

    val_ious = [[] for i in range(args.num_hierarchy_levels)]
    model.eval()

    num_batches = len(dataloader)
    with torch.no_grad():
        for t, sample in enumerate(dataloader):
            sdfs = sample['sdf']
            if sdfs.shape[0] < args.batch_size:
                continue  # maintain same batch size
            inputs = sample['input']
            known = sample['known']
            hierarchy = sample['hierarchy']
            for h in range(len(hierarchy)):
                hierarchy[h] = hierarchy[h].cuda()
                hierarchy[h].unsqueeze_(1)
            if args.use_loss_masking:
                known = known.cuda()
            inputs = inputs.cuda()
            sdfs = sdfs.cuda()
            target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
                sdfs,
                hierarchy,
                args.num_hierarchy_levels,
                args.truncation,
                args.use_loss_masking,
                known,
                flipped=args.flipped)

            output_sdf, output_occs = model(inputs, loss_weights)
            # loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known)
            bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_dense(
                output_sdf,
                output_occs,
                target_for_sdf,
                target_for_occs,
                target_for_hier,
                loss_weights,
                args.truncation,
                args.logweight_target_sdf,
                args.weight_missing_geo,
                inputs,
                args.use_loss_masking,
                known,
                flipped=args.flipped)
            loss = bce_loss * 10 + sdf_loss

            val_losses[0].append(loss.item())
            if last_loss > 0:
                val_bceloss.append(last_loss)
            WRITER.add_scalar("loss/val", loss.item(), iter + t)
            output_visual = (output_save and t + 1 == num_batches
                             and output_sdf is not None)
            compute_pred_occs = (t % 20 == 0) or output_visual

            # if compute_pred_occs:
            pred_occs = [None] * args.num_hierarchy_levels
            vis_occs = [None] * args.num_hierarchy_levels
            for h in range(args.num_hierarchy_levels):
                factor = 2**(args.num_hierarchy_levels - h - 1)
                if isinstance(output_occs[h], list):
                    continue
                # output_occs[h][1] = torch.nn.Sigmoid()(output_occs[h][1][:,0].detach()) > 0.5
                output_occs[h] = output_occs[h].detach()
                vis_occs[h] = output_occs[h][:, 1, :, :, :].unsqueeze(1)

            for h in range(args.num_hierarchy_levels):
                val_losses[h + 1].append(losses[h])

            # if len(output_occs[-1]) is not 0:
            if output_sdf is not None:
                pred_occ = output_sdf[:, 0].squeeze()
                # pred_occ = pred_occ.abs() < (args.truncation / 2)
                pred_occ = pred_occ > 0.5
                target_occ = target_for_occs[-1].squeeze()
                input_occ = inputs.squeeze()
                precision, recall, f1score = loss_util.compute_dense_occ_accuracy(
                    input_occ,
                    target_occ,
                    pred_occ,
                    truncation=args.truncation)
            else:
                precision, recall, f1score = 0, 0, 0

            val_precision.append(precision)
            val_recall.append(recall)
            val_f1score.append(f1score)
            val_losses[args.num_hierarchy_levels + 1].append(losses[-1])

            if output_visual:
                data_util.save_dense_predictions(os.path.join(
                    args.save, 'iter%d-epoch%d' % (iter, epoch), 'val'),
                                                 sample['name'],
                                                 input_occ.cpu().numpy(),
                                                 target_occ.cpu().numpy(),
                                                 pred_occ.cpu().numpy(),
                                                 args.truncation,
                                                 flipped=args.flipped)

    WRITER.add_scalar("epoch_loss/val/total", np.mean(val_losses[0]), epoch)
    if len(val_bceloss) > 0:
        WRITER.add_scalar("epoch_loss/val/bce", np.mean(val_bceloss), epoch)

    WRITER.add_scalar("epoch_loss/val/precision", np.mean(val_precision),
                      epoch)
    WRITER.add_scalar("epoch_loss/val/recall", np.mean(val_recall), epoch)
    WRITER.add_scalar("epoch_loss/val/f1", np.mean(val_f1score), epoch)

    print(" test epoch {} used {} seconds".format(epoch,
                                                  time.time() - start_time))
    return val_losses, val_precision, val_recall, val_f1score,
Esempio n. 12
0
def train(epoch, iter, dataloader, log_file, output_save):
    start_time = time.time()

    train_losses = [[] for i in range(args.num_hierarchy_levels + 2)]
    train_bceloss = []
    train_sdfloss = []
    train_precision = []
    train_recall = []
    train_f1score = []

    model.train()
    start = time.time()

    num_batches = len(dataloader)
    for t, sample in enumerate(dataloader):
        loss_weights = get_loss_weights(iter, args.num_hierarchy_levels,
                                        args.num_iters_per_level,
                                        args.weight_sdf_loss)
        if epoch == args.start_epoch and t == 0:
            print('[iter %d/epoch %d] loss_weights' % (iter, epoch),
                  loss_weights)

        sdfs = sample['sdf']
        if sdfs.shape[0] < args.batch_size:
            continue  # maintain same batch size for training
        inputs = sample['input']
        known = sample['known']
        hierarchy = sample['hierarchy']
        for h in range(len(hierarchy)):
            hierarchy[h] = hierarchy[h].cuda()
            hierarchy[h].unsqueeze_(1)
        if args.use_loss_masking:
            known = known.cuda()
        inputs = inputs.cuda()
        sdfs = sdfs.cuda()
        target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
            sdfs,
            hierarchy,
            args.num_hierarchy_levels,
            args.truncation,
            args.use_loss_masking,
            known,
            flipped=args.flipped)

        optimizer.zero_grad()
        output_sdf, output_occs = model(inputs, loss_weights)
        # loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known)
        bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_dense(
            output_sdf,
            output_occs,
            target_for_sdf,
            target_for_occs,
            target_for_hier,
            loss_weights,
            args.truncation,
            args.logweight_target_sdf,
            args.weight_missing_geo,
            inputs,
            args.use_loss_masking,
            known,
            flipped=args.flipped)
        loss = bce_loss * 10 + sdf_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

        train_losses[0].append(loss.item())
        if last_loss > 0:
            train_bceloss.append(last_loss)
        train_sdfloss.append(bce_loss.item())

        output_visual = (output_save and t + 1 == num_batches
                         and output_sdf is not None)
        compute_pred_occs = (iter % 20 == 0) or output_visual

        # if compute_pred_occs:
        pred_occs = [None] * args.num_hierarchy_levels
        vis_occs = [None] * args.num_hierarchy_levels
        for h in range(args.num_hierarchy_levels):
            factor = 2**(args.num_hierarchy_levels - h - 1)
            if isinstance(output_occs[h], list):
                continue
            # output_occs[h][1] = torch.nn.Sigmoid()(output_occs[h][1][:,0].detach()) > 0.5
            output_occs[h] = output_occs[h].detach()
            vis_occs[h] = output_occs[h][:, 1, :, :, :].unsqueeze(1)

        for h in range(args.num_hierarchy_levels):
            train_losses[h + 1].append(losses[h])

        # if len(output_occs[-1]) is not 0:
        if output_sdf is not None:
            output_sdf.detach()
            # occ
            pred_occ = output_sdf[:, 0].squeeze()
            ## sdf
            # pred_occ = pred_occ.abs() < (args.truncation / 2)
            # occ
            pred_occ = pred_occ > 0.1

            target_occ = target_for_occs[-1].squeeze()
            input_occ = inputs.detach().squeeze()
            precision, recall, f1score = loss_util.compute_dense_occ_accuracy(
                input_occ, target_occ, pred_occ, truncation=args.truncation)
        else:
            precision, recall, f1score = 0, 0, 0

        train_precision.append(precision)
        train_recall.append(recall)
        train_f1score.append(f1score)

        train_losses[args.num_hierarchy_levels + 1].append(losses[-1])

        if output_sdf is not None:
            output_sdf = output_sdf.detach()

        iter += 1
        WRITER.add_scalar("loss/train", loss.item(), iter)

        if iter % 20 == 0:
            took = time.time() - start
            print_log(log_file, epoch, iter, train_losses, None, None,
                      train_precision, None, None, None, None, took)
        if iter % 2000 == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                os.path.join(args.save,
                             'model-iter%s-epoch%s.pth' % (iter, epoch)))

    if args.scheduler_step_size == 0:
        scheduler.step()
    else:
        args.scheduler_step_size -= 1

    WRITER.add_scalar("epoch_loss/train/total", np.mean(train_losses[0]),
                      epoch)
    if len(train_bceloss) > 0:
        WRITER.add_scalar("epoch_loss/train/bce", np.mean(train_bceloss),
                          epoch)
    WRITER.add_scalar("epoch_loss/train/precision", np.mean(train_precision),
                      epoch)
    WRITER.add_scalar("epoch_loss/train/recall", np.mean(train_recall), epoch)
    WRITER.add_scalar("epoch_loss/train/f1", np.mean(train_f1score), epoch)

    print(" test epoch {} used {} seconds".format(epoch,
                                                  time.time() - start_time))

    return train_losses, train_precision, train_recall, train_f1score, iter, loss_weights
Esempio n. 13
0
def test(loss_weights, dataloader, output_pred, output_vis, num_to_vis):
    model.eval()
    missing = []

    num_proc = 0
    num_vis = 0
    num_batches = len(dataloader)
    with torch.no_grad():
        for t, sample in enumerate(dataloader):
            inputs = sample['input']
            sdfs = sample['sdf']
            known = sample['known']
            hierarchy = sample['hierarchy']
            input_dim = np.array(sdfs.shape[2:])
            sys.stdout.write('\r[ %d | %d ] %s (%d, %d, %d)    ' %
                             (num_proc, args.max_to_process, sample['name'],
                              input_dim[0], input_dim[1], input_dim[2]))
            sys.stdout.flush()
            target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
                sdfs, hierarchy, args.num_hierarchy_levels, args.truncation,
                args.use_loss_masking, known)
            hierarchy_factor = pow(2, args.num_hierarchy_levels - 1)
            model.update_sizes(input_dim, input_dim // hierarchy_factor)
            try:
                if not args.cpu:
                    inputs[1] = inputs[1].cuda()
                    target_for_sdf = target_for_sdf.cuda()
                    for h in range(len(target_for_occs)):
                        target_for_occs[h] = target_for_occs[h].cuda()
                        target_for_hier[h] = target_for_hier[h].cuda()
                output_sdf, output_occs = model(inputs, loss_weights)
            except:
                print('exception at %s' % sample['name'])
                gc.collect()
                missing.extend(sample['name'])
                continue
            # remove padding
            dims = sample['orig_dims'][0]
            mask = (output_sdf[0][:, 0] <
                    dims[0]) & (output_sdf[0][:, 1] <
                                dims[1]) & (output_sdf[0][:, 2] < dims[2])
            output_sdf[0] = output_sdf[0][mask]
            output_sdf[1] = output_sdf[1][mask]
            for h in range(len(output_occs)):
                dims = target_for_occs[h].shape[2:]
                mask = (output_occs[h][0][:, 0] <
                        dims[0]) & (output_occs[h][0][:, 1] < dims[1]) & (
                            output_occs[h][0][:, 2] < dims[2])
                output_occs[h][0] = output_occs[h][0][mask]
                output_occs[h][1] = output_occs[h][1][mask]

            # save prediction files
            data_util.save_predictions_to_file(
                output_sdf[0][:, :3].cpu().numpy(),
                output_sdf[1].cpu().numpy(),
                os.path.join(output_pred, sample['name'][0] + '.pred'))

            try:
                pred_occs = [None] * args.num_hierarchy_levels
                for h in range(args.num_hierarchy_levels):
                    pred_occs[h] = [None]
                    if len(output_occs[h][0]) == 0:
                        continue
                    locs = output_occs[h][0][:, :-1]
                    vals = torch.nn.Sigmoid()(
                        output_occs[h][1][:, 0].detach()) > 0.5
                    pred_occs[h][0] = locs[vals.view(-1)]
            except:
                print('exception at %s' % sample['name'])
                gc.collect()
                missing.extend(sample['name'])
                continue

            num_proc += 1
            if num_vis < num_to_vis:
                num = min(num_to_vis - num_vis, 1)
                vis_pred_sdf = [None] * num
                if len(output_sdf[0]) > 0:
                    for b in range(num):
                        mask = output_sdf[0][:, -1] == b
                        if len(mask) > 0:
                            vis_pred_sdf[b] = [
                                output_sdf[0][mask].cpu().numpy(),
                                output_sdf[1][mask].squeeze().cpu().numpy()
                            ]
                inputs = [inputs[0].numpy(), inputs[1].cpu().numpy()]
                data_util.save_predictions(
                    output_vis, np.arange(num), inputs,
                    target_for_sdf.cpu().numpy(),
                    [x.cpu().numpy()
                     for x in target_for_occs], vis_pred_sdf, pred_occs,
                    sample['world2grid'], args.vis_dfs, args.truncation)
                num_vis += 1
            gc.collect()

    sys.stdout.write('\n')
    print('missing', missing)
Esempio n. 14
0
def train(epoch, iter, dataloader, output_save):
    global STEP_COUNTER
    start_time = time.time()

    train_losses = [[] for i in range(args.num_hierarchy_levels + 2)]
    train_bceloss = []
    train_sdfloss = []
    train_precision = []
    train_recall = []
    train_f1score = []

    model.train()
    start = time.time()

    num_batches = len(dataloader)
    for t, sample in enumerate(dataloader):
        loss_weights = get_loss_weights(iter, args.num_hierarchy_levels,
                                        args.num_iters_per_level,
                                        args.weight_sdf_loss)
        if epoch == args.start_epoch and t == 0:
            print('-------------- epoch %d -------------]' % (epoch))

        sdfs = sample['sdf']
        if sdfs.shape[0] < args.batch_size:
            continue  # maintain same batch size for training
        inputs = sample['input']
        known = sample['known']
        hierarchy = sample['hierarchy']
        for h in range(len(hierarchy)):
            hierarchy[h] = hierarchy[h].cuda()
            hierarchy[h].unsqueeze_(1)
        if args.use_loss_masking:
            known = known.cuda()
        inputs = inputs.cuda()
        sdfs = sdfs.cuda()
        target_for_sdf, target_for_occs, target_for_hier = loss_util.compute_targets(
            sdfs,
            hierarchy,
            args.num_hierarchy_levels,
            args.truncation,
            args.use_loss_masking,
            known,
            flipped=args.flipped)

        optimizer.zero_grad()
        output_sdf, output_occs = model(inputs, loss_weights)
        # loss, losses = loss_util.compute_loss(output_sdf, output_occs, target_for_sdf, target_for_occs, target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs[0], args.use_loss_masking, known)
        bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_nosdf(
            output_sdf,
            output_occs,
            target_for_sdf,
            target_for_occs,
            target_for_hier,
            loss_weights,
            args.truncation,
            args.logweight_target_sdf,
            args.weight_missing_geo,
            inputs,
            args.use_loss_masking,
            known,
            flipped=args.flipped)
        # bce_loss, sdf_loss, losses, last_loss = loss_util.compute_loss_nosurf(output_sdf, output_occs, target_for_sdf, target_for_occs,
        #     target_for_hier, loss_weights, args.truncation, args.logweight_target_sdf, args.weight_missing_geo, inputs,
        #      args.use_loss_masking, known, flipped=args.flipped)
        loss = bce_loss * 2  # + sdf_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

        train_losses[0].append(loss.item())
        if last_loss > 0:
            train_bceloss.append(last_loss)
        train_sdfloss.append(bce_loss.item())

        output_visual = (output_save and t + 1 == num_batches
                         and loss_weights[-2] >= 1)
        compute_pred_occs = (iter % 20 == 0) or output_visual

        # if compute_pred_occs:
        pred_occs = [None] * args.num_hierarchy_levels
        vis_occs = [None] * args.num_hierarchy_levels

        for h in range(args.num_hierarchy_levels):
            train_losses[h + 1].append(losses[h])

        # if len(output_occs[-1]) is not 0:
        if loss_weights[-2] >= 1 is not None:
            output_occs[-1].detach()
            # occ
            pred_occ = output_occs[-1][:, 0].squeeze()
            ## sdf
            # pred_occ = pred_occ.abs() < (args.truncation / 2)
            # occ
            pred_occ = pred_occ > 0

            target_occ = target_for_occs[-1].squeeze()
            input_occ = inputs.detach().squeeze()
            precision, recall, f1score = loss_util.compute_dense_occ_accuracy(
                input_occ, target_occ, pred_occ, truncation=args.truncation)
        else:
            precision, recall, f1score = 0, 0, 0

        train_precision.append(precision)
        train_recall.append(recall)
        train_f1score.append(f1score)
        train_losses[args.num_hierarchy_levels + 1].append(losses[-1])

        STEP_COUNTER += 1
        iter += 1
        WRITER.add_scalar("loss/train", loss.item(), STEP_COUNTER)

    if args.scheduler_step_size == 0:
        scheduler.step()
    else:
        args.scheduler_step_size -= 1

    WRITER.add_scalar("epoch_loss/train/total", np.mean(train_losses[0]),
                      epoch)
    if len(train_bceloss) > 0:
        WRITER.add_scalar("epoch_loss/train/bce", np.mean(train_bceloss),
                          epoch)
    WRITER.add_scalar("epoch_loss/train/precision", np.mean(train_precision),
                      epoch)
    WRITER.add_scalar("epoch_loss/train/recall", np.mean(train_recall), epoch)
    WRITER.add_scalar("epoch_loss/train/f1", np.mean(train_f1score), epoch)

    print(" test epoch {} used {} seconds".format(epoch,
                                                  time.time() - start_time))

    return loss_weights