Example #1
0
def get_data_loaders(data_dir, batch_size, val_batch_size, num_workers,
                     include_coarse):
    transform = Compose([
        RandomHorizontalFlip(),
        RandomAffine(translate=(0.1, 0.1), scale=(0.7, 2.0), shear=(-10, 10)),
        RandomGaussionBlur(radius=2.0),
        ColorJitter(0.1, 0.1, 0.1, 0.1),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        RandomGaussionNoise(),
        ConvertIdToTrainId()
    ])

    val_transform = Compose([
        ToTensor(),
        ConvertIdToTrainId(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    fine = CityscapesDataset(root=data_dir,
                             split='train',
                             mode='fine',
                             transforms=transform)

    if include_coarse:
        coarse = CityscapesDataset(root=data_dir,
                                   split='train_extra',
                                   mode='coarse',
                                   transforms=transform)
        train_loader = data.DataLoader(FineCoarseDataset(fine, coarse),
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers,
                                       pin_memory=True)
    else:
        train_loader = data.DataLoader(fine,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers,
                                       pin_memory=True)

    val_loader = data.DataLoader(CityscapesDataset(root=data_dir,
                                                   split='val',
                                                   transforms=val_transform),
                                 batch_size=val_batch_size,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 pin_memory=True)

    return train_loader, val_loader
Example #2
0
def run(args):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    num_classes = CityscapesDataset.num_classes()
    if args.checkpoint:
        model = GoogLeNetFCN(num_classes)
        model.load_state_dict(torch.load(args.checkpoint))
    else:
        model = googlenet_fcn(pretrained=True, num_classes=num_classes)

    device_count = torch.cuda.device_count()
    if device_count > 1:
        print("Using %d GPU(s)" % device_count)
        model = nn.DataParallel(model)
        args.batch_size = device_count * args.batch_size

    val_loader = get_data_loaders(args.dataset_dir, args.batch_size, args.num_workers)

    model = model.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=255)

    def _prepare_batch(batch, non_blocking=True):
        image, target = batch

        return (convert_tensor(image, device=device, non_blocking=non_blocking),
                convert_tensor(target, device=device, non_blocking=non_blocking))

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            image, target = _prepare_batch(batch)
            pred = model(image)

            return pred, target

    evaluator = Engine(_inference)
    cm = ConfusionMatrix(num_classes)
    IoU(cm).attach(evaluator, 'IoU')
    Loss(criterion).attach(evaluator, 'loss')

    pbar = ProgressBar(persist=True, desc='Eval')
    pbar.attach(evaluator)

    @evaluator.on(Events.EPOCH_COMPLETED)
    def run_validation(engine):
        metrics = engine.state.metrics
        loss = metrics['loss']
        iou = metrics['IoU'] * 100.0
        mean_iou = iou.mean()

        pbar.log_message("Validation results:\nLoss: {:.2e}\nmIoU: {:.1f}"
                         .format(loss, mean_iou))

    print("Start validation")
    evaluator.run(val_loader, max_epochs=1)
Example #3
0
def get_data_loaders(data_dir, batch_size, val_batch_size, num_workers):
    joint_transforms = Compose([
        RandomHorizontalFlip(),
        RandomAffine(scale=(0.9, 1.6), shear=(-15, 15), fillcolor=255),
        ColorJitter(0.3, 0.3, 0.3),
        # RandomGaussionBlur(sigma=(0, 1.2)),
        ToTensor()
    ])

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_loader = DataLoader(CityscapesDataset(root=data_dir, split='train', joint_transform=joint_transforms,
                                                img_transform=normalize),
                              batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

    val_loader = DataLoader(CityscapesDataset(root=data_dir, split='val', joint_transform=ToTensor(),
                                              img_transform=normalize),
                            batch_size=val_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader
Example #4
0
def get_data_loaders(data_dir, batch_size, num_workers):
    val_transforms = Compose([
        ToTensor(),
        ConvertIdToTrainId(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_loader = DataLoader(CityscapesDataset(root=data_dir, split='val', transforms=val_transforms),
                            batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return val_loader
Example #5
0
def run(args):
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)

    num_classes = CityscapesDataset.num_classes()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = GoogLeNetFCN(num_classes)
    model.init_from_googlenet()

    device_count = torch.cuda.device_count()
    if device_count > 1:
        print("Using %d GPU(s)" % device_count)
        model = nn.DataParallel(model)
        args.batch_size = device_count * args.batch_size
        args.val_batch_size = device_count * args.val_batch_size

    model = model.to(device)

    train_loader, val_loader = get_data_loaders(args.dataset_dir,
                                                args.batch_size,
                                                args.val_batch_size,
                                                args.num_workers,
                                                args.include_coarse)

    criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='sum')

    optimizer = optim.SGD([{
        'params': [
            param for name, param in model.named_parameters()
            if name.endswith('weight')
        ]
    }, {
        'params': [
            param for name, param in model.named_parameters()
            if name.endswith('bias')
        ],
        'lr':
        args.lr * 2,
        'weight_decay':
        0
    }],
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if args.resume:
        if os.path.isfile(args.resume):
            print("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint.get('bestIoU', 0.0)
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("Loaded checkpoint '{}' (Epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            sys.exit()

    if args.freeze_bn:
        print("Freezing batch norm")
        model = freeze_batchnorm(model)

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device,
                                        non_blocking=True)

    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    # attach progress bar
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=['loss'])

    cm = ConfusionMatrix(num_classes)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'loss': Loss(criterion),
                                                'IoU': IoU(cm),
                                                'accuracy': cmAccuracy(cm)
                                            },
                                            device=device,
                                            non_blocking=True)

    pbar2 = ProgressBar(persist=True, desc='Eval Epoch')
    pbar2.attach(evaluator)

    def _global_step_transform(engine, event_name):
        return trainer.state.iteration

    tb_logger = TensorboardLogger(args.log_dir)
    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag='training',
                                               metric_names=['loss']),
                     event_name=Events.ITERATION_COMPLETED)

    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.ITERATION_STARTED)

    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED)

    tb_logger.attach(evaluator,
                     log_handler=OutputHandler(
                         tag='validation',
                         metric_names=['loss', 'IoU', 'accuracy'],
                         global_step_transform=_global_step_transform),
                     event_name=Events.EPOCH_COMPLETED)

    @evaluator.on(Events.EPOCH_COMPLETED)
    def save_checkpoint(engine):
        iou = engine.state.metrics['IoU'] * 100.0
        mean_iou = iou.mean()

        is_best = mean_iou.item() > trainer.state.best_iou
        trainer.state.best_iou = max(mean_iou.item(), trainer.state.best_iou)

        name = 'epoch{}_mIoU={:.1f}.pth'.format(trainer.state.epoch, mean_iou)
        file = {
            'model': model.state_dict(),
            'epoch': trainer.state.epoch,
            'iteration': engine.state.iteration,
            'optimizer': optimizer.state_dict(),
            'args': args,
            'bestIoU': trainer.state.best_iou
        }

        save(file, args.output_dir, 'checkpoint_{}'.format(name))
        if is_best:
            save(model.state_dict(), args.output_dir, 'model_{}'.format(name))

    @trainer.on(Events.STARTED)
    def initialize(engine):
        if args.resume:
            engine.state.epoch = args.start_epoch
            engine.state.iteration = args.start_epoch * len(
                engine.state.dataloader)
            engine.state.best_iou = best_iou
        else:
            engine.state.best_iou = 0.0

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        pbar.log_message("Start Validation - Epoch: [{}/{}]".format(
            engine.state.epoch, engine.state.max_epochs))
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        loss = metrics['loss']
        iou = metrics['IoU']
        acc = metrics['accuracy']
        mean_iou = iou.mean()

        pbar.log_message(
            "Validation results - Epoch: [{}/{}]: Loss: {:.2e}, Accuracy: {:.1f}, mIoU: {:.1f}"
            .format(engine.state.epoch, engine.state.max_epochs, loss,
                    acc * 100.0, mean_iou * 100.0))

    print("Start training")
    trainer.run(train_loader, max_epochs=args.epochs)
    tb_logger.close()
Example #6
0
def run(args):
    train_loader, val_loader = get_data_loaders(args.dataset_dir, args.batch_size, args.val_batch_size,
                                                args.num_workers)

    if args.seed is not None:
        torch.manual_seed(args.seed)

    num_classes = CityscapesDataset.num_classes()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = GoogLeNetFCN(num_classes)
    model.init_from_googlenet()

    device_count = torch.cuda.device_count()
    if device_count > 1:
        print("Using %d GPU(s)" % device_count)
        model = nn.DataParallel(model)
        args.batch_size = device_count * args.batch_size
        args.val_batch_size = device_count * args.val_batch_size

    model = model.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=255)

    optimizer = optim.SGD([{'params': [p for p, name in model.named_parameters() if name[-4:] != 'bias'],
                            'lr': args.lr, 'weight_decay': 5e-4},
                           {'params': [p for p, name in model.named_parameters() if name[-4:] == 'bias'],
                            'lr': args.lr * 2}], momentum=args.momentum, lr=args.lr)

    if args.resume:
        if os.path.isfile(args.resume):
            print("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("Loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    trainer = create_supervised_trainer(model, optimizer, criterion, device, non_blocking=True)
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    # attach progress bar
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=['loss'])

    cm = ConfusionMatrix(num_classes)
    evaluator = create_supervised_evaluator(model, metrics={'loss': Loss(criterion),
                                                            'IoU': IoU(cm, ignore_index=0)},
                                            device=device, non_blocking=True)

    pbar2 = ProgressBar(persist=True, desc='Eval Epoch')
    pbar2.attach(evaluator)

    def _global_step_transform(engine, event_name):
        return trainer.state.iteration

    tb_logger = TensorboardLogger(args.log_dir)
    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag='training',
                                               metric_names=['loss']),
                     event_name=Events.ITERATION_COMPLETED)

    tb_logger.attach(evaluator,
                     log_handler=OutputHandler(tag='validation',
                                               metric_names=['loss', 'IoU'],
                                               global_step_transform=_global_step_transform),
                     event_name=Events.EPOCH_COMPLETED)

    @evaluator.on(Events.EPOCH_COMPLETED)
    def save_checkpoint(engine):
        iou = engine.state.metrics['IoU'] * 100.0
        mean_iou = iou.mean()

        name = 'epoch{}_mIoU={:.1f}.pth'.format(trainer.state.epoch, mean_iou)
        file = {'model': model.state_dict(), 'epoch': trainer.state.epoch,
                'optimizer': optimizer.state_dict(), 'args': args}

        torch.save(file, os.path.join(args.output_dir, 'checkpoint_{}'.format(name)))
        torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_{}'.format(name)))

    @trainer.on(Events.STARTED)
    def initialize(engine):
        if args.resume:
            engine.state.epoch = args.start_epoch

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        pbar.log_message('Start Validation - Epoch: [{}/{}]'.format(engine.state.epoch, engine.state.max_epochs))
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        loss = metrics['loss']
        iou = metrics['IoU']
        mean_iou = iou.mean()

        pbar.log_message('Validation results - Epoch: [{}/{}]: Loss: {:.2e}, mIoU: {:.1f}'
                         .format(engine.state.epoch, engine.state.max_epochs, loss, mean_iou * 100.0))

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        engine.state.exception_raised = True
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            name = 'epoch{}_exception.pth'.format(trainer.state.epoch)
            file = {'model': model.state_dict(), 'epoch': trainer.state.epoch,
                    'optimizer': optimizer.state_dict()}

            torch.save(file, os.path.join(args.output_dir, 'checkpoint_{}'.format(name)))
            torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_{}'.format(name)))
        else:
            raise e

    print("Start training")
    trainer.run(train_loader, max_epochs=args.epochs)
    tb_logger.close()
Example #7
0
    def __call__(self, img, target):
        target = CityscapesDataset.convert_id_to_train_id(target)

        return img, target