def main(args): # read all the images in the folder if args.dataset == 'city': #image_path = os.path.join(args.data_path, "leftImg8bit", args.split, "*", "*.png") image_path ='./input/*.png' image_list = glob.glob(image_path) print(image_list) from data_loader.segmentation.cityscapes import CITYSCAPE_CLASS_LIST seg_classes = len(CITYSCAPE_CLASS_LIST) elif args.dataset == 'pascal': from data_loader.segmentation.voc import VOC_CLASS_LIST seg_classes = len(VOC_CLASS_LIST) data_file = os.path.join(args.data_path, 'VOC2012', 'list', '{}.txt'.format(args.split)) if not os.path.isfile(data_file): print_error_message('{} file does not exist'.format(data_file)) image_list = [] with open(data_file, 'r') as lines: for line in lines: rgb_img_loc = '{}/{}/{}'.format(args.data_path, 'VOC2012', line.split()[0]) if not os.path.isfile(rgb_img_loc): print_error_message('{} image file does not exist'.format(rgb_img_loc)) image_list.append(rgb_img_loc) else: print_error_message('{} dataset not yet supported'.format(args.dataset)) if len(image_list) == 0: print_error_message('No files in directory: {}'.format(image_path)) print_info_message('# of images for testing: {}'.format(len(image_list))) if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'dicenet': from model.segmentation.dicenet import dicenet_seg model = dicenet_seg(args, classes=seg_classes) else: print_error_message('{} network not yet supported'.format(args.model)) exit(-1) # mdoel information num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, args.im_size[0], args.im_size[1])) print_info_message('FLOPs for an input of size {}x{}: {:.2f} million'.format(args.im_size[0], args.im_size[1], flops)) print_info_message('# of parameters: {}'.format(num_params)) if args.weights_test: print_info_message('Loading model weights') weight_dict = torch.load(args.weights_test, map_location=torch.device('cpu')) model.load_state_dict(weight_dict) print_info_message('Weight loaded successfully') else: print_error_message('weight file does not exist or not specified. Please check: {}', format(args.weights_test)) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' model = model.to(device=device) evaluate(args, model, image_list, device=device)
def main(args): # read all the images in the folder if args.dataset == 'city': from data_loader.segmentation.cityscapes import CITYSCAPE_CLASS_LIST seg_classes = len(CITYSCAPE_CLASS_LIST) elif args.dataset == 'pascal': from data_loader.segmentation.voc import VOC_CLASS_LIST seg_classes = len(VOC_CLASS_LIST) else: print_error_message('{} dataset not yet supported'.format(args.dataset)) image_list = [] for extn in IMAGE_EXTENSIONS: image_list = image_list + glob.glob(args.data_path + os.sep + '*' + extn) if len(image_list) == 0: print_error_message('No files in directory: {}'.format(args.data_path)) print_info_message('# of images used for demonstration: {}'.format(len(image_list))) if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'dicenet': from model.segmentation.dicenet import dicenet_seg model = dicenet_seg(args, classes=seg_classes) else: print_error_message('{} network not yet supported'.format(args.model)) exit(-1) if args.weights_test: print_info_message('Loading model weights') weight_dict = torch.load(args.weights_test, map_location=torch.device('cpu')) model.load_state_dict(weight_dict) print_info_message('Weight loaded successfully') else: print_error_message('weight file does not exist or not specified. Please check: {}', format(args.weights_test)) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' model = model.to(device=device) if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True run_segmentation(args, model, image_list, device=device)
def main(args): image_list = [] for extn in IMAGE_EXTENSIONS: image_list = image_list + glob.glob(args.data_path + os.sep + '*' + extn) if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = 20 model = espnetv2_seg(args) if args.weights_test: print('Loading model weights') weight_dict = torch.load(args.weights_test, map_location=torch.device('cpu')) model.load_state_dict(weight_dict) print('Weight loaded successfully') else: print("ERRORRRR") model = model.cuda() run_segmentation(model, image_list, device='cuda')
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) if args.dataset == 'pascal': from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST train_dataset = VOCSegmentation(root=args.data_path, train=True, crop_size=crop_size, scale=args.scale, coco_root_dir=args.coco_path) val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=crop_size, scale=args.scale) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'city': from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST train_dataset = CityscapesSegmentation(root=args.data_path, train=True, size=crop_size, scale=args.scale, coarse=args.coarse) val_dataset = CityscapesSegmentation(root=args.data_path, train=False, size=crop_size, scale=args.scale, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 2.8149201869965 class_wts[1] = 6.9850029945374 class_wts[2] = 3.7890393733978 class_wts[3] = 9.9428062438965 class_wts[4] = 9.7702074050903 class_wts[5] = 9.5110931396484 class_wts[6] = 10.311357498169 class_wts[7] = 10.026463508606 class_wts[8] = 4.6323022842407 class_wts[9] = 9.5608062744141 class_wts[10] = 7.8698215484619 class_wts[11] = 9.5168733596802 class_wts[12] = 10.373730659485 class_wts[13] = 6.6616044044495 class_wts[14] = 10.260489463806 class_wts[15] = 10.287888526917 class_wts[16] = 10.289801597595 class_wts[17] = 10.405355453491 class_wts[18] = 10.138095855713 class_wts[19] = 0.0 else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'dicenet': from model.segmentation.dicenet import dicenet_seg model = dicenet_seg(args, classes=seg_classes) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if args.finetune: if os.path.isfile(args.finetune): print_info_message('Loading weights for finetuning from {}'.format( args.finetune)) weight_dict = torch.load(args.finetune, map_location=torch.device(device='cpu')) model.load_state_dict(weight_dict) print_info_message('Done') else: print_warning_message('No file for finetuning. Please check.') if args.freeze_bn: print_info_message('Freezing batch normalization layers') for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] optimizer = optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, crop_size[0], crop_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[0], crop_size[1], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.Tensor(1, 3, crop_size[0], crop_size[1])) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) start_epoch = 0 best_miou = 0.0 if args.resume: if os.path.isfile(args.resume): print_info_message("=> loading checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) start_epoch = checkpoint['epoch'] best_miou = checkpoint['best_miou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print_info_message("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print_warning_message("=> no checkpoint found at '{}'".format( args.resume)) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) if args.scheduler == 'fixed': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import FixedMultiStepLR lr_scheduler = FixedMultiStepLR(base_lr=args.lr, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'clr': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import CyclicLR lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=5, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs, power=args.power) elif args.scheduler == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max, cycle_len=args.cycle_len) elif args.scheduler == 'linear': from utilities.lr_scheduler import LinearLR lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs) else: print_error_message('{} scheduler Not supported'.format( args.scheduler)) exit() print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' if args.dataset == 'pascal': from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST train_dataset = VOCSegmentation(root=args.data_path, train=True, crop_size=crop_size, scale=args.scale, coco_root_dir=args.coco_path) val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=crop_size, scale=args.scale) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'city': from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST, color_encoding train_dataset = CityscapesSegmentation(root=args.data_path, train=True, coarse=False) val_dataset = CityscapesSegmentation(root=args.data_path, train=False, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 10 / 2.8149201869965 class_wts[1] = 10 / 6.9850029945374 class_wts[2] = 10 / 3.7890393733978 class_wts[3] = 10 / 9.9428062438965 class_wts[4] = 10 / 9.7702074050903 class_wts[5] = 10 / 9.5110931396484 class_wts[6] = 10 / 10.311357498169 class_wts[7] = 10 / 10.026463508606 class_wts[8] = 10 / 4.6323022842407 class_wts[9] = 10 / 9.5608062744141 class_wts[10] = 10 / 7.8698215484619 class_wts[11] = 10 / 9.5168733596802 class_wts[12] = 10 / 10.373730659485 class_wts[13] = 10 / 6.6616044044495 class_wts[14] = 10 / 10.260489463806 class_wts[15] = 10 / 10.287888526917 class_wts[16] = 10 / 10.289801597595 class_wts[17] = 10 / 10.405355453491 class_wts[18] = 10 / 10.138095855713 class_wts[19] = 0.0 elif args.dataset == 'greenhouse': print(args.use_depth) from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation, GREENHOUSE_CLASS_LIST, color_encoding train_dataset = GreenhouseRGBDSegmentation( root=args.data_path, list_name=args.train_list, train=True, size=crop_size, scale=args.scale, use_depth=args.use_depth, use_traversable=args.greenhouse_use_trav) val_dataset = GreenhouseRGBDSegmentation( root=args.data_path, list_name=args.val_list, train=False, size=crop_size, scale=args.scale, use_depth=args.use_depth, use_traversable=args.greenhouse_use_trav) class_weights = np.load('class_weights.npy') # [:4] print(class_weights) class_wts = torch.from_numpy(class_weights).float().to(device) print(GREENHOUSE_CLASS_LIST) seg_classes = len(GREENHOUSE_CLASS_LIST) # color_encoding = OrderedDict([ # ('end_of_plant', (0, 255, 0)), # ('other_part_of_plant', (0, 255, 255)), # ('artificial_objects', (255, 0, 0)), # ('ground', (255, 255, 0)), # ('background', (0, 0, 0)) # ]) elif args.dataset == 'ishihara': print(args.use_depth) from data_loader.segmentation.ishihara_rgbd import IshiharaRGBDSegmentation, ISHIHARA_RGBD_CLASS_LIST train_dataset = IshiharaRGBDSegmentation( root=args.data_path, list_name='ishihara_rgbd_train.txt', train=True, size=crop_size, scale=args.scale, use_depth=args.use_depth) val_dataset = IshiharaRGBDSegmentation( root=args.data_path, list_name='ishihara_rgbd_val.txt', train=False, size=crop_size, scale=args.scale, use_depth=args.use_depth) seg_classes = len(ISHIHARA_RGBD_CLASS_LIST) class_wts = torch.ones(seg_classes) color_encoding = OrderedDict([('Unlabeled', (0, 0, 0)), ('Building', (70, 70, 70)), ('Fence', (190, 153, 153)), ('Others', (72, 0, 90)), ('Pedestrian', (220, 20, 60)), ('Pole', (153, 153, 153)), ('Road ', (157, 234, 50)), ('Road', (128, 64, 128)), ('Sidewalk', (244, 35, 232)), ('Vegetation', (107, 142, 35)), ('Car', (0, 0, 255)), ('Wall', (102, 102, 156)), ('Traffic ', (220, 220, 0))]) elif args.dataset == 'sun': print(args.use_depth) from data_loader.segmentation.sun_rgbd import SUNRGBDSegmentation, SUN_RGBD_CLASS_LIST train_dataset = SUNRGBDSegmentation(root=args.data_path, list_name='sun_rgbd_train.txt', train=True, size=crop_size, ignore_idx=args.ignore_idx, scale=args.scale, use_depth=args.use_depth) val_dataset = SUNRGBDSegmentation(root=args.data_path, list_name='sun_rgbd_val.txt', train=False, size=crop_size, ignore_idx=args.ignore_idx, scale=args.scale, use_depth=args.use_depth) seg_classes = len(SUN_RGBD_CLASS_LIST) class_wts = torch.ones(seg_classes) color_encoding = OrderedDict([('Background', (0, 0, 0)), ('Bed', (0, 255, 0)), ('Books', (70, 70, 70)), ('Ceiling', (190, 153, 153)), ('Chair', (72, 0, 90)), ('Floor', (220, 20, 60)), ('Furniture', (153, 153, 153)), ('Objects', (157, 234, 50)), ('Picture', (128, 64, 128)), ('Sofa', (244, 35, 232)), ('Table', (107, 142, 35)), ('TV', (0, 0, 255)), ('Wall', (102, 102, 156)), ('Window', (220, 220, 0))]) elif args.dataset == 'camvid': print(args.use_depth) from data_loader.segmentation.camvid import CamVidSegmentation, CAMVID_CLASS_LIST, color_encoding train_dataset = CamVidSegmentation( root=args.data_path, list_name='train_camvid.txt', train=True, size=crop_size, scale=args.scale, label_conversion=args.label_conversion, normalize=args.normalize) val_dataset = CamVidSegmentation( root=args.data_path, list_name='val_camvid.txt', train=False, size=crop_size, scale=args.scale, label_conversion=args.label_conversion, normalize=args.normalize) if args.label_conversion: from data_loader.segmentation.greenhouse import GREENHOUSE_CLASS_LIST, color_encoding seg_classes = len(GREENHOUSE_CLASS_LIST) class_wts = torch.ones(seg_classes) else: seg_classes = len(CAMVID_CLASS_LIST) tmp_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False) class_wts = calc_cls_class_weight(tmp_loader, seg_classes, inverted=True) class_wts = torch.from_numpy(class_wts).float().to(device) # class_wts = torch.ones(seg_classes) print("class weights : {}".format(class_wts)) args.use_depth = False elif args.dataset == 'forest': from data_loader.segmentation.freiburg_forest import FreiburgForestDataset, FOREST_CLASS_LIST, color_encoding train_dataset = FreiburgForestDataset(train=True, size=crop_size, scale=args.scale, normalize=args.normalize) val_dataset = FreiburgForestDataset(train=False, size=crop_size, scale=args.scale, normalize=args.normalize) seg_classes = len(FOREST_CLASS_LIST) tmp_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False) class_wts = calc_cls_class_weight(tmp_loader, seg_classes, inverted=True) class_wts = torch.from_numpy(class_wts).float().to(device) # class_wts = torch.ones(seg_classes) print("class weights : {}".format(class_wts)) args.use_depth = False else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espdnet': from model.segmentation.espdnet import espdnet_seg args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnet_seg(args) elif args.model == 'espdnetue': from model.segmentation.espdnet_ue import espdnetue_seg2 args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnetue_seg2(args, fix_pyr_plane_proj=True) elif args.model == 'deeplabv3': # from model.segmentation.deeplabv3 import DeepLabV3 from torchvision.models.segmentation.segmentation import deeplabv3_resnet101 args.classes = seg_classes # model = DeepLabV3(seg_classes) model = deeplabv3_resnet101(num_classes=seg_classes, aux_loss=True) torch.backends.cudnn.enabled = False elif args.model == 'unet': from model.segmentation.unet import UNet model = UNet(in_channels=3, out_channels=seg_classes) # model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', # in_channels=3, out_channels=seg_classes, init_features=32, pretrained=False) elif args.model == 'dicenet': from model.segmentation.dicenet import dicenet_seg model = dicenet_seg(args, classes=seg_classes) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if args.finetune: if os.path.isfile(args.finetune): print_info_message('Loading weights for finetuning from {}'.format( args.finetune)) weight_dict = torch.load(args.finetune, map_location=torch.device(device='cpu')) model.load_state_dict(weight_dict) print_info_message('Done') else: print_warning_message('No file for finetuning. Please check.') if args.freeze_bn: print_info_message('Freezing batch normalization layers') for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False if args.model == 'deeplabv3' or args.model == 'unet': train_params = [{'params': model.parameters(), 'lr': args.lr}] elif args.use_depth: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }, { 'params': model.get_depth_encoder_params(), 'lr': args.lr * args.lr_mult }] else: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, crop_size[0], crop_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[0], crop_size[1], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.Tensor(1, 3, 288, 480)) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) start_epoch = 0 best_miou = 0.0 if args.resume: if os.path.isfile(args.resume): print_info_message("=> loading checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) start_epoch = checkpoint['epoch'] best_miou = checkpoint['best_miou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print_info_message("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print_warning_message("=> no checkpoint found at '{}'".format( args.resume)) print('device : ' + device) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) nid_loss = NIDLoss(image_bin=32, label_bin=seg_classes) if args.use_nid else None if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() if args.use_nid: nid_loss.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if args.use_nid: nid_loss = DataParallelCriteria(nid_loss) nid_loss = nid_loss.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=20, shuffle=False, pin_memory=True, num_workers=args.workers) if args.scheduler == 'fixed': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import FixedMultiStepLR lr_scheduler = FixedMultiStepLR(base_lr=args.lr, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'clr': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import CyclicLR lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=5, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs, power=args.power) elif args.scheduler == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max, cycle_len=args.cycle_len) elif args.scheduler == 'linear': from utilities.lr_scheduler import LinearLR lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs) else: print_error_message('{} scheduler Not supported'.format( args.scheduler)) exit() print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base if len(optimizer.param_groups) > 1: optimizer.param_groups[1]['lr'] = lr_seg if args.use_depth: optimizer.param_groups[2]['lr'] = lr_base print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) if args.model == 'espdnetue' or ( (args.model == 'deeplabv3' or args.model == 'unet') and args.use_aux): from utilities.train_eval_seg import train_seg_ue as train from utilities.train_eval_seg import val_seg_ue as val else: from utilities.train_eval_seg import train_seg as train from utilities.train_eval_seg import val_seg as val iou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device, use_depth=args.use_depth, add_criterion=nid_loss) iou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device, use_depth=args.use_depth, add_criterion=nid_loss) batch_train = iter(train_loader).next() batch = iter(val_loader).next() if args.use_depth: in_training_visualization_img( model, images=batch_train[0].to(device=device), depths=batch_train[2].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), depths=batch[2].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/val', device=device) image_grid = torchvision.utils.make_grid( batch[2].to(device=device).data.cpu()).numpy() print(type(image_grid)) writer.add_image('Segmentation/depths', image_grid, epoch) else: in_training_visualization_img( model, images=batch_train[0].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/val', device=device) # image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy() # writer.add_image('Segmentation/results/val', image_grid, epoch) # remember best miou and save checkpoint miou_val = iou_val[[1, 2, 3]].mean() is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', iou_train[[1, 2, 3]].mean(), epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/plant_IOU/val', iou_val[1], epoch) writer.add_scalar('Segmentation/ao_IOU/val', iou_val[2], epoch) writer.add_scalar('Segmentation/ground_IOU/val', iou_val[3], epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()
def main(args): # read all the images in the folder if args.dataset == 'city': # image_path = os.path.join(args.data_path, "leftImg8bit", args.split, "*", "*.png") # image_list = glob.glob(image_path) # from data_loader.segmentation.cityscapes import CITYSCAPE_CLASS_LIST # seg_classes = len(CITYSCAPE_CLASS_LIST) from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST val_dataset = CityscapesSegmentation(root=args.data_path, train=False, size=(256, 256), scale=args.s, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 2.8149201869965 class_wts[1] = 6.9850029945374 class_wts[2] = 3.7890393733978 class_wts[3] = 9.9428062438965 class_wts[4] = 9.7702074050903 class_wts[5] = 9.5110931396484 class_wts[6] = 10.311357498169 class_wts[7] = 10.026463508606 class_wts[8] = 4.6323022842407 class_wts[9] = 9.5608062744141 class_wts[10] = 7.8698215484619 class_wts[11] = 9.5168733596802 class_wts[12] = 10.373730659485 class_wts[13] = 6.6616044044495 class_wts[14] = 10.260489463806 class_wts[15] = 10.287888526917 class_wts[16] = 10.289801597595 class_wts[17] = 10.405355453491 class_wts[18] = 10.138095855713 class_wts[19] = 0.0 elif args.dataset == 'pascal': # from data_loader.segmentation.voc import VOC_CLASS_LIST # seg_classes = len(VOC_CLASS_LIST) # data_file = os.path.join(args.data_path, 'VOC2012', 'list', '{}.txt'.format(args.split)) # if not os.path.isfile(data_file): # print_error_message('{} file does not exist'.format(data_file)) # image_list = [] # with open(data_file, 'r') as lines: # for line in lines: # rgb_img_loc = '{}/{}/{}'.format(args.data_path, 'VOC2012', line.split()[0]) # if not os.path.isfile(rgb_img_loc): # print_error_message('{} image file does not exist'.format(rgb_img_loc)) # image_list.append(rgb_img_loc) from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=(256, 256), scale=args.s) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'hockey': from data_loader.segmentation.hockey import HockeySegmentationDataset, HOCKEY_DATASET_CLASS_LIST train_dataset = HockeySegmentationDataset(root=args.data_path, train=True, crop_size=(256, 256), scale=args.s) val_dataset = HockeySegmentationDataset(root=args.data_path, train=False, crop_size=(256, 256), scale=args.s) seg_classes = len(HOCKEY_DATASET_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'hockey_rink_seg': from data_loader.segmentation.hockey_rink_seg import HockeyRinkSegmentationDataset, HOCKEY_DATASET_CLASS_LIST train_dataset = HockeyRinkSegmentationDataset(root=args.data_path, train=True, crop_size=(256, 256), scale=args.s) val_dataset = HockeyRinkSegmentationDataset(root=args.data_path, train=False, crop_size=(256, 256), scale=args.s) seg_classes = len(HOCKEY_DATASET_CLASS_LIST) class_wts = torch.ones(seg_classes) else: print_error_message('{} dataset not yet supported'.format( args.dataset)) if len(val_dataset) == 0: print_error_message('No files in directory: {}'.format(image_path)) print_info_message('# of images for testing: {}'.format(len(val_dataset))) if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'dicenet': from model.segmentation.dicenet import dicenet_seg model = dicenet_seg(args, classes=seg_classes) else: print_error_message('{} network not yet supported'.format(args.model)) exit(-1) # model information num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, args.im_size[0], args.im_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( args.im_size[0], args.im_size[1], flops)) print_info_message('# of parameters: {}'.format(num_params)) if args.weights_test: print_info_message('Loading model weights') weight_dict = torch.load(args.weights_test, map_location=torch.device('cpu')) if isinstance(weight_dict, dict) and 'state_dict' in weight_dict: model.load_state_dict(weight_dict['state_dict']) else: model.load_state_dict(weight_dict) print_info_message('Weight loaded successfully') else: print_error_message( 'weight file does not exist or not specified. Please check: {}', format(args.weights_test)) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' model = model.to(device=device) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=40, shuffle=False, pin_memory=True, num_workers=4) criterion = SegmentationLoss(n_classes=seg_classes, loss_type='ce', device=device, ignore_idx=255, class_wts=class_wts.to(device)) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # evaluate(args, model, image_list, seg_classes, device=device) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device) print_info_message('mIOU: {}'.format(miou_val))
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' print('device : ' + device) # Get a summary writer for tensorboard writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') # # Training the model with 13 classes of CamVid dataset # TODO: This process should be done only if specified # if not args.finetune: train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset( label_conversion=False) # 13 classes args.use_depth = False # 'use_depth' is always false for camvid print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) # Import model if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espdnet': from model.segmentation.espdnet import espdnet_seg args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnet_seg(args) elif args.model == 'espdnetue': from model.segmentation.espdnet_ue import espdnetue_seg2 args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) ("Segmentation classes : {}".format(seg_classes)) print(args.weights) model = espdnetue_seg2(args, False, fix_pyr_plane_proj=True) else: print_error_message('Arch: {} not yet supported'.format( args.model)) exit(-1) # Freeze batch normalization layers? if args.freeze_bn: freeze_bn_layer(model) # Set learning rates train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] # Define an optimizer optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) # Compute the FLOPs and the number of parameters, and display it num_params, flops = show_network_stats(model, crop_size) try: writer.add_graph(model, input_to_model=torch.Tensor( 1, 3, crop_size[0], crop_size[1])) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) nid_loss = NIDLoss(image_bin=32, label_bin=seg_classes) if args.use_nid else None if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() if args.use_nid: nid_loss.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if args.use_nid: nid_loss = DataParallelCriteria(nid_loss) nid_loss = nid_loss.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # Get data loaders for training and validation data train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=20, shuffle=False, pin_memory=True, num_workers=args.workers) # Get a learning rate scheduler lr_scheduler = get_lr_scheduler(args.scheduler) write_stats_to_json(num_params, flops) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) # # Main training loop of 13 classes # start_epoch = 0 best_miou = 0.0 for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) # Use different training functions for espdnetue if args.model == 'espdnetue': from utilities.train_eval_seg import train_seg_ue as train from utilities.train_eval_seg import val_seg_ue as val else: from utilities.train_eval_seg import train_seg as train from utilities.train_eval_seg import val_seg as val miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device, use_depth=args.use_depth, add_criterion=nid_loss) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device, use_depth=args.use_depth, add_criterion=nid_loss) batch_train = iter(train_loader).next() batch = iter(val_loader).next() in_training_visualization_img( model, images=batch_train[0].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/val', device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) # Save the pretrained weights model_dict = copy.deepcopy(model.state_dict()) del model torch.cuda.empty_cache() # # Finetuning with 4 classes # args.ignore_idx = 4 train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset( label_conversion=True) # 5 classes print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) #set_parameters_for_finetuning() # Import model if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espdnet': from model.segmentation.espdnet import espdnet_seg args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnet_seg(args) elif args.model == 'espdnetue': from model.segmentation.espdnet_ue import espdnetue_seg2 args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) print(args.weights) model = espdnetue_seg2(args, args.finetune, fix_pyr_plane_proj=True) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if not args.finetune: new_model_dict = model.state_dict() # for k, v in model_dict.items(): # if k.lstrip('module.') in new_model_dict: # print('In:{}'.format(k.lstrip('module.'))) # else: # print('Not In:{}'.format(k.lstrip('module.'))) overlap_dict = { k.replace('module.', ''): v for k, v in model_dict.items() if k.replace('module.', '') in new_model_dict and new_model_dict[k.replace('module.', '')].size() == v.size() } no_overlap_dict = { k.replace('module.', ''): v for k, v in new_model_dict.items() if k.replace('module.', '') not in new_model_dict or new_model_dict[k.replace('module.', '')].size() != v.size() } print(no_overlap_dict.keys()) new_model_dict.update(overlap_dict) model.load_state_dict(new_model_dict) output = model(torch.ones(1, 3, 288, 480)) print(output[0].size()) print(seg_classes) print(class_wts.size()) #print(model_dict.keys()) #print(new_model_dict.keys()) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) nid_loss = NIDLoss(image_bin=32, label_bin=seg_classes) if args.use_nid else None # Set learning rates args.lr /= 100 train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] # Define an optimizer optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() if args.use_nid: nid_loss.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if args.use_nid: nid_loss = DataParallelCriteria(nid_loss) nid_loss = nid_loss.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # Get data loaders for training and validation data train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=20, shuffle=False, pin_memory=True, num_workers=args.workers) # Get a learning rate scheduler args.epochs = 50 lr_scheduler = get_lr_scheduler(args.scheduler) # Compute the FLOPs and the number of parameters, and display it num_params, flops = show_network_stats(model, crop_size) write_stats_to_json(num_params, flops) extra_info_ckpt = '{}_{}_{}_{}'.format(args.model, seg_classes, args.s, crop_size[0]) # # Main training loop of 13 classes # start_epoch = 0 best_miou = 0.0 for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) # Use different training functions for espdnetue if args.model == 'espdnetue': from utilities.train_eval_seg import train_seg_ue as train from utilities.train_eval_seg import val_seg_ue as val else: from utilities.train_eval_seg import train_seg as train from utilities.train_eval_seg import val_seg as val miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device, use_depth=args.use_depth, add_criterion=nid_loss) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device, use_depth=args.use_depth, add_criterion=nid_loss) batch_train = iter(train_loader).next() batch = iter(val_loader).next() in_training_visualization_img(model, images=batch_train[0].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='SegmentationConv/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='SegmentationConv/val', device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('SegmentationConv/LR/base', round(lr_base, 6), epoch) writer.add_scalar('SegmentationConv/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('SegmentationConv/Loss/train', train_loss, epoch) writer.add_scalar('SegmentationConv/Loss/val', val_loss, epoch) writer.add_scalar('SegmentationConv/mIOU/train', miou_train, epoch) writer.add_scalar('SegmentationConv/mIOU/val', miou_val, epoch) writer.add_scalar('SegmentationConv/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('SegmentationConv/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()