示例#1
0
def HairstyleParsing(model, i_dir, o_dir):
    test_files = get_img_files(i_dir)
    data_loader = get_testdata_loader(test_files)
    dataset_size = len(test_files)

    inputs = next(iter(data_loader))
    
    idx = 0
    fig = plt.figure()
    with torch.no_grad():
        for inputs in data_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            for j in range(inputs.size()[0]):
                print(HairStyle[preds[j]])                
                """Imshow for Tensor."""
                inp = inputs[j].cpu().numpy().transpose((1, 2, 0))
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                inp = std * inp + mean
                inp = np.clip(inp, 0, 1)
                title = HairStyle[preds[j]]
                model_name = '../hairModel/' + title + '.png'
                model = Image.open(model_name)
                plt.title(title)
                plt.subplot(121)
                plt.imshow(inp)
                plt.subplot(122)
                plt.imshow(model)
                plt.pause(3)  # pause a bit so that plots are updated
                plt.show()  # pause a bit so that plots are updated
                fname = o_dir+ str(j+ (idx*inputs.size()[0])) + '.png'
                fig.savefig(fname)
            idx+=1
示例#2
0
def run_cv(img_size, pre_trained):
    image_files = get_img_files()
    kf = KFold(n_splits=N_CV, random_state=RANDOM_STATE, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for n, (train_idx, val_idx) in enumerate(kf.split(image_files)):
        train_files = image_files[train_idx]
        val_files = image_files[val_idx]

        writer = SummaryWriter()

        def on_after_epoch(m, df_hist):
            save_best_model(n, m, df_hist)
            write_on_board(writer, df_hist)
            log_hist(df_hist)

        criterion = dice_loss(scale=2)
        data_loaders = get_data_loaders(train_files, val_files, img_size)
        trainer = Trainer(data_loaders, criterion, device, on_after_epoch)

        model = MobileNetV2_unet(pre_trained=pre_trained)
        model.to(device)
        optimizer = Adam(model.parameters(), lr=LR)

        hist = trainer.train(model, optimizer, num_epochs=N_EPOCHS)
        hist.to_csv('{}/{}-hist.csv'.format(OUT_DIR, n), index=False)

        writer.close()

        break
示例#3
0
def evaluate():
    img_size = (IMG_SIZE, IMG_SIZE)
    n_shown = 0

    image_files = get_img_files()
    kf = KFold(n_splits=N_CV, random_state=RANDOM_STATE, shuffle=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    for n, (train_idx, val_idx) in enumerate(kf.split(image_files)):
        val_files = image_files[val_idx]
        data_loader = get_data_loaders(val_files)

        model = MobileNetV2_unet()
        model.load_state_dict(torch.load('{}/{}-best.pth'.format(OUT_DIR, n)))
        model.to(device)
        model.eval()

        with torch.no_grad():
            for inputs, labels in data_loader:
                start = time.clock()
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                elapsed = time.clock() - start
                print(elapsed)

                for i, l, o in zip(inputs, labels, outputs):
                    i = i.cpu().numpy().transpose((1, 2, 0)) * 255
                    l = l.cpu().numpy().reshape(*img_size) * 255
                    o = o.cpu().numpy().reshape(int(IMG_SIZE / 2),
                                                int(IMG_SIZE / 2)) * 255

                    i = cv2.resize(i.astype(np.uint8), img_size)
                    l = cv2.resize(l.astype(np.uint8), img_size)
                    o = cv2.resize(o.astype(np.uint8), img_size)

                    plt.subplot(121)
                    plt.imshow(i)
                    plt.subplot(132)
                    plt.imshow(l)
                    plt.subplot(122)
                    plt.imshow(o)
                    plt.show()
                    n_shown += 1
                    if n_shown > 10:
                        return
示例#4
0
IMG_SIZE = 224
RANDOM_STATE = 1

EXPERIMENT = 'train_unet'
OUT_DIR = 'outputs/{}'.format(EXPERIMENT)

data_dir = '../Database/'
img_size = 224
pre_trained_mobnet2 = './mobilenet_v2.pth.tar'
n_class = 9

plt.ion()  # interactive mode

#def run_cv(img_size, pre_trained, target):
if __name__ == '__main__':
    image_files = get_img_files(data_dir)
    kf = KFold(n_splits=N_CV, random_state=RANDOM_STATE, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for n, (train_idx, val_idx) in enumerate(kf.split(image_files)):
        #----Prepare Data-------------------------------------------------------------------------------
        train_files = image_files[train_idx]
        val_files = image_files[val_idx]
        data_loaders = get_data_loaders(train_files, val_files, img_size)
        dataset_sizes = [len(train_files), len(val_files)]
        print('dataset_sizes:', dataset_sizes)
        inputs, classes = next(iter(data_loaders[0]))
        #out = torchvision.utils.make_grid(inputs)
        #imshow(out, title=[x for x in classes])
        #----Prepare Model-------------------------------------------------------------------------------
示例#5
0
def evaluate():
    img_size = (IMG_SIZE, IMG_SIZE)
    n_shown = 0

    image_files = get_img_files()
    kf = KFold(n_splits=N_CV, random_state=RANDOM_STATE, shuffle=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    idx = 0
    f = open('IoU_Asia.txt', 'w')
    for n, (train_idx, val_idx) in enumerate(kf.split(image_files)):
        #for n, (train_idx, val_idx) in enumerate(image_files):
        print('n:%i, idx:%i' % (n, idx))
        val_files = image_files[val_idx]
        data_loader = get_data_loaders(val_files)

        model = MobileNetV2_unet()
        model.load_state_dict(torch.load('{}/{}-best.pth'.format(OUT_DIR, 0)))
        model.to(device)
        model.eval()

        with torch.no_grad():
            for inputs, labels in data_loader:

                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                idx += 1
                #idx_in = 1
                #if idx < 2001:
                #    continue
                for i, l, o in zip(inputs, labels, outputs):
                    i = i.cpu().numpy().transpose((1, 2, 0)) * 255
                    l = l.cpu().numpy().reshape(*img_size) * 255
                    o = o.cpu().numpy().reshape(int(IMG_SIZE / 2),
                                                int(IMG_SIZE / 2)) * 255

                    i = cv2.resize(i.astype(np.uint8), img_size)
                    l = cv2.resize(l.astype(np.uint8), img_size)
                    o = cv2.resize(o.astype(np.uint8), img_size)

                    h = l.shape[0]
                    w = l.shape[1]
                    union = np.zeros((h, w), np.uint8)
                    overlap = np.zeros((h, w), np.uint8)
                    union_pixel = 0
                    overlap_pixel = 0
                    for row in range(h):
                        for col in range(w):
                            if l[row][col] == 0 and o[row][col] == 0:
                                union[row][col] = 0
                            elif o[row][col] == 0:
                                union[row][col] = l[row][col]
                                union_pixel += 1
                            else:
                                union[row][col] = o[row][col]
                                union_pixel += 1

                            if l[row][col] == o[row][col]:
                                overlap[row][col] = o[row][col]
                                if l[row][col] != 0:
                                    overlap_pixel += 1
                            else:
                                overlap[row][col] = 0

                    plt.subplot(121)
                    plt.imshow(i)
                    plt.subplot(122)
                    plt.imshow(o)
                    o_dir = "data/raw/output_asia/%i.png" % (idx)
                    plt.savefig(o_dir)