def get_data_loaders(data_dir, batch_size, val_batch_size, num_workers, include_coarse): transform = Compose([ RandomHorizontalFlip(), RandomAffine(translate=(0.1, 0.1), scale=(0.7, 2.0), shear=(-10, 10)), RandomGaussionBlur(radius=2.0), ColorJitter(0.1, 0.1, 0.1, 0.1), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), RandomGaussionNoise(), ConvertIdToTrainId() ]) val_transform = Compose([ ToTensor(), ConvertIdToTrainId(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) fine = CityscapesDataset(root=data_dir, split='train', mode='fine', transforms=transform) if include_coarse: coarse = CityscapesDataset(root=data_dir, split='train_extra', mode='coarse', transforms=transform) train_loader = data.DataLoader(FineCoarseDataset(fine, coarse), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) else: train_loader = data.DataLoader(fine, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = data.DataLoader(CityscapesDataset(root=data_dir, split='val', transforms=val_transform), batch_size=val_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) return train_loader, val_loader
def run(args): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') num_classes = CityscapesDataset.num_classes() if args.checkpoint: model = GoogLeNetFCN(num_classes) model.load_state_dict(torch.load(args.checkpoint)) else: model = googlenet_fcn(pretrained=True, num_classes=num_classes) device_count = torch.cuda.device_count() if device_count > 1: print("Using %d GPU(s)" % device_count) model = nn.DataParallel(model) args.batch_size = device_count * args.batch_size val_loader = get_data_loaders(args.dataset_dir, args.batch_size, args.num_workers) model = model.to(device) criterion = nn.CrossEntropyLoss(ignore_index=255) def _prepare_batch(batch, non_blocking=True): image, target = batch return (convert_tensor(image, device=device, non_blocking=non_blocking), convert_tensor(target, device=device, non_blocking=non_blocking)) def _inference(engine, batch): model.eval() with torch.no_grad(): image, target = _prepare_batch(batch) pred = model(image) return pred, target evaluator = Engine(_inference) cm = ConfusionMatrix(num_classes) IoU(cm).attach(evaluator, 'IoU') Loss(criterion).attach(evaluator, 'loss') pbar = ProgressBar(persist=True, desc='Eval') pbar.attach(evaluator) @evaluator.on(Events.EPOCH_COMPLETED) def run_validation(engine): metrics = engine.state.metrics loss = metrics['loss'] iou = metrics['IoU'] * 100.0 mean_iou = iou.mean() pbar.log_message("Validation results:\nLoss: {:.2e}\nmIoU: {:.1f}" .format(loss, mean_iou)) print("Start validation") evaluator.run(val_loader, max_epochs=1)
def get_data_loaders(data_dir, batch_size, val_batch_size, num_workers): joint_transforms = Compose([ RandomHorizontalFlip(), RandomAffine(scale=(0.9, 1.6), shear=(-15, 15), fillcolor=255), ColorJitter(0.3, 0.3, 0.3), # RandomGaussionBlur(sigma=(0, 1.2)), ToTensor() ]) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_loader = DataLoader(CityscapesDataset(root=data_dir, split='train', joint_transform=joint_transforms, img_transform=normalize), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = DataLoader(CityscapesDataset(root=data_dir, split='val', joint_transform=ToTensor(), img_transform=normalize), batch_size=val_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) return train_loader, val_loader
def get_data_loaders(data_dir, batch_size, num_workers): val_transforms = Compose([ ToTensor(), ConvertIdToTrainId(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_loader = DataLoader(CityscapesDataset(root=data_dir, split='val', transforms=val_transforms), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) return val_loader
def run(args): if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) num_classes = CityscapesDataset.num_classes() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = GoogLeNetFCN(num_classes) model.init_from_googlenet() device_count = torch.cuda.device_count() if device_count > 1: print("Using %d GPU(s)" % device_count) model = nn.DataParallel(model) args.batch_size = device_count * args.batch_size args.val_batch_size = device_count * args.val_batch_size model = model.to(device) train_loader, val_loader = get_data_loaders(args.dataset_dir, args.batch_size, args.val_batch_size, args.num_workers, args.include_coarse) criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='sum') optimizer = optim.SGD([{ 'params': [ param for name, param in model.named_parameters() if name.endswith('weight') ] }, { 'params': [ param for name, param in model.named_parameters() if name.endswith('bias') ], 'lr': args.lr * 2, 'weight_decay': 0 }], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if os.path.isfile(args.resume): print("Loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_iou = checkpoint.get('bestIoU', 0.0) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("Loaded checkpoint '{}' (Epoch {})".format( args.resume, checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) sys.exit() if args.freeze_bn: print("Freezing batch norm") model = freeze_batchnorm(model) trainer = create_supervised_trainer(model, optimizer, criterion, device, non_blocking=True) RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=['loss']) cm = ConfusionMatrix(num_classes) evaluator = create_supervised_evaluator(model, metrics={ 'loss': Loss(criterion), 'IoU': IoU(cm), 'accuracy': cmAccuracy(cm) }, device=device, non_blocking=True) pbar2 = ProgressBar(persist=True, desc='Eval Epoch') pbar2.attach(evaluator) def _global_step_transform(engine, event_name): return trainer.state.iteration tb_logger = TensorboardLogger(args.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag='training', metric_names=['loss']), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator, log_handler=OutputHandler( tag='validation', metric_names=['loss', 'IoU', 'accuracy'], global_step_transform=_global_step_transform), event_name=Events.EPOCH_COMPLETED) @evaluator.on(Events.EPOCH_COMPLETED) def save_checkpoint(engine): iou = engine.state.metrics['IoU'] * 100.0 mean_iou = iou.mean() is_best = mean_iou.item() > trainer.state.best_iou trainer.state.best_iou = max(mean_iou.item(), trainer.state.best_iou) name = 'epoch{}_mIoU={:.1f}.pth'.format(trainer.state.epoch, mean_iou) file = { 'model': model.state_dict(), 'epoch': trainer.state.epoch, 'iteration': engine.state.iteration, 'optimizer': optimizer.state_dict(), 'args': args, 'bestIoU': trainer.state.best_iou } save(file, args.output_dir, 'checkpoint_{}'.format(name)) if is_best: save(model.state_dict(), args.output_dir, 'model_{}'.format(name)) @trainer.on(Events.STARTED) def initialize(engine): if args.resume: engine.state.epoch = args.start_epoch engine.state.iteration = args.start_epoch * len( engine.state.dataloader) engine.state.best_iou = best_iou else: engine.state.best_iou = 0.0 @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): pbar.log_message("Start Validation - Epoch: [{}/{}]".format( engine.state.epoch, engine.state.max_epochs)) evaluator.run(val_loader) metrics = evaluator.state.metrics loss = metrics['loss'] iou = metrics['IoU'] acc = metrics['accuracy'] mean_iou = iou.mean() pbar.log_message( "Validation results - Epoch: [{}/{}]: Loss: {:.2e}, Accuracy: {:.1f}, mIoU: {:.1f}" .format(engine.state.epoch, engine.state.max_epochs, loss, acc * 100.0, mean_iou * 100.0)) print("Start training") trainer.run(train_loader, max_epochs=args.epochs) tb_logger.close()
def run(args): train_loader, val_loader = get_data_loaders(args.dataset_dir, args.batch_size, args.val_batch_size, args.num_workers) if args.seed is not None: torch.manual_seed(args.seed) num_classes = CityscapesDataset.num_classes() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = GoogLeNetFCN(num_classes) model.init_from_googlenet() device_count = torch.cuda.device_count() if device_count > 1: print("Using %d GPU(s)" % device_count) model = nn.DataParallel(model) args.batch_size = device_count * args.batch_size args.val_batch_size = device_count * args.val_batch_size model = model.to(device) criterion = nn.CrossEntropyLoss(ignore_index=255) optimizer = optim.SGD([{'params': [p for p, name in model.named_parameters() if name[-4:] != 'bias'], 'lr': args.lr, 'weight_decay': 5e-4}, {'params': [p for p, name in model.named_parameters() if name[-4:] == 'bias'], 'lr': args.lr * 2}], momentum=args.momentum, lr=args.lr) if args.resume: if os.path.isfile(args.resume): print("Loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("Loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) trainer = create_supervised_trainer(model, optimizer, criterion, device, non_blocking=True) RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=['loss']) cm = ConfusionMatrix(num_classes) evaluator = create_supervised_evaluator(model, metrics={'loss': Loss(criterion), 'IoU': IoU(cm, ignore_index=0)}, device=device, non_blocking=True) pbar2 = ProgressBar(persist=True, desc='Eval Epoch') pbar2.attach(evaluator) def _global_step_transform(engine, event_name): return trainer.state.iteration tb_logger = TensorboardLogger(args.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag='training', metric_names=['loss']), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag='validation', metric_names=['loss', 'IoU'], global_step_transform=_global_step_transform), event_name=Events.EPOCH_COMPLETED) @evaluator.on(Events.EPOCH_COMPLETED) def save_checkpoint(engine): iou = engine.state.metrics['IoU'] * 100.0 mean_iou = iou.mean() name = 'epoch{}_mIoU={:.1f}.pth'.format(trainer.state.epoch, mean_iou) file = {'model': model.state_dict(), 'epoch': trainer.state.epoch, 'optimizer': optimizer.state_dict(), 'args': args} torch.save(file, os.path.join(args.output_dir, 'checkpoint_{}'.format(name))) torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_{}'.format(name))) @trainer.on(Events.STARTED) def initialize(engine): if args.resume: engine.state.epoch = args.start_epoch @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): pbar.log_message('Start Validation - Epoch: [{}/{}]'.format(engine.state.epoch, engine.state.max_epochs)) evaluator.run(val_loader) metrics = evaluator.state.metrics loss = metrics['loss'] iou = metrics['IoU'] mean_iou = iou.mean() pbar.log_message('Validation results - Epoch: [{}/{}]: Loss: {:.2e}, mIoU: {:.1f}' .format(engine.state.epoch, engine.state.max_epochs, loss, mean_iou * 100.0)) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): engine.state.exception_raised = True if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn("KeyboardInterrupt caught. Exiting gracefully.") name = 'epoch{}_exception.pth'.format(trainer.state.epoch) file = {'model': model.state_dict(), 'epoch': trainer.state.epoch, 'optimizer': optimizer.state_dict()} torch.save(file, os.path.join(args.output_dir, 'checkpoint_{}'.format(name))) torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_{}'.format(name))) else: raise e print("Start training") trainer.run(train_loader, max_epochs=args.epochs) tb_logger.close()
def __call__(self, img, target): target = CityscapesDataset.convert_id_to_train_id(target) return img, target