def test_gpu(self): print(f"{self.__class__.__name__}: test_gpu") if not torch.cuda.is_available(): return device = torch.device('cuda') in_channels, out_channels, D = 2, 3, 2 coords, feats, labels = data_loader(in_channels) feats = feats.double() feats.requires_grad_() input = SparseTensor(feats, coords=coords).to(device) # Initialize context conv = MinkowskiConvolution(in_channels, out_channels, kernel_size=3, stride=2, has_bias=True, dimension=D).to(device) conv = conv.double() conv_tr = MinkowskiConvolutionTranspose(out_channels, in_channels, kernel_size=3, stride=2, has_bias=True, dimension=D).to(device) conv_tr = conv_tr.double() input = conv(input) output = conv_tr(input) print(output) # Check backward fn = MinkowskiConvolutionTransposeFunction() self.assertTrue( gradcheck(fn, (input.F, conv_tr.kernel, input.tensor_stride, conv_tr.stride, conv_tr.kernel_size, conv_tr.dilation, conv_tr.region_type_, conv_tr.region_offset_, False, input.coords_key, None, input.coords_man)))
def forward(self, x): xf = x.F if self.requires_mapping: # Map the network output to CRF input xf = SparseMM()(Variable(self.in_mapping), xf) out = xf for i in range(self.meanfield_iterations): # Meanfield iteration # Normalization out = self.softmaxes[i](out) # Pairwise potential out = self.convs[i].apply(out, self.conv.kernel, x.pixel_dist, self.conv.stride, self.conv.kernel_size, self.conv.dilation, self.region_type_, self.region_offset_, x.coords_key, x.coords_key, x.C) # Add unary out += xf if self.requires_mapping: # Map the CRF output to the origianl space out = SparseMM()(Variable(self.out_mapping), out) return SparseTensor(out, coords_key=x.coords_key, coords_manager=x.C)
def test_decomposition_gpu(self): print(f"{self.__class__.__name__}: test_decomposition_gpu") if not torch.cuda.is_available(): return coords, colors, pcd = load_file("1.ply") colors = torch.from_numpy(colors) for batch_size in [5, 10, 20, 40]: for voxel_size in [0.02]: dcoords = torch.from_numpy(np.floor(coords / voxel_size)).int() bcoords = batched_coordinates([dcoords for i in range(batch_size)]) feats = torch.cat([colors for b in range(batch_size)], 0) sinput = SparseTensor(feats.to(0), bcoords.to(0)) ( decomposed_coords, decomposed_feats, ) = sinput.decomposed_coordinates_and_features print([len(c) for c in decomposed_coords]) print([len(f) for f in decomposed_feats]) self.assertEqual(len(decomposed_coords), batch_size) self.assertEqual(len(decomposed_feats), batch_size)
def test_kernelmap_gpu(self): print(f"{self.__class__.__name__}: test_kernelmap_gpu") if not torch.cuda.is_available(): return in_channels, out_channels, D = 2, 3, 2 coords, feats, labels = data_loader(in_channels) feats = feats.double() feats.requires_grad_() input = SparseTensor(feats, coords=coords) cm = input.coords_man ikey = cm._get_coords_key(1) print('Input coords: ') cm.print_diagnostics(ikey) print('Convolution: ') # Initialize context conv = MinkowskiConvolution( in_channels, out_channels, kernel_size=3, stride=2, has_bias=True, dimension=D).double() output = conv(input) iC = input.C.numpy() oC = output.C.numpy() print(iC) print(oC) in_maps, out_maps = output.coords_man.get_kernel_map( 1, 2, stride=2, kernel_size=3, on_gpu=True) kernel_index = 0 for in_map, out_map in zip(in_maps, out_maps): for i, o in zip(in_map, out_map): print(kernel_index, iC[i], '->', oC[o]) kernel_index += 1 self.assertTrue(sum(len(in_map) for in_map in in_maps) == 26)
def test(self): print(f"{self.__class__.__name__}: test_dense") in_channels, out_channels, D = 2, 3, 2 coords, feats, labels = data_loader(in_channels) feats = feats.double() feats.requires_grad_() input = SparseTensor(feats, coords=coords) # Initialize context conv = MinkowskiConvolution(in_channels, out_channels, kernel_size=3, stride=2, has_bias=True, dimension=D) conv = conv.double() output = conv(input) print(input.C, output.C) # Convert to a dense tensor dense_output, min_coord, tensor_stride = output.dense() print(dense_output.shape) print(dense_output) print(min_coord) print(tensor_stride) dense_output, min_coord, tensor_stride = output.dense( min_coords=torch.IntTensor([-2, -2]), max_coords=torch.IntTensor([4, 4])) print(dense_output) print(min_coord) print(tensor_stride) print(feats.grad) loss = dense_output.sum() loss.backward() print(feats.grad)
def test_extraction(self): coords = torch.IntTensor([[0, 0], [0, 1], [0, 2], [2, 0], [2, 2]]) feats = torch.FloatTensor([[1.1, 2.1, 3.1, 4.1, 5.1]]).t() X = SparseTensor(feats, coords) C0 = X.coordinates_at(0) F0 = X.features_at(0) self.assertTrue(0 in C0) self.assertTrue(1 in C0) self.assertTrue(2 in C0) self.assertTrue(1.1 in F0) self.assertTrue(2.1 in F0) self.assertTrue(3.1 in F0) CC0, FC0 = X.coordinates_and_features_at(0) self.assertTrue((C0 == CC0).all()) self.assertTrue((F0 == FC0).all()) coords, feats = X.decomposed_coordinates_and_features for c, f in zip(coords, feats): self.assertEqual(c.numel(), f.numel()) print(c, f) self.assertEqual(len(coords[0]), 3) self.assertEqual(len(coords[1]), 0) self.assertEqual(len(coords[2]), 2) if not is_cuda_available(): return coords = torch.IntTensor([[0, 0], [0, 1], [0, 2], [2, 0], [2, 2]]) feats = torch.FloatTensor([[1.1, 2.1, 3.1, 4.1, 5.1]]).t() X = SparseTensor(feats, coords, device=0) coords, feats = X.decomposed_coordinates_and_features for c, f in zip(coords, feats): self.assertEqual(c.numel(), f.numel()) print(c, f) self.assertEqual(len(coords[0]), 3) self.assertEqual(len(coords[1]), 0) self.assertEqual(len(coords[2]), 2)
def test_with_convtr(self): channels, D = [2, 3, 4], 2 coords, feats, labels = data_loader(channels[0], batch_size=1) feats = feats.double() feats.requires_grad_() # Create a sparse tensor with large tensor strides for upsampling start_tensor_stride = 4 input = SparseTensor(feats, coords=coords * start_tensor_stride, tensor_stride=start_tensor_stride) conv_tr1 = MinkowskiConvolutionTranspose(channels[0], channels[1], kernel_size=3, stride=2, generate_new_coords=True, dimension=D).double() conv_tr2 = MinkowskiConvolutionTranspose(channels[1], channels[2], kernel_size=3, stride=2, generate_new_coords=True, dimension=D).double() pruning = MinkowskiPruning(D) out1 = conv_tr1(input) use_feat = torch.rand(len(out1)) < 0.5 out1 = pruning(out1, use_feat) out2 = conv_tr2(out1) use_feat = torch.rand(len(out2)) < 0.5 out2 = pruning(out2, use_feat) print(out2) out2.F.sum().backward() # Check gradient flow print(input.F.grad)
def test(self): in_channels, D = 2, 2 coords, feats, labels = data_loader(in_channels) feats = feats.double() feats.requires_grad_() input = SparseTensor(feats, coords) pool = MinkowskiGlobalAvgPooling() output = pool(input) print(output) # Check backward fn = MinkowskiGlobalPoolingFunction() self.assertTrue( gradcheck( fn, ( input.F, pool.pooling_mode, input.coordinate_map_key, output.coordinate_map_key, input._manager, ), ))
def test_broadcast(self): in_channels, D = 2, 2 coords, feats, labels = data_loader(in_channels) coords, feats_glob, labels = data_loader(in_channels) feats = feats.double() feats_glob = feats_glob.double() input = SparseTensor(feats, coords=coords) pool = MinkowskiGlobalPooling() input_glob = pool(input) input_glob.F.requires_grad_() broadcast = MinkowskiBroadcast() broadcast_cat = MinkowskiBroadcastConcatenation() broadcast_add = MinkowskiBroadcastAddition() broadcast_mul = MinkowskiBroadcastMultiplication() output = broadcast(input, input_glob) print(output) output = broadcast_cat(input, input_glob) print(output) output = broadcast_add(input, input_glob) print(output) output = broadcast_mul(input, input_glob) print(output) # Check backward fn = MinkowskiBroadcastFunction() self.assertTrue( gradcheck( fn, (input.F, input_glob.F, OperationType.ADDITION, input.coords_key, input_glob.coords_key, input.coords_man))) self.assertTrue( gradcheck( fn, (input.F, input_glob.F, OperationType.MULTIPLICATION, input.coords_key, input_glob.coords_key, input.coords_man)))
def test_analytic(self): print(f"{self.__class__.__name__}: test") in_channels, out_channels, D = 2, 2, 2 coords = torch.IntTensor([[0, 0, 0], [0, 1, 1], [0, 2, 1]]) feats = torch.FloatTensor([[0, 1], [1, 0], [1, 1]]) input = SparseTensor(feats, coordinates=coords) # Initialize context conv = MinkowskiConvolution( in_channels, out_channels, kernel_size=2, stride=2, bias=False, dimension=D ) conv.kernel[:] = torch.FloatTensor( [[[1, 2], [2, 1]], [[0, 1], [1, 0]], [[0, 1], [1, 1]], [[1, 1], [1, 0]]] ) output = conv(input) print(output) conv_tr = MinkowskiConvolutionTranspose( in_channels, out_channels, kernel_size=2, stride=2, bias=False, dimension=D ) conv_tr.kernel[:] = torch.FloatTensor( [[[1, 2], [2, 1]], [[0, 1], [1, 0]], [[0, 1], [1, 1]], [[1, 1], [1, 0]]] ) output_tr = conv_tr(output) print(output_tr)
def test_zero(self): # Issue #383 https://github.com/NVIDIA/MinkowskiEngine/issues/383 # # create point and features, all with batch 0 pc = torch.randint(-10, 10, size=(32, 4), dtype=torch.float32, device='cuda') pc[:, 0] = 0 feat = torch.randn(32, 3, dtype=torch.float32, device='cuda', requires_grad=True) # feature to interpolate x = SparseTensor(feat, pc, device='cuda') interp = MinkowskiInterpolation() # samples with original coordinates, OK for now samples = pc y = interp(x, samples) print(y.shape, y.stride()) torch.sum(y).backward() # samples with all zeros, shape is inconsistent and backward gives error samples = torch.zeros_like(pc) samples[:, 0] = 0 y = interp(x, samples) print(y.shape, y.stride()) torch.sum(y).backward()
def test(self): print(f"{self.__class__.__name__}: test") in_channels, D = 3, 2 coords, feats, labels = data_loader(in_channels, batch_size=2) # Create random coordinates with tensor stride == 2 out_coords, tensor_stride = get_random_coords() feats = feats.double() feats.requires_grad_() input = SparseTensor(feats, coords=coords) conv = MinkowskiChannelwiseConvolution(in_channels, kernel_size=3, stride=1, has_bias=False, dimension=D).double() print('Initial input: ', input) output = conv(input) print('Conv output: ', output) output.F.sum().backward() print(input.F.grad)
def test(self): in_channels, D = 2, 2 coords, feats, labels = data_loader(in_channels) feats = feats.double() feats.requires_grad_() input = SparseTensor(feats, coordinates=coords) pool = MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D) output = pool(input) print(output) # Check backward fn = MinkowskiLocalPoolingFunction() self.assertTrue( gradcheck( fn, ( input.F, pool.pooling_mode, pool.kernel_generator, input.coordinate_map_key, output.coordinate_map_key, input._manager, ), ))
def test_gpu(self): print(f"{self.__class__.__name__}: test_gpu") if not torch.cuda.is_available(): return device = torch.device('cuda') in_channels, D = 3, 2 coords, feats, labels = data_loader(in_channels, batch_size=2) # Create random coordinates with tensor stride == 2 out_coords, tensor_stride = get_random_coords() feats = feats.double() feats.requires_grad_() input = SparseTensor(feats, coords=coords).to(device) conv = MinkowskiChannelwiseConvolution(in_channels, kernel_size=3, stride=1, has_bias=False, dimension=D).double().to(device) print('Initial input: ', input) output = conv(input) print('Conv output: ', output)
def test(self): print(f"{self.__class__.__name__}: test SparseTensor") coords, feats, labels = data_loader(nchannel=2) input = SparseTensor(feats, coordinates=coords) print(input)
def test_operation_mode(self): # Set to use the global sparse tensor coords manager by default set_sparse_tensor_operation_mode( SparseTensorOperationMode.SHARE_COORDINATE_MANAGER) coords, feats, labels = data_loader(nchannel=2) # Create a sparse tensor on two different coordinates. A = SparseTensor(torch.rand(feats.shape), coordinates=coords) B = SparseTensor( torch.rand(4, 2), coordinates=torch.IntTensor([[0, 0, 0], [1, 1, 1], [0, 1, 0], [1, 0, 1]]), ) self.assertTrue(A.coordinate_manager == B.coordinate_manager) A.requires_grad_(True) B.requires_grad_(True) C = A + B C.F.sum().backward() self.assertTrue(torch.all(A.F.grad == 1).item()) self.assertTrue(torch.all(B.F.grad == 1).item()) C = A - B C = A * B C = A / B # Inplace A.requires_grad_(False) D = SparseTensor( torch.rand(feats.shape), coordinate_map_key=A.coordinate_map_key, coordinate_manager=A.coordinate_manager, ) A -= D A *= D A /= D
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration writer = SummaryWriter(log_dir=config.log_dir) data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) writer = SummaryWriter(log_dir=config.log_dir) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] model.load_state_dict(state['state_dict']) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): # Get training data data_timer.tic() if config.return_transformation: coords, input, target, pointcloud, transformation = data_iter.next( ) else: coords, input, target = data_iter.next() # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) data_time += data_timer.toc(False) # model.initialize_coords(*init_args) soutput = model(sinput) # The output of the network is not sorted target = target.long().to(device) loss = criterion(soutput.F, target.long()) # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # Update number of steps optimizer.step() scheduler.step() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs writer.add_scalar('training/loss', losses.avg, curr_iter) writer.add_scalar('training/precision_at_1', scores.avg, curr_iter) writer.add_scalar('training/learning_rate', scheduler.get_lr()[0], curr_iter) losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): device = config.device_id distributed = get_world_size() > 1 # Set up the train flag for batch normalization model.train() # Configuration writer = SummaryWriter(log_dir=config.log_dir) data_timer, iter_timer = Timer(), Timer() fw_timer, bw_timer, ddp_timer = Timer(), Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() fw_time_avg, bw_time_avg, ddp_time_avg = AverageMeter(), AverageMeter( ), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) writer = SummaryWriter(log_dir=config.log_dir) # Train the network logging.info('===> Start training on {} GPUs, batch-size={}'.format( get_world_size(), config.batch_size * get_world_size())) best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load( checkpoint_fn, map_location=lambda s, l: default_restore_location(s, 'cpu')) curr_iter = state['iteration'] + 1 epoch = state['epoch'] load_state(model, state['state_dict']) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() # (distributed) infinite sampler while is_training: for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss, batch_score = 0, 0, 0 iter_timer.tic() # set random seed for every iteration for trackability _set_seed(config, curr_iter) for sub_iter in range(config.iter_size): # Get training data data_timer.tic() coords, input, target = data_iter.next() # For some networks, making the network invariant to even, odd coords is important coords[:, :3] += (torch.rand(3) * 100).type_as(coords) # Preprocess input color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) data_time += data_timer.toc(False) # Feed forward fw_timer.tic() inputs = (sinput, ) if config.wrapper_type == 'None' else ( sinput, coords, color) # model.initialize_coords(*init_args) soutput = model(*inputs) # The output of the network is not sorted target = target.long().to(device) loss = criterion(soutput.F, target.long()) # Compute and accumulate gradient loss /= config.iter_size pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target) fw_timer.toc(False) bw_timer.tic() # bp the loss loss.backward() bw_timer.toc(False) # gather information logging_output = { 'loss': loss.item(), 'score': score / config.iter_size } ddp_timer.tic() if distributed: logging_output = all_gather_list(logging_output) logging_output = { w: np.mean([a[w] for a in logging_output]) for w in logging_output[0] } batch_loss += logging_output['loss'] batch_score += logging_output['score'] ddp_timer.toc(False) # Update number of steps optimizer.step() scheduler.step() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) fw_time_avg.update(fw_timer.diff) bw_time_avg.update(bw_timer.diff) ddp_time_avg.update(ddp_timer.diff) losses.update(batch_loss, target.size(0)) scores.update(batch_score, target.size(0)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Forward time: {:.4f}, Backward time: {:.4f}, DDP time: {:.4f}, Total iter time: {:.4f}".format( scores.avg, data_time_avg.avg, fw_time_avg.avg, bw_time_avg.avg, ddp_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs writer.add_scalar('training/loss', losses.avg, curr_iter) writer.add_scalar('training/precision_at_1', scores.avg, curr_iter) writer.add_scalar('training/learning_rate', scheduler.get_lr()[0], curr_iter) losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() if curr_iter % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() # End of iteration curr_iter += 1 epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def train_worker(gpu, num_devices, NetClass, data_loader, val_data_loader, config, transform_data_fn=None): if gpu is not None: print("Use GPU: {} for training".format(gpu)) rank = gpu addr = 23491 dist.init_process_group(backend="nccl", init_method="tcp://127.0.0.1:{}".format(addr), world_size=num_devices, rank=rank) # replace with DistributedSampler if config.multiprocess: from lib.dataloader_dist import InfSampler sampler = InfSampler(data_loader.dataset) data_loader = DataLoader(dataset=data_loader.dataset, num_workers=data_loader.num_workers, batch_size=data_loader.batch_size, collate_fn=data_loader.collate_fn, worker_init_fn=data_loader.worker_init_fn, sampler=sampler) if data_loader.dataset.NUM_IN_CHANNEL is not None: num_in_channel = data_loader.dataset.NUM_IN_CHANNEL else: num_in_channel = 3 num_labels = data_loader.dataset.NUM_LABELS # load model if config.pure_point: model = NetClass(num_class=config.num_labels, N=config.num_points, normal_channel=config.num_in_channel) else: if config.model == 'MixedTransformer': model = NetClass(config, num_class=num_labels, N=config.num_points, normal_channel=num_in_channel) elif config.model == 'MinkowskiVoxelTransformer': model = NetClass(config, num_in_channel, num_labels) elif config.model == 'MinkowskiTransformerNet': model = NetClass(config, num_in_channel, num_labels) elif "Res" in config.model: model = NetClass(num_in_channel, num_labels, config) else: model = NetClass(num_in_channel, num_labels, config) if config.weights == 'modelzoo': model.preload_modelzoo() elif config.weights.lower() != 'none': state = torch.load(config.weights) # delete the keys containing the attn since it raises size mismatch d = {k: v for k, v in state['state' '_dict'].items() if 'map' not in k} if config.weights_for_inner_model: model.model.load_state_dict(d) else: if config.lenient_weight_loading: matched_weights = load_state_with_same_shape( model, state['state_dict']) model_dict = model.state_dict() model_dict.update(matched_weights) model.load_state_dict(model_dict) else: model.load_state_dict(d, strict=False) torch.cuda.set_device(gpu) model.cuda(gpu) # use model with DDP model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[gpu], find_unused_parameters=False) # Synchronized batch norm model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() regs, losses, scores = AverageMeter(), AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) # Train the network if rank == 0: setup_logger(config) logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: # Test loaded ckpt first v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config) checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] # we skip attention maps because the shape won't match because voxel number is different # e.g. copyting a param with shape (23385, 8, 4) to (43529, 8, 4) d = { k: v for k, v in state['state_dict'].items() if 'map' not in k } # handle those attn maps we don't load from saved dict for k in model.state_dict().keys(): if k in d.keys(): continue d[k] = model.state_dict()[k] model.load_state_dict(d) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() device = gpu # multitrain fed in the device if config.dataset == "SemanticKITTI": num_class = 19 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 10 # origianl val_freq_ elif config.dataset == 'S3DIS': num_class = 13 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq elif config.dataset == "Nuscenes": num_class = 16 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 50 else: val_freq_ = config.val_freq num_class = 20 while is_training: total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() if curr_iter >= config.max_iter: # if curr_iter >= max(config.max_iter, config.epochs*(len(data_loader) // config.iter_size): is_training = False break elif curr_iter >= config.max_iter * (2 / 3): config.val_freq = val_freq_ * 2 # valid more freq on lower half for sub_iter in range(config.iter_size): # Get training data data_timer.tic() if config.return_transformation: coords, input, target, _, _, pointcloud, transformation = data_iter.next( ) else: coords, input, target, _, _ = data_iter.next( ) # ignore unique_map and inverse_map if config.use_aux: assert target.shape[1] == 2 aux = target[:, 1] target = target[:, 0] else: aux = None # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / input[:, :3].max() - 0.5 coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5 # cat xyz into the rgb feature if config.xyz_input: input = torch.cat([coords_norm, input], dim=1) # print(device) sinput = SparseTensor(input, coords, device=device) # d = {} # d['coord'] = sinput.C # d['feat'] = sinput.F # torch.save(d, 'voxel.pth') # import ipdb; ipdb.set_trace() data_time += data_timer.toc(False) # model.initialize_coords(*init_args) if aux is not None: soutput = model(sinput, aux) elif config.enable_point_branch: soutput = model(sinput, iter_=curr_iter / config.max_iter, enable_point_branch=True) else: soutput = model( sinput, iter_=curr_iter / config.max_iter ) # feed in the progress of training for annealing inside the model # soutput = model(sinput) # The output of the network is not sorted target = target.view(-1).long().to(device) loss = criterion(soutput.F, target.long()) # ====== other loss regs ===== cur_loss = torch.tensor([0.], device=device) if hasattr(model, 'module.block1'): cur_loss = torch.tensor([0.], device=device) if hasattr(model.module.block1[0], 'vq_loss'): if model.block1[0].vq_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].vq_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur vq_loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.module.block1[0], 'diverse_loss'): if model.block1[0].diverse_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].diverse_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur diverse _loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.module.block1[0], 'label_reg'): if model.block1[0].label_reg is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].label_reg # m is the nn.Sequential obj, m[0] is the TRBlock # logging.info('Cur Loss: {}, Cur diverse _loss: {}'.format(loss, cur_loss)) loss += cur_loss # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() if not config.use_sam: loss.backward() else: with model.no_sync(): loss.backward() # Update number of steps if not config.use_sam: optimizer.step() else: optimizer.first_step(zero_grad=True) soutput = model(sinput, iter_=curr_iter / config.max_iter, aux=starget) criterion(soutput.F, target.long()).backward() optimizer.second_step(zero_grad=True) if config.lr_warmup is None: scheduler.step() else: if curr_iter >= config.lr_warmup: scheduler.step() else: for g in optimizer.param_groups: g['lr'] = config.lr * (iteration + 1) / config.lr_warmup # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target, ignore_label=-1) regs.update(cur_loss.item(), target.size(0)) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # calc the train-iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target != -1)) | (target == l)).sum() if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(g['lr']) for g in optimizer.param_groups]) IoU = ((total_correct_class) / (total_iou_deno_class + 1e-6)).mean() * 100. debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tIoU {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, IoU.item(), data_time_avg.avg, iter_time_avg.avg) if regs.avg > 0: debug_str += "\n Additional Reg Loss {:.3f}".format( regs.avg) if rank == 0: logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs losses.reset() scores.reset() # only save status on the 1st gpu if rank == 0: # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, None, curr_iter, config, transform_data_fn ) # feedin None for SummaryWriter args if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val", save_inter=True) if rank == 0: logging.info( "Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) if rank == 0: logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model if rank == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) v_loss, v_score, v_mAP, val_mIoU = test(model, val_data_loader, config) if val_miou > best_val_miou and rank == 0: best_val_miou = val_miou best_val_iter = curr_iter logging.info("Final best miou: {} at iter {} ".format( val_miou, curr_iter)) checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def test_unpooling_gpu(self): if not torch.cuda.is_available(): return in_channels, out_channels, D = 2, 3, 2 coords, feats, labels = data_loader(in_channels) feats = feats.double() input = SparseTensor(feats, coords=coords) conv = MinkowskiConvolution( in_channels, out_channels, kernel_size=3, stride=2, dimension=D ) conv = conv.double() unpool = MinkowskiPoolingTranspose(kernel_size=3, stride=2, dimension=D) input = conv(input) output = unpool(input) print(output) # Check backward fn = MinkowskiPoolingTransposeFunction() self.assertTrue( gradcheck( fn, ( input.F, input.tensor_stride, unpool.stride, unpool.kernel_size, unpool.dilation, unpool.region_type_, unpool.region_offset_, False, input.coords_key, None, input.coords_man, ), ) ) device = torch.device("cuda") with torch.cuda.device(0): input = input.to(device) output = unpool(input) print(output) # Check backward self.assertTrue( gradcheck( fn, ( input.F, input.tensor_stride, unpool.stride, unpool.kernel_size, unpool.dilation, unpool.region_type_, unpool.region_offset_, True, input.coords_key, None, input.coords_man, ), ) )
def test(model, data_loader, config, transform_data_fn=None, has_gt=True, validation=None, epoch=None): device = get_torch_device(config.is_cuda) dataset = data_loader.dataset num_labels = dataset.NUM_LABELS global_timer, data_timer, iter_timer = Timer(), Timer(), Timer() criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) alpha, gamma, eps = 1, 2, 1e-6 # Focal Loss parameters losses, scores, ious = AverageMeter(), AverageMeter(), 0 aps = np.zeros((0, num_labels)) hist = np.zeros((num_labels, num_labels)) if not config.is_train: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) model.load_state_dict(state['state_dict']) logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) if validation: logging.info('===> Start validating') else: logging.info('===> Start testing') global_timer.tic() data_iter = data_loader.__iter__() max_iter = len(data_loader) max_iter_unique = max_iter all_preds, all_labels, batch_losses, batch_loss = [], [], {}, 0 # Fix batch normalization running mean and std model.eval() # Clear cache (when run in val mode, cleanup training cache) torch.cuda.empty_cache() if config.save_prediction or config.test_original_pointcloud: if config.save_prediction: save_pred_dir = config.save_pred_dir os.makedirs(save_pred_dir, exist_ok=True) else: save_pred_dir = tempfile.mkdtemp() if os.listdir(save_pred_dir): raise ValueError(f'Directory {save_pred_dir} not empty. ' 'Please remove the existing prediction.') with torch.no_grad(): for iteration in range(max_iter): data_timer.tic() if config.return_transformation: coords, input, target, transformation = data_iter.next() else: coords, input, target = data_iter.next() transformation = None data_time = data_timer.toc(False) # Preprocess input iter_timer.tic() if config.wrapper_type != 'None': color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) # Feed forward inputs = (sinput, ) if config.wrapper_type == 'None' else (sinput, coords, color) soutput = model(*inputs) output = soutput.F pred = get_prediction(dataset, output, target).int() iter_time = iter_timer.toc(False) all_preds.append(pred.cpu().detach().numpy()) all_labels.append(target.cpu().detach().numpy()) if config.save_prediction or config.test_original_pointcloud: save_predictions(coords, pred, transformation, dataset, config, iteration, save_pred_dir) if has_gt: if config.evaluate_original_pointcloud: raise NotImplementedError('pointcloud') output, pred, target = permute_pointcloud( coords, pointcloud, transformation, dataset.label_map, output, pred) target_np = target.numpy() num_sample = target_np.shape[0] target = target.to(device) """# focal loss input_soft = nn.functional.softmax(output, dim=1) + eps focal_weight = torch.pow(-input_soft + 1., gamma) loss = (-alpha * focal_weight * torch.log(input_soft)).mean()""" loss = criterion(output, target.long()) batch_loss += loss losses.update(float(loss), num_sample) scores.update(precision_at_one(pred, target), num_sample) hist += fast_hist(pred.cpu().numpy().flatten(), target_np.flatten(), num_labels) ious = per_class_iu(hist) * 100 prob = torch.nn.functional.softmax(output, dim=1) ap = average_precision(prob.cpu().detach().numpy(), target_np) aps = np.vstack((aps, ap)) # Due to heavy bias in class, there exists class with no test label at all with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) ap_class = np.nanmean(aps, 0) * 100. if iteration % config.test_stat_freq == 0 and iteration > 0: preds = np.concatenate(all_preds) targets = np.concatenate(all_labels) to_ignore = [ i for i in range(len(targets)) if targets[i] == 255 ] preds_trunc = [ preds[i] for i in range(len(preds)) if i not in to_ignore ] targets_trunc = [ targets[i] for i in range(len(targets)) if i not in to_ignore ] cm = confusion_matrix(targets_trunc, preds_trunc, normalize='true') np.savetxt(config.log_dir + '/cm_epoch_{0}.txt'.format(epoch), cm) reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if iteration % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() batch_losses[epoch] = batch_loss global_time = global_timer.toc(False) reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if not config.is_train: preds = np.concatenate(all_preds) targets = np.concatenate(all_labels) to_ignore = [i for i in range(len(targets)) if targets[i] == 255] preds_trunc = [ preds[i] for i in range(len(preds)) if i not in to_ignore ] targets_trunc = [ targets[i] for i in range(len(targets)) if i not in to_ignore ] cm = confusion_matrix(targets_trunc, preds_trunc, normalize='true') np.savetxt(config.log_dir + '/cm.txt', cm) if config.test_original_pointcloud: logging.info('===> Start testing on original pointcloud space.') dataset.test_pointcloud(save_pred_dir) logging.info("Finished test. Elapsed time: {:.4f}".format(global_time)) if validation: loss_file_name = "/val_loss.txt" with open(config.log_dir + loss_file_name, 'a') as val_loss_file: for key in batch_losses: val_loss_file.writelines('{0}, {1}\n'.format( batch_losses[key], key)) val_loss_file.close() return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean( per_class_iu(hist)) * 100, batch_losses else: return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean( per_class_iu(hist)) * 100
def forward(self, x: ME.SparseTensor): return x.dense(self.min_coords, self.max_coords)[0]
def visualize(self, options, model: Model, writer: SummaryWriter, step): training = model.training model.eval() vis_config = self.config['vis'] if vis_config.get('num_scene_samples'): # sample k data points from n data points with equal interval n = len(self) k = vis_config.get('num_scene_samples') vis_indices = torch.linspace(0, n - 1, k) \ .type(torch.IntTensor).tolist() else: vis_indices = [self.dir2idx[i] for i in vis_config.get('scene_names')] if self.config['overfit_one_ex']: vis_scene = self.config['overfit_one_ex'] vis_indices = [self.dir2idx[vis_scene]] vis_indices = list(set(vis_indices)) for i in vis_indices: coords, feats, labels, _ = self[i] coords, feats, = sparse_collate([coords], [feats]) x = SparseTensor(feats, coords) x = x.to(model.device) with torch.no_grad(): y_hat = model(x) embs = y_hat insts = labels[:, 1] for option in options: # visualize tsne if option == 'tsne': tsne_img = visualization.visualize_tsne( embs.cpu(), insts.cpu(), config=self.config['vis']['tsne'] ) writer.add_image('tsne/{}'.format(self.idx2dir[i]), tsne_img, step) elif option == 'embs': vis_config = self.config['vis']['embs'] # visualize embs with background emb_imgs, axis_range = visualization.visualize_embs( embs.cpu(), insts.cpu(), remove_bg=False, max_sample=vis_config['max_sample'], num_view=vis_config['num_view'] ) for view_num, img in enumerate(emb_imgs): writer.add_image( 'emb/with_bg/{}_{}'.format(self.idx2dir[i], view_num), img, step ) # visualize embs without background not_bg_emb_imgs, _ = visualization.visualize_embs( embs.cpu(), insts.cpu(), remove_bg=True, max_sample=vis_config['max_sample'], num_view=vis_config['num_view'], axis_range=axis_range ) for view_num, img in enumerate(not_bg_emb_imgs): writer.add_image( 'emb/no_bg/{}_{}'.format(self.idx2dir[i], view_num), img, step ) model.train(training)
def test_empty(self): print(f"{self.__class__.__name__}: test_empty SparseTensor") feats = torch.FloatTensor(0, 16) coords = torch.IntTensor(0, 4) input = SparseTensor(feats, coordinates=coords) print(input)
def collate_fn(self, batch): coords, features, labels = list(zip(*batch)) coords, features, labels = sparse_collate(coords, features, labels) return SparseTensor(features, coords=coords), labels
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): all_losses = [] device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration writer = SummaryWriter(log_dir=config.log_dir) data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores, batch_losses = AverageMeter(), AverageMeter(), {} optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) alpha, gamma, eps = 1, 2, 1e-6 writer = SummaryWriter(log_dir=config.log_dir) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] model.load_state_dict(state['state_dict']) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() while is_training: print( "********************************** epoch N° {0} ************************" .format(epoch)) for iteration in range(len(data_loader) // config.iter_size): print("####### Iteration N° {0}".format(iteration)) optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): print("------------- Sub_iteration N° {0}".format(sub_iter)) # Get training data data_timer.tic() coords, input, target = data_iter.next() print("len of coords : {0}".format(len(coords))) # For some networks, making the network invariant to even, odd coords is important coords[:, :3] += (torch.rand(3) * 100).type_as(coords) # Preprocess input color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) data_time += data_timer.toc(False) # Feed forward inputs = (sinput, ) if config.wrapper_type == 'None' else ( sinput, coords, color) # model.initialize_coords(*init_args) soutput = model(*inputs) # The output of the network is not sorted target = target.long().to(device) print("count of classes : {0}".format( np.unique(target.cpu().numpy(), return_counts=True))) print("target : {0}\ntarget_len : {1}".format( target, len(target))) print("target [0]: {0}".format(target[0])) input_soft = nn.functional.softmax(soutput.F, dim=1) + eps print("input_soft[0] : {0}".format(input_soft[0])) focal_weight = torch.pow(-input_soft + 1., gamma) print("focal_weight : {0}\nweight[0] : {1}".format( focal_weight, focal_weight[0])) focal_loss = (-alpha * focal_weight * torch.log(input_soft)).mean() loss = criterion(soutput.F, target.long()) print("focal_loss :{0}\nloss : {1}".format(focal_loss, loss)) # Compute and accumulate gradient loss /= config.iter_size #batch_loss += loss batch_loss += loss.item() print("batch_loss : {0}".format(batch_loss)) loss.backward() # Update number of steps optimizer.step() scheduler.step() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Total iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs writer.add_scalar('training/loss', losses.avg, curr_iter) writer.add_scalar('training/precision_at_1', scores.avg, curr_iter) writer.add_scalar('training/learning_rate', scheduler.get_lr()[0], curr_iter) losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) # Validation if curr_iter % config.val_freq == 0: val_miou, val_losses = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn, epoch) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() if curr_iter % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() batch_losses[epoch] = batch_loss # End of iteration curr_iter += 1 with open(config.log_dir + "/train_loss.txt", 'a') as train_loss_log: train_loss_log.writelines('{0}, {1}\n'.format( batch_losses[epoch], epoch)) train_loss_log.close() epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) val_miou = validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn, epoch)[0] if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def forward(self, x): out = self.conv0p1s1(x) out = self.bn0(out) out_p1 = self.relu(out) out = self.conv1p1s2(out_p1) out = self.bn1(out) out = self.relu(out) out_b1p2 = self.block1(out) out = self.conv2p2s2(out_b1p2) out = self.bn2(out) out = self.relu(out) out_b2p4 = self.block2(out) out = self.conv3p4s2(out_b2p4) out = self.bn3(out) out = self.relu(out) out_b3p8 = self.block3(out) out = self.conv4p8s2(out_b3p8) out = self.bn4(out) out = self.relu(out) encoder_out = self.block4(out) out = self.convtr4p16s2(encoder_out) out = self.bntr4(out) out = self.relu(out) out = me.cat(out, out_b3p8) out = self.block5(out) out = self.convtr5p8s2(out) out = self.bntr5(out) out = self.relu(out) out = me.cat(out, out_b2p4) out = self.block6(out) out = self.convtr6p4s2(out) out = self.bntr6(out) out = self.relu(out) out = me.cat(out, out_b1p2) out = self.block7(out) out = self.convtr7p2s2(out) out = self.bntr7(out) out = self.relu(out) out = me.cat(out, out_p1) out = self.block8(out) out = self.final(out) if self.normalize_feature: return SparseTensor(out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), coords_key=out.coords_key, coords_manager=out.coords_man) else: return out
def train(model, data_loader, val_data_loader, config, transform_data_fn=None): device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() regs, losses, scores = AverageMeter(), AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True if config.resume: # Test loaded ckpt first v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config) checkpoint_fn = config.resume + '/weights.pth' if osp.isfile(checkpoint_fn): logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) curr_iter = state['iteration'] + 1 epoch = state['epoch'] # we skip attention maps because the shape won't match because voxel number is different # e.g. copyting a param with shape (23385, 8, 4) to (43529, 8, 4) d = { k: v for k, v in state['state_dict'].items() if 'map' not in k } # handle those attn maps we don't load from saved dict for k in model.state_dict().keys(): if k in d.keys(): continue d[k] = model.state_dict()[k] model.load_state_dict(d) if config.resume_optimizer: scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) optimizer.load_state_dict(state['optimizer']) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) else: raise ValueError( "=> no checkpoint found at '{}'".format(checkpoint_fn)) data_iter = data_loader.__iter__() if config.dataset == "SemanticKITTI": num_class = 19 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 10 elif config.dataset == "S3DIS": num_class = 13 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq elif config.dataset == "Nuscenes": num_class = 16 config.normalize_color = False config.xyz_input = False val_freq_ = config.val_freq config.val_freq = config.val_freq * 50 else: num_class = 20 val_freq_ = config.val_freq while is_training: total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) for iteration in range(len(data_loader) // config.iter_size): optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() if curr_iter >= config.max_iter: # if curr_iter >= max(config.max_iter, config.epochs*(len(data_loader) // config.iter_size): is_training = False break elif curr_iter >= config.max_iter * (2 / 3): config.val_freq = val_freq_ * 2 # valid more freq on lower half for sub_iter in range(config.iter_size): # Get training data data_timer.tic() pointcloud = None if config.return_transformation: coords, input, target, _, _, pointcloud, transformation, _ = data_iter.next( ) else: coords, input, target, _, _, _ = data_iter.next( ) # ignore unique_map and inverse_map if config.use_aux: assert target.shape[1] == 2 aux = target[:, 1] target = target[:, 0] else: aux = None # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / input[:, :3].max() - 0.5 coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5 # cat xyz into the rgb feature if config.xyz_input: input = torch.cat([coords_norm, input], dim=1) sinput = SparseTensor(input, coords, device=device) starget = SparseTensor( target.unsqueeze(-1).float(), coordinate_map_key=sinput.coordinate_map_key, coordinate_manager=sinput.coordinate_manager, device=device ) # must share the same coord-manager to align for sinput data_time += data_timer.toc(False) # model.initialize_coords(*init_args) # d = {} # d['c'] = sinput.C # d['l'] = starget.F # torch.save('./plot/test-label.pth') # import ipdb; ipdb.set_trace() # Set up profiler # memory_profiler = CUDAMemoryProfiler( # [model, criterion], # filename="cuda_memory.profile" # ) # sys.settrace(memory_profiler) # threading.settrace(memory_profiler) # with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=True) as prof0: if aux is not None: soutput = model(sinput, aux) elif config.enable_point_branch: soutput = model(sinput, iter_=curr_iter / config.max_iter, enable_point_branch=True) else: # label-aux, feed it in as additional reg soutput = model( sinput, iter_=curr_iter / config.max_iter, aux=starget ) # feed in the progress of training for annealing inside the model # The output of the network is not sorted target = target.view(-1).long().to(device) loss = criterion(soutput.F, target.long()) # ====== other loss regs ===== if hasattr(model, 'block1'): cur_loss = torch.tensor([0.], device=device) if hasattr(model.block1[0], 'vq_loss'): if model.block1[0].vq_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].vq_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur vq_loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.block1[0], 'diverse_loss'): if model.block1[0].diverse_loss is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].diverse_loss # m is the nn.Sequential obj, m[0] is the TRBlock logging.info( 'Cur Loss: {}, Cur diverse _loss: {}'.format( loss, cur_loss)) loss += cur_loss if hasattr(model.block1[0], 'label_reg'): if model.block1[0].label_reg is not None: cur_loss = torch.tensor([0.], device=device) for n, m in model.named_children(): if 'block' in n: cur_loss += m[ 0].label_reg # m is the nn.Sequential obj, m[0] is the TRBlock # logging.info('Cur Loss: {}, Cur diverse _loss: {}'.format(loss, cur_loss)) loss += cur_loss # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # soutput = model(sinput) # Update number of steps if not config.use_sam: optimizer.step() else: optimizer.first_step(zero_grad=True) soutput = model(sinput, iter_=curr_iter / config.max_iter, aux=starget) criterion(soutput.F, target.long()).backward() optimizer.second_step(zero_grad=True) if config.lr_warmup is None: scheduler.step() else: if curr_iter >= config.lr_warmup: scheduler.step() for g in optimizer.param_groups: g['lr'] = config.lr * (iteration + 1) / config.lr_warmup # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target, ignore_label=-1) regs.update(cur_loss.item(), target.size(0)) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # calc the train-iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target != -1)) | (target == l)).sum() if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) IoU = ((total_correct_class) / (total_iou_deno_class + 1e-6)).mean() * 100. debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( config.log_dir.split('/')[-2], epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tIoU {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, IoU.item(), data_time_avg.avg, iter_time_avg.avg) if regs.avg > 0: debug_str += "\n Additional Reg Loss {:.3f}".format( regs.avg) # print(debug_str) logging.info(debug_str) # Reset timers data_time_avg.reset() iter_time_avg.reset() # Write logs losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, None, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val", save_inter=True) logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # print("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))
def test(model, data_loader, config, transform_data_fn=None, has_gt=True): device = get_torch_device(config.is_cuda) dataset = data_loader.dataset num_labels = dataset.NUM_LABELS global_timer, data_timer, iter_timer = Timer(), Timer(), Timer() criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) losses, scores, ious = AverageMeter(), AverageMeter(), 0 aps = np.zeros((0, num_labels)) hist = np.zeros((num_labels, num_labels)) logging.info('===> Start testing') global_timer.tic() data_iter = data_loader.__iter__() max_iter = len(data_loader) max_iter_unique = max_iter # Fix batch normalization running mean and std model.eval() # Clear cache (when run in val mode, cleanup training cache) torch.cuda.empty_cache() if config.save_prediction or config.test_original_pointcloud: if config.save_prediction: save_pred_dir = config.save_pred_dir os.makedirs(save_pred_dir, exist_ok=True) else: save_pred_dir = tempfile.mkdtemp() if os.listdir(save_pred_dir): raise ValueError(f'Directory {save_pred_dir} not empty. ' 'Please remove the existing prediction.') with torch.no_grad(): for iteration in range(max_iter): data_timer.tic() if config.return_transformation: coords, input, target, transformation = data_iter.next() else: coords, input, target = data_iter.next() transformation = None data_time = data_timer.toc(False) # Preprocess input iter_timer.tic() if config.wrapper_type != 'None': color = input[:, :3].int() if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 sinput = SparseTensor(input, coords).to(device) # Feed forward inputs = (sinput, ) if config.wrapper_type == 'None' else (sinput, coords, color) soutput = model(*inputs) output = soutput.F pred = get_prediction(dataset, output, target).int() iter_time = iter_timer.toc(False) if config.save_prediction or config.test_original_pointcloud: save_predictions(coords, pred, transformation, dataset, config, iteration, save_pred_dir) if has_gt: if config.evaluate_original_pointcloud: raise NotImplementedError('pointcloud') output, pred, target = permute_pointcloud( coords, pointcloud, transformation, dataset.label_map, output, pred) target_np = target.numpy() num_sample = target_np.shape[0] target = target.to(device) cross_ent = criterion(output, target.long()) losses.update(float(cross_ent), num_sample) scores.update(precision_at_one(pred, target), num_sample) hist += fast_hist(pred.cpu().numpy().flatten(), target_np.flatten(), num_labels) ious = per_class_iu(hist) * 100 prob = torch.nn.functional.softmax(output, dim=1) ap = average_precision(prob.cpu().detach().numpy(), target_np) aps = np.vstack((aps, ap)) # Due to heavy bias in class, there exists class with no test label at all with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) ap_class = np.nanmean(aps, 0) * 100. if iteration % config.test_stat_freq == 0 and iteration > 0: reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if iteration % config.empty_cache_freq == 0: # Clear cache torch.cuda.empty_cache() global_time = global_timer.toc(False) reordered_ious = dataset.reorder_result(ious) reordered_ap_class = dataset.reorder_result(ap_class) class_names = dataset.get_classnames() print_info(iteration, max_iter_unique, data_time, iter_time, has_gt, losses, scores, reordered_ious, hist, reordered_ap_class, class_names=class_names) if config.test_original_pointcloud: logging.info('===> Start testing on original pointcloud space.') dataset.test_pointcloud(save_pred_dir) logging.info("Finished test. Elapsed time: {:.4f}".format(global_time)) return losses.avg, scores.avg, np.nanmean(ap_class), np.nanmean( per_class_iu(hist)) * 100
def train_distill(model, data_loader, val_data_loader, config, transform_data_fn=None): ''' the distillation training some cfgs here ''' # distill_lambda = 1 # distill_lambda = 0.33 distill_lambda = 0.67 # TWO_STAGE=True: Transformer is first trained with L2 loss to match ResNet's activation, and then it fintunes like normal training on the second stage. # TWO_STAGE=False: Transformer trains with combined loss TWO_STAGE = False # STAGE_PERCENTAGE = 0.7 device = get_torch_device(config.is_cuda) # Set up the train flag for batch normalization model.train() # Configuration data_timer, iter_timer = Timer(), Timer() data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() losses, scores = AverageMeter(), AverageMeter() optimizer = initialize_optimizer(model.parameters(), config) scheduler = initialize_scheduler(optimizer, config) criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) # Train the network logging.info('===> Start training') best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True # TODO: # load the sub-model only # FIXME: some dirty hard-written stuff, only supporting current state tch_model_cls = load_model('Res16UNet18A') tch_model = tch_model_cls(3, 20, config).to(device) # checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/resnet_base/weights.pth" checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/Res18A/weights.pth" # voxel-size: 0.05 assert osp.isfile(checkpoint_fn) logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) d = {k: v for k, v in state['state_dict'].items() if 'map' not in k} tch_model.load_state_dict(d) if 'best_val' in state: best_val_miou = state['best_val'] best_val_iter = state['best_val_iter'] logging.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_fn, state['epoch'])) if config.resume: raise NotImplementedError # Test loaded ckpt first # checkpoint_fn = config.resume + '/weights.pth' # if osp.isfile(checkpoint_fn): # logging.info("=> loading checkpoint '{}'".format(checkpoint_fn)) # state = torch.load(checkpoint_fn) # curr_iter = state['iteration'] + 1 # epoch = state['epoch'] # d = {k:v for k,v in state['state_dict'].items() if 'map' not in k } # model.load_state_dict(d) # if config.resume_optimizer: # scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter) # optimizer.load_state_dict(state['optimizer']) # if 'best_val' in state: # best_val_miou = state['best_val'] # best_val_iter = state['best_val_iter'] # logging.info("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_fn, state['epoch'])) # else: # raise ValueError("=> no checkpoint found at '{}'".format(checkpoint_fn)) # test after loading the ckpt v_loss, v_score, v_mAP, v_mIoU = test(tch_model, val_data_loader, config) logging.info('Tch model tested, bes_miou: {}'.format(v_mIoU)) data_iter = data_loader.__iter__() while is_training: num_class = 20 total_correct_class = torch.zeros(num_class, device=device) total_iou_deno_class = torch.zeros(num_class, device=device) total_iteration = len(data_loader) // config.iter_size for iteration in range(total_iteration): # NOTE: for single stage distillation, L2 loss might be too large at first # so we added a warmup training that don't use L2 loss if iteration < 0: use_distill = False else: use_distill = True # Stage 1 / Stage 2 boundary if TWO_STAGE: stage_boundary = int(total_iteration * STAGE_PERCENTAGE) optimizer.zero_grad() data_time, batch_loss = 0, 0 iter_timer.tic() for sub_iter in range(config.iter_size): # Get training data data_timer.tic() if config.return_transformation: coords, input, target, _, _, pointcloud, transformation = data_iter.next( ) else: coords, input, target, _, _ = data_iter.next( ) # ignore unique_map and inverse_map if config.use_aux: assert target.shape[1] == 2 aux = target[:, 1] target = target[:, 0] else: aux = None # For some networks, making the network invariant to even, odd coords is important coords[:, 1:] += (torch.rand(3) * 100).type_as(coords) # Preprocess input if config.normalize_color: input[:, :3] = input[:, :3] / 255. - 0.5 coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5 # cat xyz into the rgb feature if config.xyz_input: input = torch.cat([coords_norm, input], dim=1) sinput = SparseTensor(input, coords, device=device) # TODO: return both-models # in order to not breaking the valid interface, use a get_loss to get the regsitered loss data_time += data_timer.toc(False) # model.initialize_coords(*init_args) if aux is not None: raise NotImplementedError # flatten ground truth tensor target = target.view(-1).long().to(device) if TWO_STAGE: if iteration < stage_boundary: # Stage 1: train transformer on L2 loss soutput, anchor = model(sinput, save_anchor=True) # Make sure gradient don't flow to teacher model with torch.no_grad(): _, tch_anchor = tch_model(sinput, save_anchor=True) loss = DistillLoss(tch_anchor, anchor) else: # Stage 2: finetune transformer on Cross-Entropy soutput = model(sinput) loss = criterion(soutput.F, target.long()) else: if use_distill: # after warm up soutput, anchor = model(sinput, save_anchor=True) # if pretrained teacher, do not let the grad flow to teacher to update its params with torch.no_grad(): tch_soutput, tch_anchor = tch_model( sinput, save_anchor=True) else: # warming up soutput = model(sinput) # The output of the network is not sorted loss = criterion(soutput.F, target.long()) # Add L2 loss if use distillation if use_distill: distill_loss = DistillLoss(tch_anchor, anchor) * distill_lambda loss += distill_loss # Compute and accumulate gradient loss /= config.iter_size batch_loss += loss.item() loss.backward() # Update number of steps optimizer.step() scheduler.step() # CLEAR CACHE! torch.cuda.empty_cache() data_time_avg.update(data_time) iter_time_avg.update(iter_timer.toc(False)) pred = get_prediction(data_loader.dataset, soutput.F, target) score = precision_at_one(pred, target, ignore_label=-1) losses.update(batch_loss, target.size(0)) scores.update(score, target.size(0)) # calc the train-iou for l in range(num_class): total_correct_class[l] += ((pred == l) & (target == l)).sum() total_iou_deno_class[l] += (((pred == l) & (target != -1)) | (target == l)).sum() if curr_iter >= config.max_iter: is_training = False break if curr_iter % config.stat_freq == 0 or curr_iter == 1: lrs = ', '.join( ['{:.3e}'.format(x) for x in scheduler.get_lr()]) debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format( config.log_dir, epoch, curr_iter, len(data_loader) // config.iter_size, losses.avg, lrs) debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format( scores.avg, data_time_avg.avg, iter_time_avg.avg) logging.info(debug_str) if use_distill and not TWO_STAGE: logging.info('Loss {} Distill Loss:{}'.format( loss, distill_loss)) # Reset timers data_time_avg.reset() iter_time_avg.reset() losses.reset() scores.reset() # Save current status, save before val to prevent occational mem overflow if curr_iter % config.save_freq == 0: checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, save_inter=True) # Validation if curr_iter % config.val_freq == 0: val_miou = validate(model, val_data_loader, None, curr_iter, config, transform_data_fn) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val", save_inter=True) logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter)) # Recover back model.train() # End of iteration curr_iter += 1 IoU = (total_correct_class) / (total_iou_deno_class + 1e-6) logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.)) epoch += 1 # Explicit memory cleanup if hasattr(data_iter, 'cleanup'): data_iter.cleanup() # Save the final model checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter) v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config) if val_miou > best_val_miou: best_val_miou = val_miou best_val_iter = curr_iter checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter, "best_val") logging.info("Current best mIoU: {:.3f} at iter {}".format( best_val_miou, best_val_iter))