Ejemplo n.º 1
0
        image_file_name = 'data/raw/{:s}.jpg'.format(hash)
        result_file_name = 'data/images_alpha/{:s}.png'.format(hash)

        if os.path.exists(result_file_name):
            return SKIP_ITEM

        try:
            image = load_image(image_file_name)
        except ValueError:
            print("Could not open {:s}.".format(image_file_name))
            return SKIP_ITEM
        
        return image, result_file_name

dataset = ImageDataset()
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8)

for item in tqdm(data_loader):
    if item == SKIP_ITEM:
        continue

    image, result_file_name = item
    image = image.to(device)

    image = classifier.apply(image, margin=MARGIN, create_alpha=ALPHA)
    
    if image is None or len(image.shape) != 3 or image.shape[1] < 10 or image.shape[2] < 10:
        print("Found nothing.")
        continue
    
    utils.save_image(image, result_file_name[0])
Ejemplo n.º 2
0
import torch
from torchvision import utils
import random
import glob
from shutil import copyfile
from mask_loader import load_image

from classifier import Classifier

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASSIFIER_FILENAME = 'trained_models/classifier.to'

classifier = Classifier()
classifier.cuda()
classifier.load_state_dict(torch.load(CLASSIFIER_FILENAME))
classifier.eval()

file_names = glob.glob('data/raw/**.jpg', recursive=True)

while True:
    file_name = random.choice(file_names)
    hash = file_name.split('/')[-1][:-4]

    image = load_image(file_name).to(device)
    image = classifier.apply(image)

    if image is None:
        continue

    copyfile(file_name, 'data/test/{:s}.jpg'.format(hash))
    utils.save_image(image, 'data/test/{:s}_result.png'.format(hash))
Ejemplo n.º 3
0
            image = load_image(image_file_name)
        except:
            print("Could not open {:s}.".format(image_file_name))
            return SKIP_ITEM

        return image, result_file_name


dataset = ImageDataset()
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8)

for item in tqdm(data_loader):
    if item == SKIP_ITEM:
        continue

    image, result_file_name = item
    image = image.to(device)

    try:
        image = classifier.apply(image, margin=0, create_alpha=True)
    except Exception as exception:
        if isinstance(exception, KeyboardInterrupt):
            raise exception
        print(("Error while handling {:s}".format(result_file_name[0])))

    if image is None or len(
            image.shape) != 3 or image.shape[1] < 10 or image.shape[2] < 10:
        print("Found nothing.")
        continue

    utils.save_image(image, result_file_name[0])