def main(): global args, logger args = get_parser() logger = get_logger() logger.info(args) assert args.classes > 1 logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) if args.arch == 'pointnet_seg': from model.pointnet.pointnet import PointNetSeg as Model elif args.arch == 'pointnet2_seg': from model.pointnet2.pointnet2_seg import PointNet2SSGSeg as Model elif args.arch == 'pointweb_seg': from model.pointweb.pointweb_seg import PointWebSeg as Model elif args.arch == 'pointweb_cls': from model.pointweb.pointweb_cls import PointWebCls as Model else: raise Exception('architecture not supported yet'.format(args.arch)) model = Model(c=args.fea_dim, k=args.classes, use_xyz=args.use_xyz) model = torch.nn.DataParallel(model.cuda()) logger.info(model) criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda() names = [line.rstrip('\n') for line in open(args.names_path)] if os.path.isfile(args.model_path): logger.info("=> loading checkpoint '{}'".format(args.model_path)) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict'], strict=False) logger.info("=> loaded checkpoint '{}'".format(args.model_path)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.model_path)) test(model, criterion, names)
def main(): init() if args.arch == 'pointnet_seg': from model.pointnet.pointnet import PointNetSeg as Model elif args.arch == 'pointnet2_seg': from model.pointnet2.pointnet2_seg import PointNet2SSGSeg as Model elif args.arch == 'pointweb_seg': from model.pointweb.pointweb_seg import PointWebSeg as Model else: raise Exception('architecture not supported yet'.format(args.arch)) model = Model(c=args.fea_dim, k=args.classes, use_xyz=args.use_xyz) if args.sync_bn: from util.util import convert_to_syncbn from lib.sync_bn import patch_replication_callback convert_to_syncbn(model), patch_replication_callback(model) criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_epoch, gamma=args.multiplier) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) logger.info(model) model = torch.nn.DataParallel(model.cuda()) if args.weight: if os.path.isfile(args.weight): logger.info("=> loading weight '{}'".format(args.weight)) checkpoint = torch.load(args.weight) model.load_state_dict(checkpoint['state_dict']) logger.info("=> loaded weight '{}'".format(args.weight)) else: logger.info("=> no weight found at '{}'".format(args.weight)) if args.resume: if os.path.isfile(args.resume): logger.info("=> loading checkpoint '{}'".format(args.resume)) # checkpoint = torch.load(args.resume) checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: logger.info("=> no checkpoint found at '{}'".format(args.resume)) train_transform = transform.Compose([transform.ToTensor()]) if args.data_name == 's3dis': train_data = S3DIS(split='train', data_root=args.train_full_folder, num_point=args.num_point, test_area=args.test_area, block_size=args.block_size, sample_rate=args.sample_rate, transform=train_transform) # train_data = dataset.PointData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform) elif args.data_name == 'scannet': train_data = ScanNet(split='train', data_root=args.data_root, num_point=args.num_point, block_size=args.block_size, sample_rate=args.sample_rate, transform=train_transform) elif args.data_name == 'modelnet40': train_data = dataset.PointData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform, num_point=args.num_point, random_index=True) train_loader = torch.utils.data.DataLoader( train_data, batch_size=args.train_batch_size, shuffle=True, num_workers=args.train_workers, pin_memory=True) val_loader = None if args.evaluate: val_transform = transform.Compose([transform.ToTensor()]) val_data = dataset.PointData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform) val_loader = torch.utils.data.DataLoader( val_data, batch_size=args.train_batch_size_val, shuffle=False, num_workers=args.train_workers, pin_memory=True) for epoch in range(args.start_epoch, args.epochs): scheduler.step() loss_train, mIoU_train, mAcc_train, allAcc_train = train( train_loader, model, criterion, optimizer, epoch) epoch_log = epoch + 1 writer.add_scalar('loss_train', loss_train, epoch_log) writer.add_scalar('mIoU_train', mIoU_train, epoch_log) writer.add_scalar('mAcc_train', mAcc_train, epoch_log) writer.add_scalar('allAcc_train', allAcc_train, epoch_log) if epoch_log % args.save_freq == 0: filename = args.save_path + '/train_epoch_' + str( epoch_log) + '.pth' logger.info('Saving checkpoint to: ' + filename) torch.save( { 'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict() }, filename) if epoch_log / args.save_freq > 2: deletename = args.save_path + '/train_epoch_' + str( epoch_log - args.save_freq * 2) + '.pth' os.remove(deletename) if args.evaluate: loss_val, mIoU_val, mAcc_val, allAcc_val = validate( val_loader, model, criterion) writer.add_scalar('loss_val', loss_val, epoch_log) writer.add_scalar('mIoU_val', mIoU_val, epoch_log) writer.add_scalar('mAcc_val', mAcc_val, epoch_log) writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
def main(): init() if args.arch == 'pointnet_seg': from model.pointnet.pointnet import PointNetSeg as Model elif args.arch == 'pointnet2_seg': from model.pointnet2.pointnet2_seg import PointNet2SSGSeg as Model elif args.arch == 'pointnet2_paconv_seg': from model.pointnet2.pointnet2_paconv_seg import PointNet2SSGSeg as Model else: raise Exception('architecture not supported yet'.format(args.arch)) model = Model(c=args.fea_dim, k=args.classes, use_xyz=args.use_xyz, args=args) best_mIoU = 0.0 if args.sync_bn: from util.util import convert_to_syncbn convert_to_syncbn(model) criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.get('lr_multidecay', False): scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=[int(args.epochs * 0.6), int(args.epochs * 0.8)], gamma=args.multiplier) else: scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_epoch, gamma=args.multiplier) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) logger.info(model) model = torch.nn.DataParallel(model.cuda()) if args.sync_bn: from lib.sync_bn import patch_replication_callback patch_replication_callback(model) if args.weight: if os.path.isfile(args.weight): logger.info("=> loading weight '{}'".format(args.weight)) checkpoint = torch.load(args.weight) model.load_state_dict(checkpoint['state_dict']) logger.info("=> loaded weight '{}'".format(args.weight)) else: logger.info("=> no weight found at '{}'".format(args.weight)) if args.resume: if os.path.isfile(args.resume): logger.info("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) try: best_mIoU = checkpoint['val_mIoU'] except Exception: pass logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: logger.info("=> no checkpoint found at '{}'".format(args.resume)) if args.get('no_transformation', True): train_transform = None else: train_transform = transform.Compose([ transform.RandomRotate(along_z=args.get('rotate_along_z', True)), transform.RandomScale(scale_low=args.get('scale_low', 0.8), scale_high=args.get('scale_high', 1.2)), transform.RandomJitter(sigma=args.get('jitter_sigma', 0.01), clip=args.get('jitter_clip', 0.05)), transform.RandomDropColor( color_augment=args.get('color_augment', 0.0)) ]) logger.info(train_transform) if args.data_name == 's3dis': train_data = S3DIS(split='train', data_root=args.train_full_folder, num_point=args.num_point, test_area=args.test_area, block_size=args.block_size, sample_rate=args.sample_rate, transform=train_transform, fea_dim=args.get('fea_dim', 6), shuffle_idx=args.get('shuffle_idx', False)) else: raise ValueError('{} dataset not supported.'.format(args.data_name)) train_loader = torch.utils.data.DataLoader( train_data, batch_size=args.train_batch_size, shuffle=True, num_workers=args.train_workers, pin_memory=True, drop_last=True) val_loader = None if args.evaluate: val_transform = transform.Compose([transform.ToTensor()]) if args.data_name == 's3dis': val_data = dataset.PointData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform, norm_as_feat=args.get( 'norm_as_feat', True), fea_dim=args.get('fea_dim', 6)) else: raise ValueError('{} dataset not supported.'.format( args.data_name)) val_loader = torch.utils.data.DataLoader( val_data, batch_size=args.train_batch_size_val, shuffle=False, num_workers=args.train_workers, pin_memory=True) for epoch in range(args.start_epoch, args.epochs): loss_train, mIoU_train, mAcc_train, allAcc_train = train( train_loader, model, criterion, optimizer, epoch, args.get('correlation_loss', False)) epoch_log = epoch + 1 writer.add_scalar('loss_train', loss_train, epoch_log) writer.add_scalar('mIoU_train', mIoU_train, epoch_log) writer.add_scalar('mAcc_train', mAcc_train, epoch_log) writer.add_scalar('allAcc_train', allAcc_train, epoch_log) if epoch_log % args.save_freq == 0: filename = args.save_path + '/train_epoch_' + str( epoch_log) + '.pth' logger.info('Saving checkpoint to: ' + filename) torch.save( { 'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'commit_id': get_git_commit_id() }, filename) if epoch_log / args.save_freq > 2: try: deletename = args.save_path + '/train_epoch_' + str( epoch_log - args.save_freq * 2) + '.pth' os.remove(deletename) except Exception: logger.info('{} Not found.'.format(deletename)) if args.evaluate and epoch_log % args.get('eval_freq', 1) == 0: loss_val, mIoU_val, mAcc_val, allAcc_val = validate( val_loader, model, criterion) writer.add_scalar('loss_val', loss_val, epoch_log) writer.add_scalar('mIoU_val', mIoU_val, epoch_log) writer.add_scalar('mAcc_val', mAcc_val, epoch_log) writer.add_scalar('allAcc_val', allAcc_val, epoch_log) if mIoU_val > best_mIoU: best_mIoU = mIoU_val filename = args.save_path + '/best_train.pth' logger.info('Best Model Saving checkpoint to: ' + filename) torch.save( { 'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'val_mIoU': best_mIoU, 'commit_id': get_git_commit_id() }, filename) scheduler.step()