Ejemplo n.º 1
0
def _create_dummy_weight(category_name, model_name, workdir):
    # Create dummy weight
    category_type = surface_types.from_string(category_name)
    net = models.load_model(model_name, category_type)
    weight_path = workdir / f'{model_name}_{category_name}_dummy.pth'
    torch.save(net.state_dict(), weight_path)
    return weight_path
 def _create_dummy_weight(self, category_name):
     # Create dummy weight
     category_type = surface_types.from_string(category_name)
     net = models.load_model(self.model_name, category_type)
     weight_path = self.workdir / 'dummy.pth'
     torch.save(net.state_dict(), weight_path)
     return weight_path
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--weight-path', required=True)
    parser.add_argument('--image-dirs', required=True, type=str, nargs='+')
    parser.add_argument('--mask-dirs', required=True, type=str, nargs='+')
    parser.add_argument('--model-name', type=str, default='unet11')
    parser.add_argument('--dataset-type', type=str, default='base')
    parser.add_argument('--save-path', default='forward')
    parser.add_argument('--category-type',
                        default='binary',
                        choices=['binary', 'simple'])
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--device-id', type=int, default=0)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--input-size', type=int, nargs=2, default=(640, 640))
    parser.add_argument('--jaccard-weight', type=float, default=0.3)

    args = parser.parse_args()
    print(args)

    image_dirs = [Path(p) for p in args.image_dirs]
    mask_dirs = [Path(p) for p in args.mask_dirs]
    for data_dir in (image_dirs + mask_dirs):
        assert data_dir.exists(), f'{str(data_dir)} does not exist.'

    device = torch_tools.get_device(args.cpu, args.device_id)
    torch_tools.set_seeds(args.seed, device)

    weight_path = Path(args.weight_path)

    category_type = datasets.surface_types.from_string(args.category_type)

    save_path = Path(args.save_path)
    if not save_path.exists():
        save_path.mkdir(parents=True)

    net = models.load_model(args.model_name, category_type).to(device)
    state_dict = torch.load(weight_path, map_location=device)
    net.load_state_dict(state_dict=state_dict)

    input_size = args.input_size

    transform = Compose([
        CenterCrop(*input_size),
    ])

    dataset = datasets.create_dataset(
        args.dataset_type,
        image_dirs,
        mask_dirs,
        category_type,
        transform,
    )

    loader = DataLoader(dataset, batch_size=1, shuffle=False)
    criterion = models.loss.get_criterion(category_type, args.jaccard_weight)
    evaluate(net, loader, criterion, device, save_path, category_type)
Ejemplo n.º 4
0
    def __init__(self, weight_path: Path, device) -> None:
        self.device = device

        self.net = models.load_model(_MODEL_NAME,
                                     surface_types.from_string('simple')).to(
                                         self.device)

        print(f'Loading {str(weight_path)}')
        state_dict = torch.load(weight_path, map_location=self.device)
        self.net.load_state_dict(state_dict=state_dict)
        print('Done')
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--train-image-dirs', required=True, nargs='+')
    parser.add_argument('--train-mask-dirs', required=True, nargs='+')
    parser.add_argument('--train-dataset-types', nargs='+')

    parser.add_argument('--validation-image-dirs', required=True, nargs='+')
    parser.add_argument('--validation-mask-dirs', required=True, nargs='+')
    parser.add_argument('--validation-dataset-types', nargs='+')

    parser.add_argument('--save-dir', default='./runs')
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--category-type', default='binary', choices=['binary', 'simple'])
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--input-size', type=int, nargs=2, default=(640, 640))
    parser.add_argument('--jaccard-weight', type=float, default=0.3)

    available_networks = ['unet11', 'unet16']
    parser.add_argument('--model-name', type=str, default='unet11', choices=available_networks)

    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--device-id', type=int, default=0)
    parser.add_argument('--run-name', type=str, required=True)

    args = parser.parse_args()
    print(args)

    device = torch_tools.get_device(args.cpu, args.device_id)
    torch_tools.set_seeds(args.seed, device)

    train_image_dirs = [Path(p) for p in args.train_image_dirs]
    train_mask_dirs = [Path(p) for p in args.train_mask_dirs]
    validation_image_dirs = [Path(p) for p in args.validation_image_dirs]
    validation_mask_dirs = [Path(p) for p in args.validation_mask_dirs]

    for data_dir in\
            train_image_dirs +\
            train_mask_dirs +\
            validation_image_dirs +\
            validation_mask_dirs:
        assert data_dir.exists(), f'{str(data_dir)} does not exist.'

    input_size = args.input_size
    w, h = input_size
    rate = 0.9

    '''
    train_transform = Compose([
        HorizontalFlip(p=0.5),
        RandomCrop(*input_size),
    ])
    '''
    # Transforms
    transforms = {}

    # Basic transform
    transforms['base'] = Compose([
        HorizontalFlip(p=0.5),
        #IAAPerspective(scale=(0.05, 0.1), p=0.3),
        Rotate(5, p=0.5),
        RandomGamma(p=0.5),
        HueSaturationValue(
            hue_shift_limit=10,
            sat_shift_limit=15,
            val_shift_limit=10,
            p=0.5
        ),
        RandomBrightnessContrast(p=0.5),
        OneOf([
            RandomSizedCrop((int(h * rate), int(w * rate)), h, w, p=1.0),
            RandomCrop(h, w, p=1.0),
        ], p=1.0)
    ])

    # BDD dataset
    transforms['bdd'] = transforms['base']

    # Always shrink to 22% - 50%
    transforms['walk'] = Compose([
        HorizontalFlip(p=0.5),
        Rotate(5, p=0.5),
        RandomGamma(p=0.5),
        HueSaturationValue(
            hue_shift_limit=10,
            sat_shift_limit=15,
            val_shift_limit=10,
            p=0.5
        ),
        RandomBrightnessContrast(p=0.5),
        RandomScale((-0.78, -0.5), p=1.0),
        RandomCrop(h, w, p=1.0),
    ])

    # MISC dataset transform
    transforms['misc'] = Compose([
        HorizontalFlip(p=0.5),
        Rotate(5, p=0.5),
        RandomGamma(p=0.5),
        HueSaturationValue(
            hue_shift_limit=10,
            sat_shift_limit=15,
            val_shift_limit=10,
            p=0.5
        ),
        RandomBrightnessContrast(p=0.5),
        Resize(h, w, p=1.0),
    ])

    validation_transform = Compose([
        CenterCrop(h, w),
    ])

    category_type = datasets.surface_types.from_string(args.category_type)

    # Logger
    log_dir = _get_log_dir(args)
    logger = logging.Logger(log_dir, n_save=16, image_size=256, category_type=category_type)
    logger.writer.add_text('args', str(args))

    train_datasets = []
    for image_dir, mask_dir, dataset_type in zip(train_image_dirs, train_mask_dirs, args.train_dataset_types):
        _dataset = datasets.create_dataset(
            dataset_type,
            [image_dir],
            [mask_dir],
            category_type,
            transforms[dataset_type],
        )
        train_datasets.append(_dataset)

    validation_datasets = []
    for image_dir, mask_dir, dataset_type in zip(validation_image_dirs, validation_mask_dirs, args.validation_dataset_types):
        _dataset = datasets.create_dataset(
            dataset_type,
            [image_dir],
            [mask_dir],
            category_type,
            validation_transform,
        )
        validation_datasets.append(_dataset)

    # Merge datasets
    train_dataset = ConcatDataset(train_datasets)
    validation_dataset = ConcatDataset(validation_datasets)

    train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True)
    validation_loader = DataLoader(validation_dataset, args.batch_size, shuffle=False)
    net = models.load_model(args.model_name, category_type).to(device)

    criterion = models.loss.get_criterion(category_type, args.jaccard_weight)
    optimizer = torch.optim.Adam(net.parameters())

    for epoch in range(1, args.epochs + 1):
        print(f'epoch: {epoch:03d}')
        sys.stdout.flush()
        train(net, train_loader, epoch, optimizer, criterion, device, logger)
        evaluate(net, validation_loader, epoch, criterion, device, logger, 'validation')