def train(args): experiment_dir = mkdir('./experiment/') checkpoints_dir = mkdir('./experiment/semseg/%s/'%(args.model_name)) train_data, train_label, test_data, test_label = _load() dataset = S3DISDataLoader(train_data, train_label, data_augmentation = args.augment) dataloader = DataLoader(dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers) test_dataset = S3DISDataLoader(test_data, test_label) testdataloader = DataLoader(test_dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers) num_classes = 13 if args.model_name == 'pointnet': model = PointNetSeg(num_classes, feature_transform=True, input_dims = 9) else: model = PointNet2SemSeg(num_classes, feature_dims = 6) torch.backends.cudnn.benchmark = True model = torch.nn.DataParallel(model).cuda() log.debug('Using gpu:',args.gpu) if args.pretrain is not None: log.debug('Use pretrain model...') model.load_state_dict(torch.load(args.pretrain)) init_epoch = int(args.pretrain[:-4].split('-')[-1]) log.debug('start epoch from', init_epoch) else: log.debug('Training from scratch') init_epoch = 0 if args.optimizer == 'SGD': optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) elif args.optimizer == 'Adam': optimizer = torch.optim.Adam( model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.decay_rate) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) LEARNING_RATE_CLIP = 1e-5 history = {'loss':[]} best_acc = 0 best_meaniou = 0 for epoch in range(init_epoch,args.epoch): scheduler.step() lr = max(optimizer.param_groups[0]['lr'],LEARNING_RATE_CLIP) log.info(job='semseg',model=args.model_name,gpu=args.gpu,epoch='%d/%s' % (epoch, args.epoch),lr=lr) for param_group in optimizer.param_groups: param_group['lr'] = lr for points, target in tqdm(dataloader, total=len(dataloader), smoothing=0.9, dynamic_ncols=True): points, target = points.float(), target.long() points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() optimizer.zero_grad() model = model.train() if args.model_name == 'pointnet': pred, trans_feat = model(points) else: pred = model(points) pred = pred.contiguous().view(-1, num_classes) target = target.view(-1, 1)[:, 0] loss = F.nll_loss(pred, target) if args.model_name == 'pointnet': loss += feature_transform_reguliarzer(trans_feat) * 0.001 history['loss'].append(loss.cpu().data.numpy()) loss.backward() optimizer.step() log.debug('clear cuda cache') torch.cuda.empty_cache() test_metrics, cat_mean_iou = test_semseg( model.eval(), testdataloader, label_id_to_name, args.model_name, num_classes, ) mean_iou = np.mean(cat_mean_iou) save_model = False if test_metrics['accuracy'] > best_acc: best_acc = test_metrics['accuracy'] if mean_iou > best_meaniou: best_meaniou = mean_iou save_model = True if save_model: fn_pth = 'semseg-%s-%.5f-%04d.pth' % (args.model_name, best_meaniou, epoch) log.info('Save model...',fn = fn_pth) torch.save(model.state_dict(), os.path.join(checkpoints_dir, fn_pth)) log.warn(cat_mean_iou) else: log.info('No need to save model') log.warn(cat_mean_iou) log.warn('Curr',accuracy=test_metrics['accuracy'], meanIOU=mean_iou) log.warn('Best',accuracy=best_acc, meanIOU=best_meaniou)
def train(args): experiment_dir = mkdir('experiment/') checkpoints_dir = mkdir('experiment/%s/' % (args.model_name)) kitti_utils = Semantic_KITTI_Utils(KITTI_ROOT, subset=args.subset) class_names = kitti_utils.class_names num_classes = kitti_utils.num_classes if args.subset == 'inview': train_npts = 8000 test_npts = 24000 if args.subset == 'all': train_npts = 50000 test_npts = 100000 log.info(subset=args.subset, train_npts=train_npts, test_npts=test_npts) dataset = SemKITTI_Loader(KITTI_ROOT, train_npts, train=True, subset=args.subset) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) test_dataset = SemKITTI_Loader(KITTI_ROOT, test_npts, train=False, subset=args.subset) testdataloader = DataLoader(test_dataset, batch_size=int(args.batch_size / 2), shuffle=False, num_workers=args.workers, pin_memory=True) if args.model_name == 'pointnet': model = PointNetSeg(num_classes, input_dims=4, feature_transform=True) else: model = PointNet2SemSeg(num_classes, feature_dims=1) if args.optimizer == 'SGD': optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) elif args.optimizer == 'Adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4) torch.backends.cudnn.benchmark = True model = torch.nn.DataParallel(model) model.cuda() log.info('Using gpu:', args.gpu) if args.pretrain is not None: log.info('Use pretrain model...') model.load_state_dict(torch.load(args.pretrain)) init_epoch = int(args.pretrain[:-4].split('-')[-1]) log.info('Restart training', epoch=init_epoch) else: log.msg('Training from scratch') init_epoch = 0 best_acc = 0 best_miou = 0 for epoch in range(init_epoch, args.epoch): model.train() lr = calc_decay(args.learning_rate, epoch) log.info(subset=args.subset, model=args.model_name, gpu=args.gpu, epoch=epoch, lr=lr) for param_group in optimizer.param_groups: param_group['lr'] = lr for points, target in tqdm(dataloader, total=len(dataloader), smoothing=0.9, dynamic_ncols=True): points = points.float().transpose(2, 1).cuda() target = target.long().cuda() if args.model_name == 'pointnet': logits, trans_feat = model(points) else: logits = model(points) #logits = logits.contiguous().view(-1, num_classes) #target = target.view(-1, 1)[:, 0] #loss = F.nll_loss(logits, target) logits = logits.transpose(2, 1) loss = nn.CrossEntropyLoss()(logits, target) if args.model_name == 'pointnet': loss += feature_transform_reguliarzer(trans_feat) * 0.001 optimizer.zero_grad() loss.backward() optimizer.step() torch.cuda.empty_cache() acc, miou = test_kitti_semseg(model.eval(), testdataloader, args.model_name, num_classes, class_names) save_model = False if acc > best_acc: best_acc = acc if miou > best_miou: best_miou = miou save_model = True if save_model: fn_pth = '%s-%s-%.5f-%04d.pth' % (args.model_name, args.subset, best_miou, epoch) log.info('Save model...', fn=fn_pth) torch.save(model.state_dict(), os.path.join(checkpoints_dir, fn_pth)) else: log.info('No need to save model') log.warn('Curr', accuracy=acc, mIOU=miou) log.warn('Best', accuracy=best_acc, mIOU=best_miou)