Example #1
0
    def test_gpu(self):
        print(f"{self.__class__.__name__}: test_gpu")
        if not torch.cuda.is_available():
            return

        device = torch.device('cuda')
        in_channels, out_channels, D = 2, 3, 2
        coords, feats, labels = data_loader(in_channels)
        feats = feats.double()
        feats.requires_grad_()
        input = SparseTensor(feats, coords=coords).to(device)
        # Initialize context
        conv = MinkowskiConvolution(in_channels,
                                    out_channels,
                                    kernel_size=3,
                                    stride=2,
                                    has_bias=True,
                                    dimension=D).to(device)
        conv = conv.double()
        conv_tr = MinkowskiConvolutionTranspose(out_channels,
                                                in_channels,
                                                kernel_size=3,
                                                stride=2,
                                                has_bias=True,
                                                dimension=D).to(device)
        conv_tr = conv_tr.double()
        input = conv(input)
        output = conv_tr(input)
        print(output)

        # Check backward
        fn = MinkowskiConvolutionTransposeFunction()

        self.assertTrue(
            gradcheck(fn,
                      (input.F, conv_tr.kernel, input.tensor_stride,
                       conv_tr.stride, conv_tr.kernel_size, conv_tr.dilation,
                       conv_tr.region_type_, conv_tr.region_offset_, False,
                       input.coords_key, None, input.coords_man)))
Example #2
0
  def forward(self, x):
    xf = x.F
    if self.requires_mapping:
      # Map the network output to CRF input
      xf = SparseMM()(Variable(self.in_mapping), xf)

    out = xf
    for i in range(self.meanfield_iterations):  # Meanfield iteration
      # Normalization
      out = self.softmaxes[i](out)
      # Pairwise potential
      out = self.convs[i].apply(out, self.conv.kernel, x.pixel_dist, self.conv.stride,
                                self.conv.kernel_size, self.conv.dilation, self.region_type_,
                                self.region_offset_, x.coords_key, x.coords_key, x.C)
      # Add unary
      out += xf

    if self.requires_mapping:
      # Map the CRF output to the origianl space
      out = SparseMM()(Variable(self.out_mapping), out)

    return SparseTensor(out, coords_key=x.coords_key, coords_manager=x.C)
    def test_decomposition_gpu(self):
        print(f"{self.__class__.__name__}: test_decomposition_gpu")
        if not torch.cuda.is_available():
            return

        coords, colors, pcd = load_file("1.ply")
        colors = torch.from_numpy(colors)

        for batch_size in [5, 10, 20, 40]:
            for voxel_size in [0.02]:
                dcoords = torch.from_numpy(np.floor(coords / voxel_size)).int()
                bcoords = batched_coordinates([dcoords for i in range(batch_size)])
                feats = torch.cat([colors for b in range(batch_size)], 0)
                sinput = SparseTensor(feats.to(0), bcoords.to(0))
                (
                    decomposed_coords,
                    decomposed_feats,
                ) = sinput.decomposed_coordinates_and_features
                print([len(c) for c in decomposed_coords])
                print([len(f) for f in decomposed_feats])
                self.assertEqual(len(decomposed_coords), batch_size)
                self.assertEqual(len(decomposed_feats), batch_size)
Example #4
0
    def test_kernelmap_gpu(self):
        print(f"{self.__class__.__name__}: test_kernelmap_gpu")
        if not torch.cuda.is_available():
            return

        in_channels, out_channels, D = 2, 3, 2
        coords, feats, labels = data_loader(in_channels)
        feats = feats.double()
        feats.requires_grad_()
        input = SparseTensor(feats, coords=coords)
        cm = input.coords_man
        ikey = cm._get_coords_key(1)
        print('Input coords: ')
        cm.print_diagnostics(ikey)

        print('Convolution: ')

        # Initialize context
        conv = MinkowskiConvolution(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=2,
            has_bias=True,
            dimension=D).double()
        output = conv(input)

        iC = input.C.numpy()
        oC = output.C.numpy()
        print(iC)
        print(oC)
        in_maps, out_maps = output.coords_man.get_kernel_map(
            1, 2, stride=2, kernel_size=3, on_gpu=True)
        kernel_index = 0
        for in_map, out_map in zip(in_maps, out_maps):
            for i, o in zip(in_map, out_map):
                print(kernel_index, iC[i], '->', oC[o])
            kernel_index += 1
        self.assertTrue(sum(len(in_map) for in_map in in_maps) == 26)
Example #5
0
    def test(self):
        print(f"{self.__class__.__name__}: test_dense")
        in_channels, out_channels, D = 2, 3, 2
        coords, feats, labels = data_loader(in_channels)
        feats = feats.double()
        feats.requires_grad_()
        input = SparseTensor(feats, coords=coords)
        # Initialize context
        conv = MinkowskiConvolution(in_channels,
                                    out_channels,
                                    kernel_size=3,
                                    stride=2,
                                    has_bias=True,
                                    dimension=D)
        conv = conv.double()
        output = conv(input)
        print(input.C, output.C)

        # Convert to a dense tensor
        dense_output, min_coord, tensor_stride = output.dense()
        print(dense_output.shape)
        print(dense_output)
        print(min_coord)
        print(tensor_stride)

        dense_output, min_coord, tensor_stride = output.dense(
            min_coords=torch.IntTensor([-2, -2]),
            max_coords=torch.IntTensor([4, 4]))

        print(dense_output)
        print(min_coord)
        print(tensor_stride)

        print(feats.grad)

        loss = dense_output.sum()
        loss.backward()

        print(feats.grad)
    def test_extraction(self):
        coords = torch.IntTensor([[0, 0], [0, 1], [0, 2], [2, 0], [2, 2]])
        feats = torch.FloatTensor([[1.1, 2.1, 3.1, 4.1, 5.1]]).t()
        X = SparseTensor(feats, coords)
        C0 = X.coordinates_at(0)
        F0 = X.features_at(0)
        self.assertTrue(0 in C0)
        self.assertTrue(1 in C0)
        self.assertTrue(2 in C0)

        self.assertTrue(1.1 in F0)
        self.assertTrue(2.1 in F0)
        self.assertTrue(3.1 in F0)

        CC0, FC0 = X.coordinates_and_features_at(0)
        self.assertTrue((C0 == CC0).all())
        self.assertTrue((F0 == FC0).all())

        coords, feats = X.decomposed_coordinates_and_features
        for c, f in zip(coords, feats):
            self.assertEqual(c.numel(), f.numel())
            print(c, f)
        self.assertEqual(len(coords[0]), 3)
        self.assertEqual(len(coords[1]), 0)
        self.assertEqual(len(coords[2]), 2)

        if not is_cuda_available():
            return

        coords = torch.IntTensor([[0, 0], [0, 1], [0, 2], [2, 0], [2, 2]])
        feats = torch.FloatTensor([[1.1, 2.1, 3.1, 4.1, 5.1]]).t()

        X = SparseTensor(feats, coords, device=0)
        coords, feats = X.decomposed_coordinates_and_features
        for c, f in zip(coords, feats):
            self.assertEqual(c.numel(), f.numel())
            print(c, f)

        self.assertEqual(len(coords[0]), 3)
        self.assertEqual(len(coords[1]), 0)
        self.assertEqual(len(coords[2]), 2)
Example #7
0
    def test_with_convtr(self):
        channels, D = [2, 3, 4], 2
        coords, feats, labels = data_loader(channels[0], batch_size=1)
        feats = feats.double()
        feats.requires_grad_()
        # Create a sparse tensor with large tensor strides for upsampling
        start_tensor_stride = 4
        input = SparseTensor(feats,
                             coords=coords * start_tensor_stride,
                             tensor_stride=start_tensor_stride)
        conv_tr1 = MinkowskiConvolutionTranspose(channels[0],
                                                 channels[1],
                                                 kernel_size=3,
                                                 stride=2,
                                                 generate_new_coords=True,
                                                 dimension=D).double()
        conv_tr2 = MinkowskiConvolutionTranspose(channels[1],
                                                 channels[2],
                                                 kernel_size=3,
                                                 stride=2,
                                                 generate_new_coords=True,
                                                 dimension=D).double()
        pruning = MinkowskiPruning(D)

        out1 = conv_tr1(input)
        use_feat = torch.rand(len(out1)) < 0.5
        out1 = pruning(out1, use_feat)

        out2 = conv_tr2(out1)
        use_feat = torch.rand(len(out2)) < 0.5
        out2 = pruning(out2, use_feat)

        print(out2)

        out2.F.sum().backward()

        # Check gradient flow
        print(input.F.grad)
Example #8
0
    def test(self):
        in_channels, D = 2, 2
        coords, feats, labels = data_loader(in_channels)
        feats = feats.double()
        feats.requires_grad_()
        input = SparseTensor(feats, coords)
        pool = MinkowskiGlobalAvgPooling()
        output = pool(input)
        print(output)

        # Check backward
        fn = MinkowskiGlobalPoolingFunction()
        self.assertTrue(
            gradcheck(
                fn,
                (
                    input.F,
                    pool.pooling_mode,
                    input.coordinate_map_key,
                    output.coordinate_map_key,
                    input._manager,
                ),
            ))
Example #9
0
    def test_broadcast(self):
        in_channels, D = 2, 2
        coords, feats, labels = data_loader(in_channels)
        coords, feats_glob, labels = data_loader(in_channels)
        feats = feats.double()
        feats_glob = feats_glob.double()
        input = SparseTensor(feats, coords=coords)
        pool = MinkowskiGlobalPooling()
        input_glob = pool(input)
        input_glob.F.requires_grad_()
        broadcast = MinkowskiBroadcast()
        broadcast_cat = MinkowskiBroadcastConcatenation()
        broadcast_add = MinkowskiBroadcastAddition()
        broadcast_mul = MinkowskiBroadcastMultiplication()
        output = broadcast(input, input_glob)
        print(output)
        output = broadcast_cat(input, input_glob)
        print(output)
        output = broadcast_add(input, input_glob)
        print(output)
        output = broadcast_mul(input, input_glob)
        print(output)

        # Check backward
        fn = MinkowskiBroadcastFunction()

        self.assertTrue(
            gradcheck(
                fn,
                (input.F, input_glob.F, OperationType.ADDITION,
                 input.coords_key, input_glob.coords_key, input.coords_man)))

        self.assertTrue(
            gradcheck(
                fn,
                (input.F, input_glob.F, OperationType.MULTIPLICATION,
                 input.coords_key, input_glob.coords_key, input.coords_man)))
Example #10
0
    def test_analytic(self):
        print(f"{self.__class__.__name__}: test")
        in_channels, out_channels, D = 2, 2, 2
        coords = torch.IntTensor([[0, 0, 0], [0, 1, 1], [0, 2, 1]])
        feats = torch.FloatTensor([[0, 1], [1, 0], [1, 1]])
        input = SparseTensor(feats, coordinates=coords)
        # Initialize context
        conv = MinkowskiConvolution(
            in_channels, out_channels, kernel_size=2, stride=2, bias=False, dimension=D
        )
        conv.kernel[:] = torch.FloatTensor(
            [[[1, 2], [2, 1]], [[0, 1], [1, 0]], [[0, 1], [1, 1]], [[1, 1], [1, 0]]]
        )
        output = conv(input)
        print(output)

        conv_tr = MinkowskiConvolutionTranspose(
            in_channels, out_channels, kernel_size=2, stride=2, bias=False, dimension=D
        )
        conv_tr.kernel[:] = torch.FloatTensor(
            [[[1, 2], [2, 1]], [[0, 1], [1, 0]], [[0, 1], [1, 1]], [[1, 1], [1, 0]]]
        )
        output_tr = conv_tr(output)
        print(output_tr)
Example #11
0
    def test_zero(self):
        # Issue #383 https://github.com/NVIDIA/MinkowskiEngine/issues/383
        #
        # create point and features, all with batch 0
        pc = torch.randint(-10, 10, size=(32, 4), dtype=torch.float32, device='cuda')
        pc[:, 0] = 0
        feat = torch.randn(32, 3, dtype=torch.float32, device='cuda', requires_grad=True)
    
        # feature to interpolate
        x = SparseTensor(feat, pc, device='cuda')
        interp = MinkowskiInterpolation()
 
        # samples with original coordinates, OK for now
        samples = pc
        y = interp(x, samples)
        print(y.shape, y.stride())
        torch.sum(y).backward()

        # samples with all zeros, shape is inconsistent and backward gives error
        samples = torch.zeros_like(pc)
        samples[:, 0] = 0
        y = interp(x, samples)
        print(y.shape, y.stride())
        torch.sum(y).backward()
Example #12
0
    def test(self):
        print(f"{self.__class__.__name__}: test")
        in_channels, D = 3, 2
        coords, feats, labels = data_loader(in_channels, batch_size=2)

        # Create random coordinates with tensor stride == 2
        out_coords, tensor_stride = get_random_coords()

        feats = feats.double()
        feats.requires_grad_()
        input = SparseTensor(feats, coords=coords)

        conv = MinkowskiChannelwiseConvolution(in_channels,
                                               kernel_size=3,
                                               stride=1,
                                               has_bias=False,
                                               dimension=D).double()

        print('Initial input: ', input)
        output = conv(input)
        print('Conv output: ', output)

        output.F.sum().backward()
        print(input.F.grad)
Example #13
0
    def test(self):
        in_channels, D = 2, 2
        coords, feats, labels = data_loader(in_channels)
        feats = feats.double()
        feats.requires_grad_()
        input = SparseTensor(feats, coordinates=coords)
        pool = MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D)
        output = pool(input)
        print(output)

        # Check backward
        fn = MinkowskiLocalPoolingFunction()
        self.assertTrue(
            gradcheck(
                fn,
                (
                    input.F,
                    pool.pooling_mode,
                    pool.kernel_generator,
                    input.coordinate_map_key,
                    output.coordinate_map_key,
                    input._manager,
                ),
            ))
Example #14
0
    def test_gpu(self):
        print(f"{self.__class__.__name__}: test_gpu")
        if not torch.cuda.is_available():
            return

        device = torch.device('cuda')
        in_channels, D = 3, 2
        coords, feats, labels = data_loader(in_channels, batch_size=2)

        # Create random coordinates with tensor stride == 2
        out_coords, tensor_stride = get_random_coords()

        feats = feats.double()
        feats.requires_grad_()
        input = SparseTensor(feats, coords=coords).to(device)
        conv = MinkowskiChannelwiseConvolution(in_channels,
                                               kernel_size=3,
                                               stride=1,
                                               has_bias=False,
                                               dimension=D).double().to(device)

        print('Initial input: ', input)
        output = conv(input)
        print('Conv output: ', output)
 def test(self):
     print(f"{self.__class__.__name__}: test SparseTensor")
     coords, feats, labels = data_loader(nchannel=2)
     input = SparseTensor(feats, coordinates=coords)
     print(input)
    def test_operation_mode(self):
        # Set to use the global sparse tensor coords manager by default
        set_sparse_tensor_operation_mode(
            SparseTensorOperationMode.SHARE_COORDINATE_MANAGER)

        coords, feats, labels = data_loader(nchannel=2)

        # Create a sparse tensor on two different coordinates.
        A = SparseTensor(torch.rand(feats.shape), coordinates=coords)
        B = SparseTensor(
            torch.rand(4, 2),
            coordinates=torch.IntTensor([[0, 0, 0], [1, 1, 1], [0, 1, 0],
                                         [1, 0, 1]]),
        )

        self.assertTrue(A.coordinate_manager == B.coordinate_manager)

        A.requires_grad_(True)
        B.requires_grad_(True)

        C = A + B

        C.F.sum().backward()

        self.assertTrue(torch.all(A.F.grad == 1).item())
        self.assertTrue(torch.all(B.F.grad == 1).item())

        C = A - B
        C = A * B
        C = A / B

        # Inplace
        A.requires_grad_(False)
        D = SparseTensor(
            torch.rand(feats.shape),
            coordinate_map_key=A.coordinate_map_key,
            coordinate_manager=A.coordinate_manager,
        )
        A -= D
        A *= D
        A /= D
def train(model, data_loader, val_data_loader, config, transform_data_fn=None):
    device = get_torch_device(config.is_cuda)
    # Set up the train flag for batch normalization
    model.train()

    # Configuration
    writer = SummaryWriter(log_dir=config.log_dir)
    data_timer, iter_timer = Timer(), Timer()
    data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
    losses, scores = AverageMeter(), AverageMeter()

    optimizer = initialize_optimizer(model.parameters(), config)
    scheduler = initialize_scheduler(optimizer, config)
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)

    writer = SummaryWriter(log_dir=config.log_dir)

    # Train the network
    logging.info('===> Start training')
    best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True

    if config.resume:
        checkpoint_fn = config.resume + '/weights.pth'
        if osp.isfile(checkpoint_fn):
            logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
            state = torch.load(checkpoint_fn)
            curr_iter = state['iteration'] + 1
            epoch = state['epoch']
            model.load_state_dict(state['state_dict'])
            if config.resume_optimizer:
                scheduler = initialize_scheduler(optimizer,
                                                 config,
                                                 last_step=curr_iter)
                optimizer.load_state_dict(state['optimizer'])
            if 'best_val' in state:
                best_val_miou = state['best_val']
                best_val_iter = state['best_val_iter']
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_fn, state['epoch']))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(checkpoint_fn))

    data_iter = data_loader.__iter__()
    while is_training:
        for iteration in range(len(data_loader) // config.iter_size):
            optimizer.zero_grad()
            data_time, batch_loss = 0, 0
            iter_timer.tic()

            for sub_iter in range(config.iter_size):
                # Get training data
                data_timer.tic()
                if config.return_transformation:
                    coords, input, target, pointcloud, transformation = data_iter.next(
                    )
                else:
                    coords, input, target = data_iter.next()

                # For some networks, making the network invariant to even, odd coords is important
                coords[:, 1:] += (torch.rand(3) * 100).type_as(coords)

                # Preprocess input
                if config.normalize_color:
                    input[:, :3] = input[:, :3] / 255. - 0.5
                sinput = SparseTensor(input, coords).to(device)

                data_time += data_timer.toc(False)

                # model.initialize_coords(*init_args)
                soutput = model(sinput)
                # The output of the network is not sorted
                target = target.long().to(device)

                loss = criterion(soutput.F, target.long())

                # Compute and accumulate gradient
                loss /= config.iter_size
                batch_loss += loss.item()
                loss.backward()

            # Update number of steps
            optimizer.step()
            scheduler.step()

            data_time_avg.update(data_time)
            iter_time_avg.update(iter_timer.toc(False))

            pred = get_prediction(data_loader.dataset, soutput.F, target)
            score = precision_at_one(pred, target)
            losses.update(batch_loss, target.size(0))
            scores.update(score, target.size(0))

            if curr_iter >= config.max_iter:
                is_training = False
                break

            if curr_iter % config.stat_freq == 0 or curr_iter == 1:
                lrs = ', '.join(
                    ['{:.3e}'.format(x) for x in scheduler.get_lr()])
                debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
                    epoch, curr_iter,
                    len(data_loader) // config.iter_size, losses.avg, lrs)
                debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format(
                    scores.avg, data_time_avg.avg, iter_time_avg.avg)
                logging.info(debug_str)
                # Reset timers
                data_time_avg.reset()
                iter_time_avg.reset()
                # Write logs
                writer.add_scalar('training/loss', losses.avg, curr_iter)
                writer.add_scalar('training/precision_at_1', scores.avg,
                                  curr_iter)
                writer.add_scalar('training/learning_rate',
                                  scheduler.get_lr()[0], curr_iter)
                losses.reset()
                scores.reset()

            # Save current status, save before val to prevent occational mem overflow
            if curr_iter % config.save_freq == 0:
                checkpoint(model, optimizer, epoch, curr_iter, config,
                           best_val_miou, best_val_iter)

            # Validation
            if curr_iter % config.val_freq == 0:
                val_miou = validate(model, val_data_loader, writer, curr_iter,
                                    config, transform_data_fn)
                if val_miou > best_val_miou:
                    best_val_miou = val_miou
                    best_val_iter = curr_iter
                    checkpoint(model, optimizer, epoch, curr_iter, config,
                               best_val_miou, best_val_iter, "best_val")
                logging.info("Current best mIoU: {:.3f} at iter {}".format(
                    best_val_miou, best_val_iter))

                # Recover back
                model.train()

            # End of iteration
            curr_iter += 1

        epoch += 1

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    # Save the final model
    checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
               best_val_iter)
    val_miou = validate(model, val_data_loader, writer, curr_iter, config,
                        transform_data_fn)
    if val_miou > best_val_miou:
        best_val_miou = val_miou
        best_val_iter = curr_iter
        checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
                   best_val_iter, "best_val")
    logging.info("Current best mIoU: {:.3f} at iter {}".format(
        best_val_miou, best_val_iter))
Example #18
0
def train(model, data_loader, val_data_loader, config, transform_data_fn=None):

    device = config.device_id
    distributed = get_world_size() > 1

    # Set up the train flag for batch normalization
    model.train()

    # Configuration
    writer = SummaryWriter(log_dir=config.log_dir)
    data_timer, iter_timer = Timer(), Timer()
    fw_timer, bw_timer, ddp_timer = Timer(), Timer(), Timer()

    data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
    fw_time_avg, bw_time_avg, ddp_time_avg = AverageMeter(), AverageMeter(
    ), AverageMeter()

    losses, scores = AverageMeter(), AverageMeter()

    optimizer = initialize_optimizer(model.parameters(), config)
    scheduler = initialize_scheduler(optimizer, config)
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)

    writer = SummaryWriter(log_dir=config.log_dir)

    # Train the network
    logging.info('===> Start training on {} GPUs, batch-size={}'.format(
        get_world_size(), config.batch_size * get_world_size()))
    best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True

    if config.resume:
        checkpoint_fn = config.resume + '/weights.pth'
        if osp.isfile(checkpoint_fn):
            logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
            state = torch.load(
                checkpoint_fn,
                map_location=lambda s, l: default_restore_location(s, 'cpu'))
            curr_iter = state['iteration'] + 1
            epoch = state['epoch']
            load_state(model, state['state_dict'])

            if config.resume_optimizer:
                scheduler = initialize_scheduler(optimizer,
                                                 config,
                                                 last_step=curr_iter)
                optimizer.load_state_dict(state['optimizer'])
            if 'best_val' in state:
                best_val_miou = state['best_val']
                best_val_iter = state['best_val_iter']
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_fn, state['epoch']))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(checkpoint_fn))

    data_iter = data_loader.__iter__()  # (distributed) infinite sampler
    while is_training:
        for iteration in range(len(data_loader) // config.iter_size):
            optimizer.zero_grad()
            data_time, batch_loss, batch_score = 0, 0, 0
            iter_timer.tic()

            # set random seed for every iteration for trackability
            _set_seed(config, curr_iter)

            for sub_iter in range(config.iter_size):
                # Get training data
                data_timer.tic()
                coords, input, target = data_iter.next()

                # For some networks, making the network invariant to even, odd coords is important
                coords[:, :3] += (torch.rand(3) * 100).type_as(coords)

                # Preprocess input
                color = input[:, :3].int()
                if config.normalize_color:
                    input[:, :3] = input[:, :3] / 255. - 0.5
                sinput = SparseTensor(input, coords).to(device)

                data_time += data_timer.toc(False)

                # Feed forward
                fw_timer.tic()

                inputs = (sinput, ) if config.wrapper_type == 'None' else (
                    sinput, coords, color)
                # model.initialize_coords(*init_args)
                soutput = model(*inputs)
                # The output of the network is not sorted
                target = target.long().to(device)

                loss = criterion(soutput.F, target.long())

                # Compute and accumulate gradient
                loss /= config.iter_size

                pred = get_prediction(data_loader.dataset, soutput.F, target)
                score = precision_at_one(pred, target)

                fw_timer.toc(False)
                bw_timer.tic()

                # bp the loss
                loss.backward()

                bw_timer.toc(False)

                # gather information
                logging_output = {
                    'loss': loss.item(),
                    'score': score / config.iter_size
                }

                ddp_timer.tic()
                if distributed:
                    logging_output = all_gather_list(logging_output)
                    logging_output = {
                        w: np.mean([a[w] for a in logging_output])
                        for w in logging_output[0]
                    }

                batch_loss += logging_output['loss']
                batch_score += logging_output['score']
                ddp_timer.toc(False)

            # Update number of steps
            optimizer.step()
            scheduler.step()

            data_time_avg.update(data_time)
            iter_time_avg.update(iter_timer.toc(False))
            fw_time_avg.update(fw_timer.diff)
            bw_time_avg.update(bw_timer.diff)
            ddp_time_avg.update(ddp_timer.diff)

            losses.update(batch_loss, target.size(0))
            scores.update(batch_score, target.size(0))

            if curr_iter >= config.max_iter:
                is_training = False
                break

            if curr_iter % config.stat_freq == 0 or curr_iter == 1:
                lrs = ', '.join(
                    ['{:.3e}'.format(x) for x in scheduler.get_lr()])
                debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
                    epoch, curr_iter,
                    len(data_loader) // config.iter_size, losses.avg, lrs)
                debug_str += "Score {:.3f}\tData time: {:.4f}, Forward time: {:.4f}, Backward time: {:.4f}, DDP time: {:.4f}, Total iter time: {:.4f}".format(
                    scores.avg, data_time_avg.avg, fw_time_avg.avg,
                    bw_time_avg.avg, ddp_time_avg.avg, iter_time_avg.avg)
                logging.info(debug_str)
                # Reset timers
                data_time_avg.reset()
                iter_time_avg.reset()
                # Write logs
                writer.add_scalar('training/loss', losses.avg, curr_iter)
                writer.add_scalar('training/precision_at_1', scores.avg,
                                  curr_iter)
                writer.add_scalar('training/learning_rate',
                                  scheduler.get_lr()[0], curr_iter)
                losses.reset()
                scores.reset()

            # Save current status, save before val to prevent occational mem overflow
            if curr_iter % config.save_freq == 0:
                checkpoint(model, optimizer, epoch, curr_iter, config,
                           best_val_miou, best_val_iter)

            # Validation
            if curr_iter % config.val_freq == 0:
                val_miou = validate(model, val_data_loader, writer, curr_iter,
                                    config, transform_data_fn)
                if val_miou > best_val_miou:
                    best_val_miou = val_miou
                    best_val_iter = curr_iter
                    checkpoint(model, optimizer, epoch, curr_iter, config,
                               best_val_miou, best_val_iter, "best_val")
                logging.info("Current best mIoU: {:.3f} at iter {}".format(
                    best_val_miou, best_val_iter))

                # Recover back
                model.train()

            if curr_iter % config.empty_cache_freq == 0:
                # Clear cache
                torch.cuda.empty_cache()

            # End of iteration
            curr_iter += 1

        epoch += 1

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    # Save the final model
    checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
               best_val_iter)
    val_miou = validate(model, val_data_loader, writer, curr_iter, config,
                        transform_data_fn)
    if val_miou > best_val_miou:
        best_val_miou = val_miou
        best_val_iter = curr_iter
        checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
                   best_val_iter, "best_val")
    logging.info("Current best mIoU: {:.3f} at iter {}".format(
        best_val_miou, best_val_iter))
Example #19
0
def train_worker(gpu,
                 num_devices,
                 NetClass,
                 data_loader,
                 val_data_loader,
                 config,
                 transform_data_fn=None):
    if gpu is not None:
        print("Use GPU: {} for training".format(gpu))
        rank = gpu
    addr = 23491
    dist.init_process_group(backend="nccl",
                            init_method="tcp://127.0.0.1:{}".format(addr),
                            world_size=num_devices,
                            rank=rank)

    # replace with DistributedSampler
    if config.multiprocess:
        from lib.dataloader_dist import InfSampler
        sampler = InfSampler(data_loader.dataset)
        data_loader = DataLoader(dataset=data_loader.dataset,
                                 num_workers=data_loader.num_workers,
                                 batch_size=data_loader.batch_size,
                                 collate_fn=data_loader.collate_fn,
                                 worker_init_fn=data_loader.worker_init_fn,
                                 sampler=sampler)

    if data_loader.dataset.NUM_IN_CHANNEL is not None:
        num_in_channel = data_loader.dataset.NUM_IN_CHANNEL
    else:
        num_in_channel = 3
    num_labels = data_loader.dataset.NUM_LABELS

    # load model
    if config.pure_point:
        model = NetClass(num_class=config.num_labels,
                         N=config.num_points,
                         normal_channel=config.num_in_channel)
    else:
        if config.model == 'MixedTransformer':
            model = NetClass(config,
                             num_class=num_labels,
                             N=config.num_points,
                             normal_channel=num_in_channel)
        elif config.model == 'MinkowskiVoxelTransformer':
            model = NetClass(config, num_in_channel, num_labels)
        elif config.model == 'MinkowskiTransformerNet':
            model = NetClass(config, num_in_channel, num_labels)
        elif "Res" in config.model:
            model = NetClass(num_in_channel, num_labels, config)
        else:
            model = NetClass(num_in_channel, num_labels, config)

    if config.weights == 'modelzoo':
        model.preload_modelzoo()
    elif config.weights.lower() != 'none':
        state = torch.load(config.weights)
        # delete the keys containing the attn since it raises size mismatch
        d = {k: v for k, v in state['state' '_dict'].items() if 'map' not in k}
        if config.weights_for_inner_model:
            model.model.load_state_dict(d)
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                model.load_state_dict(d, strict=False)

    torch.cuda.set_device(gpu)
    model.cuda(gpu)
    # use model with DDP
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[gpu], find_unused_parameters=False)
    # Synchronized batch norm
    model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model)

    # Set up the train flag for batch normalization
    model.train()

    # Configuration
    data_timer, iter_timer = Timer(), Timer()
    data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
    regs, losses, scores = AverageMeter(), AverageMeter(), AverageMeter()

    optimizer = initialize_optimizer(model.parameters(), config)
    scheduler = initialize_scheduler(optimizer, config)
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)

    # Train the network
    if rank == 0:
        setup_logger(config)
        logging.info('===> Start training')

    best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True

    if config.resume:
        # Test loaded ckpt first
        v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config)

        checkpoint_fn = config.resume + '/weights.pth'
        if osp.isfile(checkpoint_fn):
            logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
            state = torch.load(checkpoint_fn)
            curr_iter = state['iteration'] + 1
            epoch = state['epoch']
            # we skip attention maps because the shape won't match because voxel number is different
            # e.g. copyting a param with shape (23385, 8, 4) to (43529, 8, 4)
            d = {
                k: v
                for k, v in state['state_dict'].items() if 'map' not in k
            }
            # handle those attn maps we don't load from saved dict
            for k in model.state_dict().keys():
                if k in d.keys(): continue
                d[k] = model.state_dict()[k]
            model.load_state_dict(d)
            if config.resume_optimizer:
                scheduler = initialize_scheduler(optimizer,
                                                 config,
                                                 last_step=curr_iter)
                optimizer.load_state_dict(state['optimizer'])
            if 'best_val' in state:
                best_val_miou = state['best_val']
                best_val_iter = state['best_val_iter']
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_fn, state['epoch']))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(checkpoint_fn))

    data_iter = data_loader.__iter__()
    device = gpu  # multitrain fed in the device
    if config.dataset == "SemanticKITTI":
        num_class = 19
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
        config.val_freq = config.val_freq * 10  # origianl val_freq_
    elif config.dataset == 'S3DIS':
        num_class = 13
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
    elif config.dataset == "Nuscenes":
        num_class = 16
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
        config.val_freq = config.val_freq * 50
    else:
        val_freq_ = config.val_freq
        num_class = 20

    while is_training:

        total_correct_class = torch.zeros(num_class, device=device)
        total_iou_deno_class = torch.zeros(num_class, device=device)

        for iteration in range(len(data_loader) // config.iter_size):

            optimizer.zero_grad()
            data_time, batch_loss = 0, 0
            iter_timer.tic()

            if curr_iter >= config.max_iter:
                # if curr_iter >= max(config.max_iter, config.epochs*(len(data_loader) // config.iter_size):
                is_training = False
                break
            elif curr_iter >= config.max_iter * (2 / 3):
                config.val_freq = val_freq_ * 2  # valid more freq on lower half

            for sub_iter in range(config.iter_size):

                # Get training data
                data_timer.tic()
                if config.return_transformation:
                    coords, input, target, _, _, pointcloud, transformation = data_iter.next(
                    )
                else:
                    coords, input, target, _, _ = data_iter.next(
                    )  # ignore unique_map and inverse_map

                if config.use_aux:
                    assert target.shape[1] == 2
                    aux = target[:, 1]
                    target = target[:, 0]
                else:
                    aux = None

                # For some networks, making the network invariant to even, odd coords is important
                coords[:, 1:] += (torch.rand(3) * 100).type_as(coords)

                # Preprocess input
                if config.normalize_color:
                    input[:, :3] = input[:, :3] / input[:, :3].max() - 0.5
                    coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5

                # cat xyz into the rgb feature
                if config.xyz_input:
                    input = torch.cat([coords_norm, input], dim=1)
                # print(device)

                sinput = SparseTensor(input, coords, device=device)

                # d = {}
                # d['coord'] = sinput.C
                # d['feat'] = sinput.F
                # torch.save(d, 'voxel.pth')
                # import ipdb; ipdb.set_trace()

                data_time += data_timer.toc(False)
                # model.initialize_coords(*init_args)
                if aux is not None:
                    soutput = model(sinput, aux)
                elif config.enable_point_branch:
                    soutput = model(sinput,
                                    iter_=curr_iter / config.max_iter,
                                    enable_point_branch=True)
                else:
                    soutput = model(
                        sinput, iter_=curr_iter / config.max_iter
                    )  # feed in the progress of training for annealing inside the model
                    # soutput = model(sinput)
                # The output of the network is not sorted
                target = target.view(-1).long().to(device)

                loss = criterion(soutput.F, target.long())

                # ====== other loss regs =====
                cur_loss = torch.tensor([0.], device=device)
                if hasattr(model, 'module.block1'):
                    cur_loss = torch.tensor([0.], device=device)

                    if hasattr(model.module.block1[0], 'vq_loss'):
                        if model.block1[0].vq_loss is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].vq_loss  # m is the nn.Sequential obj, m[0] is the TRBlock
                            logging.info(
                                'Cur Loss: {}, Cur vq_loss: {}'.format(
                                    loss, cur_loss))
                            loss += cur_loss

                    if hasattr(model.module.block1[0], 'diverse_loss'):
                        if model.block1[0].diverse_loss is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].diverse_loss  # m is the nn.Sequential obj, m[0] is the TRBlock
                            logging.info(
                                'Cur Loss: {}, Cur diverse _loss: {}'.format(
                                    loss, cur_loss))
                            loss += cur_loss

                    if hasattr(model.module.block1[0], 'label_reg'):
                        if model.block1[0].label_reg is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].label_reg  # m is the nn.Sequential obj, m[0] is the TRBlock
                            # logging.info('Cur Loss: {}, Cur diverse _loss: {}'.format(loss, cur_loss))
                            loss += cur_loss

                # Compute and accumulate gradient
                loss /= config.iter_size
                batch_loss += loss.item()
                if not config.use_sam:
                    loss.backward()
                else:
                    with model.no_sync():
                        loss.backward()

            # Update number of steps
            if not config.use_sam:
                optimizer.step()
            else:
                optimizer.first_step(zero_grad=True)
                soutput = model(sinput,
                                iter_=curr_iter / config.max_iter,
                                aux=starget)
                criterion(soutput.F, target.long()).backward()
                optimizer.second_step(zero_grad=True)

            if config.lr_warmup is None:
                scheduler.step()
            else:
                if curr_iter >= config.lr_warmup:
                    scheduler.step()
                else:
                    for g in optimizer.param_groups:
                        g['lr'] = config.lr * (iteration +
                                               1) / config.lr_warmup

            # CLEAR CACHE!
            torch.cuda.empty_cache()

            data_time_avg.update(data_time)
            iter_time_avg.update(iter_timer.toc(False))

            pred = get_prediction(data_loader.dataset, soutput.F, target)
            score = precision_at_one(pred, target, ignore_label=-1)

            regs.update(cur_loss.item(), target.size(0))
            losses.update(batch_loss, target.size(0))
            scores.update(score, target.size(0))

            # calc the train-iou
            for l in range(num_class):
                total_correct_class[l] += ((pred == l) & (target == l)).sum()
                total_iou_deno_class[l] += (((pred == l) & (target != -1)) |
                                            (target == l)).sum()

            if curr_iter % config.stat_freq == 0 or curr_iter == 1:
                lrs = ', '.join(
                    ['{:.3e}'.format(g['lr']) for g in optimizer.param_groups])
                IoU = ((total_correct_class) /
                       (total_iou_deno_class + 1e-6)).mean() * 100.
                debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
                    epoch, curr_iter,
                    len(data_loader) // config.iter_size, losses.avg, lrs)
                debug_str += "Score {:.3f}\tIoU {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format(
                    scores.avg, IoU.item(), data_time_avg.avg,
                    iter_time_avg.avg)
                if regs.avg > 0:
                    debug_str += "\n Additional Reg Loss {:.3f}".format(
                        regs.avg)

                if rank == 0:
                    logging.info(debug_str)
                # Reset timers
                data_time_avg.reset()
                iter_time_avg.reset()
                # Write logs
                losses.reset()
                scores.reset()

            # only save status on the 1st gpu
            if rank == 0:

                # Save current status, save before val to prevent occational mem overflow
                if curr_iter % config.save_freq == 0:
                    checkpoint(model,
                               optimizer,
                               epoch,
                               curr_iter,
                               config,
                               best_val_miou,
                               best_val_iter,
                               save_inter=True)

                # Validation
                if curr_iter % config.val_freq == 0:
                    val_miou = validate(model, val_data_loader, None,
                                        curr_iter, config, transform_data_fn
                                        )  # feedin None for SummaryWriter args
                    if val_miou > best_val_miou:
                        best_val_miou = val_miou
                        best_val_iter = curr_iter
                        checkpoint(model,
                                   optimizer,
                                   epoch,
                                   curr_iter,
                                   config,
                                   best_val_miou,
                                   best_val_iter,
                                   "best_val",
                                   save_inter=True)
                    if rank == 0:
                        logging.info(
                            "Current best mIoU: {:.3f} at iter {}".format(
                                best_val_miou, best_val_iter))

                    # Recover back
                    model.train()

            # End of iteration
            curr_iter += 1

        IoU = (total_correct_class) / (total_iou_deno_class + 1e-6)
        if rank == 0:
            logging.info('train point avg class IoU: %f' %
                         ((IoU).mean() * 100.))

        epoch += 1

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    # Save the final model
    if rank == 0:
        checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
                   best_val_iter)
        v_loss, v_score, v_mAP, val_mIoU = test(model, val_data_loader, config)

        if val_miou > best_val_miou and rank == 0:
            best_val_miou = val_miou
            best_val_iter = curr_iter
            logging.info("Final best miou: {}  at iter {} ".format(
                val_miou, curr_iter))
            checkpoint(model, optimizer, epoch, curr_iter, config,
                       best_val_miou, best_val_iter, "best_val")

            logging.info("Current best mIoU: {:.3f} at iter {}".format(
                best_val_miou, best_val_iter))
Example #20
0
    def test_unpooling_gpu(self):
        if not torch.cuda.is_available():
            return

        in_channels, out_channels, D = 2, 3, 2
        coords, feats, labels = data_loader(in_channels)
        feats = feats.double()
        input = SparseTensor(feats, coords=coords)
        conv = MinkowskiConvolution(
            in_channels, out_channels, kernel_size=3, stride=2, dimension=D
        )
        conv = conv.double()
        unpool = MinkowskiPoolingTranspose(kernel_size=3, stride=2, dimension=D)
        input = conv(input)
        output = unpool(input)
        print(output)
        # Check backward
        fn = MinkowskiPoolingTransposeFunction()

        self.assertTrue(
            gradcheck(
                fn,
                (
                    input.F,
                    input.tensor_stride,
                    unpool.stride,
                    unpool.kernel_size,
                    unpool.dilation,
                    unpool.region_type_,
                    unpool.region_offset_,
                    False,
                    input.coords_key,
                    None,
                    input.coords_man,
                ),
            )
        )

        device = torch.device("cuda")
        with torch.cuda.device(0):
            input = input.to(device)
            output = unpool(input)
            print(output)

        # Check backward
        self.assertTrue(
            gradcheck(
                fn,
                (
                    input.F,
                    input.tensor_stride,
                    unpool.stride,
                    unpool.kernel_size,
                    unpool.dilation,
                    unpool.region_type_,
                    unpool.region_offset_,
                    True,
                    input.coords_key,
                    None,
                    input.coords_man,
                ),
            )
        )
Example #21
0
def test(model,
         data_loader,
         config,
         transform_data_fn=None,
         has_gt=True,
         validation=None,
         epoch=None):
    device = get_torch_device(config.is_cuda)
    dataset = data_loader.dataset
    num_labels = dataset.NUM_LABELS
    global_timer, data_timer, iter_timer = Timer(), Timer(), Timer()
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)
    alpha, gamma, eps = 1, 2, 1e-6  # Focal Loss parameters
    losses, scores, ious = AverageMeter(), AverageMeter(), 0
    aps = np.zeros((0, num_labels))
    hist = np.zeros((num_labels, num_labels))

    if not config.is_train:
        checkpoint_fn = config.resume + '/weights.pth'
        if osp.isfile(checkpoint_fn):
            logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
            state = torch.load(checkpoint_fn)
            model.load_state_dict(state['state_dict'])
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_fn, state['epoch']))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(checkpoint_fn))
    if validation:
        logging.info('===> Start validating')
    else:
        logging.info('===> Start testing')

    global_timer.tic()
    data_iter = data_loader.__iter__()
    max_iter = len(data_loader)
    max_iter_unique = max_iter

    all_preds, all_labels, batch_losses, batch_loss = [], [], {}, 0

    # Fix batch normalization running mean and std
    model.eval()

    # Clear cache (when run in val mode, cleanup training cache)
    torch.cuda.empty_cache()

    if config.save_prediction or config.test_original_pointcloud:
        if config.save_prediction:
            save_pred_dir = config.save_pred_dir
            os.makedirs(save_pred_dir, exist_ok=True)
        else:
            save_pred_dir = tempfile.mkdtemp()
        if os.listdir(save_pred_dir):
            raise ValueError(f'Directory {save_pred_dir} not empty. '
                             'Please remove the existing prediction.')

    with torch.no_grad():
        for iteration in range(max_iter):
            data_timer.tic()
            if config.return_transformation:
                coords, input, target, transformation = data_iter.next()
            else:
                coords, input, target = data_iter.next()
                transformation = None
            data_time = data_timer.toc(False)

            # Preprocess input
            iter_timer.tic()

            if config.wrapper_type != 'None':
                color = input[:, :3].int()
            if config.normalize_color:
                input[:, :3] = input[:, :3] / 255. - 0.5
            sinput = SparseTensor(input, coords).to(device)

            # Feed forward
            inputs = (sinput, ) if config.wrapper_type == 'None' else (sinput,
                                                                       coords,
                                                                       color)
            soutput = model(*inputs)
            output = soutput.F

            pred = get_prediction(dataset, output, target).int()
            iter_time = iter_timer.toc(False)

            all_preds.append(pred.cpu().detach().numpy())
            all_labels.append(target.cpu().detach().numpy())

            if config.save_prediction or config.test_original_pointcloud:
                save_predictions(coords, pred, transformation, dataset, config,
                                 iteration, save_pred_dir)

            if has_gt:
                if config.evaluate_original_pointcloud:
                    raise NotImplementedError('pointcloud')
                    output, pred, target = permute_pointcloud(
                        coords, pointcloud, transformation, dataset.label_map,
                        output, pred)

                target_np = target.numpy()
                num_sample = target_np.shape[0]
                target = target.to(device)
                """# focal loss
        input_soft = nn.functional.softmax(output, dim=1) + eps
        focal_weight = torch.pow(-input_soft + 1., gamma)
        loss = (-alpha * focal_weight * torch.log(input_soft)).mean()"""

                loss = criterion(output, target.long())

                batch_loss += loss

                losses.update(float(loss), num_sample)
                scores.update(precision_at_one(pred, target), num_sample)
                hist += fast_hist(pred.cpu().numpy().flatten(),
                                  target_np.flatten(), num_labels)
                ious = per_class_iu(hist) * 100

                prob = torch.nn.functional.softmax(output, dim=1)
                ap = average_precision(prob.cpu().detach().numpy(), target_np)
                aps = np.vstack((aps, ap))
                # Due to heavy bias in class, there exists class with no test label at all
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
                    ap_class = np.nanmean(aps, 0) * 100.

            if iteration % config.test_stat_freq == 0 and iteration > 0:
                preds = np.concatenate(all_preds)
                targets = np.concatenate(all_labels)
                to_ignore = [
                    i for i in range(len(targets)) if targets[i] == 255
                ]
                preds_trunc = [
                    preds[i] for i in range(len(preds)) if i not in to_ignore
                ]
                targets_trunc = [
                    targets[i] for i in range(len(targets))
                    if i not in to_ignore
                ]
                cm = confusion_matrix(targets_trunc,
                                      preds_trunc,
                                      normalize='true')
                np.savetxt(config.log_dir + '/cm_epoch_{0}.txt'.format(epoch),
                           cm)

                reordered_ious = dataset.reorder_result(ious)
                reordered_ap_class = dataset.reorder_result(ap_class)
                class_names = dataset.get_classnames()
                print_info(iteration,
                           max_iter_unique,
                           data_time,
                           iter_time,
                           has_gt,
                           losses,
                           scores,
                           reordered_ious,
                           hist,
                           reordered_ap_class,
                           class_names=class_names)

            if iteration % config.empty_cache_freq == 0:
                # Clear cache
                torch.cuda.empty_cache()

            batch_losses[epoch] = batch_loss

    global_time = global_timer.toc(False)

    reordered_ious = dataset.reorder_result(ious)
    reordered_ap_class = dataset.reorder_result(ap_class)
    class_names = dataset.get_classnames()
    print_info(iteration,
               max_iter_unique,
               data_time,
               iter_time,
               has_gt,
               losses,
               scores,
               reordered_ious,
               hist,
               reordered_ap_class,
               class_names=class_names)

    if not config.is_train:
        preds = np.concatenate(all_preds)
        targets = np.concatenate(all_labels)
        to_ignore = [i for i in range(len(targets)) if targets[i] == 255]
        preds_trunc = [
            preds[i] for i in range(len(preds)) if i not in to_ignore
        ]
        targets_trunc = [
            targets[i] for i in range(len(targets)) if i not in to_ignore
        ]
        cm = confusion_matrix(targets_trunc, preds_trunc, normalize='true')
        np.savetxt(config.log_dir + '/cm.txt', cm)

    if config.test_original_pointcloud:
        logging.info('===> Start testing on original pointcloud space.')
        dataset.test_pointcloud(save_pred_dir)
    logging.info("Finished test. Elapsed time: {:.4f}".format(global_time))

    if validation:
        loss_file_name = "/val_loss.txt"
        with open(config.log_dir + loss_file_name, 'a') as val_loss_file:
            for key in batch_losses:
                val_loss_file.writelines('{0}, {1}\n'.format(
                    batch_losses[key], key))
        val_loss_file.close()
        return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean(
            per_class_iu(hist)) * 100, batch_losses

    else:
        return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean(
            per_class_iu(hist)) * 100
Example #22
0
 def forward(self, x: ME.SparseTensor):
     return x.dense(self.min_coords, self.max_coords)[0]
Example #23
0
	def visualize(self, options, model: Model, writer: SummaryWriter, step):
		training = model.training
		model.eval()

		vis_config = self.config['vis']

		if vis_config.get('num_scene_samples'):
			# sample k data points from n data points with equal interval
			n = len(self)
			k = vis_config.get('num_scene_samples')
			vis_indices = torch.linspace(0, n - 1, k) \
				.type(torch.IntTensor).tolist()
		else:
			vis_indices = [self.dir2idx[i] for i in vis_config.get('scene_names')]

		if self.config['overfit_one_ex']:
			vis_scene = self.config['overfit_one_ex']
			vis_indices = [self.dir2idx[vis_scene]]
			vis_indices = list(set(vis_indices))

		for i in vis_indices:
			coords, feats, labels, _ = self[i]
			coords, feats, = sparse_collate([coords], [feats])
			x = SparseTensor(feats, coords)

			x = x.to(model.device)
			with torch.no_grad():
				y_hat = model(x)

			embs = y_hat
			insts = labels[:, 1]

			for option in options:
				# visualize tsne
				if option == 'tsne':
					tsne_img = visualization.visualize_tsne(
						embs.cpu(), insts.cpu(),
						config=self.config['vis']['tsne']
					)
					writer.add_image('tsne/{}'.format(self.idx2dir[i]), tsne_img, step)

				elif option == 'embs':
					vis_config = self.config['vis']['embs']

					# visualize embs with background
					emb_imgs, axis_range = visualization.visualize_embs(
						embs.cpu(), insts.cpu(),
						remove_bg=False, max_sample=vis_config['max_sample'],
						num_view=vis_config['num_view']
					)
					for view_num, img in enumerate(emb_imgs):
						writer.add_image(
							'emb/with_bg/{}_{}'.format(self.idx2dir[i], view_num),
							img, step
						)

					# visualize embs without background
					not_bg_emb_imgs, _ = visualization.visualize_embs(
						embs.cpu(), insts.cpu(),
						remove_bg=True, max_sample=vis_config['max_sample'],
						num_view=vis_config['num_view'], axis_range=axis_range
					)
					for view_num, img in enumerate(not_bg_emb_imgs):
						writer.add_image(
							'emb/no_bg/{}_{}'.format(self.idx2dir[i], view_num),
							img, step
						)

			model.train(training)
 def test_empty(self):
     print(f"{self.__class__.__name__}: test_empty SparseTensor")
     feats = torch.FloatTensor(0, 16)
     coords = torch.IntTensor(0, 4)
     input = SparseTensor(feats, coordinates=coords)
     print(input)
Example #25
0
	def collate_fn(self, batch):
		coords, features, labels = list(zip(*batch))
		coords, features, labels = sparse_collate(coords, features, labels)
		return SparseTensor(features, coords=coords), labels
def train(model, data_loader, val_data_loader, config, transform_data_fn=None):
    all_losses = []
    device = get_torch_device(config.is_cuda)
    # Set up the train flag for batch normalization
    model.train()

    # Configuration
    writer = SummaryWriter(log_dir=config.log_dir)
    data_timer, iter_timer = Timer(), Timer()
    data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
    losses, scores, batch_losses = AverageMeter(), AverageMeter(), {}

    optimizer = initialize_optimizer(model.parameters(), config)
    scheduler = initialize_scheduler(optimizer, config)
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)
    alpha, gamma, eps = 1, 2, 1e-6

    writer = SummaryWriter(log_dir=config.log_dir)

    # Train the network
    logging.info('===> Start training')
    best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True

    if config.resume:
        checkpoint_fn = config.resume + '/weights.pth'
        if osp.isfile(checkpoint_fn):
            logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
            state = torch.load(checkpoint_fn)
            curr_iter = state['iteration'] + 1
            epoch = state['epoch']
            model.load_state_dict(state['state_dict'])
            if config.resume_optimizer:
                scheduler = initialize_scheduler(optimizer,
                                                 config,
                                                 last_step=curr_iter)
                optimizer.load_state_dict(state['optimizer'])
            if 'best_val' in state:
                best_val_miou = state['best_val']
                best_val_iter = state['best_val_iter']
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_fn, state['epoch']))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(checkpoint_fn))

    data_iter = data_loader.__iter__()

    while is_training:
        print(
            "********************************** epoch N° {0} ************************"
            .format(epoch))
        for iteration in range(len(data_loader) // config.iter_size):
            print("####### Iteration N° {0}".format(iteration))
            optimizer.zero_grad()
            data_time, batch_loss = 0, 0
            iter_timer.tic()
            for sub_iter in range(config.iter_size):
                print("------------- Sub_iteration N° {0}".format(sub_iter))
                # Get training data
                data_timer.tic()
                coords, input, target = data_iter.next()
                print("len of coords : {0}".format(len(coords)))

                # For some networks, making the network invariant to even, odd coords is important
                coords[:, :3] += (torch.rand(3) * 100).type_as(coords)

                # Preprocess input
                color = input[:, :3].int()

                if config.normalize_color:
                    input[:, :3] = input[:, :3] / 255. - 0.5
                sinput = SparseTensor(input, coords).to(device)

                data_time += data_timer.toc(False)

                # Feed forward
                inputs = (sinput, ) if config.wrapper_type == 'None' else (
                    sinput, coords, color)
                # model.initialize_coords(*init_args)
                soutput = model(*inputs)
                # The output of the network is not sorted
                target = target.long().to(device)
                print("count of classes : {0}".format(
                    np.unique(target.cpu().numpy(), return_counts=True)))
                print("target : {0}\ntarget_len : {1}".format(
                    target, len(target)))
                print("target [0]: {0}".format(target[0]))
                input_soft = nn.functional.softmax(soutput.F, dim=1) + eps
                print("input_soft[0] : {0}".format(input_soft[0]))
                focal_weight = torch.pow(-input_soft + 1., gamma)
                print("focal_weight : {0}\nweight[0] : {1}".format(
                    focal_weight, focal_weight[0]))
                focal_loss = (-alpha * focal_weight *
                              torch.log(input_soft)).mean()
                loss = criterion(soutput.F, target.long())
                print("focal_loss :{0}\nloss : {1}".format(focal_loss, loss))

                # Compute and accumulate gradient
                loss /= config.iter_size
                #batch_loss += loss
                batch_loss += loss.item()
                print("batch_loss : {0}".format(batch_loss))
                loss.backward()

            # Update number of steps
            optimizer.step()
            scheduler.step()

            data_time_avg.update(data_time)
            iter_time_avg.update(iter_timer.toc(False))

            pred = get_prediction(data_loader.dataset, soutput.F, target)
            score = precision_at_one(pred, target)
            losses.update(batch_loss, target.size(0))
            scores.update(score, target.size(0))

            if curr_iter >= config.max_iter:
                is_training = False
                break

            if curr_iter % config.stat_freq == 0 or curr_iter == 1:
                lrs = ', '.join(
                    ['{:.3e}'.format(x) for x in scheduler.get_lr()])
                debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
                    epoch, curr_iter,
                    len(data_loader) // config.iter_size, losses.avg, lrs)
                debug_str += "Score {:.3f}\tData time: {:.4f}, Total iter time: {:.4f}".format(
                    scores.avg, data_time_avg.avg, iter_time_avg.avg)
                logging.info(debug_str)
                # Reset timers
                data_time_avg.reset()
                iter_time_avg.reset()
                # Write logs
                writer.add_scalar('training/loss', losses.avg, curr_iter)
                writer.add_scalar('training/precision_at_1', scores.avg,
                                  curr_iter)
                writer.add_scalar('training/learning_rate',
                                  scheduler.get_lr()[0], curr_iter)
                losses.reset()
                scores.reset()

            # Save current status, save before val to prevent occational mem overflow
            if curr_iter % config.save_freq == 0:
                checkpoint(model, optimizer, epoch, curr_iter, config,
                           best_val_miou, best_val_iter)

            # Validation
            if curr_iter % config.val_freq == 0:
                val_miou, val_losses = validate(model, val_data_loader, writer,
                                                curr_iter, config,
                                                transform_data_fn, epoch)

                if val_miou > best_val_miou:
                    best_val_miou = val_miou
                    best_val_iter = curr_iter
                    checkpoint(model, optimizer, epoch, curr_iter, config,
                               best_val_miou, best_val_iter, "best_val")
                logging.info("Current best mIoU: {:.3f} at iter {}".format(
                    best_val_miou, best_val_iter))

                # Recover back
                model.train()

            if curr_iter % config.empty_cache_freq == 0:
                # Clear cache
                torch.cuda.empty_cache()

            batch_losses[epoch] = batch_loss
            # End of iteration
            curr_iter += 1
        with open(config.log_dir + "/train_loss.txt", 'a') as train_loss_log:
            train_loss_log.writelines('{0}, {1}\n'.format(
                batch_losses[epoch], epoch))
        train_loss_log.close()
        epoch += 1

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    # Save the final model
    checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
               best_val_iter)
    val_miou = validate(model, val_data_loader, writer, curr_iter, config,
                        transform_data_fn, epoch)[0]
    if val_miou > best_val_miou:
        best_val_miou = val_miou
        best_val_iter = curr_iter
        checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
                   best_val_iter, "best_val")
    logging.info("Current best mIoU: {:.3f} at iter {}".format(
        best_val_miou, best_val_iter))
Example #27
0
    def forward(self, x):
        out = self.conv0p1s1(x)
        out = self.bn0(out)
        out_p1 = self.relu(out)

        out = self.conv1p1s2(out_p1)
        out = self.bn1(out)
        out = self.relu(out)
        out_b1p2 = self.block1(out)

        out = self.conv2p2s2(out_b1p2)
        out = self.bn2(out)
        out = self.relu(out)
        out_b2p4 = self.block2(out)

        out = self.conv3p4s2(out_b2p4)
        out = self.bn3(out)
        out = self.relu(out)
        out_b3p8 = self.block3(out)

        out = self.conv4p8s2(out_b3p8)
        out = self.bn4(out)
        out = self.relu(out)
        encoder_out = self.block4(out)

        out = self.convtr4p16s2(encoder_out)
        out = self.bntr4(out)
        out = self.relu(out)

        out = me.cat(out, out_b3p8)
        out = self.block5(out)

        out = self.convtr5p8s2(out)
        out = self.bntr5(out)
        out = self.relu(out)

        out = me.cat(out, out_b2p4)
        out = self.block6(out)

        out = self.convtr6p4s2(out)
        out = self.bntr6(out)
        out = self.relu(out)

        out = me.cat(out, out_b1p2)
        out = self.block7(out)

        out = self.convtr7p2s2(out)
        out = self.bntr7(out)
        out = self.relu(out)

        out = me.cat(out, out_p1)
        out = self.block8(out)

        out = self.final(out)

        if self.normalize_feature:
            return SparseTensor(out.F /
                                torch.norm(out.F, p=2, dim=1, keepdim=True),
                                coords_key=out.coords_key,
                                coords_manager=out.coords_man)
        else:
            return out
Example #28
0
def train(model, data_loader, val_data_loader, config, transform_data_fn=None):

    device = get_torch_device(config.is_cuda)
    # Set up the train flag for batch normalization
    model.train()

    # Configuration
    data_timer, iter_timer = Timer(), Timer()
    data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
    regs, losses, scores = AverageMeter(), AverageMeter(), AverageMeter()

    optimizer = initialize_optimizer(model.parameters(), config)
    scheduler = initialize_scheduler(optimizer, config)
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)

    # Train the network
    logging.info('===> Start training')
    best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True

    if config.resume:
        # Test loaded ckpt first
        v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config)

        checkpoint_fn = config.resume + '/weights.pth'
        if osp.isfile(checkpoint_fn):
            logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
            state = torch.load(checkpoint_fn)
            curr_iter = state['iteration'] + 1
            epoch = state['epoch']
            # we skip attention maps because the shape won't match because voxel number is different
            # e.g. copyting a param with shape (23385, 8, 4) to (43529, 8, 4)
            d = {
                k: v
                for k, v in state['state_dict'].items() if 'map' not in k
            }
            # handle those attn maps we don't load from saved dict
            for k in model.state_dict().keys():
                if k in d.keys(): continue
                d[k] = model.state_dict()[k]
            model.load_state_dict(d)
            if config.resume_optimizer:
                scheduler = initialize_scheduler(optimizer,
                                                 config,
                                                 last_step=curr_iter)
                optimizer.load_state_dict(state['optimizer'])
            if 'best_val' in state:
                best_val_miou = state['best_val']
                best_val_iter = state['best_val_iter']
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_fn, state['epoch']))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(checkpoint_fn))

    data_iter = data_loader.__iter__()
    if config.dataset == "SemanticKITTI":
        num_class = 19
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
        config.val_freq = config.val_freq * 10
    elif config.dataset == "S3DIS":
        num_class = 13
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
        config.val_freq = config.val_freq
    elif config.dataset == "Nuscenes":
        num_class = 16
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
        config.val_freq = config.val_freq * 50
    else:
        num_class = 20
        val_freq_ = config.val_freq

    while is_training:
        total_correct_class = torch.zeros(num_class, device=device)
        total_iou_deno_class = torch.zeros(num_class, device=device)

        for iteration in range(len(data_loader) // config.iter_size):
            optimizer.zero_grad()
            data_time, batch_loss = 0, 0
            iter_timer.tic()

            if curr_iter >= config.max_iter:
                # if curr_iter >= max(config.max_iter, config.epochs*(len(data_loader) // config.iter_size):
                is_training = False
                break
            elif curr_iter >= config.max_iter * (2 / 3):
                config.val_freq = val_freq_ * 2  # valid more freq on lower half

            for sub_iter in range(config.iter_size):
                # Get training data
                data_timer.tic()
                pointcloud = None

                if config.return_transformation:
                    coords, input, target, _, _, pointcloud, transformation, _ = data_iter.next(
                    )
                else:
                    coords, input, target, _, _, _ = data_iter.next(
                    )  # ignore unique_map and inverse_map

                if config.use_aux:
                    assert target.shape[1] == 2
                    aux = target[:, 1]
                    target = target[:, 0]
                else:
                    aux = None

                # For some networks, making the network invariant to even, odd coords is important
                coords[:, 1:] += (torch.rand(3) * 100).type_as(coords)

                # Preprocess input
                if config.normalize_color:
                    input[:, :3] = input[:, :3] / input[:, :3].max() - 0.5
                    coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5

                # cat xyz into the rgb feature
                if config.xyz_input:
                    input = torch.cat([coords_norm, input], dim=1)
                sinput = SparseTensor(input, coords, device=device)
                starget = SparseTensor(
                    target.unsqueeze(-1).float(),
                    coordinate_map_key=sinput.coordinate_map_key,
                    coordinate_manager=sinput.coordinate_manager,
                    device=device
                )  # must share the same coord-manager to align for sinput

                data_time += data_timer.toc(False)
                # model.initialize_coords(*init_args)

                # d = {}
                # d['c'] = sinput.C
                # d['l'] = starget.F
                # torch.save('./plot/test-label.pth')
                # import ipdb; ipdb.set_trace()

                # Set up profiler
                # memory_profiler = CUDAMemoryProfiler(
                # [model, criterion],
                # filename="cuda_memory.profile"
                # )
                # sys.settrace(memory_profiler)
                # threading.settrace(memory_profiler)

                # with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=True) as prof0:
                if aux is not None:
                    soutput = model(sinput, aux)
                elif config.enable_point_branch:
                    soutput = model(sinput,
                                    iter_=curr_iter / config.max_iter,
                                    enable_point_branch=True)
                else:
                    # label-aux, feed it in as additional reg
                    soutput = model(
                        sinput, iter_=curr_iter / config.max_iter, aux=starget
                    )  # feed in the progress of training for annealing inside the model

                # The output of the network is not sorted
                target = target.view(-1).long().to(device)
                loss = criterion(soutput.F, target.long())

                # ====== other loss regs =====
                if hasattr(model, 'block1'):
                    cur_loss = torch.tensor([0.], device=device)

                    if hasattr(model.block1[0], 'vq_loss'):
                        if model.block1[0].vq_loss is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].vq_loss  # m is the nn.Sequential obj, m[0] is the TRBlock
                            logging.info(
                                'Cur Loss: {}, Cur vq_loss: {}'.format(
                                    loss, cur_loss))
                            loss += cur_loss

                    if hasattr(model.block1[0], 'diverse_loss'):
                        if model.block1[0].diverse_loss is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].diverse_loss  # m is the nn.Sequential obj, m[0] is the TRBlock
                            logging.info(
                                'Cur Loss: {}, Cur diverse _loss: {}'.format(
                                    loss, cur_loss))
                            loss += cur_loss

                    if hasattr(model.block1[0], 'label_reg'):
                        if model.block1[0].label_reg is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].label_reg  # m is the nn.Sequential obj, m[0] is the TRBlock
                            # logging.info('Cur Loss: {}, Cur diverse _loss: {}'.format(loss, cur_loss))
                            loss += cur_loss

                # Compute and accumulate gradient
                loss /= config.iter_size
                batch_loss += loss.item()
                loss.backward()

                # soutput = model(sinput)

            # Update number of steps
            if not config.use_sam:
                optimizer.step()
            else:
                optimizer.first_step(zero_grad=True)
                soutput = model(sinput,
                                iter_=curr_iter / config.max_iter,
                                aux=starget)
                criterion(soutput.F, target.long()).backward()
                optimizer.second_step(zero_grad=True)

            if config.lr_warmup is None:
                scheduler.step()
            else:
                if curr_iter >= config.lr_warmup:
                    scheduler.step()
                for g in optimizer.param_groups:
                    g['lr'] = config.lr * (iteration + 1) / config.lr_warmup

            # CLEAR CACHE!
            torch.cuda.empty_cache()

            data_time_avg.update(data_time)
            iter_time_avg.update(iter_timer.toc(False))

            pred = get_prediction(data_loader.dataset, soutput.F, target)
            score = precision_at_one(pred, target, ignore_label=-1)

            regs.update(cur_loss.item(), target.size(0))
            losses.update(batch_loss, target.size(0))
            scores.update(score, target.size(0))

            # calc the train-iou
            for l in range(num_class):
                total_correct_class[l] += ((pred == l) & (target == l)).sum()
                total_iou_deno_class[l] += (((pred == l) & (target != -1)) |
                                            (target == l)).sum()

            if curr_iter % config.stat_freq == 0 or curr_iter == 1:
                lrs = ', '.join(
                    ['{:.3e}'.format(x) for x in scheduler.get_lr()])
                IoU = ((total_correct_class) /
                       (total_iou_deno_class + 1e-6)).mean() * 100.
                debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
                    config.log_dir.split('/')[-2], epoch, curr_iter,
                    len(data_loader) // config.iter_size, losses.avg, lrs)
                debug_str += "Score {:.3f}\tIoU {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format(
                    scores.avg, IoU.item(), data_time_avg.avg,
                    iter_time_avg.avg)
                if regs.avg > 0:
                    debug_str += "\n Additional Reg Loss {:.3f}".format(
                        regs.avg)
                # print(debug_str)
                logging.info(debug_str)
                # Reset timers
                data_time_avg.reset()
                iter_time_avg.reset()
                # Write logs
                losses.reset()
                scores.reset()

            # Save current status, save before val to prevent occational mem overflow
            if curr_iter % config.save_freq == 0:
                checkpoint(model,
                           optimizer,
                           epoch,
                           curr_iter,
                           config,
                           best_val_miou,
                           best_val_iter,
                           save_inter=True)

            # Validation
            if curr_iter % config.val_freq == 0:
                val_miou = validate(model, val_data_loader, None, curr_iter,
                                    config, transform_data_fn)
                if val_miou > best_val_miou:
                    best_val_miou = val_miou
                    best_val_iter = curr_iter
                    checkpoint(model,
                               optimizer,
                               epoch,
                               curr_iter,
                               config,
                               best_val_miou,
                               best_val_iter,
                               "best_val",
                               save_inter=True)
                logging.info("Current best mIoU: {:.3f} at iter {}".format(
                    best_val_miou, best_val_iter))
                # print("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter))

                # Recover back
                model.train()

            # End of iteration
            curr_iter += 1

        IoU = (total_correct_class) / (total_iou_deno_class + 1e-6)
        logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.))

        epoch += 1

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    # Save the final model
    checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
               best_val_iter)
    v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config)
    if val_miou > best_val_miou:
        best_val_miou = val_miou
        best_val_iter = curr_iter
        checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
                   best_val_iter, "best_val")
    logging.info("Current best mIoU: {:.3f} at iter {}".format(
        best_val_miou, best_val_iter))
Example #29
0
def test(model, data_loader, config, transform_data_fn=None, has_gt=True):
    device = get_torch_device(config.is_cuda)
    dataset = data_loader.dataset
    num_labels = dataset.NUM_LABELS
    global_timer, data_timer, iter_timer = Timer(), Timer(), Timer()
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)
    losses, scores, ious = AverageMeter(), AverageMeter(), 0
    aps = np.zeros((0, num_labels))
    hist = np.zeros((num_labels, num_labels))

    logging.info('===> Start testing')

    global_timer.tic()
    data_iter = data_loader.__iter__()
    max_iter = len(data_loader)
    max_iter_unique = max_iter

    # Fix batch normalization running mean and std
    model.eval()

    # Clear cache (when run in val mode, cleanup training cache)
    torch.cuda.empty_cache()

    if config.save_prediction or config.test_original_pointcloud:
        if config.save_prediction:
            save_pred_dir = config.save_pred_dir
            os.makedirs(save_pred_dir, exist_ok=True)
        else:
            save_pred_dir = tempfile.mkdtemp()
        if os.listdir(save_pred_dir):
            raise ValueError(f'Directory {save_pred_dir} not empty. '
                             'Please remove the existing prediction.')

    with torch.no_grad():
        for iteration in range(max_iter):
            data_timer.tic()
            if config.return_transformation:
                coords, input, target, transformation = data_iter.next()
            else:
                coords, input, target = data_iter.next()
                transformation = None
            data_time = data_timer.toc(False)

            # Preprocess input
            iter_timer.tic()

            if config.wrapper_type != 'None':
                color = input[:, :3].int()
            if config.normalize_color:
                input[:, :3] = input[:, :3] / 255. - 0.5
            sinput = SparseTensor(input, coords).to(device)

            # Feed forward
            inputs = (sinput, ) if config.wrapper_type == 'None' else (sinput,
                                                                       coords,
                                                                       color)
            soutput = model(*inputs)
            output = soutput.F

            pred = get_prediction(dataset, output, target).int()
            iter_time = iter_timer.toc(False)

            if config.save_prediction or config.test_original_pointcloud:
                save_predictions(coords, pred, transformation, dataset, config,
                                 iteration, save_pred_dir)

            if has_gt:
                if config.evaluate_original_pointcloud:
                    raise NotImplementedError('pointcloud')
                    output, pred, target = permute_pointcloud(
                        coords, pointcloud, transformation, dataset.label_map,
                        output, pred)

                target_np = target.numpy()

                num_sample = target_np.shape[0]

                target = target.to(device)

                cross_ent = criterion(output, target.long())
                losses.update(float(cross_ent), num_sample)
                scores.update(precision_at_one(pred, target), num_sample)
                hist += fast_hist(pred.cpu().numpy().flatten(),
                                  target_np.flatten(), num_labels)
                ious = per_class_iu(hist) * 100

                prob = torch.nn.functional.softmax(output, dim=1)
                ap = average_precision(prob.cpu().detach().numpy(), target_np)
                aps = np.vstack((aps, ap))
                # Due to heavy bias in class, there exists class with no test label at all
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
                    ap_class = np.nanmean(aps, 0) * 100.

            if iteration % config.test_stat_freq == 0 and iteration > 0:
                reordered_ious = dataset.reorder_result(ious)
                reordered_ap_class = dataset.reorder_result(ap_class)
                class_names = dataset.get_classnames()
                print_info(iteration,
                           max_iter_unique,
                           data_time,
                           iter_time,
                           has_gt,
                           losses,
                           scores,
                           reordered_ious,
                           hist,
                           reordered_ap_class,
                           class_names=class_names)

            if iteration % config.empty_cache_freq == 0:
                # Clear cache
                torch.cuda.empty_cache()

    global_time = global_timer.toc(False)

    reordered_ious = dataset.reorder_result(ious)
    reordered_ap_class = dataset.reorder_result(ap_class)
    class_names = dataset.get_classnames()
    print_info(iteration,
               max_iter_unique,
               data_time,
               iter_time,
               has_gt,
               losses,
               scores,
               reordered_ious,
               hist,
               reordered_ap_class,
               class_names=class_names)

    if config.test_original_pointcloud:
        logging.info('===> Start testing on original pointcloud space.')
        dataset.test_pointcloud(save_pred_dir)

    logging.info("Finished test. Elapsed time: {:.4f}".format(global_time))

    return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean(
        per_class_iu(hist)) * 100
Example #30
0
def train_distill(model,
                  data_loader,
                  val_data_loader,
                  config,
                  transform_data_fn=None):
    '''
    the distillation training
    some cfgs here
    '''

    # distill_lambda = 1
    # distill_lambda = 0.33
    distill_lambda = 0.67

    # TWO_STAGE=True: Transformer is first trained with L2 loss to match ResNet's activation, and then it fintunes like normal training on the second stage.
    # TWO_STAGE=False: Transformer trains with combined loss

    TWO_STAGE = False
    # STAGE_PERCENTAGE = 0.7

    device = get_torch_device(config.is_cuda)
    # Set up the train flag for batch normalization
    model.train()

    # Configuration
    data_timer, iter_timer = Timer(), Timer()
    data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
    losses, scores = AverageMeter(), AverageMeter()

    optimizer = initialize_optimizer(model.parameters(), config)
    scheduler = initialize_scheduler(optimizer, config)
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)

    # Train the network
    logging.info('===> Start training')
    best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True

    # TODO:
    # load the sub-model only
    # FIXME: some dirty hard-written stuff, only supporting current state

    tch_model_cls = load_model('Res16UNet18A')
    tch_model = tch_model_cls(3, 20, config).to(device)

    # checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/resnet_base/weights.pth"
    checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/Res18A/weights.pth"  # voxel-size: 0.05
    assert osp.isfile(checkpoint_fn)
    logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
    state = torch.load(checkpoint_fn)
    d = {k: v for k, v in state['state_dict'].items() if 'map' not in k}
    tch_model.load_state_dict(d)
    if 'best_val' in state:
        best_val_miou = state['best_val']
        best_val_iter = state['best_val_iter']
    logging.info("=> loaded checkpoint '{}' (epoch {})".format(
        checkpoint_fn, state['epoch']))

    if config.resume:
        raise NotImplementedError
        # Test loaded ckpt first

        # checkpoint_fn = config.resume + '/weights.pth'
        # if osp.isfile(checkpoint_fn):
        # logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
        # state = torch.load(checkpoint_fn)
        # curr_iter = state['iteration'] + 1
        # epoch = state['epoch']
        # d = {k:v for k,v in state['state_dict'].items() if 'map' not in k }
        # model.load_state_dict(d)
        # if config.resume_optimizer:
        # scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter)
        # optimizer.load_state_dict(state['optimizer'])
        # if 'best_val' in state:
        # best_val_miou = state['best_val']
        # best_val_iter = state['best_val_iter']
        # logging.info("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_fn, state['epoch']))
        # else:
        # raise ValueError("=> no checkpoint found at '{}'".format(checkpoint_fn))

    # test after loading the ckpt
    v_loss, v_score, v_mAP, v_mIoU = test(tch_model, val_data_loader, config)
    logging.info('Tch model tested, bes_miou: {}'.format(v_mIoU))

    data_iter = data_loader.__iter__()
    while is_training:

        num_class = 20
        total_correct_class = torch.zeros(num_class, device=device)
        total_iou_deno_class = torch.zeros(num_class, device=device)

        total_iteration = len(data_loader) // config.iter_size
        for iteration in range(total_iteration):

            # NOTE: for single stage distillation, L2 loss might be too large at first
            # so we added a warmup training that don't use L2 loss
            if iteration < 0:
                use_distill = False
            else:
                use_distill = True

            # Stage 1 / Stage 2 boundary
            if TWO_STAGE:
                stage_boundary = int(total_iteration * STAGE_PERCENTAGE)

            optimizer.zero_grad()
            data_time, batch_loss = 0, 0
            iter_timer.tic()

            for sub_iter in range(config.iter_size):
                # Get training data
                data_timer.tic()
                if config.return_transformation:
                    coords, input, target, _, _, pointcloud, transformation = data_iter.next(
                    )
                else:
                    coords, input, target, _, _ = data_iter.next(
                    )  # ignore unique_map and inverse_map

                if config.use_aux:
                    assert target.shape[1] == 2
                    aux = target[:, 1]
                    target = target[:, 0]
                else:
                    aux = None

                # For some networks, making the network invariant to even, odd coords is important
                coords[:, 1:] += (torch.rand(3) * 100).type_as(coords)

                # Preprocess input
                if config.normalize_color:
                    input[:, :3] = input[:, :3] / 255. - 0.5
                    coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5

                # cat xyz into the rgb feature
                if config.xyz_input:
                    input = torch.cat([coords_norm, input], dim=1)

                sinput = SparseTensor(input, coords, device=device)

                # TODO: return both-models
                # in order to not breaking the valid interface, use a get_loss to get the regsitered loss

                data_time += data_timer.toc(False)
                # model.initialize_coords(*init_args)
                if aux is not None:
                    raise NotImplementedError

                # flatten ground truth tensor
                target = target.view(-1).long().to(device)

                if TWO_STAGE:
                    if iteration < stage_boundary:
                        # Stage 1: train transformer on L2 loss
                        soutput, anchor = model(sinput, save_anchor=True)
                        # Make sure gradient don't flow to teacher model
                        with torch.no_grad():
                            _, tch_anchor = tch_model(sinput, save_anchor=True)
                        loss = DistillLoss(tch_anchor, anchor)
                    else:
                        # Stage 2: finetune transformer on Cross-Entropy
                        soutput = model(sinput)
                        loss = criterion(soutput.F, target.long())
                else:
                    if use_distill:  # after warm up
                        soutput, anchor = model(sinput, save_anchor=True)
                        # if pretrained teacher, do not let the grad flow to teacher to update its params
                        with torch.no_grad():
                            tch_soutput, tch_anchor = tch_model(
                                sinput, save_anchor=True)

                    else:  # warming up
                        soutput = model(sinput)
                    # The output of the network is not sorted
                    loss = criterion(soutput.F, target.long())
                    #  Add L2 loss if use distillation
                    if use_distill:
                        distill_loss = DistillLoss(tch_anchor,
                                                   anchor) * distill_lambda
                        loss += distill_loss

                # Compute and accumulate gradient
                loss /= config.iter_size
                batch_loss += loss.item()
                loss.backward()

            # Update number of steps
            optimizer.step()
            scheduler.step()

            # CLEAR CACHE!
            torch.cuda.empty_cache()

            data_time_avg.update(data_time)
            iter_time_avg.update(iter_timer.toc(False))

            pred = get_prediction(data_loader.dataset, soutput.F, target)
            score = precision_at_one(pred, target, ignore_label=-1)
            losses.update(batch_loss, target.size(0))
            scores.update(score, target.size(0))

            # calc the train-iou
            for l in range(num_class):
                total_correct_class[l] += ((pred == l) & (target == l)).sum()
                total_iou_deno_class[l] += (((pred == l) & (target != -1)) |
                                            (target == l)).sum()

            if curr_iter >= config.max_iter:
                is_training = False
                break

            if curr_iter % config.stat_freq == 0 or curr_iter == 1:
                lrs = ', '.join(
                    ['{:.3e}'.format(x) for x in scheduler.get_lr()])
                debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
                    config.log_dir, epoch, curr_iter,
                    len(data_loader) // config.iter_size, losses.avg, lrs)
                debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format(
                    scores.avg, data_time_avg.avg, iter_time_avg.avg)
                logging.info(debug_str)
                if use_distill and not TWO_STAGE:
                    logging.info('Loss {} Distill Loss:{}'.format(
                        loss, distill_loss))
                # Reset timers
                data_time_avg.reset()
                iter_time_avg.reset()
                losses.reset()
                scores.reset()

            # Save current status, save before val to prevent occational mem overflow
            if curr_iter % config.save_freq == 0:
                checkpoint(model,
                           optimizer,
                           epoch,
                           curr_iter,
                           config,
                           best_val_miou,
                           best_val_iter,
                           save_inter=True)

            # Validation
            if curr_iter % config.val_freq == 0:
                val_miou = validate(model, val_data_loader, None, curr_iter,
                                    config, transform_data_fn)
                if val_miou > best_val_miou:
                    best_val_miou = val_miou
                    best_val_iter = curr_iter
                    checkpoint(model,
                               optimizer,
                               epoch,
                               curr_iter,
                               config,
                               best_val_miou,
                               best_val_iter,
                               "best_val",
                               save_inter=True)
                logging.info("Current best mIoU: {:.3f} at iter {}".format(
                    best_val_miou, best_val_iter))

                # Recover back
                model.train()

            # End of iteration
            curr_iter += 1

        IoU = (total_correct_class) / (total_iou_deno_class + 1e-6)
        logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.))

        epoch += 1

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    # Save the final model
    checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
               best_val_iter)
    v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config)
    if val_miou > best_val_miou:
        best_val_miou = val_miou
        best_val_iter = curr_iter
        checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
                   best_val_iter, "best_val")
    logging.info("Current best mIoU: {:.3f} at iter {}".format(
        best_val_miou, best_val_iter))