def load_unlabeled_data(tar, data_dir='dataset/'):
    from torchvision import transforms

    folder_tar = data_dir + tar + '/images'

    transform = {
        'weak':
        transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        'strong':
        transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            RandAugment(
            ),  # Paper: RandAugment: Practical data augmentation with no separate search
            transforms.ToTensor(),
            # Cutout(size=16), # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    }

    unlabeled_weak_data = datasets.ImageFolder(root=folder_tar,
                                               transform=transform['weak'])
    unlabeled_strong_data = datasets.ImageFolder(root=folder_tar,
                                                 transform=transform['strong'])

    return unlabeled_weak_data, unlabeled_strong_data
def remove_background(input_image):
    # Disable ssl verification (This was used to download the trained dataset)
    if (not os.environ.get('PYTHONHTTPSVERIFY', '')
            and getattr(ssl, '_create_unverified_context', None)):
        ssl._create_default_https_context = ssl._create_unverified_context

    # Defining the convolutional network with a pretrained dataset that will be used to detect the person in the image
    fully_convolutional_network = models.segmentation.fcn_resnet101(
        pretrained=True).eval()

    # Preprocessing the image and making it become a tensor so it can be used in the convolutional network
    image_to_tensor_transform = cvtransforms.Compose([
        cvtransforms.ToTensor(),
        cvtransforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
    ])

    input_tensor = image_to_tensor_transform(input_image).unsqueeze(0)

    # Passing the input structure through the net to get the ouput model with the person classified
    output_model = fully_convolutional_network(input_tensor)['out']

    # Transforming the tensor into n labeled image
    labeled_image = torch.argmax(output_model.squeeze(),
                                 dim=0).detach().cpu().numpy()

    # Generating the segmentation mask to remove the background
    mask = segmentation_mask(labeled_image)

    # Removing the background leaving only the person
    result_image = cv2.bitwise_and(input_image,
                                   input_image,
                                   mask=mask.astype(np.uint8))

    return result_image
Beispiel #3
0
def get_cpu_transforms(augs: DictConfig) -> dict:
    """Makes CPU augmentations from the aug section of a configuration. 

    Parameters
    ----------
    augs : DictConfig
        augmentation parameters

    Returns
    -------
    xform : dict
        keys: ['train', 'val', 'test']. Values: a composed OpenCV augmentation pipeline callable. 
        Example: auged_images = xform['train'](images)
    """
    train_transforms = []
    val_transforms = []
    # order here matters a lot!!
    if augs.crop_size is not None:
        train_transforms.append(transforms.RandomCrop(augs.crop_size))
        val_transforms.append(transforms.CenterCrop(augs.crop_size))
    if augs.resize is not None:
        train_transforms.append(transforms.Resize(augs.resize))
        val_transforms.append(transforms.Resize(augs.resize))
    if augs.pad is not None:
        pad = tuple(augs.pad)
        train_transforms.append(transforms.Pad(pad))
        val_transforms.append(transforms.Pad(pad))

    train_transforms.append(Transpose())
    val_transforms.append(Transpose())

    train_transforms = transforms.Compose(train_transforms)
    val_transforms = transforms.Compose(val_transforms)

    xform = {'train': train_transforms,
             'val': val_transforms,
             'test': val_transforms}
    log.debug('CPU transforms: {}'.format(xform))
    return xform
Beispiel #4
0
def get_image_from_url(url=''):
    img_transforms = opencv_transform.Compose([
        opencv_transform.Resize(int(math.ceil(160 / 0.875)),
                                interpolation=cv2.INTER_LINEAR),
        opencv_transform.CenterCrop(160),
        opencv_transform.ToTensor(),
        opencv_transform.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225]),
    ])
    # download the image from web if uri is valid.
    if (validators.url(url)):
        img = Image.open(requests.get(url, stream=True).raw)
    else:
        img = Image.open(url)

    img = img_transforms(img)
    return img
Beispiel #5
0
criterion_L1 = torch.nn.L1Loss() # L1 Loss
criterion_L2 = torch.nn.MSELoss() # L2 Loss
# criterion_TV = TVLoss() # Total variance Loss
# Lambda
lambda1 = 100
lambda2 = 1e-4
lambda3 = 1e-2
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
beta2 = 0.999

batch_size = 1
ncluster = 16
# configs ######################################
test_transform = TF.Compose([ # ToTensor and Normalize in dataloader
    TF.Resize(256)
    ])

model = Sketch2Color(img_size=img_size, trs_dim=transformer_dim, ncluster=ncluster)
optimizer_G = torch.optim.Adam(model.parameters(),lr=lr, betas=(beta1, beta2))

m = torch.nn.Upsample(scale_factor=16)

def load(model, optimizer_G):
    # global current_epoch, best_losses, loss_list_D, loss_list_G, optimizer_G, optimizer_D
    print('Loading...', end=' ')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    # checkpoint = torch.load('./checkpoint/edge2color/ckpt.pth')
    checkpoint = torch.load('./checkpoint\edge2color\save/e40_n16_tf_plt.pth', map_location=device)
    #current_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['Sketch2Color'], strict=True)
def load_data(src, tar, data_dir='dataset', use_cv2=False):
    folder_src = os.path.join(os.path.join(data_dir, src), 'images')
    folder_tar = os.path.join(os.path.join(data_dir, tar), 'images')

    if use_cv2:
        import cv2
        from opencv_transforms import transforms

        def loader_opencv(path: str) -> np.ndarray:
            return cv2.imread(path)

        transform = {
            'train':
            transforms.Compose([
                transforms.Resize((256, 256), interpolation=cv2.INTER_LINEAR),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ]),
            'test':
            transforms.Compose([
                transforms.Resize((224, 224), interpolation=cv2.INTER_LINEAR),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
        }

        source_data = datasets.ImageFolder(root=folder_src,
                                           transform=transform['train'],
                                           loader=loader_opencv)
        # source_data_loader = torch.utils.data.DataLoader(source_data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True)

        target_train_data = datasets.ImageFolder(root=folder_tar,
                                                 transform=transform['train'],
                                                 loader=loader_opencv)
        # target_train_loader = torch.utils.data.DataLoader(target_train_data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True)

        target_test_data = datasets.ImageFolder(root=folder_tar,
                                                transform=transform['test'],
                                                loader=loader_opencv)
        # target_test_loader = torch.utils.data.DataLoader(target_test_data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = False)
    else:
        from torchvision import transforms
        transform = {
            'train':
            transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ]),
            'test':
            transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
        }

        source_data = datasets.ImageFolder(root=folder_src,
                                           transform=transform['train'])
        # source_data_loader = torch.utils.data.DataLoader(source_data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True)

        target_train_data = datasets.ImageFolder(root=folder_tar,
                                                 transform=transform['train'])
        # target_train_loader = torch.utils.data.DataLoader(target_train_data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True)

        target_test_data = datasets.ImageFolder(root=folder_tar,
                                                transform=transform['test'])
        # target_test_loader = torch.utils.data.DataLoader(target_test_data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = False)

    return source_data, target_train_data, target_test_data
# Disable ssl verification (This was used to download the trained dataset)
if (not os.environ.get('PYTHONHTTPSVERIFY', '')
        and getattr(ssl, '_create_unverified_context', None)):
    ssl._create_default_https_context = ssl._create_unverified_context

# Defining the convolutional network with a pretrained dataset that will be used to detect the person in the image
fully_convolutional_network = models.segmentation.fcn_resnet101(
    pretrained=True).eval()

input_image = cv2.imread('test_images/input.jpg')

# Preprocessing the image and making it become a tensor so it can be used in the convolutional network
image_to_tensor_transform = cvtransforms.Compose([
    cvtransforms.ToTensor(),
    cvtransforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
])

input_tensor = image_to_tensor_transform(input_image).unsqueeze(0)

# Passing the input structure through the net to get the ouput model with the person classified
output_model = fully_convolutional_network(input_tensor)['out']

# Transforming the tensor into n labeled image
labeled_image = torch.argmax(output_model.squeeze(),
                             dim=0).detach().cpu().numpy()

# Generating the segmentation mask to remove the background
mask = segmentation_mask(labeled_image)
Beispiel #8
0
def main(cfg):
    workdir = Path(cfg.workdir)
    workdir.mkdir(parents=True, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    set_logger(workdir / 'log.txt')
    cfg.dump_to_file(workdir / 'config.yml')
    saver = Saver(workdir, keep_num=10)
    logging.info(f'config: \n{cfg}')
    logging.info(f'use device: {device}')

    model = iqa.__dict__[cfg.model.name](**cfg.model.kwargs)
    model = model.to(device)

    if torch.cuda.device_count() > 1:
        model_dp = nn.DataParallel(model)
    else:
        model_dp = model

    train_transform = Transform(
        transforms.Compose([
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ]))

    val_transform = Transform(
        transforms.Compose([transforms.RandomCrop(224),
                            transforms.ToTensor()]))

    if not Path(cfg.ava.train_cache).exists():
        create_memmap(cfg.ava.train_labels, cfg.ava.images,
                      cfg.ava.train_cache, cfg.num_workers)
    if not Path(cfg.ava.val_cache).exists():
        create_memmap(cfg.ava.train_labels, cfg.ava.images, cfg.ava.val_cache,
                      cfg.num_workers)

    trainset = MemMap(cfg.ava.train_cache, train_transform)
    valset = MemMap(cfg.ava.val_cache, val_transform)

    total_steps = len(trainset) // cfg.batch_size * cfg.num_epochs
    eval_interval = len(trainset) // cfg.batch_size
    logging.info(f'total steps: {total_steps}, eval interval: {eval_interval}')
    model_dp.train()
    parameters = group_parameters(model)
    optimizer = SGD(parameters,
                    cfg.lr,
                    cfg.momentum,
                    weight_decay=cfg.weight_decay)

    lr_scheduler = OneCycleLR(optimizer,
                              max_lr=cfg.lr,
                              div_factor=cfg.lr / cfg.warmup_lr,
                              total_steps=total_steps,
                              pct_start=0.01,
                              final_div_factor=cfg.warmup_lr / cfg.final_lr)

    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=cfg.batch_size,
                                               shuffle=True,
                                               num_workers=cfg.num_workers,
                                               drop_last=True,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=cfg.batch_size,
                                             shuffle=False,
                                             num_workers=cfg.num_workers,
                                             pin_memory=True)

    curr_loss = 1e9
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'step': 0,  # init step,
        'cfg': cfg,
        'loss': curr_loss
    }

    saver.save(0, state)

    trainloader = repeat_loader(train_loader)
    batch_processor = BatchProcessor(device)
    start = time.time()
    for step in range(0, total_steps, eval_interval):
        num_steps = min(step + eval_interval, total_steps) - step
        step += num_steps
        trainmeter = train_steps(model_dp, trainloader, optimizer,
                                 lr_scheduler, emd_loss, batch_processor,
                                 num_steps)
        valmeter = evaluate(model_dp, val_loader, emd_loss, batch_processor)
        finish = time.time()
        img_s = cfg.batch_size * eval_interval / (finish - start)
        loss = valmeter.meters['loss'].global_avg

        state = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'step': step,  # init step,
            'cfg': cfg,
            'loss': loss
        }
        saver.save(step, state)

        if loss < curr_loss:
            curr_loss = loss
            saver.save_best(state)

        logging.info(
            f'step: [{step}/{total_steps}] img_s: {img_s:.2f} train: [{trainmeter}] eval:[{valmeter}]'
        )
        start = time.time()