def main(args): # read all the images in the folder crop_size = args.crop_size if args.dataset == 'city': image_path = os.path.join(args.data_path, "leftImg8bit", args.split, "*", "*.png") image_list = glob.glob(image_path) from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST train_dataset = CityscapesSegmentation(root=args.data_path, train=True, size=crop_size, scale=args.scale, coarse=False) val_dataset = CityscapesSegmentation(root=args.data_path, train=False, size=crop_size, scale=args.scale, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'pascal': from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST train_dataset = VOCSegmentation(root=args.data_path, train=True, crop_size=crop_size, scale=args.scale, coco_root_dir=args.coco_path) val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=crop_size, scale=args.scale) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) data_file = os.path.join(args.data_path, 'VOC2012', 'list', '{}.txt'.format(args.split)) if not os.path.isfile(data_file): print_error_message('{} file does not exist'.format(data_file)) image_list = [] with open(data_file, 'r') as lines: for line in lines: rgb_img_loc = '{}/{}/{}'.format(args.data_path, 'VOC2012', line.split()[0]) if not os.path.isfile(rgb_img_loc): print_error_message('{} image file does not exist'.format(rgb_img_loc)) image_list.append(rgb_img_loc) else: print_error_message('{} dataset not yet supported'.format(args.dataset)) if len(image_list) == 0: print_error_message('No files in directory: {}'.format(image_path)) print_info_message('# of images for testing: {}'.format(len(image_list))) if args.model == 'espnetv2': from model.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espnet': from model.espnet import espnet_seg args.classes = seg_classes model = espnet_seg(args) elif args.model == 'mobilenetv2_1_0': from model.mobilenetv2 import get_mobilenet_v2_1_0_seg args.classes = seg_classes model = get_mobilenet_v2_1_0_seg(args) elif args.model == 'mobilenetv2_0_35': from model.mobilenetv2 import get_mobilenet_v2_0_35_seg args.classes = seg_classes model = get_mobilenet_v2_0_35_seg(args) elif args.model == 'mobilenetv2_0_5': from model.mobilenetv2 import get_mobilenet_v2_0_5_seg args.classes = seg_classes model = get_mobilenet_v2_0_5_seg(args) elif args.model == 'mobilenetv3_small': from model.mobilenetv3 import get_mobilenet_v3_small_seg args.classes = seg_classes model = get_mobilenet_v3_small_seg(args) elif args.model == 'mobilenetv3_large': from model.mobilenetv3 import get_mobilenet_v3_large_seg args.classes = seg_classes model = get_mobilenet_v3_large_seg(args) elif args.model == 'mobilenetv3_RE_small': from model.mobilenetv3 import get_mobilenet_v3_RE_small_seg args.classes = seg_classes model = get_mobilenet_v3_RE_small_seg(args) elif args.model == 'mobilenetv3_RE_large': from model.mobilenetv3 import get_mobilenet_v3_RE_large_seg args.classes = seg_classes model = get_mobilenet_v3_RE_large_seg(args) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) # model information train_params = [] params_dict = dict(model.named_parameters()) others= args.weight_decay*0.01 for key, value in params_dict.items(): if len(value.data.shape) == 4: if value.data.shape[1] == 1: train_params += [{'params': [value], 'lr': args.lr, 'weight_decay': 0.0}] # names_set.append(key) else: train_params += [{'params': [value], 'lr': args.lr, 'weight_decay': args.weight_decay}] else: train_params += [{'params': [value],'lr': args.lr, 'weight_decay': others}] args.learning_rate = args.lr optimizer = get_optimizer(args.optimizer, train_params, args) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, args.crop_size[0], args.crop_size[1])) print_info_message('FLOPs for an input of size {}x{}: {:.2f} million'.format(args.crop_size[0], args.crop_size[1], flops)) print_info_message('# of parameters: {}'.format(num_params)) def print_size_of_model(model): torch.save(model.state_dict(), "temp.p") print('Size (MB):', os.path.getsize("temp.p")/1e6) os.remove('temp.p') num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' model = model.to(device=device) print("========== MODEL CALIBRATION ===========") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers,drop_last=True) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, 1, device=device) print('========== ORIGINAL MODEL SIZE ==========') print_size_of_model(model) model.quantized.fuse_model() model.quantized.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack') torch.quantization.prepare_qat(model.quantized, inplace=True) if args.weights_test: print_info_message('Loading model weights') weight_dict = torch.load(args.weights_test, map_location=torch.device('cpu')) model.load_state_dict(weight_dict) print_info_message('Weight loaded successfully') else: print_error_message('weight file does not exist or not specified. Please check: {}', format(args.weights_test)) qat_miou_val, qat_val_loss = val(model, val_loader, criterion, seg_classes, device=device) print("========== QUANTIZED MODEL SIZE ==========") torch.quantization.convert(model.quantized.eval(),inplace = True) print_size_of_model(model) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device) evaluate(args, model, image_list, device=device) print("========== EVALUATION FINISHED ==========") this_state_dict = model.state_dict() this_savedir = args.savedir model_file_name = this_savedir + '/quantized_' + args.weights_test.split('/')[-1] + '.pth' torch.save(this_state_dict, model_file_name) print("Accuracy(QAT) : {} mIOU(val): {:.4f}".format(args.model, qat_miou_val)) print("Accuracy(Quantized) : {} mIOU(val): {:.4f}".format(args.model, miou_val)) print("quantized model saved in {}".format(model_file_name))
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) if args.dataset == 'pascal': from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST train_dataset = VOCSegmentation(root=args.data_path, train=True, crop_size=crop_size, scale=args.scale, coco_root_dir=args.coco_path) val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=crop_size, scale=args.scale) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'city': from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST train_dataset = CityscapesSegmentation(root=args.data_path, train=True, size=crop_size, scale=args.scale, coarse=args.coarse) val_dataset = CityscapesSegmentation(root=args.data_path, train=False, size=crop_size, scale=args.scale, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 2.8149201869965 class_wts[1] = 6.9850029945374 class_wts[2] = 3.7890393733978 class_wts[3] = 9.9428062438965 class_wts[4] = 9.7702074050903 class_wts[5] = 9.5110931396484 class_wts[6] = 10.311357498169 class_wts[7] = 10.026463508606 class_wts[8] = 4.6323022842407 class_wts[9] = 9.5608062744141 class_wts[10] = 7.8698215484619 class_wts[11] = 9.5168733596802 class_wts[12] = 10.373730659485 class_wts[13] = 6.6616044044495 class_wts[14] = 10.260489463806 class_wts[15] = 10.287888526917 class_wts[16] = 10.289801597595 class_wts[17] = 10.405355453491 class_wts[18] = 10.138095855713 class_wts[19] = 0.0 else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'dicenet': from model.segmentation.dicenet import dicenet_seg model = dicenet_seg(args, classes=seg_classes) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if args.finetune: if os.path.isfile(args.finetune): print_info_message('Loading weights for finetuning from {}'.format( args.finetune)) weight_dict = torch.load(args.finetune, map_location=torch.device(device='cpu')) model.load_state_dict(weight_dict) print_info_message('Done') else: print_warning_message('No file for finetuning. Please check.') if args.freeze_bn: print_info_message('Freezing batch normalization layers') for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] optimizer = optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, crop_size[0], crop_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[0], crop_size[1], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.Tensor(1, 3, crop_size[0], crop_size[1])) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) start_epoch = 0 best_miou = 0.0 if args.resume: if os.path.isfile(args.resume): print_info_message("=> loading checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) start_epoch = checkpoint['epoch'] best_miou = checkpoint['best_miou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print_info_message("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print_warning_message("=> no checkpoint found at '{}'".format( args.resume)) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) if args.scheduler == 'fixed': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import FixedMultiStepLR lr_scheduler = FixedMultiStepLR(base_lr=args.lr, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'clr': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import CyclicLR lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=5, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs, power=args.power) elif args.scheduler == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max, cycle_len=args.cycle_len) elif args.scheduler == 'linear': from utilities.lr_scheduler import LinearLR lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs) else: print_error_message('{} scheduler Not supported'.format( args.scheduler)) exit() print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()
def main(args): logdir = args.savedir + '/logs/' if not os.path.isdir(logdir): os.makedirs(logdir) my_logger = Logger(60066, logdir) if args.dataset == 'pascal': crop_size = (512, 512) args.scale = (0.5, 2.0) elif args.dataset == 'city': crop_size = (768, 768) args.scale = (0.5, 2.0) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[1], crop_size[0], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) if args.dataset == 'pascal': from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST train_dataset = VOCSegmentation(root=args.data_path, train=True, crop_size=crop_size, scale=args.scale, coco_root_dir=args.coco_path) val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=crop_size, scale=args.scale) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'city': from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST train_dataset = CityscapesSegmentation(root=args.data_path, train=True, size=crop_size, scale=args.scale, coarse=args.coarse) val_dataset = CityscapesSegmentation(root=args.data_path, train=False, size=crop_size, scale=args.scale, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 2.8149201869965 class_wts[1] = 6.9850029945374 class_wts[2] = 3.7890393733978 class_wts[3] = 9.9428062438965 class_wts[4] = 9.7702074050903 class_wts[5] = 9.5110931396484 class_wts[6] = 10.311357498169 class_wts[7] = 10.026463508606 class_wts[8] = 4.6323022842407 class_wts[9] = 9.5608062744141 class_wts[10] = 7.8698215484619 class_wts[11] = 9.5168733596802 class_wts[12] = 10.373730659485 class_wts[13] = 6.6616044044495 class_wts[14] = 10.260489463806 class_wts[15] = 10.287888526917 class_wts[16] = 10.289801597595 class_wts[17] = 10.405355453491 class_wts[18] = 10.138095855713 class_wts[19] = 0.0 else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) if args.model == 'espnetv2': from model.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espnet': from model.espnet import espnet_seg args.classes = seg_classes model = espnet_seg(args) elif args.model == 'mobilenetv2_1_0': from model.mobilenetv2 import get_mobilenet_v2_1_0_seg args.classes = seg_classes model = get_mobilenet_v2_1_0_seg(args) elif args.model == 'mobilenetv2_0_35': from model.mobilenetv2 import get_mobilenet_v2_0_35_seg args.classes = seg_classes model = get_mobilenet_v2_0_35_seg(args) elif args.model == 'mobilenetv2_0_5': from model.mobilenetv2 import get_mobilenet_v2_0_5_seg args.classes = seg_classes model = get_mobilenet_v2_0_5_seg(args) elif args.model == 'mobilenetv3_small': from model.mobilenetv3 import get_mobilenet_v3_small_seg args.classes = seg_classes model = get_mobilenet_v3_small_seg(args) elif args.model == 'mobilenetv3_large': from model.mobilenetv3 import get_mobilenet_v3_large_seg args.classes = seg_classes model = get_mobilenet_v3_large_seg(args) elif args.model == 'mobilenetv3_RE_small': from model.mobilenetv3 import get_mobilenet_v3_RE_small_seg args.classes = seg_classes model = get_mobilenet_v3_RE_small_seg(args) elif args.model == 'mobilenetv3_RE_large': from model.mobilenetv3 import get_mobilenet_v3_RE_large_seg args.classes = seg_classes model = get_mobilenet_v3_RE_large_seg(args) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' train_params = [] params_dict = dict(model.named_parameters()) others = args.weight_decay * 0.01 for key, value in params_dict.items(): if len(value.data.shape) == 4: if value.data.shape[1] == 1: train_params += [{ 'params': [value], 'lr': args.lr, 'weight_decay': 0.0 }] else: train_params += [{ 'params': [value], 'lr': args.lr, 'weight_decay': args.weight_decay }] else: train_params += [{ 'params': [value], 'lr': args.lr, 'weight_decay': others }] args.learning_rate = args.lr optimizer = get_optimizer(args.optimizer, train_params, args) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, crop_size[1], crop_size[0])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[1], crop_size[0], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) start_epoch = 0 epochs_len = args.epochs best_miou = 0.0 #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers, drop_last=True) if args.dataset == 'city': val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=args.workers, drop_last=True) else: val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers, drop_last=True) lr_scheduler = get_lr_scheduler(args) print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) if args.fp_epochs > 0: print_info_message("========== MODEL FP WARMUP ===========") for epoch in range(args.fp_epochs): lr = lr_scheduler.step(epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr print_info_message( 'Running epoch {} with learning rates: {:.6f}'.format( epoch, lr)) start_t = time.time() miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device) if args.optimizer.startswith('Q'): optimizer.is_warmup = False print('exp_sensitivity calibration fin.') if not args.fp_train: model.module.quantized.fuse_model() model.module.quantized.qconfig = torch.quantization.get_default_qat_qconfig( 'qnnpack') torch.quantization.prepare_qat(model.module.quantized, inplace=True) if args.resume: start_epoch = args.start_epoch if os.path.isfile(args.resume): print_info_message('Loading weights from {}'.format(args.resume)) weight_dict = torch.load(args.resume, device) model.module.load_state_dict(weight_dict) print_info_message('Done') else: print_warning_message('No file for resume. Please check.') for epoch in range(start_epoch, args.epochs): lr = lr_scheduler.step(epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr print_info_message( 'Running epoch {} with learning rates: {:.6f}'.format(epoch, lr)) miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) if is_best: model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth' torch.save(weights_dict, model_file_name) print('weights saved in {}'.format(model_file_name)) info = { 'Segmentation/LR': round(lr, 6), 'Segmentation/Loss/train': train_loss, 'Segmentation/Loss/val': val_loss, 'Segmentation/mIOU/train': miou_train, 'Segmentation/mIOU/val': miou_val, 'Segmentation/Complexity/Flops': best_miou, 'Segmentation/Complexity/Params': best_miou, } for tag, value in info.items(): if tag == 'Segmentation/Complexity/Flops': my_logger.scalar_summary(tag, value, math.ceil(flops)) elif tag == 'Segmentation/Complexity/Params': my_logger.scalar_summary(tag, value, math.ceil(num_params)) else: my_logger.scalar_summary(tag, value, epoch + 1) print_info_message("========== TRAINING FINISHED ===========")
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' if args.dataset == 'greenhouse': print(args.use_depth) from data_loader.segmentation.greenhouse import GreenhouseRGBDSegCls, GREENHOUSE_CLASS_LIST train_dataset = GreenhouseRGBDSegCls( root=args.data_path, list_name='train_greenhouse_mult.txt', train=True, size=crop_size, scale=args.scale, use_depth=args.use_depth) val_dataset = GreenhouseRGBDSegCls(root=args.data_path, list_name='val_greenhouse_mult.txt', train=False, size=crop_size, scale=args.scale, use_depth=args.use_depth) class_weights = np.load('class_weights.npy')[:4] print(class_weights) class_wts = torch.from_numpy(class_weights).float().to(device) seg_classes = len(GREENHOUSE_CLASS_LIST) color_encoding = OrderedDict([('end_of_plant', (0, 255, 0)), ('other_part_of_plant', (0, 255, 255)), ('artificial_objects', (255, 0, 0)), ('ground', (255, 255, 0)), ('background', (0, 0, 0))]) else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) if args.model == 'espdnet': from model.segmentation.espdnet_mult import espdnet_mult args.classes = seg_classes args.cls_classes = 5 model = espdnet_mult(args) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if args.finetune: if os.path.isfile(args.finetune): print_info_message('Loading weights for finetuning from {}'.format( args.finetune)) weight_dict = torch.load(args.finetune, map_location=torch.device(device='cpu')) model.load_state_dict(weight_dict) print_info_message('Done') else: print_warning_message('No file for finetuning. Please check.') if args.freeze_bn: print_info_message('Freezing batch normalization layers') for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False if args.use_depth: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }, { 'params': model.get_depth_encoder_params(), 'lr': args.lr }] else: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, crop_size[0], crop_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[0], crop_size[1], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.Tensor(1, 3, crop_size[0], crop_size[1])) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) start_epoch = 0 best_miou = 0.0 if args.resume: if os.path.isfile(args.resume): print_info_message("=> loading checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) start_epoch = checkpoint['epoch'] best_miou = checkpoint['best_miou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print_info_message("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print_warning_message("=> no checkpoint found at '{}'".format( args.resume)) print('device : ' + device) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) cls_class_weight = calc_cls_class_weight(train_loader, 5) print(cls_class_weight) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion_seg = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) criterion_cls = nn.CrossEntropyLoss( weight=torch.from_numpy(cls_class_weight).float().to(device)) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion_seg = criterion_seg.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion_seg = DataParallelCriteria(criterion_seg) criterion_seg = criterion_seg.cuda() criterion_cls = criterion_cls.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True if args.scheduler == 'fixed': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import FixedMultiStepLR lr_scheduler = FixedMultiStepLR(base_lr=args.lr, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'clr': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import CyclicLR lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=5, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs, power=args.power) elif args.scheduler == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max, cycle_len=args.cycle_len) elif args.scheduler == 'linear': from utilities.lr_scheduler import LinearLR lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs) else: print_error_message('{} scheduler Not supported'.format( args.scheduler)) exit() print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg if args.use_depth: optimizer.param_groups[2]['lr'] = lr_base print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) miou_train, train_loss, train_seg_loss, train_cls_loss = train( model, train_loader, optimizer, criterion_seg, seg_classes, epoch, criterion_cls, device=device, use_depth=args.use_depth) miou_val, val_loss, val_seg_loss, val_cls_loss = val( model, val_loader, criterion_seg, criterion_cls, seg_classes, device=device, use_depth=args.use_depth) batch = iter(val_loader).next() if args.use_depth: in_training_visualization_2(model, images=batch[0].to(device=device), depths=batch[2].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation', device=device) else: in_training_visualization_2(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation', device=device) # image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy() # writer.add_image('Segmentation/results/val', image_grid, epoch) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/SegLoss/train', train_seg_loss, epoch) writer.add_scalar('Segmentation/ClsLoss/train', train_cls_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/SegLoss/val', val_seg_loss, epoch) writer.add_scalar('Segmentation/ClsLoss/val', val_cls_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' if args.dataset == 'pascal': from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST train_dataset = VOCSegmentation(root=args.data_path, train=True, crop_size=crop_size, scale=args.scale, coco_root_dir=args.coco_path) val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=crop_size, scale=args.scale) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'city': from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST, color_encoding train_dataset = CityscapesSegmentation(root=args.data_path, train=True, coarse=False) val_dataset = CityscapesSegmentation(root=args.data_path, train=False, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 10 / 2.8149201869965 class_wts[1] = 10 / 6.9850029945374 class_wts[2] = 10 / 3.7890393733978 class_wts[3] = 10 / 9.9428062438965 class_wts[4] = 10 / 9.7702074050903 class_wts[5] = 10 / 9.5110931396484 class_wts[6] = 10 / 10.311357498169 class_wts[7] = 10 / 10.026463508606 class_wts[8] = 10 / 4.6323022842407 class_wts[9] = 10 / 9.5608062744141 class_wts[10] = 10 / 7.8698215484619 class_wts[11] = 10 / 9.5168733596802 class_wts[12] = 10 / 10.373730659485 class_wts[13] = 10 / 6.6616044044495 class_wts[14] = 10 / 10.260489463806 class_wts[15] = 10 / 10.287888526917 class_wts[16] = 10 / 10.289801597595 class_wts[17] = 10 / 10.405355453491 class_wts[18] = 10 / 10.138095855713 class_wts[19] = 0.0 elif args.dataset == 'greenhouse': print(args.use_depth) from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation, GREENHOUSE_CLASS_LIST, color_encoding train_dataset = GreenhouseRGBDSegmentation( root=args.data_path, list_name=args.train_list, train=True, size=crop_size, scale=args.scale, use_depth=args.use_depth, use_traversable=args.greenhouse_use_trav) val_dataset = GreenhouseRGBDSegmentation( root=args.data_path, list_name=args.val_list, train=False, size=crop_size, scale=args.scale, use_depth=args.use_depth, use_traversable=args.greenhouse_use_trav) class_weights = np.load('class_weights.npy') # [:4] print(class_weights) class_wts = torch.from_numpy(class_weights).float().to(device) print(GREENHOUSE_CLASS_LIST) seg_classes = len(GREENHOUSE_CLASS_LIST) # color_encoding = OrderedDict([ # ('end_of_plant', (0, 255, 0)), # ('other_part_of_plant', (0, 255, 255)), # ('artificial_objects', (255, 0, 0)), # ('ground', (255, 255, 0)), # ('background', (0, 0, 0)) # ]) elif args.dataset == 'ishihara': print(args.use_depth) from data_loader.segmentation.ishihara_rgbd import IshiharaRGBDSegmentation, ISHIHARA_RGBD_CLASS_LIST train_dataset = IshiharaRGBDSegmentation( root=args.data_path, list_name='ishihara_rgbd_train.txt', train=True, size=crop_size, scale=args.scale, use_depth=args.use_depth) val_dataset = IshiharaRGBDSegmentation( root=args.data_path, list_name='ishihara_rgbd_val.txt', train=False, size=crop_size, scale=args.scale, use_depth=args.use_depth) seg_classes = len(ISHIHARA_RGBD_CLASS_LIST) class_wts = torch.ones(seg_classes) color_encoding = OrderedDict([('Unlabeled', (0, 0, 0)), ('Building', (70, 70, 70)), ('Fence', (190, 153, 153)), ('Others', (72, 0, 90)), ('Pedestrian', (220, 20, 60)), ('Pole', (153, 153, 153)), ('Road ', (157, 234, 50)), ('Road', (128, 64, 128)), ('Sidewalk', (244, 35, 232)), ('Vegetation', (107, 142, 35)), ('Car', (0, 0, 255)), ('Wall', (102, 102, 156)), ('Traffic ', (220, 220, 0))]) elif args.dataset == 'sun': print(args.use_depth) from data_loader.segmentation.sun_rgbd import SUNRGBDSegmentation, SUN_RGBD_CLASS_LIST train_dataset = SUNRGBDSegmentation(root=args.data_path, list_name='sun_rgbd_train.txt', train=True, size=crop_size, ignore_idx=args.ignore_idx, scale=args.scale, use_depth=args.use_depth) val_dataset = SUNRGBDSegmentation(root=args.data_path, list_name='sun_rgbd_val.txt', train=False, size=crop_size, ignore_idx=args.ignore_idx, scale=args.scale, use_depth=args.use_depth) seg_classes = len(SUN_RGBD_CLASS_LIST) class_wts = torch.ones(seg_classes) color_encoding = OrderedDict([('Background', (0, 0, 0)), ('Bed', (0, 255, 0)), ('Books', (70, 70, 70)), ('Ceiling', (190, 153, 153)), ('Chair', (72, 0, 90)), ('Floor', (220, 20, 60)), ('Furniture', (153, 153, 153)), ('Objects', (157, 234, 50)), ('Picture', (128, 64, 128)), ('Sofa', (244, 35, 232)), ('Table', (107, 142, 35)), ('TV', (0, 0, 255)), ('Wall', (102, 102, 156)), ('Window', (220, 220, 0))]) elif args.dataset == 'camvid': print(args.use_depth) from data_loader.segmentation.camvid import CamVidSegmentation, CAMVID_CLASS_LIST, color_encoding train_dataset = CamVidSegmentation( root=args.data_path, list_name='train_camvid.txt', train=True, size=crop_size, scale=args.scale, label_conversion=args.label_conversion, normalize=args.normalize) val_dataset = CamVidSegmentation( root=args.data_path, list_name='val_camvid.txt', train=False, size=crop_size, scale=args.scale, label_conversion=args.label_conversion, normalize=args.normalize) if args.label_conversion: from data_loader.segmentation.greenhouse import GREENHOUSE_CLASS_LIST, color_encoding seg_classes = len(GREENHOUSE_CLASS_LIST) class_wts = torch.ones(seg_classes) else: seg_classes = len(CAMVID_CLASS_LIST) tmp_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False) class_wts = calc_cls_class_weight(tmp_loader, seg_classes, inverted=True) class_wts = torch.from_numpy(class_wts).float().to(device) # class_wts = torch.ones(seg_classes) print("class weights : {}".format(class_wts)) args.use_depth = False elif args.dataset == 'forest': from data_loader.segmentation.freiburg_forest import FreiburgForestDataset, FOREST_CLASS_LIST, color_encoding train_dataset = FreiburgForestDataset(train=True, size=crop_size, scale=args.scale, normalize=args.normalize) val_dataset = FreiburgForestDataset(train=False, size=crop_size, scale=args.scale, normalize=args.normalize) seg_classes = len(FOREST_CLASS_LIST) tmp_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False) class_wts = calc_cls_class_weight(tmp_loader, seg_classes, inverted=True) class_wts = torch.from_numpy(class_wts).float().to(device) # class_wts = torch.ones(seg_classes) print("class weights : {}".format(class_wts)) args.use_depth = False else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espdnet': from model.segmentation.espdnet import espdnet_seg args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnet_seg(args) elif args.model == 'espdnetue': from model.segmentation.espdnet_ue import espdnetue_seg2 args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnetue_seg2(args, fix_pyr_plane_proj=True) elif args.model == 'deeplabv3': # from model.segmentation.deeplabv3 import DeepLabV3 from torchvision.models.segmentation.segmentation import deeplabv3_resnet101 args.classes = seg_classes # model = DeepLabV3(seg_classes) model = deeplabv3_resnet101(num_classes=seg_classes, aux_loss=True) torch.backends.cudnn.enabled = False elif args.model == 'unet': from model.segmentation.unet import UNet model = UNet(in_channels=3, out_channels=seg_classes) # model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', # in_channels=3, out_channels=seg_classes, init_features=32, pretrained=False) elif args.model == 'dicenet': from model.segmentation.dicenet import dicenet_seg model = dicenet_seg(args, classes=seg_classes) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if args.finetune: if os.path.isfile(args.finetune): print_info_message('Loading weights for finetuning from {}'.format( args.finetune)) weight_dict = torch.load(args.finetune, map_location=torch.device(device='cpu')) model.load_state_dict(weight_dict) print_info_message('Done') else: print_warning_message('No file for finetuning. Please check.') if args.freeze_bn: print_info_message('Freezing batch normalization layers') for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False if args.model == 'deeplabv3' or args.model == 'unet': train_params = [{'params': model.parameters(), 'lr': args.lr}] elif args.use_depth: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }, { 'params': model.get_depth_encoder_params(), 'lr': args.lr * args.lr_mult }] else: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, crop_size[0], crop_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[0], crop_size[1], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.Tensor(1, 3, 288, 480)) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) start_epoch = 0 best_miou = 0.0 if args.resume: if os.path.isfile(args.resume): print_info_message("=> loading checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) start_epoch = checkpoint['epoch'] best_miou = checkpoint['best_miou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print_info_message("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print_warning_message("=> no checkpoint found at '{}'".format( args.resume)) print('device : ' + device) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) nid_loss = NIDLoss(image_bin=32, label_bin=seg_classes) if args.use_nid else None if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() if args.use_nid: nid_loss.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if args.use_nid: nid_loss = DataParallelCriteria(nid_loss) nid_loss = nid_loss.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=20, shuffle=False, pin_memory=True, num_workers=args.workers) if args.scheduler == 'fixed': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import FixedMultiStepLR lr_scheduler = FixedMultiStepLR(base_lr=args.lr, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'clr': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import CyclicLR lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=5, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs, power=args.power) elif args.scheduler == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max, cycle_len=args.cycle_len) elif args.scheduler == 'linear': from utilities.lr_scheduler import LinearLR lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs) else: print_error_message('{} scheduler Not supported'.format( args.scheduler)) exit() print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base if len(optimizer.param_groups) > 1: optimizer.param_groups[1]['lr'] = lr_seg if args.use_depth: optimizer.param_groups[2]['lr'] = lr_base print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) if args.model == 'espdnetue' or ( (args.model == 'deeplabv3' or args.model == 'unet') and args.use_aux): from utilities.train_eval_seg import train_seg_ue as train from utilities.train_eval_seg import val_seg_ue as val else: from utilities.train_eval_seg import train_seg as train from utilities.train_eval_seg import val_seg as val iou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device, use_depth=args.use_depth, add_criterion=nid_loss) iou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device, use_depth=args.use_depth, add_criterion=nid_loss) batch_train = iter(train_loader).next() batch = iter(val_loader).next() if args.use_depth: in_training_visualization_img( model, images=batch_train[0].to(device=device), depths=batch_train[2].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), depths=batch[2].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/val', device=device) image_grid = torchvision.utils.make_grid( batch[2].to(device=device).data.cpu()).numpy() print(type(image_grid)) writer.add_image('Segmentation/depths', image_grid, epoch) else: in_training_visualization_img( model, images=batch_train[0].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/val', device=device) # image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy() # writer.add_image('Segmentation/results/val', image_grid, epoch) # remember best miou and save checkpoint miou_val = iou_val[[1, 2, 3]].mean() is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', iou_train[[1, 2, 3]].mean(), epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/plant_IOU/val', iou_val[1], epoch) writer.add_scalar('Segmentation/ao_IOU/val', iou_val[2], epoch) writer.add_scalar('Segmentation/ground_IOU/val', iou_val[3], epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()
def main(): device = 'cuda' now = datetime.datetime.now() now += datetime.timedelta(hours=9) timestr = now.strftime("%Y%m%d-%H%M%S") use_depth_str = "_rgbd" if args.use_depth else "_rgb" if args.use_depth: trainable_fusion_str = "_gated" if args.trainable_fusion else "_naive" else: trainable_fusion_str = "" save_path = '{}/model_{}_{}/{}'.format(args.save, args.model, args.dataset, timestr) print(save_path) if not os.path.isdir(save_path): os.makedirs(save_path) tgt_train_lst = osp.join(save_path, 'tgt_train.lst') save_pred_path = osp.join(save_path, 'pred') if not os.path.isdir(save_pred_path): os.makedirs(save_pred_path) writer = SummaryWriter(save_path) # Dataset from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentationTrav, GREENHOUSE_CLASS_LIST args.classes = len(GREENHOUSE_CLASS_LIST) travset = GreenhouseRGBDSegmentationTrav(list_name=args.data_trav_list, use_depth=args.use_depth) class_encoding = OrderedDict([('end_of_plant', (0, 255, 0)), ('other_part_of_plant', (0, 255, 255)), ('artificial_objects', (255, 0, 0)), ('ground', (255, 255, 0)), ('background', (0, 0, 0))]) # Dataloader for generating the pseudo-labels travloader = torch.utils.data.DataLoader(travset, batch_size=1, shuffle=False, num_workers=0, pin_memory=args.pin_memory) # Model from model.segmentation.espdnet_ue import espdnetue_seg2 args.weights = args.restore_from model = espdnetue_seg2(args, load_entire_weights=True, fix_pyr_plane_proj=True) model.to(device) generate_label(model, travloader, save_pred_path, tgt_train_lst) # Datset for training from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation trainset = GreenhouseRGBDSegmentation(list_name=tgt_train_lst, use_depth=args.use_depth, use_traversable=True) testset = GreenhouseRGBDSegmentation(list_name=args.data_test_list, use_depth=args.use_depth, use_traversable=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=args.pin_memory) testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=args.pin_memory) # Loss class_weights = torch.tensor([1.0, 0.2, 1.0, 1.0, 0.0]).to(device) if args.use_uncertainty: criterion = UncertaintyWeightedSegmentationLoss( args.classes, class_weights=class_weights) else: criterion = SegmentationLoss(n_classes=args.classes, device=device, class_weights=class_weights) criterion_test = SegmentationLoss(n_classes=args.classes, device=device, class_weights=class_weights) # Optimizer if args.use_depth: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.learning_rate * 0.1 }, { 'params': model.get_segment_params(), 'lr': args.learning_rate }, { 'params': model.get_depth_encoder_params(), 'lr': args.learning_rate }] else: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.learning_rate * 0.1 }, { 'params': model.get_segment_params(), 'lr': args.learning_rate }] if args.optimizer == 'SGD': optimizer = optim.SGD(train_params, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) else: optimizer = optim.Adam(train_params, lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = optim.lr_scheduler.CyclicLR( optimizer, base_lr=args.learning_rate, max_lr=args.learning_rate * 10, step_size_up=10, step_size_down=20, cycle_momentum=True if args.optimizer == 'SGD' else False) # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100], gamma=0.5) best_miou = 0.0 for i in range(0, args.epoch): # Run a training epoch train(trainloader, model, criterion, device, optimizer, class_encoding, i, writer=writer) # Update the learning rate scheduler.step() # set the optimizer with the learning rate # This can be done inside the MyLRScheduler # optimizer.param_groups[0]['lr'] = lr_base # if len(optimizer.param_groups) > 1: # optimizer.param_groups[1]['lr'] = lr_seg # if args.use_depth: # optimizer.param_groups[2]['lr'] = lr_base * 10 new_miou = test(testloader, model, criterion_test, device, optimizer, class_encoding, i, writer=writer) # Save the weights if it produces the best IoU is_best = new_miou > best_miou best_miou = max(new_miou, best_miou) model.to(device) # weights_dict = model.module.state_dict() if device == 'cuda' else model.state_dict() weights_dict = model.state_dict() extra_info_ckpt = '{}'.format(args.model) if is_best: save_checkpoint( { 'epoch': i + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, save_path, extra_info_ckpt)
def main(args): # read all the images in the folder if args.dataset == 'city': # image_path = os.path.join(args.data_path, "leftImg8bit", args.split, "*", "*.png") # image_list = glob.glob(image_path) # from data_loader.segmentation.cityscapes import CITYSCAPE_CLASS_LIST # seg_classes = len(CITYSCAPE_CLASS_LIST) from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST val_dataset = CityscapesSegmentation(root=args.data_path, train=False, size=(256, 256), scale=args.s, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 2.8149201869965 class_wts[1] = 6.9850029945374 class_wts[2] = 3.7890393733978 class_wts[3] = 9.9428062438965 class_wts[4] = 9.7702074050903 class_wts[5] = 9.5110931396484 class_wts[6] = 10.311357498169 class_wts[7] = 10.026463508606 class_wts[8] = 4.6323022842407 class_wts[9] = 9.5608062744141 class_wts[10] = 7.8698215484619 class_wts[11] = 9.5168733596802 class_wts[12] = 10.373730659485 class_wts[13] = 6.6616044044495 class_wts[14] = 10.260489463806 class_wts[15] = 10.287888526917 class_wts[16] = 10.289801597595 class_wts[17] = 10.405355453491 class_wts[18] = 10.138095855713 class_wts[19] = 0.0 elif args.dataset == 'pascal': # from data_loader.segmentation.voc import VOC_CLASS_LIST # seg_classes = len(VOC_CLASS_LIST) # data_file = os.path.join(args.data_path, 'VOC2012', 'list', '{}.txt'.format(args.split)) # if not os.path.isfile(data_file): # print_error_message('{} file does not exist'.format(data_file)) # image_list = [] # with open(data_file, 'r') as lines: # for line in lines: # rgb_img_loc = '{}/{}/{}'.format(args.data_path, 'VOC2012', line.split()[0]) # if not os.path.isfile(rgb_img_loc): # print_error_message('{} image file does not exist'.format(rgb_img_loc)) # image_list.append(rgb_img_loc) from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=(256, 256), scale=args.s) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'hockey': from data_loader.segmentation.hockey import HockeySegmentationDataset, HOCKEY_DATASET_CLASS_LIST train_dataset = HockeySegmentationDataset(root=args.data_path, train=True, crop_size=(256, 256), scale=args.s) val_dataset = HockeySegmentationDataset(root=args.data_path, train=False, crop_size=(256, 256), scale=args.s) seg_classes = len(HOCKEY_DATASET_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'hockey_rink_seg': from data_loader.segmentation.hockey_rink_seg import HockeyRinkSegmentationDataset, HOCKEY_DATASET_CLASS_LIST train_dataset = HockeyRinkSegmentationDataset(root=args.data_path, train=True, crop_size=(256, 256), scale=args.s) val_dataset = HockeyRinkSegmentationDataset(root=args.data_path, train=False, crop_size=(256, 256), scale=args.s) seg_classes = len(HOCKEY_DATASET_CLASS_LIST) class_wts = torch.ones(seg_classes) else: print_error_message('{} dataset not yet supported'.format( args.dataset)) if len(val_dataset) == 0: print_error_message('No files in directory: {}'.format(image_path)) print_info_message('# of images for testing: {}'.format(len(val_dataset))) if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'dicenet': from model.segmentation.dicenet import dicenet_seg model = dicenet_seg(args, classes=seg_classes) else: print_error_message('{} network not yet supported'.format(args.model)) exit(-1) # model information num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, args.im_size[0], args.im_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( args.im_size[0], args.im_size[1], flops)) print_info_message('# of parameters: {}'.format(num_params)) if args.weights_test: print_info_message('Loading model weights') weight_dict = torch.load(args.weights_test, map_location=torch.device('cpu')) if isinstance(weight_dict, dict) and 'state_dict' in weight_dict: model.load_state_dict(weight_dict['state_dict']) else: model.load_state_dict(weight_dict) print_info_message('Weight loaded successfully') else: print_error_message( 'weight file does not exist or not specified. Please check: {}', format(args.weights_test)) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' model = model.to(device=device) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=40, shuffle=False, pin_memory=True, num_workers=4) criterion = SegmentationLoss(n_classes=seg_classes, loss_type='ce', device=device, ignore_idx=255, class_wts=class_wts.to(device)) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # evaluate(args, model, image_list, seg_classes, device=device) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device) print_info_message('mIOU: {}'.format(miou_val))
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' print('device : ' + device) # Get a summary writer for tensorboard writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') # # Training the model with 13 classes of CamVid dataset # TODO: This process should be done only if specified # if not args.finetune: train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset( label_conversion=False) # 13 classes args.use_depth = False # 'use_depth' is always false for camvid print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) # Import model if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espdnet': from model.segmentation.espdnet import espdnet_seg args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnet_seg(args) elif args.model == 'espdnetue': from model.segmentation.espdnet_ue import espdnetue_seg2 args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) ("Segmentation classes : {}".format(seg_classes)) print(args.weights) model = espdnetue_seg2(args, False, fix_pyr_plane_proj=True) else: print_error_message('Arch: {} not yet supported'.format( args.model)) exit(-1) # Freeze batch normalization layers? if args.freeze_bn: freeze_bn_layer(model) # Set learning rates train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] # Define an optimizer optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) # Compute the FLOPs and the number of parameters, and display it num_params, flops = show_network_stats(model, crop_size) try: writer.add_graph(model, input_to_model=torch.Tensor( 1, 3, crop_size[0], crop_size[1])) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) nid_loss = NIDLoss(image_bin=32, label_bin=seg_classes) if args.use_nid else None if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() if args.use_nid: nid_loss.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if args.use_nid: nid_loss = DataParallelCriteria(nid_loss) nid_loss = nid_loss.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # Get data loaders for training and validation data train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=20, shuffle=False, pin_memory=True, num_workers=args.workers) # Get a learning rate scheduler lr_scheduler = get_lr_scheduler(args.scheduler) write_stats_to_json(num_params, flops) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) # # Main training loop of 13 classes # start_epoch = 0 best_miou = 0.0 for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) # Use different training functions for espdnetue if args.model == 'espdnetue': from utilities.train_eval_seg import train_seg_ue as train from utilities.train_eval_seg import val_seg_ue as val else: from utilities.train_eval_seg import train_seg as train from utilities.train_eval_seg import val_seg as val miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device, use_depth=args.use_depth, add_criterion=nid_loss) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device, use_depth=args.use_depth, add_criterion=nid_loss) batch_train = iter(train_loader).next() batch = iter(val_loader).next() in_training_visualization_img( model, images=batch_train[0].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/val', device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) # Save the pretrained weights model_dict = copy.deepcopy(model.state_dict()) del model torch.cuda.empty_cache() # # Finetuning with 4 classes # args.ignore_idx = 4 train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset( label_conversion=True) # 5 classes print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) #set_parameters_for_finetuning() # Import model if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espdnet': from model.segmentation.espdnet import espdnet_seg args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnet_seg(args) elif args.model == 'espdnetue': from model.segmentation.espdnet_ue import espdnetue_seg2 args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) print(args.weights) model = espdnetue_seg2(args, args.finetune, fix_pyr_plane_proj=True) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if not args.finetune: new_model_dict = model.state_dict() # for k, v in model_dict.items(): # if k.lstrip('module.') in new_model_dict: # print('In:{}'.format(k.lstrip('module.'))) # else: # print('Not In:{}'.format(k.lstrip('module.'))) overlap_dict = { k.replace('module.', ''): v for k, v in model_dict.items() if k.replace('module.', '') in new_model_dict and new_model_dict[k.replace('module.', '')].size() == v.size() } no_overlap_dict = { k.replace('module.', ''): v for k, v in new_model_dict.items() if k.replace('module.', '') not in new_model_dict or new_model_dict[k.replace('module.', '')].size() != v.size() } print(no_overlap_dict.keys()) new_model_dict.update(overlap_dict) model.load_state_dict(new_model_dict) output = model(torch.ones(1, 3, 288, 480)) print(output[0].size()) print(seg_classes) print(class_wts.size()) #print(model_dict.keys()) #print(new_model_dict.keys()) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) nid_loss = NIDLoss(image_bin=32, label_bin=seg_classes) if args.use_nid else None # Set learning rates args.lr /= 100 train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] # Define an optimizer optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() if args.use_nid: nid_loss.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if args.use_nid: nid_loss = DataParallelCriteria(nid_loss) nid_loss = nid_loss.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # Get data loaders for training and validation data train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=20, shuffle=False, pin_memory=True, num_workers=args.workers) # Get a learning rate scheduler args.epochs = 50 lr_scheduler = get_lr_scheduler(args.scheduler) # Compute the FLOPs and the number of parameters, and display it num_params, flops = show_network_stats(model, crop_size) write_stats_to_json(num_params, flops) extra_info_ckpt = '{}_{}_{}_{}'.format(args.model, seg_classes, args.s, crop_size[0]) # # Main training loop of 13 classes # start_epoch = 0 best_miou = 0.0 for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) # Use different training functions for espdnetue if args.model == 'espdnetue': from utilities.train_eval_seg import train_seg_ue as train from utilities.train_eval_seg import val_seg_ue as val else: from utilities.train_eval_seg import train_seg as train from utilities.train_eval_seg import val_seg as val miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device, use_depth=args.use_depth, add_criterion=nid_loss) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device, use_depth=args.use_depth, add_criterion=nid_loss) batch_train = iter(train_loader).next() batch = iter(val_loader).next() in_training_visualization_img(model, images=batch_train[0].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='SegmentationConv/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='SegmentationConv/val', device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('SegmentationConv/LR/base', round(lr_base, 6), epoch) writer.add_scalar('SegmentationConv/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('SegmentationConv/Loss/train', train_loss, epoch) writer.add_scalar('SegmentationConv/Loss/val', val_loss, epoch) writer.add_scalar('SegmentationConv/mIOU/train', miou_train, epoch) writer.add_scalar('SegmentationConv/mIOU/val', miou_val, epoch) writer.add_scalar('SegmentationConv/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('SegmentationConv/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()