def train(args): dataset = args.dataset data_path = args.data_path class_path = args.class_path checkpoint_path = args.checkpoint_path input_height = args.input_height input_width = args.input_width batch_size = args.batch_size num_epochs = args.num_epochs lr = args.lr weight_decay = args.weight_decay dropout = args.dropout l_coord = args.l_coord l_noobj = args.l_noobj num_gpus = [i for i in range(args.num_gpus)] num_class = args.num_class USE_AUGMENTATION = args.use_augmentation # USE_VISDOM = args.use_visdom # USE_WANDB = args.use_wandb USE_SUMMARY = args.use_summary if USE_AUGMENTATION: seq = iaa.SomeOf(2, [ iaa.Multiply((1.2, 1.5)), iaa.Affine(translate_px={ "x": 3, "y": 10 }, scale=(0.9, 0.9)), iaa.AdditiveGaussianNoise(scale=0.1 * 255), iaa.CoarseDropout(0.02, size_percent=0.15, per_channel=0.5), iaa.Affine(rotate=45), iaa.Sharpen(alpha=0.5) ]) else: seq = iaa.Sequential([]) composed = transforms.Compose([Augmenter(seq)]) # DataLoader train_dataset = VOC(root=data_path, transform=composed, class_path=class_path) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=detection_collate) # model model = models.YOLOv1(num_class, dropout) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') if torch.cuda.is_available(): model = torch.nn.DataParallel(model, device_ids=num_gpus).to(device) else: model = torch.nn.DataParallel(model) if USE_SUMMARY: summary(model, (3, 448, 448)) optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95) total_step = 0 # total_train_step = num_epochs * total_step for epoch in range(1, num_epochs + 1): if (epoch == 200) or (epoch == 400) or (epoch == 600) or ( epoch == 20000) or (epoch == 30000): scheduler.step() for i, (images, labels, sizes) in enumerate(train_loader): total_step += 1 images = images.to(device) labels = labels.to(device) pred = model(images) loss, losses = detection_loss_4_yolo(pred, labels, l_coord, l_noobj, device) coord_loss = losses[0] size_loss = losses[1] objness_loss = losses[2] noobjness_loss = losses[3] class_loss = losses[4] optimizer.zero_grad() loss.backward() optimizer.step() if total_step % 100 == 0: print("epoch: [{}/{}], step:{}, lr:{}, total_loss:{:.4f}, \ \ncoord:{:.4f}, size:{:.4f}, objness:{:.4f}, noobjness:{:.4f}, class:{:.4f}" .format(epoch, num_epochs, total_step, ([ param['lr'] for param in optimizer.param_groups ])[0], loss.item(), coord_loss, size_loss, objness_loss, noobjness_loss, class_loss)) if epoch % 1000 == 0: save_checkpoint( { "epoch": epoch, "arch": "YoloV1", "state_dict": model.state.dict(), "optimizer": optimizer.state.dict() }, False, filename=os.path.join( checkpoint_path, "ckpt_ep{:.05d}_loss{:.04f}_lr{}.pth.tar".format( epoch, loss.item(), ([param['lr'] for param in optimizer.param_group])[0])))
def train(params): # future work variable dataset = params["dataset"] input_height = params["input_height"] input_width = params["input_width"] data_path = params["data_path"] class_path = params["class_path"] batch_size = params["batch_size"] num_epochs = params["num_epochs"] learning_rate = params["lr"] dropout = params["dropout"] num_gpus = [i for i in range(params["num_gpus"])] checkpoint_path = params["checkpoint_path"] USE_VISDOM = params["use_visdom"] USE_WANDB = params["use_wandb"] USE_SUMMARY = params["use_summary"] USE_AUGMENTATION = params["use_augmentation"] USE_GTCHECKER = params["use_gtcheck"] USE_GITHASH = params["use_githash"] num_class = params["num_class"] if (USE_WANDB): wandb.init() wandb.config.update( params) # adds all of the arguments as config variables with open(class_path) as f: class_list = f.read().splitlines() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') if (USE_GITHASH): repo = git.Repo(search_parent_directories=True) sha = repo.head.object.hexsha short_sha = repo.git.rev_parse(sha, short=7) if USE_VISDOM: viz = visdom.Visdom(use_incoming_socket=False) vis_title = 'Yolo V1 Deepbaksu_vision (feat. martin, visionNoob) PyTorch on ' + 'VOC' vis_legend = ['Train Loss'] iter_plot = create_vis_plot(viz, 'Iteration', 'Total Loss', vis_title, vis_legend) coord1_plot = create_vis_plot(viz, 'Iteration', 'coord1', vis_title, vis_legend) size1_plot = create_vis_plot(viz, 'Iteration', 'size1', vis_title, vis_legend) noobjectness1_plot = create_vis_plot(viz, 'Iteration', 'noobjectness1', vis_title, vis_legend) objectness1_plot = create_vis_plot(viz, 'Iteration', 'objectness1', vis_title, vis_legend) obj_cls_plot = create_vis_plot(viz, 'Iteration', 'obj_cls', vis_title, vis_legend) # 2. Data augmentation setting if (USE_AUGMENTATION): seq = iaa.SomeOf( 2, [ iaa.Multiply( (1.2, 1.5)), # change brightness, doesn't affect BBs iaa.Affine( translate_px={ "x": 3, "y": 10 }, scale=(0.9, 0.9) ), # translate by 40/60px on x/y axis, and scale to 50-70%, affects BBs iaa.AdditiveGaussianNoise(scale=0.1 * 255), iaa.CoarseDropout(0.02, size_percent=0.15, per_channel=0.5), iaa.Affine(rotate=45), iaa.Sharpen(alpha=0.5) ]) else: seq = iaa.Sequential([]) composed = transforms.Compose([Augmenter(seq)]) # 3. Load Dataset # composed # transforms.ToTensor train_dataset = VOC(root=data_path, transform=composed, class_path=class_path) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=detection_collate) # 5. Load YOLOv1 net = yolov1.YOLOv1(params={"dropout": dropout, "num_class": num_class}) # model = torch.nn.DataParallel(net, device_ids=num_gpus).cuda() print("device : ", device) if device.type == 'cpu': model = torch.nn.DataParallel(net) else: model = torch.nn.DataParallel(net, device_ids=num_gpus).cuda() if USE_SUMMARY: summary(model, (3, 448, 448)) # 7.Train the model optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95) # Train the model total_step = len(train_loader) total_train_step = num_epochs * total_step # for epoch in range(num_epochs): for epoch in range(1, num_epochs + 1): if (epoch == 200) or (epoch == 400) or (epoch == 600) or ( epoch == 20000) or (epoch == 30000): scheduler.step() for i, (images, labels, sizes) in enumerate(train_loader): current_train_step = (epoch) * total_step + (i + 1) if USE_GTCHECKER: visualize_GT(images, labels, class_list) images = images.to(device) labels = labels.to(device) # Forward pass outputs = model(images) # Calc Loss loss, \ obj_coord1_loss, \ obj_size1_loss, \ obj_class_loss, \ noobjness1_loss, \ objness1_loss = detection_loss_4_yolo(outputs, labels, device.type) # objness1_loss = detection_loss_4_yolo(outputs, labels) # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() if (((current_train_step) % 100) == 0) or (current_train_step % 10 == 0 and current_train_step < 100): print( 'epoch: [{}/{}], total step: [{}/{}], batch step [{}/{}], lr: {}, total_loss: {:.4f}, coord1: {:.4f}, size1: {:.4f}, noobj_clss: {:.4f}, objness1: {:.4f}, class_loss: {:.4f}' .format(epoch + 1, num_epochs, current_train_step, total_train_step, i + 1, total_step, ([ param_group['lr'] for param_group in optimizer.param_groups ])[0], loss.item(), obj_coord1_loss, obj_size1_loss, noobjness1_loss, objness1_loss, obj_class_loss)) if USE_VISDOM: update_vis_plot(viz, (epoch + 1) * total_step + (i + 1), loss.item(), iter_plot, None, 'append') update_vis_plot(viz, (epoch + 1) * total_step + (i + 1), obj_coord1_loss, coord1_plot, None, 'append') update_vis_plot(viz, (epoch + 1) * total_step + (i + 1), obj_size1_loss, size1_plot, None, 'append') update_vis_plot(viz, (epoch + 1) * total_step + (i + 1), obj_class_loss, obj_cls_plot, None, 'append') update_vis_plot(viz, (epoch + 1) * total_step + (i + 1), noobjness1_loss, noobjectness1_plot, None, 'append') update_vis_plot(viz, (epoch + 1) * total_step + (i + 1), objness1_loss, objectness1_plot, None, 'append') if USE_WANDB: wandb.log({ 'total_loss': loss.item(), 'obj_coord1_loss': obj_coord1_loss, 'obj_size1_loss': obj_size1_loss, 'obj_class_loss': obj_class_loss, 'noobjness1_loss': noobjness1_loss, 'objness1_loss': objness1_loss }) if not USE_GITHASH: short_sha = 'noHash' # if ((epoch % 1000) == 0) and (epoch != 0): if ((epoch % 1000) == 0): save_checkpoint( { 'epoch': epoch + 1, 'arch': "YOLOv1", 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, False, filename=os.path.join( checkpoint_path, 'ckpt_{}_ep{:05d}_loss{:.04f}_lr{}.pth.tar'.format( short_sha, epoch, loss.item(), ([ param_group['lr'] for param_group in optimizer.param_groups ])[0])))
def main(args): if args.im_size in [300, 512]: from model.detection.ssd_config import get_config cfg = get_config(args.im_size) else: print_error_message('{} image size not supported'.format(args.im_size)) # ----------------------------------------------------------------------------- # Dataset # ----------------------------------------------------------------------------- train_transform = TrainTransform(cfg.image_size) target_transform = MatchPrior( PriorBox(cfg)(), cfg.center_variance, cfg.size_variance, cfg.iou_threshold) val_transform = ValTransform(cfg.image_size) if args.dataset in ['voc', 'pascal']: from data_loader.detection.voc import VOCDataset, VOC_CLASS_LIST train_dataset_2007 = VOCDataset(root_dir=args.data_path, transform=train_transform, target_transform=target_transform, is_training=True, split="VOC2007") train_dataset_2012 = VOCDataset(root_dir=args.data_path, transform=train_transform, target_transform=target_transform, is_training=True, split="VOC2012") train_dataset = torch.utils.data.ConcatDataset( [train_dataset_2007, train_dataset_2012]) val_dataset = VOCDataset(root_dir=args.data_path, transform=val_transform, target_transform=target_transform, is_training=False, split="VOC2007") num_classes = len(VOC_CLASS_LIST) elif args.dataset == 'coco': from data_loader.detection.coco import COCOObjectDetection, COCO_CLASS_LIST train_dataset = COCOObjectDetection(root_dir=args.data_path, transform=train_transform, target_transform=target_transform, is_training=True) val_dataset = COCOObjectDetection(root_dir=args.data_path, transform=val_transform, target_transform=target_transform, is_training=False) num_classes = len(COCO_CLASS_LIST) else: print_error_message('{} dataset is not supported yet'.format( args.dataset)) exit() cfg.NUM_CLASSES = num_classes # ----------------------------------------------------------------------------- # Dataset loader # ----------------------------------------------------------------------------- print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # ----------------------------------------------------------------------------- # Model # ----------------------------------------------------------------------------- model = ssd(args, cfg) if args.finetune: if os.path.isfile(args.finetune): print_info_message('Loading weights for finetuning from {}'.format( args.finetune)) weight_dict = torch.load(args.finetune, map_location=torch.device(device='cpu')) model.load_state_dict(weight_dict) print_info_message('Done') else: print_warning_message('No file for finetuning. Please check.') if args.freeze_bn: print_info_message('Freezing batch normalization layers') for m in model.modules(): if isinstance(m, torch.nn.BatchNorm2d): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False # ----------------------------------------------------------------------------- # Optimizer and Criterion # ----------------------------------------------------------------------------- optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) criterion = MultiBoxLoss(neg_pos_ratio=cfg.neg_pos_ratio) # writer for logs writer = SummaryWriter(log_dir=args.save, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.Tensor(1, 3, cfg.image_size, cfg.image_size)) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) #model stats num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, cfg.image_size, cfg.image_size)) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( cfg.image_size, cfg.image_size, flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus >= 1 else 'cpu' min_val_loss = float('inf') start_epoch = 0 # start from epoch 0 or last epoch if args.resume: if os.path.isfile(args.resume): print_info_message("=> loading checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu')) model.load_state_dict(checkpoint['state_dict']) min_val_loss = checkpoint['min_loss'] start_epoch = checkpoint['epoch'] else: print_warning_message("=> no checkpoint found at '{}'".format( args.resume)) if num_gpus >= 1: model = torch.nn.DataParallel(model) model = model.to(device) if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # ----------------------------------------------------------------------------- # Scheduler # ----------------------------------------------------------------------------- if args.lr_type == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs, power=args.power) elif args.lr_type == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max, cycle_len=args.cycle_len) elif args.lr_type == 'clr': from utilities.lr_scheduler import CyclicLR lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=args.cycle_len, steps=args.steps, gamma=args.gamma, step=True) elif args.lr_type == 'cosine': from utilities.lr_scheduler import CosineLR lr_scheduler = CosineLR(base_lr=args.lr, max_epochs=args.epochs) else: print_error_message('{} scheduler not yet supported'.format( args.lr_type)) exit() print_info_message(lr_scheduler) # ----------------------------------------------------------------------------- # Training and validation loop # ----------------------------------------------------------------------------- extra_info_ckpt = '{}_{}'.format(args.model, args.s) for epoch in range(start_epoch, args.epochs): curr_lr = lr_scheduler.step(epoch) optimizer.param_groups[0]['lr'] = curr_lr print_info_message('Running epoch {} at LR {}'.format(epoch, curr_lr)) train_loss, train_cl_loss, train_loc_loss = train(train_loader, model, criterion, optimizer, device, epoch=epoch) val_loss, val_cl_loss, val_loc_loss = validate(val_loader, model, criterion, device, epoch=epoch) # Save checkpoint is_best = val_loss < min_val_loss min_val_loss = min(val_loss, min_val_loss) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch, 'model': args.model, 'state_dict': weights_dict, 'min_loss': min_val_loss }, is_best, args.save, extra_info_ckpt) writer.add_scalar('Detection/LR/learning_rate', round(curr_lr, 6), epoch) writer.add_scalar('Detection/Loss/train', train_loss, epoch) writer.add_scalar('Detection/Loss/val', val_loss, epoch) writer.add_scalar('Detection/Loss/train_cls', train_cl_loss, epoch) writer.add_scalar('Detection/Loss/val_cls', val_cl_loss, epoch) writer.add_scalar('Detection/Loss/train_loc', train_loc_loss, epoch) writer.add_scalar('Detection/Loss/val_loc', val_loc_loss, epoch) writer.add_scalar('Detection/Complexity/Flops', min_val_loss, math.ceil(flops)) writer.add_scalar('Detection/Complexity/Params', min_val_loss, math.ceil(num_params)) writer.close()
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' if args.dataset == 'pascal': from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST train_dataset = VOCSegmentation(root=args.data_path, train=True, crop_size=crop_size, scale=args.scale, coco_root_dir=args.coco_path) val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=crop_size, scale=args.scale) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'city': from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST train_dataset = CityscapesSegmentation(root=args.data_path, train=True, size=crop_size, scale=args.scale, coarse=args.coarse) val_dataset = CityscapesSegmentation(root=args.data_path, train=False, size=crop_size, scale=args.scale, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 2.8149201869965 class_wts[1] = 6.9850029945374 class_wts[2] = 3.7890393733978 class_wts[3] = 9.9428062438965 class_wts[4] = 9.7702074050903 class_wts[5] = 9.5110931396484 class_wts[6] = 10.311357498169 class_wts[7] = 10.026463508606 class_wts[8] = 4.6323022842407 class_wts[9] = 9.5608062744141 class_wts[10] = 7.8698215484619 class_wts[11] = 9.5168733596802 class_wts[12] = 10.373730659485 class_wts[13] = 6.6616044044495 class_wts[14] = 10.260489463806 class_wts[15] = 10.287888526917 class_wts[16] = 10.289801597595 class_wts[17] = 10.405355453491 class_wts[18] = 10.138095855713 class_wts[19] = 0.0 elif args.dataset == 'greenhouse': print(args.use_depth) from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation, GreenhouseDepth, GREENHOUSE_CLASS_LIST train_dataset = GreenhouseDepth(root=args.data_path, list_name='train_depth_ae.txt', train=True, size=crop_size, scale=args.scale, use_filter=True) val_dataset = GreenhouseRGBDSegmentation(root=args.data_path, list_name='val_depth_ae.txt', train=False, size=crop_size, scale=args.scale, use_depth=True) class_weights = np.load('class_weights.npy')[:4] print(class_weights) class_wts = torch.from_numpy(class_weights).float().to(device) seg_classes = len(GREENHOUSE_CLASS_LIST) else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) from model.autoencoder.depth_autoencoder import espnetv2_autoenc args.classes = 3 model = espnetv2_autoenc(args) train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr * args.lr_mult }] optimizer = optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 1, crop_size[0], crop_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[0], crop_size[1], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.Tensor(1, 3, crop_size[0], crop_size[1])) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) start_epoch = 0 print('device : ' + device) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) #criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, # device=device, ignore_idx=args.ignore_idx, # class_wts=class_wts.to(device)) criterion = nn.MSELoss() # criterion = nn.L1Loss() if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) if args.scheduler == 'fixed': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import FixedMultiStepLR lr_scheduler = FixedMultiStepLR(base_lr=args.lr, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'clr': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import CyclicLR lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=5, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs, power=args.power) elif args.scheduler == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max, cycle_len=args.cycle_len) elif args.scheduler == 'linear': from utilities.lr_scheduler import LinearLR lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs) else: print_error_message('{} scheduler Not supported'.format( args.scheduler)) exit() print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) best_loss = 0.0 for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_seg # optimizer.param_groups[1]['lr'] = lr_seg # Train model.train() losses = AverageMeter() for i, batch in enumerate(train_loader): inputs = batch[1].to(device=device) # Depth target = batch[0].to(device=device) # RGB outputs = model(inputs) if device == 'cuda': loss = criterion(outputs, target).mean() if isinstance(outputs, (list, tuple)): target_dev = outputs[0].device outputs = gather(outputs, target_device=target_dev) else: loss = criterion(outputs, target) losses.update(loss.item(), inputs.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() # if not (i % 10): # print("Step {}, write images".format(i)) # image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy() # writer.add_image('Autoencoder/results/train', image_grid, len(train_loader) * epoch + i) writer.add_scalar('Autoencoder/Loss/train', loss.item(), len(train_loader) * epoch + i) print_info_message('Running batch {}/{} of epoch {}'.format( i + 1, len(train_loader), epoch + 1)) train_loss = losses.avg writer.add_scalar('Autoencoder/LR/seg', round(lr_seg, 6), epoch) # Val if epoch % 5 == 0: losses = AverageMeter() with torch.no_grad(): for i, batch in enumerate(val_loader): inputs = batch[2].to(device=device) # Depth target = batch[0].to(device=device) # RGB outputs = model(inputs) if device == 'cuda': loss = criterion(outputs, target) # .mean() if isinstance(outputs, (list, tuple)): target_dev = outputs[0].device outputs = gather(outputs, target_device=target_dev) else: loss = criterion(outputs, target) losses.update(loss.item(), inputs.size(0)) image_grid = torchvision.utils.make_grid( outputs.data.cpu()).numpy() writer.add_image('Autoencoder/results/val', image_grid, epoch) image_grid = torchvision.utils.make_grid( inputs.data.cpu()).numpy() writer.add_image('Autoencoder/inputs/val', image_grid, epoch) image_grid = torchvision.utils.make_grid( target.data.cpu()).numpy() writer.add_image('Autoencoder/target/val', image_grid, epoch) val_loss = losses.avg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) # remember best miou and save checkpoint is_best = val_loss < best_loss best_loss = min(val_loss, best_loss) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Autoencoder/Loss/val', val_loss, epoch) writer.close()
def train(self): # Initialize saver, model parameters (hidden inputs to lstm) # Load model if needed. self.model.train() start_time = time.time() avg_loss, avg_elloss, avg_mtypeloss = 0.0, 0.0, 0.0 epochs = self.tr_reader.tr_epochs steps = 0 ncorrect, ntotal = 0, 0 ncorrectOA, ntotalOA = 0, 0 ncorrectB, ntotalB = 0, 0 bestmodel, bestval, beststep = self.model, 0.0, 0 bestFinalVal = 0.0 readtime, convtime, processtime = 0, 0, 0 # while ((steps < maxsteps and bestFinalVal < 0.999) or # (CURR_SWITCHES < len(CURRICULUM_ORDER) - 1)): while steps < maxtrsteps: steps += 1 # print(curr) rtimestart = time.time() b = self.tr_reader.next_train_batch() (leftb, leftlens, rightb, rightlens, docb, typesb, wididxsb, widprobsb) = (b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]) (ind, vals, dvsize) = docb readtime += (time.time() - rtimestart) ctimestart = time.time() ind = torch.LongTensor(ind) vals = torch.FloatTensor(vals) docb = torch.sparse.FloatTensor(ind.t(), vals, torch.Size(dvsize)) (leftb, leftlens, rightb, rightlens, typesb, wididxsb, widprobsb) = (torch.FloatTensor(leftb), torch.LongTensor(leftlens), torch.FloatTensor(rightb), torch.LongTensor(rightlens), torch.FloatTensor(typesb), torch.LongTensor(wididxsb), torch.FloatTensor(widprobsb)) (leftb, leftlens, rightb, rightlens, docb, typesb, wididxsb, widprobsb) = utils.toCudaVariable(device_id, leftb, leftlens, rightb, rightlens, docb, typesb, wididxsb, widprobsb) truewidvec = utils.toCudaVariable(device_id, torch.LongTensor([0] * bs))[0] convtime += (time.time() - ctimestart) ptimestart = time.time() rets = self.model.forward_context(leftb=leftb, leftlens=leftlens, rightb=rightb, rightlens=rightlens, docb=docb, wididxsb=wididxsb) (wididxscores, wididxprobs, mentype_probs) = (rets[0], rets[1], rets[2]) (loss, elloss, mentype_loss) = self.model.lossfunc(mentype=mentype, predwidscores=wididxscores, truewidvec=truewidvec, mentype_probs=mentype_probs, mentype_trueprobs=typesb) self.optstep(loss) loss = loss.data.cpu().numpy()[0] elloss = elloss.data.cpu().numpy()[0] mentype_loss = mentype_loss.data.cpu().numpy()[0] avg_loss += loss avg_elloss += elloss avg_mtypeloss += mentype_loss processtime += (time.time() - ptimestart) if steps % log_interval == 0: totaltime = readtime + processtime + convtime print() avg_loss = utils.round_all(avg_loss / log_interval, 3) avg_elloss = utils.round_all(avg_elloss / log_interval, 3) avg_mtypeloss = utils.round_all(avg_mtypeloss / log_interval, 3) print("[{}, {}, rt:{:0.1f} secs ct:{:0.1f} pt:{:0.1f} " "tt:{:0.1f} secs]: L:{} EL:{} MenTypL:{}".format( steps, self.tr_reader.tr_epochs, readtime, convtime, processtime, totaltime, avg_loss, avg_elloss, avg_mtypeloss)) readtime, convtime, processtime = 0, 0, 0 # tracc = float(ncorrect)/float(ntotal) # oAtracc = float(ncorrectOA)/float(ntotalOA) if ntotalOA != 0.0 else 0.0 # Btracc = float(ncorrectB)/float(ntotalB) if ntotalB != 0.0 else 0.0 # avg_loss /= log_interval # time_elapsed = float(time.time() - start_time)/60.0 # print("[{}, {}, {:0.1f} mins]: {}".format( # steps, self.tr_reader.epochs, # time_elapsed, avg_loss)) # print("TrAcc: {} / {} : {:.3f}".format( # ncorrect, ntotal, tracc)) # print("OA : {}/{}: {}".format(ncorrectOA, ntotalOA, oAtracc)) # print("Bool : {}/{}: {}".format(ncorrectB, ntotalB, Btracc)) # avg_loss = 0.0 # ntotal=0 # ncorrect=0 # ntotalOA = 0 # ncorrectOA = 0 # ntotalB = 0 # ncorrectB = 0 # if epochs != self.tr_reader.epochs or steps % 15000 == 0: if steps % 1000 == 0: print("Running Validation") print("Saving model: {}".format(ckptpath)) bestmodel = copy.deepcopy(self.model) beststep = steps utils.save_checkpoint(m=bestmodel, o=self.optimizer, steps=steps, beststeps=beststep, path=ckptpath) self.validation() # (vt, vc, va) = self.validation_performance() # if va > bestval: # bestval = va # bestmodel = copy.deepcopy(self.model) # beststep = steps # if bestval == 0.0 and va == 0.0: # keep latest model # bestval = va # bestmodel = copy.deepcopy(self.model) # beststep = steps # # Check if final curricula is reached, then update bestFinalVal # if CURR_SWITCHES == len(CURRICULUM_ORDER) - 1: # bestFinalVal = bestval # print("[##] Total: {}. Correct: {}. Acc: {:0.3f} " # "[B:{:.3f} E:{}]".format(vt, vc, va, bestval, beststep)) # print("[##] Best Final Val : {}\n".format(bestFinalVal)) # print("Saving model: {}".format(ckptpath)) # # Saving latest model # bestmodel = copy.deepcopy(self.model) # utils.save_checkpoint(m=bestmodel, o=self.optimizer, # steps=steps, beststeps=beststep, # path=ckptpath) # epochs = self.tr_reader.epochs self.model.train() return (bestmodel, bestval, beststep, steps)
def main(args): logdir = args.savedir + '/logs/' if not os.path.isdir(logdir): os.makedirs(logdir) my_logger = Logger(60066, logdir) if args.dataset == 'pascal': crop_size = (512, 512) args.scale = (0.5, 2.0) elif args.dataset == 'city': crop_size = (768, 768) args.scale = (0.5, 2.0) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[1], crop_size[0], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) if args.dataset == 'pascal': from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST train_dataset = VOCSegmentation(root=args.data_path, train=True, crop_size=crop_size, scale=args.scale, coco_root_dir=args.coco_path) val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=crop_size, scale=args.scale) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'city': from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST train_dataset = CityscapesSegmentation(root=args.data_path, train=True, size=crop_size, scale=args.scale, coarse=args.coarse) val_dataset = CityscapesSegmentation(root=args.data_path, train=False, size=crop_size, scale=args.scale, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 2.8149201869965 class_wts[1] = 6.9850029945374 class_wts[2] = 3.7890393733978 class_wts[3] = 9.9428062438965 class_wts[4] = 9.7702074050903 class_wts[5] = 9.5110931396484 class_wts[6] = 10.311357498169 class_wts[7] = 10.026463508606 class_wts[8] = 4.6323022842407 class_wts[9] = 9.5608062744141 class_wts[10] = 7.8698215484619 class_wts[11] = 9.5168733596802 class_wts[12] = 10.373730659485 class_wts[13] = 6.6616044044495 class_wts[14] = 10.260489463806 class_wts[15] = 10.287888526917 class_wts[16] = 10.289801597595 class_wts[17] = 10.405355453491 class_wts[18] = 10.138095855713 class_wts[19] = 0.0 else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) if args.model == 'espnetv2': from model.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espnet': from model.espnet import espnet_seg args.classes = seg_classes model = espnet_seg(args) elif args.model == 'mobilenetv2_1_0': from model.mobilenetv2 import get_mobilenet_v2_1_0_seg args.classes = seg_classes model = get_mobilenet_v2_1_0_seg(args) elif args.model == 'mobilenetv2_0_35': from model.mobilenetv2 import get_mobilenet_v2_0_35_seg args.classes = seg_classes model = get_mobilenet_v2_0_35_seg(args) elif args.model == 'mobilenetv2_0_5': from model.mobilenetv2 import get_mobilenet_v2_0_5_seg args.classes = seg_classes model = get_mobilenet_v2_0_5_seg(args) elif args.model == 'mobilenetv3_small': from model.mobilenetv3 import get_mobilenet_v3_small_seg args.classes = seg_classes model = get_mobilenet_v3_small_seg(args) elif args.model == 'mobilenetv3_large': from model.mobilenetv3 import get_mobilenet_v3_large_seg args.classes = seg_classes model = get_mobilenet_v3_large_seg(args) elif args.model == 'mobilenetv3_RE_small': from model.mobilenetv3 import get_mobilenet_v3_RE_small_seg args.classes = seg_classes model = get_mobilenet_v3_RE_small_seg(args) elif args.model == 'mobilenetv3_RE_large': from model.mobilenetv3 import get_mobilenet_v3_RE_large_seg args.classes = seg_classes model = get_mobilenet_v3_RE_large_seg(args) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' train_params = [] params_dict = dict(model.named_parameters()) others = args.weight_decay * 0.01 for key, value in params_dict.items(): if len(value.data.shape) == 4: if value.data.shape[1] == 1: train_params += [{ 'params': [value], 'lr': args.lr, 'weight_decay': 0.0 }] else: train_params += [{ 'params': [value], 'lr': args.lr, 'weight_decay': args.weight_decay }] else: train_params += [{ 'params': [value], 'lr': args.lr, 'weight_decay': others }] args.learning_rate = args.lr optimizer = get_optimizer(args.optimizer, train_params, args) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, crop_size[1], crop_size[0])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[1], crop_size[0], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) start_epoch = 0 epochs_len = args.epochs best_miou = 0.0 #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers, drop_last=True) if args.dataset == 'city': val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=args.workers, drop_last=True) else: val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers, drop_last=True) lr_scheduler = get_lr_scheduler(args) print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) if args.fp_epochs > 0: print_info_message("========== MODEL FP WARMUP ===========") for epoch in range(args.fp_epochs): lr = lr_scheduler.step(epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr print_info_message( 'Running epoch {} with learning rates: {:.6f}'.format( epoch, lr)) start_t = time.time() miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device) if args.optimizer.startswith('Q'): optimizer.is_warmup = False print('exp_sensitivity calibration fin.') if not args.fp_train: model.module.quantized.fuse_model() model.module.quantized.qconfig = torch.quantization.get_default_qat_qconfig( 'qnnpack') torch.quantization.prepare_qat(model.module.quantized, inplace=True) if args.resume: start_epoch = args.start_epoch if os.path.isfile(args.resume): print_info_message('Loading weights from {}'.format(args.resume)) weight_dict = torch.load(args.resume, device) model.module.load_state_dict(weight_dict) print_info_message('Done') else: print_warning_message('No file for resume. Please check.') for epoch in range(start_epoch, args.epochs): lr = lr_scheduler.step(epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr print_info_message( 'Running epoch {} with learning rates: {:.6f}'.format(epoch, lr)) miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) if is_best: model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth' torch.save(weights_dict, model_file_name) print('weights saved in {}'.format(model_file_name)) info = { 'Segmentation/LR': round(lr, 6), 'Segmentation/Loss/train': train_loss, 'Segmentation/Loss/val': val_loss, 'Segmentation/mIOU/train': miou_train, 'Segmentation/mIOU/val': miou_val, 'Segmentation/Complexity/Flops': best_miou, 'Segmentation/Complexity/Params': best_miou, } for tag, value in info.items(): if tag == 'Segmentation/Complexity/Flops': my_logger.scalar_summary(tag, value, math.ceil(flops)) elif tag == 'Segmentation/Complexity/Params': my_logger.scalar_summary(tag, value, math.ceil(num_params)) else: my_logger.scalar_summary(tag, value, epoch + 1) print_info_message("========== TRAINING FINISHED ===========")
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) if args.dataset == 'pascal': from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST train_dataset = VOCSegmentation(root=args.data_path, train=True, crop_size=crop_size, scale=args.scale, coco_root_dir=args.coco_path) val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=crop_size, scale=args.scale) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'city': from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST train_dataset = CityscapesSegmentation(root=args.data_path, train=True, size=crop_size, scale=args.scale, coarse=args.coarse) val_dataset = CityscapesSegmentation(root=args.data_path, train=False, size=crop_size, scale=args.scale, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 2.8149201869965 class_wts[1] = 6.9850029945374 class_wts[2] = 3.7890393733978 class_wts[3] = 9.9428062438965 class_wts[4] = 9.7702074050903 class_wts[5] = 9.5110931396484 class_wts[6] = 10.311357498169 class_wts[7] = 10.026463508606 class_wts[8] = 4.6323022842407 class_wts[9] = 9.5608062744141 class_wts[10] = 7.8698215484619 class_wts[11] = 9.5168733596802 class_wts[12] = 10.373730659485 class_wts[13] = 6.6616044044495 class_wts[14] = 10.260489463806 class_wts[15] = 10.287888526917 class_wts[16] = 10.289801597595 class_wts[17] = 10.405355453491 class_wts[18] = 10.138095855713 class_wts[19] = 0.0 else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'dicenet': from model.segmentation.dicenet import dicenet_seg model = dicenet_seg(args, classes=seg_classes) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if args.finetune: if os.path.isfile(args.finetune): print_info_message('Loading weights for finetuning from {}'.format( args.finetune)) weight_dict = torch.load(args.finetune, map_location=torch.device(device='cpu')) model.load_state_dict(weight_dict) print_info_message('Done') else: print_warning_message('No file for finetuning. Please check.') if args.freeze_bn: print_info_message('Freezing batch normalization layers') for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] optimizer = optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, crop_size[0], crop_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[0], crop_size[1], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.Tensor(1, 3, crop_size[0], crop_size[1])) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) start_epoch = 0 best_miou = 0.0 if args.resume: if os.path.isfile(args.resume): print_info_message("=> loading checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) start_epoch = checkpoint['epoch'] best_miou = checkpoint['best_miou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print_info_message("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print_warning_message("=> no checkpoint found at '{}'".format( args.resume)) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) if args.scheduler == 'fixed': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import FixedMultiStepLR lr_scheduler = FixedMultiStepLR(base_lr=args.lr, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'clr': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import CyclicLR lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=5, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs, power=args.power) elif args.scheduler == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max, cycle_len=args.cycle_len) elif args.scheduler == 'linear': from utilities.lr_scheduler import LinearLR lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs) else: print_error_message('{} scheduler Not supported'.format( args.scheduler)) exit() print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()
def run(self, *args, **kwargs): kwargs['need_attn'] = False if self.opts.warm_up: self.warm_up(args=args, kwargs=kwargs) if self.resume is not None: # find the LR value for epoch in range(self.start_epoch): self.lr_scheduler.step(epoch) eval_stats_dict = dict() for epoch in range(self.start_epoch, self.opts.epochs): epoch_lr = self.lr_scheduler.step(epoch) self.optimizer = update_optimizer(optimizer=self.optimizer, lr_value=epoch_lr) # Uncomment this line if you want to check the optimizer's LR is updated correctly # assert read_lr_from_optimzier(self.optimizer) == epoch_lr train_acc, train_loss = self.training(epoch=epoch, lr=epoch_lr, args=args, kwargs=kwargs) val_acc, val_loss = self.validation(epoch=epoch, lr=epoch_lr, args=args, kwargs=kwargs) eval_stats_dict[epoch] = val_acc gc.collect() # remember best accuracy and save checkpoint for best model is_best = val_acc >= self.best_acc self.best_acc = max(val_acc, self.best_acc) model_state = self.mi_model.module.state_dict() if isinstance(self.mi_model, torch.nn.DataParallel) \ else self.mi_model.state_dict() optimizer_state = self.optimizer.state_dict() save_checkpoint(epoch=epoch, model_state=model_state, optimizer_state=optimizer_state, best_perf=self.best_acc, save_dir=self.opts.savedir, is_best=is_best, keep_best_k_models=self.opts.keep_best_k_models ) self.logger.add_scalar('LR', round(epoch_lr, 6), epoch) self.logger.add_scalar('TrainingLoss', train_loss, epoch) self.logger.add_scalar('TrainingAcc', train_acc, epoch) self.logger.add_scalar('ValidationLoss', val_loss, epoch) self.logger.add_scalar('ValidationAcc', val_acc, epoch) # dump the validation epoch id and accuracy data, so that it could be used for filtering later on eval_stats_dict_sort = {k: v for k, v in sorted(eval_stats_dict.items(), key=lambda item: item[1], reverse=True )} eval_stats_fname = '{}/val_stats_bag_{}_word_{}_{}_{}'.format( self.opts.savedir, self.opts.bag_size, self.opts.word_size, self.opts.attn_fn, self.opts.attn_type, ) writer = DictWriter(file_name=eval_stats_fname, format='json') # if json file does not exist if not os.path.isfile(eval_stats_fname): writer.write(data_dict=eval_stats_dict_sort) else: with open(eval_stats_fname, 'r') as json_file: eval_stats_dict_old = json.load(json_file) eval_stats_dict_old.update(eval_stats_dict_sort) eval_stats_dict_updated = {k: v for k, v in sorted(eval_stats_dict_old.items(), key=lambda item: item[1], reverse=True )} writer.write(data_dict=eval_stats_dict_updated) self.logger.close()
def main(args): # ----------------------------------------------------------------------------- # Create model # ----------------------------------------------------------------------------- if args.model == 'dicenet': from model.classification import dicenet as net model = net.CNNModel(args) elif args.model == 'espnetv2': from model.classification import espnetv2 as net model = net.EESPNet(args) elif args.model == 'shufflenetv2': from model.classification import shufflenetv2 as net model = net.CNNModel(args) else: print_error_message('Model {} not yet implemented'.format(args.model)) exit() if args.finetune: # laod the weights for finetuning if os.path.isfile(args.weights_ft): pretrained_dict = torch.load(args.weights_ft, map_location=torch.device('cpu')) print_info_message('Loading pretrained basenet model weights') model_dict = model.state_dict() overlap_dict = { k: v for k, v in model_dict.items() if k in pretrained_dict } total_size_overlap = 0 for k, v in enumerate(overlap_dict): total_size_overlap += torch.numel(overlap_dict[v]) total_size_pretrain = 0 for k, v in enumerate(pretrained_dict): total_size_pretrain += torch.numel(pretrained_dict[v]) if len(overlap_dict) == 0: print_error_message( 'No overlaping weights between model file and pretrained weight file. Please check' ) print_info_message('Overlap ratio of weights: {:.2f} %'.format( (total_size_overlap * 100.0) / total_size_pretrain)) model_dict.update(overlap_dict) model.load_state_dict(model_dict, strict=False) print_info_message('Pretrained basenet model loaded!!') else: print_error_message('Unable to find the weights: {}'.format( args.weights_ft)) # ----------------------------------------------------------------------------- # Writer for logging # ----------------------------------------------------------------------------- if not os.path.isdir(args.savedir): os.makedirs(args.savedir) writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.randn(1, 70, args.inpSize, args.inpSize)) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) # network properties num_params = model_parameters(model) flops = compute_flops(model) print_info_message('FLOPs: {:.2f} million'.format(flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) # ----------------------------------------------------------------------------- # Optimizer # ----------------------------------------------------------------------------- optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint best_acc = 0.0 num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus >= 1 else 'cpu' if args.resume: if os.path.isfile(args.resume): print_info_message("=> loading checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict'], map_location=torch.device(device)) optimizer.load_state_dict(checkpoint['optimizer']) print_info_message("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print_warning_message("=> no checkpoint found at '{}'".format( args.resume)) # ----------------------------------------------------------------------------- # Loss Fn # ----------------------------------------------------------------------------- if args.dataset == 'imagenet': criterion = nn.CrossEntropyLoss() acc_metric = 'Top-1' elif args.dataset == 'coco': criterion = nn.BCEWithLogitsLoss() acc_metric = 'F1' elif args.dataset == 'Heart': criterion = nn.L1Loss() acc_metric = 'Test' else: print_error_message('{} dataset not yet supported'.format( args.dataset)) if num_gpus >= 1: model = torch.nn.DataParallel(model) model = model.cuda() criterion = criterion.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # ----------------------------------------------------------------------------- # Data Loaders # ----------------------------------------------------------------------------- # Data loading code if args.dataset == 'imagenet': train_loader, val_loader = img_loader.data_loaders(args) # import the loaders too from utilities.train_eval_classification import train, validate elif args.dataset == 'coco': from data_loader.classification.coco import COCOClassification train_dataset = COCOClassification(root=args.data, split='train', year='2017', inp_size=args.inpSize, scale=args.scale, is_training=True) val_dataset = COCOClassification(root=args.data, split='val', year='2017', inp_size=args.inpSize, is_training=False) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) # import the loaders too from utilities.train_eval_classification import train_multi as train from utilities.train_eval_classification import validate_multi as validate elif args.dataset == 'Heart': from utilities.train_eval_classification import train, validate def load_npy(npy_path): try: data = np.load(npy_path).item() except: data = np.load(npy_path) return data def loadData(data_path): npy_data = load_npy(data_path) signals = npy_data['signals'] gts = npy_data['gts'] return signals, gts ht_img_width, ht_img_height = args.inpSize, args.inpSize ht_batch_size = args.batch_size signal_length = args.channels signals_train, gts_train = loadData( '../DiCENeT/CardioNet/data_train/fps7_sample10_2D_train.npy') signals_val, gts_val = loadData( '../DiCENeT/CardioNet/data_train/fps7_sample10_2D_val.npy') from data_loader.classification.heart import HeartDataGenerator heart_train_data = HeartDataGenerator(signals_train, gts_train, ht_batch_size) # heart_train_data.squeeze heart_val_data = HeartDataGenerator(signals_val, gts_val, ht_batch_size) # heart_val_data.squeeze train_loader = torch.utils.data.DataLoader(heart_train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(heart_val_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) else: print_error_message('{} dataset not yet supported'.format( args.dataset)) # ----------------------------------------------------------------------------- # LR schedulers # ----------------------------------------------------------------------------- if args.scheduler == 'fixed': step_sizes = args.steps from utilities.lr_scheduler import FixedMultiStepLR lr_scheduler = FixedMultiStepLR(base_lr=args.lr, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'clr': from utilities.lr_scheduler import CyclicLR step_sizes = args.steps lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=5, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs) elif args.scheduler == 'linear': from utilities.lr_scheduler import LinearLR lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs) elif args.scheduler == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max) else: print_error_message('Scheduler ({}) not yet implemented'.format( args.scheduler)) exit() print_info_message(lr_scheduler) # set up the epoch variable in case resuming training if args.start_epoch != 0: for epoch in range(args.start_epoch): lr_scheduler.step(epoch) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) # ----------------------------------------------------------------------------- # Training and Val Loop # ----------------------------------------------------------------------------- extra_info_ckpt = args.model + '_' + str(args.s) for epoch in range(args.start_epoch, args.epochs): lr_log = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler for param_group in optimizer.param_groups: param_group['lr'] = lr_log print_info_message("LR for epoch {} = {:.5f}".format(epoch, lr_log)) train_acc, train_loss = train(data_loader=train_loader, model=model, criteria=criterion, optimizer=optimizer, epoch=epoch, device=device) # evaluate on validation set val_acc, val_loss = validate(data_loader=val_loader, model=model, criteria=criterion, device=device) # remember best prec@1 and save checkpoint is_best = val_acc > best_acc best_acc = max(val_acc, best_acc) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'state_dict': weights_dict, 'best_prec1': best_acc, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Classification/LR/learning_rate', lr_log, epoch) writer.add_scalar('Classification/Loss/Train', train_loss, epoch) writer.add_scalar('Classification/Loss/Val', val_loss, epoch) writer.add_scalar('Classification/{}/Train'.format(acc_metric), train_acc, epoch) writer.add_scalar('Classification/{}/Val'.format(acc_metric), val_acc, epoch) writer.add_scalar('Classification/Complexity/Top1_vs_flops', best_acc, round(flops, 2)) writer.add_scalar('Classification/Complexity/Top1_vs_params', best_acc, round(num_params, 2)) writer.close()
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' if args.dataset == 'greenhouse': print(args.use_depth) from data_loader.segmentation.greenhouse import GreenhouseRGBDSegCls, GREENHOUSE_CLASS_LIST train_dataset = GreenhouseRGBDSegCls( root=args.data_path, list_name='train_greenhouse_mult.txt', train=True, size=crop_size, scale=args.scale, use_depth=args.use_depth) val_dataset = GreenhouseRGBDSegCls(root=args.data_path, list_name='val_greenhouse_mult.txt', train=False, size=crop_size, scale=args.scale, use_depth=args.use_depth) class_weights = np.load('class_weights.npy')[:4] print(class_weights) class_wts = torch.from_numpy(class_weights).float().to(device) seg_classes = len(GREENHOUSE_CLASS_LIST) color_encoding = OrderedDict([('end_of_plant', (0, 255, 0)), ('other_part_of_plant', (0, 255, 255)), ('artificial_objects', (255, 0, 0)), ('ground', (255, 255, 0)), ('background', (0, 0, 0))]) else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) if args.model == 'espdnet': from model.segmentation.espdnet_mult import espdnet_mult args.classes = seg_classes args.cls_classes = 5 model = espdnet_mult(args) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if args.finetune: if os.path.isfile(args.finetune): print_info_message('Loading weights for finetuning from {}'.format( args.finetune)) weight_dict = torch.load(args.finetune, map_location=torch.device(device='cpu')) model.load_state_dict(weight_dict) print_info_message('Done') else: print_warning_message('No file for finetuning. Please check.') if args.freeze_bn: print_info_message('Freezing batch normalization layers') for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False if args.use_depth: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }, { 'params': model.get_depth_encoder_params(), 'lr': args.lr }] else: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, crop_size[0], crop_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[0], crop_size[1], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.Tensor(1, 3, crop_size[0], crop_size[1])) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) start_epoch = 0 best_miou = 0.0 if args.resume: if os.path.isfile(args.resume): print_info_message("=> loading checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) start_epoch = checkpoint['epoch'] best_miou = checkpoint['best_miou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print_info_message("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print_warning_message("=> no checkpoint found at '{}'".format( args.resume)) print('device : ' + device) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) cls_class_weight = calc_cls_class_weight(train_loader, 5) print(cls_class_weight) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion_seg = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) criterion_cls = nn.CrossEntropyLoss( weight=torch.from_numpy(cls_class_weight).float().to(device)) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion_seg = criterion_seg.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion_seg = DataParallelCriteria(criterion_seg) criterion_seg = criterion_seg.cuda() criterion_cls = criterion_cls.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True if args.scheduler == 'fixed': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import FixedMultiStepLR lr_scheduler = FixedMultiStepLR(base_lr=args.lr, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'clr': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import CyclicLR lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=5, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs, power=args.power) elif args.scheduler == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max, cycle_len=args.cycle_len) elif args.scheduler == 'linear': from utilities.lr_scheduler import LinearLR lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs) else: print_error_message('{} scheduler Not supported'.format( args.scheduler)) exit() print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg if args.use_depth: optimizer.param_groups[2]['lr'] = lr_base print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) miou_train, train_loss, train_seg_loss, train_cls_loss = train( model, train_loader, optimizer, criterion_seg, seg_classes, epoch, criterion_cls, device=device, use_depth=args.use_depth) miou_val, val_loss, val_seg_loss, val_cls_loss = val( model, val_loader, criterion_seg, criterion_cls, seg_classes, device=device, use_depth=args.use_depth) batch = iter(val_loader).next() if args.use_depth: in_training_visualization_2(model, images=batch[0].to(device=device), depths=batch[2].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation', device=device) else: in_training_visualization_2(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation', device=device) # image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy() # writer.add_image('Segmentation/results/val', image_grid, epoch) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/SegLoss/train', train_seg_loss, epoch) writer.add_scalar('Segmentation/ClsLoss/train', train_cls_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/SegLoss/val', val_seg_loss, epoch) writer.add_scalar('Segmentation/ClsLoss/val', val_cls_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' if args.dataset == 'pascal': from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST train_dataset = VOCSegmentation(root=args.data_path, train=True, crop_size=crop_size, scale=args.scale, coco_root_dir=args.coco_path) val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=crop_size, scale=args.scale) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'city': from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST, color_encoding train_dataset = CityscapesSegmentation(root=args.data_path, train=True, coarse=False) val_dataset = CityscapesSegmentation(root=args.data_path, train=False, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 10 / 2.8149201869965 class_wts[1] = 10 / 6.9850029945374 class_wts[2] = 10 / 3.7890393733978 class_wts[3] = 10 / 9.9428062438965 class_wts[4] = 10 / 9.7702074050903 class_wts[5] = 10 / 9.5110931396484 class_wts[6] = 10 / 10.311357498169 class_wts[7] = 10 / 10.026463508606 class_wts[8] = 10 / 4.6323022842407 class_wts[9] = 10 / 9.5608062744141 class_wts[10] = 10 / 7.8698215484619 class_wts[11] = 10 / 9.5168733596802 class_wts[12] = 10 / 10.373730659485 class_wts[13] = 10 / 6.6616044044495 class_wts[14] = 10 / 10.260489463806 class_wts[15] = 10 / 10.287888526917 class_wts[16] = 10 / 10.289801597595 class_wts[17] = 10 / 10.405355453491 class_wts[18] = 10 / 10.138095855713 class_wts[19] = 0.0 elif args.dataset == 'greenhouse': print(args.use_depth) from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation, GREENHOUSE_CLASS_LIST, color_encoding train_dataset = GreenhouseRGBDSegmentation( root=args.data_path, list_name=args.train_list, train=True, size=crop_size, scale=args.scale, use_depth=args.use_depth, use_traversable=args.greenhouse_use_trav) val_dataset = GreenhouseRGBDSegmentation( root=args.data_path, list_name=args.val_list, train=False, size=crop_size, scale=args.scale, use_depth=args.use_depth, use_traversable=args.greenhouse_use_trav) class_weights = np.load('class_weights.npy') # [:4] print(class_weights) class_wts = torch.from_numpy(class_weights).float().to(device) print(GREENHOUSE_CLASS_LIST) seg_classes = len(GREENHOUSE_CLASS_LIST) # color_encoding = OrderedDict([ # ('end_of_plant', (0, 255, 0)), # ('other_part_of_plant', (0, 255, 255)), # ('artificial_objects', (255, 0, 0)), # ('ground', (255, 255, 0)), # ('background', (0, 0, 0)) # ]) elif args.dataset == 'ishihara': print(args.use_depth) from data_loader.segmentation.ishihara_rgbd import IshiharaRGBDSegmentation, ISHIHARA_RGBD_CLASS_LIST train_dataset = IshiharaRGBDSegmentation( root=args.data_path, list_name='ishihara_rgbd_train.txt', train=True, size=crop_size, scale=args.scale, use_depth=args.use_depth) val_dataset = IshiharaRGBDSegmentation( root=args.data_path, list_name='ishihara_rgbd_val.txt', train=False, size=crop_size, scale=args.scale, use_depth=args.use_depth) seg_classes = len(ISHIHARA_RGBD_CLASS_LIST) class_wts = torch.ones(seg_classes) color_encoding = OrderedDict([('Unlabeled', (0, 0, 0)), ('Building', (70, 70, 70)), ('Fence', (190, 153, 153)), ('Others', (72, 0, 90)), ('Pedestrian', (220, 20, 60)), ('Pole', (153, 153, 153)), ('Road ', (157, 234, 50)), ('Road', (128, 64, 128)), ('Sidewalk', (244, 35, 232)), ('Vegetation', (107, 142, 35)), ('Car', (0, 0, 255)), ('Wall', (102, 102, 156)), ('Traffic ', (220, 220, 0))]) elif args.dataset == 'sun': print(args.use_depth) from data_loader.segmentation.sun_rgbd import SUNRGBDSegmentation, SUN_RGBD_CLASS_LIST train_dataset = SUNRGBDSegmentation(root=args.data_path, list_name='sun_rgbd_train.txt', train=True, size=crop_size, ignore_idx=args.ignore_idx, scale=args.scale, use_depth=args.use_depth) val_dataset = SUNRGBDSegmentation(root=args.data_path, list_name='sun_rgbd_val.txt', train=False, size=crop_size, ignore_idx=args.ignore_idx, scale=args.scale, use_depth=args.use_depth) seg_classes = len(SUN_RGBD_CLASS_LIST) class_wts = torch.ones(seg_classes) color_encoding = OrderedDict([('Background', (0, 0, 0)), ('Bed', (0, 255, 0)), ('Books', (70, 70, 70)), ('Ceiling', (190, 153, 153)), ('Chair', (72, 0, 90)), ('Floor', (220, 20, 60)), ('Furniture', (153, 153, 153)), ('Objects', (157, 234, 50)), ('Picture', (128, 64, 128)), ('Sofa', (244, 35, 232)), ('Table', (107, 142, 35)), ('TV', (0, 0, 255)), ('Wall', (102, 102, 156)), ('Window', (220, 220, 0))]) elif args.dataset == 'camvid': print(args.use_depth) from data_loader.segmentation.camvid import CamVidSegmentation, CAMVID_CLASS_LIST, color_encoding train_dataset = CamVidSegmentation( root=args.data_path, list_name='train_camvid.txt', train=True, size=crop_size, scale=args.scale, label_conversion=args.label_conversion, normalize=args.normalize) val_dataset = CamVidSegmentation( root=args.data_path, list_name='val_camvid.txt', train=False, size=crop_size, scale=args.scale, label_conversion=args.label_conversion, normalize=args.normalize) if args.label_conversion: from data_loader.segmentation.greenhouse import GREENHOUSE_CLASS_LIST, color_encoding seg_classes = len(GREENHOUSE_CLASS_LIST) class_wts = torch.ones(seg_classes) else: seg_classes = len(CAMVID_CLASS_LIST) tmp_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False) class_wts = calc_cls_class_weight(tmp_loader, seg_classes, inverted=True) class_wts = torch.from_numpy(class_wts).float().to(device) # class_wts = torch.ones(seg_classes) print("class weights : {}".format(class_wts)) args.use_depth = False elif args.dataset == 'forest': from data_loader.segmentation.freiburg_forest import FreiburgForestDataset, FOREST_CLASS_LIST, color_encoding train_dataset = FreiburgForestDataset(train=True, size=crop_size, scale=args.scale, normalize=args.normalize) val_dataset = FreiburgForestDataset(train=False, size=crop_size, scale=args.scale, normalize=args.normalize) seg_classes = len(FOREST_CLASS_LIST) tmp_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False) class_wts = calc_cls_class_weight(tmp_loader, seg_classes, inverted=True) class_wts = torch.from_numpy(class_wts).float().to(device) # class_wts = torch.ones(seg_classes) print("class weights : {}".format(class_wts)) args.use_depth = False else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espdnet': from model.segmentation.espdnet import espdnet_seg args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnet_seg(args) elif args.model == 'espdnetue': from model.segmentation.espdnet_ue import espdnetue_seg2 args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnetue_seg2(args, fix_pyr_plane_proj=True) elif args.model == 'deeplabv3': # from model.segmentation.deeplabv3 import DeepLabV3 from torchvision.models.segmentation.segmentation import deeplabv3_resnet101 args.classes = seg_classes # model = DeepLabV3(seg_classes) model = deeplabv3_resnet101(num_classes=seg_classes, aux_loss=True) torch.backends.cudnn.enabled = False elif args.model == 'unet': from model.segmentation.unet import UNet model = UNet(in_channels=3, out_channels=seg_classes) # model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', # in_channels=3, out_channels=seg_classes, init_features=32, pretrained=False) elif args.model == 'dicenet': from model.segmentation.dicenet import dicenet_seg model = dicenet_seg(args, classes=seg_classes) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if args.finetune: if os.path.isfile(args.finetune): print_info_message('Loading weights for finetuning from {}'.format( args.finetune)) weight_dict = torch.load(args.finetune, map_location=torch.device(device='cpu')) model.load_state_dict(weight_dict) print_info_message('Done') else: print_warning_message('No file for finetuning. Please check.') if args.freeze_bn: print_info_message('Freezing batch normalization layers') for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False if args.model == 'deeplabv3' or args.model == 'unet': train_params = [{'params': model.parameters(), 'lr': args.lr}] elif args.use_depth: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }, { 'params': model.get_depth_encoder_params(), 'lr': args.lr * args.lr_mult }] else: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 3, crop_size[0], crop_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[0], crop_size[1], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.Tensor(1, 3, 288, 480)) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) start_epoch = 0 best_miou = 0.0 if args.resume: if os.path.isfile(args.resume): print_info_message("=> loading checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) start_epoch = checkpoint['epoch'] best_miou = checkpoint['best_miou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print_info_message("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print_warning_message("=> no checkpoint found at '{}'".format( args.resume)) print('device : ' + device) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) nid_loss = NIDLoss(image_bin=32, label_bin=seg_classes) if args.use_nid else None if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() if args.use_nid: nid_loss.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if args.use_nid: nid_loss = DataParallelCriteria(nid_loss) nid_loss = nid_loss.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=20, shuffle=False, pin_memory=True, num_workers=args.workers) if args.scheduler == 'fixed': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import FixedMultiStepLR lr_scheduler = FixedMultiStepLR(base_lr=args.lr, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'clr': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import CyclicLR lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=5, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs, power=args.power) elif args.scheduler == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max, cycle_len=args.cycle_len) elif args.scheduler == 'linear': from utilities.lr_scheduler import LinearLR lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs) else: print_error_message('{} scheduler Not supported'.format( args.scheduler)) exit() print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base if len(optimizer.param_groups) > 1: optimizer.param_groups[1]['lr'] = lr_seg if args.use_depth: optimizer.param_groups[2]['lr'] = lr_base print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) if args.model == 'espdnetue' or ( (args.model == 'deeplabv3' or args.model == 'unet') and args.use_aux): from utilities.train_eval_seg import train_seg_ue as train from utilities.train_eval_seg import val_seg_ue as val else: from utilities.train_eval_seg import train_seg as train from utilities.train_eval_seg import val_seg as val iou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device, use_depth=args.use_depth, add_criterion=nid_loss) iou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device, use_depth=args.use_depth, add_criterion=nid_loss) batch_train = iter(train_loader).next() batch = iter(val_loader).next() if args.use_depth: in_training_visualization_img( model, images=batch_train[0].to(device=device), depths=batch_train[2].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), depths=batch[2].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/val', device=device) image_grid = torchvision.utils.make_grid( batch[2].to(device=device).data.cpu()).numpy() print(type(image_grid)) writer.add_image('Segmentation/depths', image_grid, epoch) else: in_training_visualization_img( model, images=batch_train[0].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/val', device=device) # image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy() # writer.add_image('Segmentation/results/val', image_grid, epoch) # remember best miou and save checkpoint miou_val = iou_val[[1, 2, 3]].mean() is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', iou_train[[1, 2, 3]].mean(), epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/plant_IOU/val', iou_val[1], epoch) writer.add_scalar('Segmentation/ao_IOU/val', iou_val[2], epoch) writer.add_scalar('Segmentation/ground_IOU/val', iou_val[3], epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()
def main(): device = 'cuda' now = datetime.datetime.now() now += datetime.timedelta(hours=9) timestr = now.strftime("%Y%m%d-%H%M%S") use_depth_str = "_rgbd" if args.use_depth else "_rgb" if args.use_depth: trainable_fusion_str = "_gated" if args.trainable_fusion else "_naive" else: trainable_fusion_str = "" save_path = '{}/model_{}_{}/{}'.format(args.save, args.model, args.dataset, timestr) print(save_path) if not os.path.isdir(save_path): os.makedirs(save_path) tgt_train_lst = osp.join(save_path, 'tgt_train.lst') save_pred_path = osp.join(save_path, 'pred') if not os.path.isdir(save_pred_path): os.makedirs(save_pred_path) writer = SummaryWriter(save_path) # Dataset from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentationTrav, GREENHOUSE_CLASS_LIST args.classes = len(GREENHOUSE_CLASS_LIST) travset = GreenhouseRGBDSegmentationTrav(list_name=args.data_trav_list, use_depth=args.use_depth) class_encoding = OrderedDict([('end_of_plant', (0, 255, 0)), ('other_part_of_plant', (0, 255, 255)), ('artificial_objects', (255, 0, 0)), ('ground', (255, 255, 0)), ('background', (0, 0, 0))]) # Dataloader for generating the pseudo-labels travloader = torch.utils.data.DataLoader(travset, batch_size=1, shuffle=False, num_workers=0, pin_memory=args.pin_memory) # Model from model.segmentation.espdnet_ue import espdnetue_seg2 args.weights = args.restore_from model = espdnetue_seg2(args, load_entire_weights=True, fix_pyr_plane_proj=True) model.to(device) generate_label(model, travloader, save_pred_path, tgt_train_lst) # Datset for training from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation trainset = GreenhouseRGBDSegmentation(list_name=tgt_train_lst, use_depth=args.use_depth, use_traversable=True) testset = GreenhouseRGBDSegmentation(list_name=args.data_test_list, use_depth=args.use_depth, use_traversable=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=args.pin_memory) testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=args.pin_memory) # Loss class_weights = torch.tensor([1.0, 0.2, 1.0, 1.0, 0.0]).to(device) if args.use_uncertainty: criterion = UncertaintyWeightedSegmentationLoss( args.classes, class_weights=class_weights) else: criterion = SegmentationLoss(n_classes=args.classes, device=device, class_weights=class_weights) criterion_test = SegmentationLoss(n_classes=args.classes, device=device, class_weights=class_weights) # Optimizer if args.use_depth: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.learning_rate * 0.1 }, { 'params': model.get_segment_params(), 'lr': args.learning_rate }, { 'params': model.get_depth_encoder_params(), 'lr': args.learning_rate }] else: train_params = [{ 'params': model.get_basenet_params(), 'lr': args.learning_rate * 0.1 }, { 'params': model.get_segment_params(), 'lr': args.learning_rate }] if args.optimizer == 'SGD': optimizer = optim.SGD(train_params, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) else: optimizer = optim.Adam(train_params, lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = optim.lr_scheduler.CyclicLR( optimizer, base_lr=args.learning_rate, max_lr=args.learning_rate * 10, step_size_up=10, step_size_down=20, cycle_momentum=True if args.optimizer == 'SGD' else False) # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100], gamma=0.5) best_miou = 0.0 for i in range(0, args.epoch): # Run a training epoch train(trainloader, model, criterion, device, optimizer, class_encoding, i, writer=writer) # Update the learning rate scheduler.step() # set the optimizer with the learning rate # This can be done inside the MyLRScheduler # optimizer.param_groups[0]['lr'] = lr_base # if len(optimizer.param_groups) > 1: # optimizer.param_groups[1]['lr'] = lr_seg # if args.use_depth: # optimizer.param_groups[2]['lr'] = lr_base * 10 new_miou = test(testloader, model, criterion_test, device, optimizer, class_encoding, i, writer=writer) # Save the weights if it produces the best IoU is_best = new_miou > best_miou best_miou = max(new_miou, best_miou) model.to(device) # weights_dict = model.module.state_dict() if device == 'cuda' else model.state_dict() weights_dict = model.state_dict() extra_info_ckpt = '{}'.format(args.model) if is_best: save_checkpoint( { 'epoch': i + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, save_path, extra_info_ckpt)
def main(): device = 'cuda' now = datetime.datetime.now() now += datetime.timedelta(hours=9) timestr = now.strftime("%Y%m%d-%H%M%S") save_path = '{}/model_{}_{}/{}'.format(args.save, args.model, args.dataset, timestr) print(save_path) if not os.path.isdir(save_path): os.makedirs(save_path) save_pred_path = osp.join(save_path, 'pred') if not os.path.isdir(save_pred_path): os.makedirs(save_pred_path) writer = SummaryWriter(save_path) # # Dataset # from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentationTrav, GREENHOUSE_CLASS_LIST args.classes = len(GREENHOUSE_CLASS_LIST) trav_train_set = GreenhouseRGBDSegmentationTrav( list_name=args.data_train_list, use_depth=args.use_depth) trav_test_set = GreenhouseRGBDSegmentationTrav( list_name=args.data_test_list, use_depth=args.use_depth) # # Dataloader for generating the pseudo-labels # trav_train_loader = torch.utils.data.DataLoader(trav_train_set, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=args.pin_memory) trav_test_loader = torch.utils.data.DataLoader( trav_test_set, batch_size=len(trav_test_set), shuffle=False, num_workers=0, pin_memory=args.pin_memory) # # Models # # Label Probability from model.classification.label_prob_estimator import LabelProbEstimator in_channels = 32 if args.feature_construction == 'concat' else 16 prob_model = LabelProbEstimator(in_channels=in_channels, spatial=args.spatial) prob_model.to(device) # Segmentation from model.segmentation.espdnet_ue import espdnetue_seg2 args.weights = args.restore_from seg_model = espdnetue_seg2(args, load_entire_weights=True, fix_pyr_plane_proj=True) seg_model.to(device) criterion = SelectiveBCE() # # Datset for training # from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation # trainset = GreenhouseRGBDSegmentation(list_name=tgt_train_lst, use_depth=args.use_depth, use_traversable=True) # testset = GreenhouseRGBDSegmentation(list_name=args.data_test_list, use_depth=args.use_depth, use_traversable=True) # # trainloader = torch.utils.data.DataLoader( # trainset, batch_size=args.batch_size, shuffle=True, # num_workers=0, pin_memory=args.pin_memory) # testloader = torch.utils.data.DataLoader( # testset, batch_size=args.batch_size, shuffle=True, # num_workers=0, pin_memory=args.pin_memory) # # # Loss # class_weights = torch.tensor([1.0, 0.2, 1.0, 1.0, 0.0]).to(device) # # criterion = nn.BCEWithLogitsLoss().to(device) # # # Optimizer # if args.use_depth: # train_params = [{'params': model.get_basenet_params(), 'lr': args.learning_rate * 0.1}, # {'params': model.get_segment_params(), 'lr': args.learning_rate}, # {'params': model.get_depth_encoder_params(), 'lr': args.learning_rate}] # else: # train_params = [{'params': model.get_basenet_params(), 'lr': args.learning_rate * 0.1}, # {'params': model.get_segment_params(), 'lr': args.learning_rate}] # if args.optimizer == 'SGD': optimizer = optim.SGD(prob_model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) else: optimizer = optim.Adam(prob_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) if args.lr_scheduling == "cyclic": scheduler = optim.lr_scheduler.CyclicLR( optimizer, base_lr=args.learning_rate, max_lr=args.learning_rate * 10, step_size_up=10, step_size_down=20, cycle_momentum=True if args.optimizer == 'SGD' else False) else: scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 150], gamma=0.5) # # best_miou = 0.0 c = 1.0 loss_old = 1000000 for epoch in range(0, args.epoch): # calculate_iou_with_different_threshold(trav_test_loader, seg_model, prob_model, c, writer, device=device, writer_idx=epoch, histogram=False) calculate_iou(trav_test_loader, seg_model, prob_model, c, writer, device=device, writer_idx=epoch) # Run a training epoch train(trav_train_loader, prob_model, seg_model, criterion, device, optimizer, epoch, writer) scheduler.step() ret_dict = test(trav_test_loader, prob_model, seg_model, criterion, device, epoch, writer) c = ret_dict["c"] loss = ret_dict["loss"] extra_info_ckpt = '{}_epoch_{}_c_{}'.format(args.model, epoch, c) weights_dict = prob_model.state_dict() if loss < loss_old: print("Save weights") save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': 0.0, 'optimizer': optimizer.state_dict(), }, loss < loss_old, save_path, extra_info_ckpt) loss_old = loss print("c = {}".format(c))
def train(params): # future work variable dataset = params["dataset"] input_height = params["input_height"] input_width = params["input_width"] data_path = params["data_path"] val_data_path = params["val_data_path"] val_datalist_path = params["val_datalist_path"] datalist_path = params["datalist_path"] class_path = params["class_path"] batch_size = params["batch_size"] num_epochs = params["num_epochs"] learning_rate = params["lr"] checkpoint_path = params["checkpoint_path"] USE_AUGMENTATION = params["use_augmentation"] USE_GTCHECKER = params["use_gtcheck"] USE_VISDOM = params["use_visdom"] USE_GITHASH = params["use_githash"] num_class = params["num_class"] num_gpus = [i for i in range(1)] with open(class_path) as f: class_list = f.read().splitlines() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') torch.manual_seed(42) torch.cuda.manual_seed_all(42) if (USE_GITHASH): repo = git.Repo(search_parent_directories=True) sha = repo.head.object.hexsha short_sha = repo.git.rev_parse(sha, short=7) if USE_VISDOM: viz = visdom.Visdom(use_incoming_socket=False) vis_title = 'YOLOv2' vis_legend_Train = ['Train Loss'] vis_legend_Val = ['Val Loss'] iter_plot = create_vis_plot(viz, 'Iteration', 'Total Loss', vis_title, vis_legend_Train) val_plot = create_vis_plot(viz, 'Iteration', 'Validation Loss', vis_title, vis_legend_Val) # 2. Data augmentation setting if (USE_AUGMENTATION): seq = iaa.SomeOf( 2, [ iaa.Multiply( (1.2, 1.5)), # change brightness, doesn't affect BBs iaa.Affine( translate_px={ "x": 3, "y": 10 }, scale=(0.9, 0.9) ), # translate by 40/60px on x/y axis, and scale to 50-70%, affects BBs iaa.AdditiveGaussianNoise(scale=0.1 * 255), iaa.CoarseDropout(0.02, size_percent=0.15, per_channel=0.5), iaa.Affine(rotate=45), iaa.Sharpen(alpha=0.5) ]) else: seq = iaa.Sequential([]) composed = transforms.Compose([Augmenter(seq)]) # 3. Load Dataset # composed # transforms.ToTensor #TODO : Datalist가 있을때 VOC parsing # import pdb;pdb.set_trace() train_dataset = VOC(root=data_path, transform=composed, class_path=class_path, datalist_path=datalist_path) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=detection_collate) val_dataset = VOC(root=val_data_path, transform=composed, class_path=class_path, datalist_path=val_datalist_path) val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, collate_fn=detection_collate) # 5. Load YOLOv2 net = yolov2.YOLOv2() model = torch.nn.DataParallel(net, device_ids=num_gpus).cuda() print("device : ", device) if device.type == 'cpu': model = torch.nn.DataParallel(net) else: model = torch.nn.DataParallel(net, device_ids=num_gpus).cuda() # 7.Train the model optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) # Train the model total_step = len(train_loader) total_train_step = num_epochs * total_step # for epoch in range(num_epochs): for epoch in range(1, num_epochs + 1): train_loss = 0 total_val_loss = 0 train_total_conf_loss = 0 train_total_xy_loss = 0 train_total_wh_loss = 0 train_total_c_loss = 0 val_total_conf_loss = 0 val_total_xy_loss = 0 val_total_wh_loss = 0 val_total_c_loss = 0 if (epoch % 500 == 0 and epoch < 1000): learning_rate /= 10 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) if (epoch == 200) or (epoch == 400) or (epoch == 600) or ( epoch == 20000) or (epoch == 30000): scheduler.step() model.train() for i, (images, labels, sizes) in enumerate(train_loader): current_train_step = (epoch) * total_step + (i + 1) if USE_GTCHECKER: visualize_GT(images, labels, class_list) images = images.to(device) labels = labels.to(device) dog = labels[0, 4, 7, :] human = labels[0, 6, 6, :] # Forward pass outputs = model(images) # Calc Loss one_loss, conf_loss, xy_loss, wh_loss, class_loss = detection_loss_4_yolo( outputs, labels, device.type) # objness1_loss = detection_loss_4_yolo(outputs, labels) # Backward and optimize optimizer.zero_grad() one_loss.backward() optimizer.step() train_loss += one_loss.item() train_total_conf_loss += conf_loss.item() train_total_xy_loss += xy_loss.item() train_total_wh_loss += wh_loss.item() train_total_c_loss += class_loss.item() train_total_conf_loss = train_total_conf_loss / len(train_loader) train_total_xy_loss = train_total_xy_loss / len(train_loader) train_total_wh_loss = train_total_wh_loss / len(train_loader) train_total_c_loss = train_total_c_loss / len(train_loader) train_epoch_loss = train_loss / len(train_loader) update_vis_plot(viz, epoch + 1, train_epoch_loss, iter_plot, None, 'append') model.eval() with torch.no_grad(): for j, (v_images, v_labels, v_sizes) in enumerate(val_loader): v_images = v_images.to(device) v_labels = v_labels.to(device) # Forward pass v_outputs = model(v_images) # Calc Loss val_loss, conf_loss, xy_loss, wh_loss, class_loss = detection_loss_4_yolo( v_outputs, v_labels, device.type) total_val_loss += val_loss.item() val_total_conf_loss += conf_loss.item() val_total_xy_loss += xy_loss.item() val_total_wh_loss += wh_loss.item() val_total_c_loss += class_loss.item() val_epoch_loss = total_val_loss / len(val_loader) val_total_conf_loss = val_total_conf_loss / len(val_loader) val_total_xy_loss = val_total_xy_loss / len(val_loader) val_total_wh_loss = val_total_wh_loss / len(val_loader) val_total_c_loss = val_total_c_loss / len(val_loader) update_vis_plot(viz, epoch + 1, val_epoch_loss, val_plot, None, 'append') if (((current_train_step) % 100) == 0) or (current_train_step % 1 == 0 and current_train_step < 300): print( 'epoch: [{}/{}], total step: [{}/{}], batch step [{}/{}], lr: {},one_loss: {:.4f},val_loss: {:.4f}' .format(epoch + 1, num_epochs, current_train_step, total_train_step, i + 1, total_step, ([ param_group['lr'] for param_group in optimizer.param_groups ])[0], one_loss, val_loss)) print('train loss', train_epoch_loss, 'val loss', val_epoch_loss) print('train conf loss', train_total_conf_loss, 'val conf loss', val_total_conf_loss) print('train xy loss', train_total_xy_loss, 'val xy loss', val_total_xy_loss) print('train wh loss', train_total_wh_loss, 'val wh loss', val_total_wh_loss) print('train class loss', train_total_c_loss, 'val class loss', val_total_c_loss) if not USE_GITHASH: short_sha = 'noHash' # if ((epoch % 1000) == 0) and (epoch != 0): # if ((epoch % 100) == 0) : if ((epoch % 10) == 0): #if (one_loss <= 1) : save_checkpoint( { 'epoch': epoch + 1, 'arch': "YOLOv2", 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, False, filename=os.path.join( checkpoint_path, 'ckpt_{}_ep{:05d}_loss{:.04f}_lr{}.pth.tar'.format( short_sha, epoch, one_loss.item(), ([ param_group['lr'] for param_group in optimizer.param_groups ])[0]))) # print(dir(model)) filename = os.path.join( checkpoint_path, 'ckpt_{}_ep{:05d}_loss{:.04f}_lr{}model.pth.tar'.format( short_sha, epoch, one_loss.item(), ([ param_group['lr'] for param_group in optimizer.param_groups ])[0])) torch.save(model.module.state_dict(), filename)
utils.set_seed(seed) ckptfilename = getCkptName(args.ckptname) ckptpath = os.path.join(ckptroot, modeltype, ckptfilename) print("CKPT PATH: {}".format(ckptpath)) # Initialized reader, model and optimizer trainer = Trainer() print("Done modelinit") print("Mode : {}".format(mode)) # print(trainer.tr_reader.ans2idx) if mode == 'train': (bestmodel, bestval, beststeps, steps) = trainer.train() print("Saving model: {}".format(ckptpath)) utils.save_checkpoint(m=bestmodel, o=trainer.optimizer, steps=steps, beststeps=beststeps, path=ckptpath) pp.pprint(args) elif mode == 'val': utils.load_checkpoint(ckptpath, trainer.model, trainer.optimizer) trainer.validation() # (vt, vc, va) = trainer.validation_performance() # print("Total: {}. Validation Acc: {}".format(vt, va)) sys.exit()
loss.backward() optimaizer.step() print( 'epoch: [{}/{}], total step:[{}/{}] , batchstep [{}/{}], lr: {},' 'total_loss: {:.4f}, objness1: {:.4f}, class_loss: {:.4f}'.format( epoch, num_epochs, current_train_step, total_train_step, i + 1, total_step, learning_rate, loss.item(), obj_coord1_loss, obj_size1_loss, obj_class_loss)) if (epoch % 2 == 0): ''' torch.save({'test': epoch}, 'cc.zip') print("Saved...") ''' save_checkpoint( { 'epoch': epoch + 1, 'arch': "YOLOv1", 'state_dict': model.state_dict(), }, False, filename=os.path.join( check_point_path, 'ep{:05d}_loss{:.04f}_lr{}.pth.tar'.format( epoch, loss.item(), learning_rate, ))) print("The check point is saved")
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' print('device : ' + device) # Get a summary writer for tensorboard writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') # # Training the model with 13 classes of CamVid dataset # TODO: This process should be done only if specified # if not args.finetune: train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset( label_conversion=False) # 13 classes args.use_depth = False # 'use_depth' is always false for camvid print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) # Import model if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espdnet': from model.segmentation.espdnet import espdnet_seg args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnet_seg(args) elif args.model == 'espdnetue': from model.segmentation.espdnet_ue import espdnetue_seg2 args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) ("Segmentation classes : {}".format(seg_classes)) print(args.weights) model = espdnetue_seg2(args, False, fix_pyr_plane_proj=True) else: print_error_message('Arch: {} not yet supported'.format( args.model)) exit(-1) # Freeze batch normalization layers? if args.freeze_bn: freeze_bn_layer(model) # Set learning rates train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] # Define an optimizer optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) # Compute the FLOPs and the number of parameters, and display it num_params, flops = show_network_stats(model, crop_size) try: writer.add_graph(model, input_to_model=torch.Tensor( 1, 3, crop_size[0], crop_size[1])) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) nid_loss = NIDLoss(image_bin=32, label_bin=seg_classes) if args.use_nid else None if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() if args.use_nid: nid_loss.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if args.use_nid: nid_loss = DataParallelCriteria(nid_loss) nid_loss = nid_loss.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # Get data loaders for training and validation data train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=20, shuffle=False, pin_memory=True, num_workers=args.workers) # Get a learning rate scheduler lr_scheduler = get_lr_scheduler(args.scheduler) write_stats_to_json(num_params, flops) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) # # Main training loop of 13 classes # start_epoch = 0 best_miou = 0.0 for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) # Use different training functions for espdnetue if args.model == 'espdnetue': from utilities.train_eval_seg import train_seg_ue as train from utilities.train_eval_seg import val_seg_ue as val else: from utilities.train_eval_seg import train_seg as train from utilities.train_eval_seg import val_seg as val miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device, use_depth=args.use_depth, add_criterion=nid_loss) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device, use_depth=args.use_depth, add_criterion=nid_loss) batch_train = iter(train_loader).next() batch = iter(val_loader).next() in_training_visualization_img( model, images=batch_train[0].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='Segmentation/val', device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch) writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('Segmentation/Loss/train', train_loss, epoch) writer.add_scalar('Segmentation/Loss/val', val_loss, epoch) writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch) writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch) writer.add_scalar('Segmentation/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('Segmentation/Complexity/Params', best_miou, math.ceil(num_params)) # Save the pretrained weights model_dict = copy.deepcopy(model.state_dict()) del model torch.cuda.empty_cache() # # Finetuning with 4 classes # args.ignore_idx = 4 train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset( label_conversion=True) # 5 classes print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) #set_parameters_for_finetuning() # Import model if args.model == 'espnetv2': from model.segmentation.espnetv2 import espnetv2_seg args.classes = seg_classes model = espnetv2_seg(args) elif args.model == 'espdnet': from model.segmentation.espdnet import espdnet_seg args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) model = espdnet_seg(args) elif args.model == 'espdnetue': from model.segmentation.espdnet_ue import espdnetue_seg2 args.classes = seg_classes print("Trainable fusion : {}".format(args.trainable_fusion)) print("Segmentation classes : {}".format(seg_classes)) print(args.weights) model = espdnetue_seg2(args, args.finetune, fix_pyr_plane_proj=True) else: print_error_message('Arch: {} not yet supported'.format(args.model)) exit(-1) if not args.finetune: new_model_dict = model.state_dict() # for k, v in model_dict.items(): # if k.lstrip('module.') in new_model_dict: # print('In:{}'.format(k.lstrip('module.'))) # else: # print('Not In:{}'.format(k.lstrip('module.'))) overlap_dict = { k.replace('module.', ''): v for k, v in model_dict.items() if k.replace('module.', '') in new_model_dict and new_model_dict[k.replace('module.', '')].size() == v.size() } no_overlap_dict = { k.replace('module.', ''): v for k, v in new_model_dict.items() if k.replace('module.', '') not in new_model_dict or new_model_dict[k.replace('module.', '')].size() != v.size() } print(no_overlap_dict.keys()) new_model_dict.update(overlap_dict) model.load_state_dict(new_model_dict) output = model(torch.ones(1, 3, 288, 480)) print(output[0].size()) print(seg_classes) print(class_wts.size()) #print(model_dict.keys()) #print(new_model_dict.keys()) criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, device=device, ignore_idx=args.ignore_idx, class_wts=class_wts.to(device)) nid_loss = NIDLoss(image_bin=32, label_bin=seg_classes) if args.use_nid else None # Set learning rates args.lr /= 100 train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr }, { 'params': model.get_segment_params(), 'lr': args.lr * args.lr_mult }] # Define an optimizer optimizer = optim.SGD(train_params, lr=args.lr * args.lr_mult, momentum=args.momentum, weight_decay=args.weight_decay) if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() if args.use_nid: nid_loss.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if args.use_nid: nid_loss = DataParallelCriteria(nid_loss) nid_loss = nid_loss.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True # Get data loaders for training and validation data train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=20, shuffle=False, pin_memory=True, num_workers=args.workers) # Get a learning rate scheduler args.epochs = 50 lr_scheduler = get_lr_scheduler(args.scheduler) # Compute the FLOPs and the number of parameters, and display it num_params, flops = show_network_stats(model, crop_size) write_stats_to_json(num_params, flops) extra_info_ckpt = '{}_{}_{}_{}'.format(args.model, seg_classes, args.s, crop_size[0]) # # Main training loop of 13 classes # start_epoch = 0 best_miou = 0.0 for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_base optimizer.param_groups[1]['lr'] = lr_seg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) # Use different training functions for espdnetue if args.model == 'espdnetue': from utilities.train_eval_seg import train_seg_ue as train from utilities.train_eval_seg import val_seg_ue as val else: from utilities.train_eval_seg import train_seg as train from utilities.train_eval_seg import val_seg as val miou_train, train_loss = train(model, train_loader, optimizer, criterion, seg_classes, epoch, device=device, use_depth=args.use_depth, add_criterion=nid_loss) miou_val, val_loss = val(model, val_loader, criterion, seg_classes, device=device, use_depth=args.use_depth, add_criterion=nid_loss) batch_train = iter(train_loader).next() batch = iter(val_loader).next() in_training_visualization_img(model, images=batch_train[0].to(device=device), labels=batch_train[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='SegmentationConv/train', device=device) in_training_visualization_img(model, images=batch[0].to(device=device), labels=batch[1].to(device=device), class_encoding=color_encoding, writer=writer, epoch=epoch, data='SegmentationConv/val', device=device) # remember best miou and save checkpoint is_best = miou_val > best_miou best_miou = max(miou_val, best_miou) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_miou': best_miou, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('SegmentationConv/LR/base', round(lr_base, 6), epoch) writer.add_scalar('SegmentationConv/LR/seg', round(lr_seg, 6), epoch) writer.add_scalar('SegmentationConv/Loss/train', train_loss, epoch) writer.add_scalar('SegmentationConv/Loss/val', val_loss, epoch) writer.add_scalar('SegmentationConv/mIOU/train', miou_train, epoch) writer.add_scalar('SegmentationConv/mIOU/val', miou_val, epoch) writer.add_scalar('SegmentationConv/Complexity/Flops', best_miou, math.ceil(flops)) writer.add_scalar('SegmentationConv/Complexity/Params', best_miou, math.ceil(num_params)) writer.close()