Esempio n. 1
0
 def __init__(self, cfg, mode='train', device='cpu'):
     super().__init__(cfg, mode, device)
     self.backbone = cfg['model']['backbone']
     self.in_channels = cfg['data']['in_channels']
     self.num_classes = cfg['data']['num_classes']
     self.preprocessing = get_preprocessing_fn(encoder_name=self.backbone,
                                               pretrained='imagenet' if self.pretrained else False)
     transforms_mode = 'train' if self.mode == 'train' else 'val'
     self.transforms = get_transforms(config=cfg,
                                      key=f'transforms_{transforms_mode}',
                                      imagenet=self.pretrained,
                                      norm=False,
                                      to_tensor=False)
    def __init__(self, image_dir, mask_dir, original_dir, 
                    base_model_name, encoder_weights, threshold=0.30,
                    is_train=True, fold_num=0, fold_total=5):
        self.threshold = threshold
        self.preprocess_input = get_preprocessing_fn(base_model_name, pretrained=encoder_weights)

        self.image_patches = []
        self.mask_patches = []
        
        original_image_name = []
        origianl_image_pathes = os.listdir(original_dir)
        origianl_image_pathes = sorted(origianl_image_pathes)
        for origianl_image_path in origianl_image_pathes:
            if not check_is_image(origianl_image_path):
                print(origianl_image_path, 'not image')
                continue
            original_image_name.append(origianl_image_path.split('.')[0])

        test_image_name = original_image_name[fold_num::fold_total]
        train_image_name = [i for i in original_image_name if i not in test_image_name]
        print('total image len:', len(original_image_name), 
                'train len: %d' % len(train_image_name) if is_train else 'test len: %d' % len(self.image_patches))

        cnt = 0
        images = os.listdir(image_dir)

        # s[i:j:k] slice of s from i to j with step k
        for image in images:
            if not check_is_image(image):
                print(image, 'not image')
                continue

            if not os.path.isfile(os.path.join(mask_dir, image)):
                print(image, 'no mask')
                continue
            
            if is_train and image.split('-')[0] in train_image_name:
                self.image_patches.append(os.path.join(image_dir, image))
                self.mask_patches.append(os.path.join(mask_dir, image))
            elif not is_train and image.split('-')[0] in test_image_name:
                self.image_patches.append(os.path.join(image_dir, image))
                self.mask_patches.append(os.path.join(mask_dir, image))
            cnt += 1

            if cnt % 50000 == 0:
                print(cnt)
                # break
        print('total patch len:', cnt, 'train patch len:' if is_train else 'test patch len:', len(self.image_patches))
Esempio n. 3
0
def make_dataloader(samples,
                    collate_fn,
                    model_config,
                    patch_sizes,
                    aug_config=None,
                    batch_size=32,
                    shuffle=True,
                    num_workers=8,
                    dataloader_type='in_memory',
                    preprocessing_mask_fn=None,
                    image_loader=None,
                    mask_loader=None):
    if aug_config is not None:
        augmentation_pipeline = make_aug(**aug_config)
    else:
        augmentation_pipeline = None

    if model_config['source'] == 'basic':
        preprocessing_image_fn = image_process_basic
    else:
        preprocessing_image_fn = get_preprocessing_fn(
            encoder_name=model_config['params']['encoder_name'],
            pretrained=model_config['params']['encoder_weights'])

    if dataloader_type == 'in_memory':
        dataset = PatchDataset(samples,
                               augmentation_fn=augmentation_pipeline,
                               preprocessing_image_fn=preprocessing_image_fn,
                               preprocessing_mask_fn=preprocessing_mask_fn)
    elif dataloader_type == 'lazy':
        dataset = ImageDataset.from_samples(
            samples=samples,
            patch_sizes=patch_sizes,
            image_loader=image_loader,
            mask_loader=mask_loader,
            augmentation_fn=augmentation_pipeline,
            preprocessing_image_fn=preprocessing_image_fn,
            preprocessing_mask_fn=preprocessing_mask_fn)
    else:
        raise ValueError('Wrong dataset type!')

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            collate_fn=collate_fn,
                            shuffle=shuffle,
                            num_workers=num_workers)
    return dataloader
Esempio n. 4
0
    def __init__(self, network, device, push_enabled, place_enabled,
                 num_rotations):
        super(ManipulationNet, self).__init__()
        self.device = device
        self.push_enabled = push_enabled
        self.place_enabled = place_enabled
        self.net = None
        self.preprocess_input = None

        # Initialize network
        if network == 'grconvnet':
            self.net = GenerativeResnet()
        elif network == 'grconvnet3':
            # if self.push_enabled:
            self.push_net = GenerativeResnet3()
            self.grasp_net = GenerativeResnet3()
            self.place_net = GenerativeResnet3()
        elif network == 'grconvnet4':
            # if self.push_enabled:
            self.push_net = GenerativeResnet4()
            self.grasp_net = GenerativeResnet4()
            self.place_net = GenerativeResnet4()
        elif network == 'denseunet':
            self.net = DenseUNet()
        elif network == 'efficientunet':
            encoder = 'efficientnet-b4'
            encoder_weights = 'imagenet'
            self.preprocess_input = get_preprocessing_fn(
                encoder, pretrained=encoder_weights)
            self.push_net = smp.Unet(encoder,
                                     encoder_weights=encoder_weights,
                                     in_channels=4)
            self.grasp_net = smp.Unet(encoder,
                                      encoder_weights=encoder_weights,
                                      in_channels=4)
            self.place_net = smp.Unet(encoder,
                                      encoder_weights=encoder_weights,
                                      in_channels=4)
        else:
            raise NotImplementedError(
                'Network type {} is not implemented'.format(network))

        self.num_rotations = num_rotations

        # Initialize variables
        self.padding_width = 0
        self.output_prob = []
Esempio n. 5
0
    def __init__(self, image_dir, mask_dir, base_model_name, encoder_weights):
        self.base_model_name = base_model_name
        self.preprocess_input = get_preprocessing_fn(base_model_name, pretrained=encoder_weights)

        self.image_pathes = []
        self.mask_pathes = []

        image_pathes = os.listdir(image_dir)
        for image_path in image_pathes:
            if not check_is_image(image_path):
                print('not image', image_path)
                continue

            if not os.path.isfile(os.path.join(mask_dir, image_path)):
                print('no mask', image_path)
                continue

            self.image_pathes.append(os.path.join(image_dir, image_path))
            self.mask_pathes.append(os.path.join(mask_dir, image_path))
def unet_train(epochs, gpu, base_model_name, encoder_weights, generator_lr,
               discriminator_lr, lambda_bce, threshold, batch_size,
               image_train_dir, mask_train_dir, original_dir, fold_num,
               fold_total):
    # make save directory
    weight_path = (
        './step1_label%d_' % fold_num) + base_model_name + '_' + str(
            int(lambda_bce)) + '_' + str(generator_lr) + '_' + str(threshold)
    image_save_path = weight_path + '/images'
    os.makedirs(weight_path, exist_ok=True)
    os.makedirs(image_save_path, exist_ok=True)

    # rgb , preprocess input
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])

    train_data_set = Dataset_Return_Four(image_train_dir,
                                         mask_train_dir,
                                         os.path.join(original_dir, 'image'),
                                         base_model_name,
                                         encoder_weights,
                                         threshold=threshold,
                                         is_train=True,
                                         fold_num=fold_num,
                                         fold_total=fold_total)

    train_loader = DataLoader(train_data_set,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True)

    device = torch.device("cuda:%s" % gpu)

    image_test_list = []
    test_image_pathes = os.listdir(os.path.join(original_dir, 'image'))
    for test_image_path in test_image_pathes:
        if not check_is_image(test_image_path):
            print(test_image_path, 'not image')
            continue
        image_test_list.append(
            (os.path.join(original_dir, 'image', test_image_path),
             os.path.join(original_dir, 'mask', test_image_path)))
    image_test_list = image_test_list[fold_num::fold_total]
    print('test len:', len(image_test_list))

    preprocess_input = get_preprocessing_fn(base_model_name,
                                            pretrained=encoder_weights)

    models = []
    optimizers = []
    for channel in range(4):
        models.append(
            smp.Unet(base_model_name,
                     encoder_weights=encoder_weights,
                     in_channels=3))
        models[channel].to(device)
        optimizers.append(
            optim.Adam(models[channel].parameters(),
                       lr=generator_lr,
                       betas=(0.5, 0.999)))

    discriminator = Discriminator(in_channels=4)
    discriminator.to(device)
    optimizer_discriminator = optim.Adam(discriminator.parameters(),
                                         lr=discriminator_lr,
                                         betas=(0.5, 0.999))

    criterion = nn.BCEWithLogitsLoss()

    channel_dict = {0: 'blue', 1: 'green', 2: 'red', 3: 'gray'}
    value = int(256 * 0.5)
    best_fmeasure = 0.0
    lambda_gp = 10.0
    epoch_start_time = time.time()
    for epoch in range(epochs):

        # train
        for channel in range(4):
            models[channel].train()
            models[channel].requires_grad_(False)

        for idx, (images, masks) in enumerate(train_loader):
            # for sample image
            test_masks_pred_list = []
            test_masks_list = []
            test_images_list = []
            for channel in range(4):
                images[channel] = images[channel].to(device)
                masks[channel] = masks[channel].to(device)

                models[channel].requires_grad_(True)
                masks_pred = models[channel](images[channel])

                # discriminator
                discriminator.requires_grad_(True)
                # Fake
                fake_AB = torch.cat((images[channel], masks_pred), 1).detach()
                pred_fake = discriminator(fake_AB)

                # Real
                real_AB = torch.cat((images[channel], masks[channel]), 1)
                pred_real = discriminator(real_AB)

                gradient_penalty = compute_gradient_penalty(
                    discriminator, real_AB, fake_AB, device)
                discriminator_loss = -torch.mean(pred_real) + torch.mean(
                    pred_fake) + lambda_gp * gradient_penalty

                optimizer_discriminator.zero_grad()
                discriminator_loss.backward()
                optimizer_discriminator.step()

                discriminator.requires_grad_(False)
                # end discriminator

                # generator
                fake_AB = torch.cat((images[channel], masks_pred), 1)
                pred_fake = discriminator(fake_AB)
                generator_loss = -torch.mean(pred_fake)
                bce_loss = criterion(masks_pred, masks[channel])
                total_loss = generator_loss + bce_loss * lambda_bce

                optimizers[channel].zero_grad()
                total_loss.backward()
                optimizers[channel].step()

                models[channel].requires_grad_(False)
                # end generator

                if idx % 2000 == 0:
                    print(
                        'channel %s train step[%d/%d] discriminator loss: %.5f, total loss: %.5f, generator loss: %.5f, bce loss: %.5f, time: %.2f'
                        % (channel_dict[channel], idx, len(train_loader),
                           discriminator_loss.item(), total_loss.item(),
                           generator_loss.item(), bce_loss.item(),
                           time.time() - epoch_start_time))

                    # for sample images
                    rand_idx_start = randrange(masks[channel].size(0) - 2)
                    rand_idx_end = rand_idx_start + 2
                    test_masks_pred = torch.sigmoid(
                        masks_pred[rand_idx_start:rand_idx_end]).detach().cpu(
                        )
                    test_masks_pred = test_masks_pred.permute(
                        0, 2, 3, 1).numpy().astype(np.float32)
                    test_masks_pred = np.squeeze(test_masks_pred, axis=-1)
                    test_masks_pred_list.extend(test_masks_pred)

                    test_masks = masks[
                        channel][rand_idx_start:rand_idx_end].permute(
                            0, 2, 3, 1).cpu().numpy().astype(np.float32)
                    test_masks = np.squeeze(test_masks, axis=-1)
                    test_masks_list.extend(test_masks)

                    test_images = images[channel][
                        rand_idx_start:rand_idx_end].permute(0, 2, 3,
                                                             1).cpu().numpy()
                    test_images = test_images * imagenet_std + imagenet_mean
                    test_images = np.maximum(test_images, 0.0)
                    test_images = np.minimum(test_images, 1.0)
                    test_images_list.extend(test_images)

            if idx % 2000 == 0:
                sample_images(epoch, idx, test_images_list, test_masks_list,
                              test_masks_pred_list, image_save_path)
            # break

        # eval
        for channel in range(4):
            models[channel].eval()

        total_fmeasure = 0.0
        total_image_number = 0
        for eval_idx, (image_test, mask_test) in enumerate(image_test_list):
            image = cv2.imread(image_test)
            h, w, _ = image.shape
            image_name = image_test.split('/')[-1].split('.')[0]
            # print('eval the image:', image_name)

            gt_mask = cv2.imread(mask_test, cv2.IMREAD_GRAYSCALE)
            gt_mask = np.expand_dims(gt_mask, axis=-1)
            image_patches, poslist = get_image_patch(image,
                                                     256,
                                                     256,
                                                     overlap=0.5,
                                                     is_mask=False)

            # random_number = randrange(10)
            for channel in range(4):
                color_patches = []
                for patch in image_patches:
                    tmp = patch.astype(np.float32)
                    if channel != 3:
                        color_patches.append(
                            preprocess_input(tmp[:, :, channel:channel + 1]))
                    else:
                        color_patches.append(
                            preprocess_input(
                                np.expand_dims(cv2.cvtColor(
                                    tmp, cv2.COLOR_BGR2GRAY),
                                               axis=-1)))

                step = 0
                preds = []
                with torch.no_grad():
                    while step < len(image_patches):
                        ps = step
                        pe = step + batch_size
                        if pe >= len(image_patches):
                            pe = len(image_patches)

                        target = torch.from_numpy(
                            np.array(color_patches[ps:pe])).permute(
                                0, 3, 1, 2).float()
                        pred = torch.sigmoid(models[channel](
                            target.to(device))).cpu()
                        preds.extend(pred)
                        step += batch_size

                # handling overlap
                out_img = np.ones((h, w, 1)) * 255
                for i in range(len(image_patches)):
                    patch = preds[i].permute(1, 2, 0).numpy() * 255

                    start_h, start_w, end_h, end_w, h_shift, w_shift = poslist[
                        i]
                    h_cut = end_h - start_h
                    w_cut = end_w - start_w

                    tmp = np.minimum(
                        out_img[start_h:end_h, start_w:end_w],
                        patch[h_shift:h_shift + h_cut,
                              w_shift:w_shift + w_cut])
                    out_img[start_h:end_h, start_w:end_w] = tmp

                out_img = out_img.astype(np.uint8)
                out_img[out_img > value] = 255
                out_img[out_img <= value] = 0

                # if random_number == 0:
                #     cv2.imwrite('%s/%d_%d_%s.png' % (image_save_path, epoch, channel, image_name), out_img)

                # f_measure
                # background 1, text 0
                gt_mask[gt_mask > 0] = 1
                out_img[out_img > 0] = 1

                # true positive
                tp = np.zeros(gt_mask.shape, np.uint8)
                tp[(out_img == 0) & (gt_mask == 0)] = 1
                numtp = tp.sum()

                # false positive
                fp = np.zeros(gt_mask.shape, np.uint8)
                fp[(out_img == 0) & (gt_mask == 1)] = 1
                numfp = fp.sum()

                # false negative
                fn = np.zeros(gt_mask.shape, np.uint8)
                fn[(out_img == 1) & (gt_mask == 0)] = 1
                numfn = fn.sum()

                precision = numtp / float(numtp + numfp)
                recall = numtp / float(numtp + numfn)
                fmeasure = 100. * (2. * recall * precision) / (
                    recall + precision)  # percent
                total_fmeasure += fmeasure
            total_image_number += 4
            # break
        total_fmeasure /= total_image_number

        if best_fmeasure < total_fmeasure:
            best_fmeasure = total_fmeasure

        print('epoch[%d/%d] fmeasure: %.4f, best_fmeasure: %.4f, time: %.2f' %
              (epoch + 1, epochs, total_fmeasure, best_fmeasure,
               time.time() - epoch_start_time))
        print()
    for channel in range(4):
        torch.save(
            models[channel].state_dict(), weight_path +
            '/unet_%d_%d_%.4f.pth' % (channel, epoch + 1, total_fmeasure))
    torch.save(discriminator.state_dict(),
               weight_path + '/dis_%d_%.4f.pth' % (epoch + 1, total_fmeasure))
        mask[mask < 0.5] = 0
        mask = np.expand_dims(mask, 2)
        sample = self.augmentation(image=image, mask=mask)
        image, mask = sample['image'], sample['mask'].reshape(384, 480, 1)
        sample = self.preprocessing(image=image, mask=mask)
        image, mask = sample['image'], sample['mask']
        return image, mask, torch.tensor(self.label[i])

    def __len__(self):
        return len(self.masks_fps)


# In[10]:

propro = get_preprocessing(
    get_preprocessing_fn('resnet101', pretrained='imagenet'))
train_dataset = Dataset(
    im_train_dir,
    el_train_dir,  #grab_train_dir, 
    el_path,  #grab_path
    train_label,
    augmentation=get_training_augmentation(),
    preprocessing=propro)
val_dataset = Dataset(
    im_val_dir,
    el_val_dir,  #grab_val_dir, 
    el_path,  #grab_path,
    val_label,
    augmentation=get_training_augmentation(),
    preprocessing=propro)
test_dataset = Dataset(im_test_dir,
model.requires_grad_(False)
model.eval()
models.append(model)

# gray
model = smp.Unet(base_model_name,
                 encoder_weights=encoder_weights,
                 in_channels=3)
model.load_state_dict(torch.load(weight_list[3], map_location='cpu'))
model.to(device)
model.requires_grad_(False)
model.eval()
models.append(model)

batch_size = 16
preprocess_input = get_preprocessing_fn(base_model_name,
                                        pretrained=encoder_weights)

# make directory
image_save_path = './predicted_image_for_step2_lrde_%d' % opt.fold_num
os.makedirs(image_save_path, exist_ok=True)

train_image_save_path = os.path.join(image_save_path, 'train')
os.makedirs(train_image_save_path, exist_ok=True)

test_image_save_path = os.path.join(image_save_path, 'test')
os.makedirs(test_image_save_path, exist_ok=True)

# patch directory
patch_save_path = os.path.join(train_image_save_path, 'patch')
os.makedirs(patch_save_path, exist_ok=True)
def unet_train(epochs, gpu, base_model_name, encoder_weights, generator_lr,
               discriminator_lr, lambda_bce, batch_size, image_train_dir,
               mask_train_dir, original_dir, fold_num, fold_total):
    # rgb , preprocess input
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])

    train_data_set = Dataset_Return_One(image_train_dir,
                                        mask_train_dir,
                                        os.path.join(original_dir, 'image'),
                                        base_model_name,
                                        encoder_weights,
                                        is_train=True,
                                        fold_num=fold_num,
                                        fold_total=fold_total)

    train_loader = DataLoader(train_data_set,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True)

    weight_path = ('./step2_resize_label_%d_' %
                   fold_num) + base_model_name + '_' + str(
                       int(lambda_bce)) + '_' + str(generator_lr)
    image_save_path = weight_path + '/images'
    os.makedirs(weight_path, exist_ok=True)
    os.makedirs(image_save_path, exist_ok=True)

    device = torch.device("cuda:%s" % gpu)

    # step2 resize unet
    model = smp.Unet(base_model_name,
                     encoder_weights=encoder_weights,
                     in_channels=3)
    model.to(device)
    optimizer_generator = optim.Adam(model.parameters(),
                                     lr=generator_lr,
                                     betas=(0.5, 0.999))

    discriminator = Discriminator(in_channels=4)
    discriminator.to(device)
    optimizer_discriminator = optim.Adam(discriminator.parameters(),
                                         lr=discriminator_lr,
                                         betas=(0.5, 0.999))

    criterion = nn.BCEWithLogitsLoss()

    image_test_list = []
    test_image_pathes = os.listdir(os.path.join(original_dir, 'image'))
    test_image_pathes = sorted(test_image_pathes)
    for test_image_path in test_image_pathes:
        if not check_is_image(test_image_path):
            print('not image', test_image_path)
            continue
        image_test_list.append(
            (os.path.join(original_dir, 'image', test_image_path),
             os.path.join(original_dir, 'mask', test_image_path)))
    image_test_list = image_test_list[fold_num::fold_total]
    print('test len:', len(image_test_list))

    preprocess_input = get_preprocessing_fn(base_model_name,
                                            pretrained=encoder_weights)

    value = int(256 * 0.5)
    lambda_gp = 10.0
    reshape = (512, 512)
    skip_resize_ratio = 6
    skip_max_length = 512
    padding_resize_ratio = 4
    kernel = np.ones((5, 5), np.uint8)

    best_fmeasure = 0.0
    epoch_start_time = time.time()
    for epoch in range(epochs):

        # train
        model.train()
        for idx, (images, masks) in enumerate(train_loader):
            images = images.to(device)
            masks = masks.to(device)
            masks_pred = model(images)

            # discriminator
            discriminator.requires_grad_(True)
            # Fake
            fake_AB = torch.cat((images, masks_pred), 1).detach()
            pred_fake = discriminator(fake_AB)

            # Real
            real_AB = torch.cat((images, masks), 1)
            pred_real = discriminator(real_AB)

            gradient_penalty = compute_gradient_penalty(
                discriminator, real_AB, fake_AB, device)
            discriminator_loss = -torch.mean(pred_real) + torch.mean(
                pred_fake) + lambda_gp * gradient_penalty

            optimizer_discriminator.zero_grad()
            discriminator_loss.backward()
            optimizer_discriminator.step()

            if idx % 5 == 0:
                discriminator.requires_grad_(False)

                # generator
                fake_AB = torch.cat((images, masks_pred), 1)
                pred_fake = discriminator(fake_AB)
                generator_loss = -torch.mean(pred_fake)
                bce_loss = criterion(masks_pred, masks)
                total_loss = generator_loss + bce_loss * lambda_bce

                optimizer_generator.zero_grad()
                total_loss.backward()
                optimizer_generator.step()

            if idx % 100 == 0:
                print(
                    'train step[%d/%d] discriminator loss: %.5f, total loss: %.5f, generator loss: %.5f, bce loss: %.5f, time: %.2f'
                    % (idx, len(train_loader), discriminator_loss.item(),
                       total_loss.item(), generator_loss.item(),
                       bce_loss.item(), time.time() - epoch_start_time))

            if epoch % 10 == 0 and idx % 100 == 0:
                rand_idx_start = randrange(masks.size(0) - 3)
                rand_idx_end = rand_idx_start + 3
                test_masks_pred = torch.sigmoid(
                    masks_pred[rand_idx_start:rand_idx_end]).detach().cpu()
                test_masks_pred = test_masks_pred.permute(
                    0, 2, 3, 1).numpy().astype(np.float32)
                test_masks_pred = np.squeeze(test_masks_pred, axis=-1)

                test_masks = masks[rand_idx_start:rand_idx_end].permute(
                    0, 2, 3, 1).cpu().numpy().astype(np.float32)
                test_masks = np.squeeze(test_masks, axis=-1)

                test_images = images[rand_idx_start:rand_idx_end].permute(
                    0, 2, 3, 1).cpu().numpy()
                test_images = test_images * imagenet_std + imagenet_mean
                test_images = np.maximum(test_images, 0.0)
                test_images = np.minimum(test_images, 1.0)
                sample_images(epoch, idx, test_images, test_masks,
                              test_masks_pred, image_save_path)
            # break

        # eval
        model.eval()
        total_fmeasure = 0.0
        total_image_number = 0
        random_number = randrange(len(image_test_list))
        for eval_idx, (image_test, mask_test) in enumerate(image_test_list):
            image = cv2.imread(image_test)
            h, w, _ = image.shape
            min_length = min(h, w)
            max_length = max(h, w)

            # pass global prediction
            if min_length * skip_resize_ratio < max_length or max_length < skip_max_length:
                continue

            image_name = image_test.split('/')[-1].split('.')[0]

            gt_mask = cv2.imread(mask_test, cv2.IMREAD_GRAYSCALE)

            if min_length * padding_resize_ratio < max_length:
                image, _ = image_padding(image)
                gt_mask, _ = image_padding(gt_mask, is_mask=True)

            image = cv2.resize(image,
                               dsize=reshape,
                               interpolation=cv2.INTER_NEAREST)
            gt_mask = cv2.resize(gt_mask,
                                 dsize=reshape,
                                 interpolation=cv2.INTER_NEAREST)
            gt_mask = cv2.erode(gt_mask, kernel, iterations=1)

            image = preprocess_input(image, input_space="BGR")
            image = np.expand_dims(image, axis=0)
            with torch.no_grad():
                image = torch.from_numpy(image).permute(0, 3, 1,
                                                        2).float().to(device)
                pred = torch.sigmoid(model(image)).cpu()

            out_img = pred[0].permute(1, 2, 0).numpy() * 255
            out_img = out_img.astype(np.uint8)
            out_img[out_img > value] = 255
            out_img[out_img <= value] = 0

            # if random_number == 0:
            #     cv2.imwrite('%s/%d_%s.png' % (image_save_path, epoch, image_name), out_img)

            gt_mask = np.expand_dims(gt_mask, axis=-1)
            # f_measure
            # background 1, text 0
            gt_mask[gt_mask > 0] = 1
            out_img[out_img > 0] = 1

            # true positive
            tp = np.zeros(gt_mask.shape, np.uint8)
            tp[(out_img == 0) & (gt_mask == 0)] = 1
            numtp = tp.sum()

            # false positive
            fp = np.zeros(gt_mask.shape, np.uint8)
            fp[(out_img == 0) & (gt_mask == 1)] = 1
            numfp = fp.sum()

            # false negative
            fn = np.zeros(gt_mask.shape, np.uint8)
            fn[(out_img == 1) & (gt_mask == 0)] = 1
            numfn = fn.sum()

            precision = numtp / float(numtp + numfp)
            recall = numtp / float(numtp + numfn)
            fmeasure = 100. * (2. * recall * precision) / (recall + precision
                                                           )  # percent

            total_fmeasure += fmeasure
            total_image_number += 1
            # break
        total_fmeasure /= total_image_number

        if best_fmeasure < total_fmeasure:
            best_fmeasure = total_fmeasure

        print('epoch[%d/%d] fmeasure: %.4f, best_fmeasure: %.4f, time: %.2f' %
              (epoch + 1, epochs, total_fmeasure, best_fmeasure,
               time.time() - epoch_start_time))
        print()

    torch.save(
        model.state_dict(),
        weight_path + '/unet_global_%d_%.4f.pth' % (epoch + 1, total_fmeasure))
    torch.save(
        discriminator.state_dict(),
        weight_path + '/dis_global_%d_%.4f.pth' % (epoch + 1, total_fmeasure))
def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)


    class WSOLDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, data, label, grabcut, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.grabcut = grabcut
        self.data = data
        self.label = label
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = augmentation(image=self.data[idx],mask=self.grabcut[idx])
        sample1 = preprocessing(image=sample['image'],mask=sample['mask'].reshape(384,480,1))
        out = {'image':sample1['image'],'mask':sample1['mask'],'label':torch.tensor(self.label[idx])}
        return out

preprocess_input = get_preprocessing_fn('resnet101', pretrained='imagenet')
augmentation = get_validation_augmentation()
preprocessing = get_preprocessing(preprocess_input)

train_dataset = WSOLDataset(data=train_raw,label=train_labels,grabcut=train_ell)
val_dataset = WSOLDataset(data=val_raw,label=val_labels,grabcut=val_ell)
test_dataset = WSOLDataset(data=test_raw,label=test_labels,grabcut=test_grab)


BATCH_SIZE = 8                                            
train_loader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=BATCH_SIZE, shuffle=True,
                                             num_workers=1, drop_last=True)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=BATCH_SIZE, shuffle=True,
                                             num_workers=1,drop_last=True)

test_loader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=BATCH_SIZE, shuffle=False,
                                             num_workers=1,drop_last=True)

def get_dice_iou(pred, mas, iou, dice):
    # iou, dice = [], []
    for i in range(len(pred)):
        pre = pred[i,132:252,180:300].numpy()
        gt = mas[i,132:252,180:300].numpy()
        inter = np.logical_and(pre==1,gt==1)
        union  = np.logical_or(pre==1,gt==1)
        iou.append(np.sum(inter)/np.sum(union))
        dice.append(np.sum(pre[gt==1])*2.0 / (np.sum(pre) + np.sum(gt)))
    # return iou, dice

def train(epoch):
    for batch_idx, sam in enumerate(train_loader):
        data = sam['image'].float().cuda()
        target = sam['mask'].float().cuda()

        data_upsampled = data
        target_upsampled = target

        # go pair by pair
        sub_batch_id = 0
        for ind in [i*COSEG_BATCH_SIZE for i in range(int(data_upsampled.shape[0]/COSEG_BATCH_SIZE))]:
            b1 = (ind,ind+COSEG_BATCH_SIZE)
            if ind+COSEG_BATCH_SIZE >= data_upsampled.shape[0]:
                b2 = (0,COSEG_BATCH_SIZE)
            else:
                b2 = (ind+COSEG_BATCH_SIZE,ind+COSEG_BATCH_SIZE*2)

            output1, output2 = new_coseg_model(data_upsampled[b1[0]:b1[1]], data_upsampled[b2[0]:b2[1]])

            loss1 = criterion(sig(output1.squeeze()), target_upsampled[b1[0]:b1[1]].squeeze(1))
            loss2 = criterion(sig(output2.squeeze()), target_upsampled[b2[0]:b2[1]].squeeze(1))

            loss = loss1 + loss2
            loss.backward()

            torch.nn.utils.clip_grad_norm_(new_coseg_model.parameters(), 0.05)
            optimizer.step()
            
            out1 = sig(output1.squeeze()).cpu()
            out2 = sig(output2.squeeze()).cpu()

            pred1 = torch.empty(out1.shape)
            pred2 = torch.empty(out2.shape)
            pred1[out1>=0.5]=1
            pred1[out1<0.5]=0
            pred2[out2>=0.5]=1
            pred2[out2<0.5]=0
            
            iou, dice = [], []    
            get_dice_iou(pred1, target_upsampled[b1[0]:b1[1]].squeeze(1).long().cpu(), iou, dice)
            get_dice_iou(pred2, target_upsampled[b2[0]:b2[1]].squeeze(1).long().cpu(), iou, dice)

            avg_iou=np.mean(iou)
            avg_dice=np.mean(dice)
    
            BATCH_ID = batch_idx * N_SUB_BATCHES + sub_batch_id
            print("Epoch %s: Batch %i/%i Loss %.2f iou %.2f dice %.2f" % (epoch, BATCH_ID,N_BATCHES, loss.item(), avg_iou, avg_dice))

            logs['loss'] = loss.item()
            logs['iou'] = avg_iou 
            logs['dice'] = avg_dice

            #liveloss.update(logs)
            #liveloss.send()
            
def valid(epoch):
    val_loss=[]
    with torch.no_grad():
        for batch_idx, sam in enumerate(val_loader):
            data = sam['image'].float().cuda()
            target = sam['mask'].float().cuda()

            data_upsampled = data
            target_upsampled = target

            # go pair by pair
            sub_batch_id = 0
            for ind in [i*COSEG_BATCH_SIZE for i in range(int(data_upsampled.shape[0]/COSEG_BATCH_SIZE))]:
                b1 = (ind,ind+COSEG_BATCH_SIZE)
                if ind+COSEG_BATCH_SIZE >= data_upsampled.shape[0]:
                    b2 = (0,COSEG_BATCH_SIZE)
                else:
                    b2 = (ind+COSEG_BATCH_SIZE,ind+COSEG_BATCH_SIZE*2)

                output1, output2 = new_coseg_model(data_upsampled[b1[0]:b1[1]], data_upsampled[b2[0]:b2[1]])

                loss1 = criterion(sig(output1.squeeze()), target_upsampled[b1[0]:b1[1]].squeeze(1))
                loss2 = criterion(sig(output2.squeeze()), target_upsampled[b2[0]:b2[1]].squeeze(1))

                loss = loss1 + loss2
                val_loss.append(loss.item())
                
    avg_loss=np.mean(val_loss)
    print('Val loss for epoch %s is %.f'%(epoch,avg_loss))
    return avg_loss

criterion = nn.BCELoss()
criterion.cuda()

new_coseg_model = model1().cuda()

sig = nn.Sigmoid()
optimizer = Adam(new_coseg_model.parameters(), lr=1e-5)#

COSEG_BATCH_SIZE = 4
N_LARGE_BATCHES = len(train_loader)
BATCH_SIZE = 8
N_SUB_BATCHES = BATCH_SIZE // COSEG_BATCH_SIZE

N_BATCHES = N_SUB_BATCHES * N_LARGE_BATCHES

#liveloss = PlotLosses()
logs = {}

min_val_loss=float('inf')
max_patient=5
m=0
for epoch in range(50):
    print('Epoch ', epoch)
    new_coseg_model.train()
    train(epoch)
    
    new_coseg_model.eval()
    new_val_loss = valid(epoch)
    if new_val_loss < min_val_loss:
        min_val_loss = new_val_loss
        torch.save({'model_dict': new_coseg_model.state_dict()},acoseg_path)
        print('model saved')
        m=0
    else:
        m+=1
        if m >= max_patient:
            break
Esempio n. 11
0
from segmentation_models_pytorch.encoders import get_preprocessing_fn

import torch
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import Dataset, DataLoader

from patho.utils.utils import set_all_seeds, check_lr
from patho.cfg.config import path_cfg, img_proc_cfg
from patho.model.model_torch import DiceLoss, HuBMAPModel, get_dice_coeff
from patho.datagen.datagen import HuBMAPDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ENCODER_NAME = 'se_resnext50_32x4d'

preprocessing_fn = Lambda(image=get_preprocessing_fn(encoder_name=ENCODER_NAME,
                                                     pretrained='imagenet'))

# https://www.kaggle.com/iafoss/hubmap-pytorch-fast-ai-starter
transforms = Compose([
    HorizontalFlip(),
    VerticalFlip(),
    RandomRotate90(),
    ShiftScaleRotate(shift_limit=0.0625,
                     scale_limit=0.2,
                     rotate_limit=15,
                     p=0.9,
                     border_mode=cv2.BORDER_REFLECT),
    OneOf([
        OpticalDistortion(p=0.3),
        GridDistortion(p=.1),
        IAAPiecewiseAffine(p=0.3),
def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)


    class WSOLDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, data, label, grabcut, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.grabcut = grabcut
        self.data = data
        self.label = label
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = augmentation(image=self.data[idx],mask=self.grabcut[idx])
        sample1 = preprocessing(image=sample['image'],mask=sample['mask'].reshape(384,480,1))
        out = {'image':sample1['image'],'mask':sample1['mask'],'label':torch.tensor(self.label[idx])}
        return out

preprocess_input = get_preprocessing_fn('resnet101', pretrained='imagenet')
augmentation = get_validation_augmentation()
preprocessing = get_preprocessing(preprocess_input)

train_dataset = WSOLDataset(data=train_raw,label=train_labels,grabcut=train_ell)
val_dataset = WSOLDataset(data=val_raw,label=val_labels,grabcut=val_ell)
test_dataset = WSOLDataset(data=test_raw,label=test_labels,grabcut=test_grab)


BATCH_SIZE = 8                                            
train_loader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=BATCH_SIZE, shuffle=True,
                                             num_workers=1, drop_last=True)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=BATCH_SIZE, shuffle=True,
                                             num_workers=1,drop_last=True)

test_loader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=BATCH_SIZE, shuffle=False,
                                             num_workers=1,drop_last=True)

new_coseg_model = model1().cuda()
checkpoint = torch.load(acoseg_path)
new_coseg_model.load_state_dict(checkpoint['model_dict'])
print('loaded pre-trained model')

sig = nn.Sigmoid()

criterion = nn.BCELoss()
test_loss=[]
mas = torch.tensor([])
res = torch.tensor([])

new_coseg_model.eval()
with torch.no_grad():
    for batch_idx, sam in enumerate(test_loader):
        #print(batch_idx)
        data = sam['image'].float().cuda()
        target = sam['mask'].float().cuda()
        
        output1, output2 = new_coseg_model(data[0:4], data[4:8])
        output1 = sig(output1)
        output2 = sig(output2)
               
        loss = criterion(output1.squeeze(), target[0:4].squeeze(1)) + criterion(output2.squeeze(), target[4:8].squeeze(1))
        test_loss.append(loss.item())

        mas = torch.cat((mas,target.cpu()),dim=0)
        res = torch.cat((res,output1.cpu()),dim=0)
        res = torch.cat((res,output2.cpu()),dim=0)
    avg_test_loss = np.mean(test_loss)
    print('Test loss = {:.{prec}f}'.format(avg_test_loss, prec=4))
mas = np.array(mas)
res = np.array(res)
pred = deepcopy(res) 
pred[res>=0.5]=1 #check threshold
pred[res<0.5]=0
cut_pred = []
cut_mas = []
for i in range(pred.shape[0]):
    cut_pred.append(np.uint8(pred[i,0,132:252,180:300]))
    cut_mas.append(np.uint8(mas[i,0,132:252,180:300]))
    
metrics(cut_mas,cut_pred)
Esempio n. 13
0
        columns=["Image_Label", "EncodedPixels"],
        index=False,
    )


if __name__ == "__main__":
    model_name = sys.argv[1]
    test_data_path = sys.argv[2]
    class_params_path = sys.argv[3]
    output_path = sys.argv[3]

    df_test = pd.read_csv(os.path.join(test_data_path), "sample_submission.csv")
    test_ids = (
        df_test["Image_Label"].apply(lambda x: x.split("_")[0]).drop_duplicates().values
    )
    preprocess_fn = get_preprocessing_fn(model_name, "imagenet")
    test_dataset = CloudDataset(
        df=df_test,
        path=test_data_path,
        img_ids=test_ids,
        image_folder="test_images",
        transforms=get_transforms("valid"),
        preprocessing_fn=preprocess_fn,
    )
    test_loader = DataLoader(
        dataset=test_dataset, batch_size=32, shuffle=False, num_workers=4,
    )
    model = Unet(model_name, classes=4, activation=None)
    class_params = np.load(class_params_path).item()
    infer(model, test_loader, class_params, output_path)
def unet_train(epochs, gpu, base_model_name, encoder_weights, generator_lr, discriminator_lr, lambda_bce,
                batch_size, image_train_dir, mask_train_dir, image_test_dir, original_dir, fold_num, fold_total):
    image_train_dir = image_train_dir.replace('%d', str(fold_num))
    mask_train_dir = mask_train_dir.replace('%d', str(fold_num))
    image_test_dir = image_test_dir.replace('%d', str(fold_num))
    
    weight_path = ('./step2_label%d_' % fold_num) + base_model_name + '_' + str(int(lambda_bce)) + '_' + str(generator_lr)
    image_save_path = weight_path + '/images'
    os.makedirs(weight_path, exist_ok=True)
    os.makedirs(image_save_path, exist_ok=True)

    # rgb , preprocess input
    imagenet_mean = np.array( [0.485, 0.456, 0.406] )
    imagenet_std = np.array( [0.229, 0.224, 0.225] )

    # patch data loader
    patch_train_data_set = Dataset_Return_One(image_train_dir, mask_train_dir, os.path.join(original_dir, 'image'),
                                                base_model_name, encoder_weights, is_train=True, fold_num=0, fold_total=5)

    patch_train_loader = DataLoader(patch_train_data_set, batch_size=batch_size, num_workers=4, shuffle=True)

    device = torch.device("cuda:%s" % gpu)

    # step2 patch unet
    patch_model = smp.Unet(base_model_name, encoder_weights=encoder_weights, in_channels=3)
    patch_model.to(device)
    optimizer_patch_generator = optim.Adam(patch_model.parameters(), lr=generator_lr, betas=(0.5, 0.999))

    discriminator = Discriminator(in_channels=4)
    discriminator.to(device)
    optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=discriminator_lr, betas=(0.5, 0.999))

    criterion = nn.BCEWithLogitsLoss()

    image_test_pathes = os.listdir(image_test_dir)
    image_test_list = []
    for image_test_path in image_test_pathes:
        if not check_is_image(image_test_path):
            print(image_test_path, 'not image')
            continue
        image_test_list.append( (os.path.join(image_test_dir, image_test_path), 
                                    os.path.join(original_dir, 'mask', image_test_path)) )
    print('test len:', len(image_test_list))

    preprocess_input = get_preprocessing_fn(base_model_name, pretrained=encoder_weights)

    lambda_gp = 10.0
    threshold_value = int(256 * 0.5)

    patch_best_fmeasure = 0.0
    epoch_start_time = time.time()
    for epoch in range(epochs):

        # train
        patch_model.train()
        for idx, (images, masks) in enumerate(patch_train_loader):
            # train discriminator with patch
            images = images.to(device)
            masks = masks.to(device)

            masks_pred = patch_model(images)

            # discriminator
            discriminator.requires_grad_(True)
            # Fake
            fake_AB = torch.cat((images, masks_pred), 1).detach()
            pred_fake = discriminator(fake_AB)

            # Real
            real_AB = torch.cat((images, masks), 1)
            pred_real = discriminator(real_AB)

            gradient_penalty = compute_gradient_penalty(discriminator, real_AB, fake_AB, device)
            discriminator_loss = -torch.mean(pred_real) + torch.mean(pred_fake) + lambda_gp * gradient_penalty

            optimizer_discriminator.zero_grad()
            discriminator_loss.backward()
            optimizer_discriminator.step()

            if idx % 5 == 0:
                discriminator.requires_grad_(False)

                # generator
                fake_AB = torch.cat((images, masks_pred), 1)
                pred_fake = discriminator(fake_AB)
                generator_loss = -torch.mean(pred_fake)
                bce_loss = criterion(masks_pred, masks)
                total_loss = generator_loss + bce_loss * lambda_bce
                
                optimizer_patch_generator.zero_grad()
                total_loss.backward()
                optimizer_patch_generator.step()

            if idx % 2000 == 0:
                print('train step[%d/%d] patch discriminator loss: %.5f, total loss: %.5f, generator loss: %.5f, bce loss: %.5f, time: %.2f' % 
                            (idx, len(patch_train_loader), discriminator_loss.item(), total_loss.item(), generator_loss.item(), bce_loss.item(), time.time() - epoch_start_time))

                rand_idx_start = randrange(masks.size(0) - 3)
                rand_idx_end = rand_idx_start + 3
                test_masks_pred = torch.sigmoid(masks_pred[rand_idx_start:rand_idx_end]).detach().cpu()
                test_masks_pred = test_masks_pred.permute(0, 2, 3, 1).numpy().astype(np.float32)
                test_masks_pred = np.squeeze(test_masks_pred, axis=-1)

                test_masks = masks[rand_idx_start:rand_idx_end].permute(0, 2, 3, 1).cpu().numpy().astype(np.float32)
                test_masks = np.squeeze(test_masks, axis=-1)

                test_images = images[rand_idx_start:rand_idx_end].permute(0, 2, 3, 1).cpu().numpy()
                test_images = test_images * imagenet_std + imagenet_mean
                test_images = np.maximum(test_images, 0.0)
                test_images = np.minimum(test_images, 1.0)
                sample_images(epoch, idx, test_images, test_masks, test_masks_pred, image_save_path)
            break
        
        # eval
        patch_model.eval()
        print('eval patch')
        total_fmeasure = 0.0
        total_image_number = 0
        # random_number = randrange(len(patch_image_test_list))
        for eval_idx, (image_test, mask_test) in enumerate(image_test_list):
            image = cv2.imread(image_test)
            h, w, _ = image.shape
            image_name = image_test.split('/')[-1].split('.')[0]

            gt_mask = cv2.imread(mask_test, cv2.IMREAD_GRAYSCALE)
            gt_mask = np.expand_dims(gt_mask, axis=-1)

            image_patches, poslist = get_image_patch(image, 256, 256, overlap=0.5, is_mask=False)
            color_patches = []
            for patch in image_patches:
                color_patches.append(preprocess_input(patch.astype(np.float32), input_space="BGR"))

            step = 0
            preds = []
            with torch.no_grad():
                while step < len(image_patches):
                    ps = step
                    pe = step + batch_size
                    if pe >= len(image_patches):
                        pe = len(image_patches)

                    images_global = torch.from_numpy(np.array(color_patches[ps:pe])).permute(0, 3, 1, 2).float().to(device)
                    preds.extend( torch.sigmoid(patch_model(images_global)).cpu() )
                    step += batch_size

            # handling overlap
            out_img = np.ones((h, w, 1)) * 255
            for i in range(len(image_patches)):
                patch = preds[i].permute(1, 2, 0).numpy() * 255

                start_h, start_w, end_h, end_w, h_shift, w_shift = poslist[i]
                h_cut = end_h - start_h
                w_cut = end_w - start_w

                tmp = np.minimum(out_img[start_h:end_h, start_w:end_w], patch[h_shift:h_shift+h_cut, w_shift:w_shift+w_cut])
                out_img[start_h:end_h, start_w:end_w] = tmp

            out_img = out_img.astype(np.uint8)
            out_img[out_img > threshold_value] = 255
            out_img[out_img <= threshold_value] = 0

            # if random_number == eval_idx:
            #     cv2.imwrite('%s/patch_%d_%s.png' % (image_save_path, epoch, image_name), out_img)

            # f_measure
            # background 1, text 0
            gt_mask[gt_mask > 0] = 1
            out_img[out_img > 0] = 1

            # true positive
            tp = np.zeros(gt_mask.shape, np.uint8)
            tp[(out_img==0) & (gt_mask==0)] = 1
            numtp = tp.sum()

            # false positive
            fp = np.zeros(gt_mask.shape, np.uint8)
            fp[(out_img==0) & (gt_mask==1)] = 1
            numfp = fp.sum()

            # false negative
            fn = np.zeros(gt_mask.shape, np.uint8)
            fn[(out_img==1) & (gt_mask==0)] = 1
            numfn = fn.sum()

            precision = (numtp) / float(numtp + numfp)
            recall = (numtp) / float(numtp + numfn)
            fmeasure = 100. * (2. * recall * precision) / (recall + precision) # percent

            total_fmeasure += fmeasure
            total_image_number += 1
            break
        total_fmeasure /= total_image_number

        if patch_best_fmeasure < total_fmeasure:
            patch_best_fmeasure = total_fmeasure
        print('epoch[%d/%d] patch fmeasure: %.4f, best_fmeasure: %.4f, time: %.2f' 
                    % (epoch + 1, epochs, total_fmeasure, patch_best_fmeasure, time.time() - epoch_start_time))
        print()

    torch.save(patch_model.state_dict(), weight_path + '/unet_patch_%d_%.4f.pth' % (epoch + 1, total_fmeasure))
    torch.save(discriminator.state_dict(), weight_path + '/dis_%d_%.4f.pth' % (epoch + 1, total_fmeasure))
Esempio n. 15
0
model.requires_grad_(False)
model.eval()
models.append(model)

# gray
model = smp.Unet(base_model_name,
                 encoder_weights=encoder_weights,
                 in_channels=3)
model.load_state_dict(torch.load(weight_list[3], map_location='cpu'))
model.to(device)
model.requires_grad_(False)
model.eval()
models.append(model)

batch_size = 16
preprocess_input = get_preprocessing_fn(base_model_name, pretrained='imagenet')

# make directory
image_save_path = './predicted_image_for_step2_dibco'
os.makedirs(image_save_path, exist_ok=True)

train_image_save_path = os.path.join(image_save_path, 'train')
os.makedirs(train_image_save_path, exist_ok=True)

test_image_save_path = os.path.join(image_save_path, 'test')
os.makedirs(test_image_save_path, exist_ok=True)

# patch directory
patch_save_path = os.path.join(train_image_save_path, 'patch')
os.makedirs(patch_save_path, exist_ok=True)
    def get_datasets(self, **kwargs):
        path = kwargs.get("path", None)
        df_train_name = kwargs.get("df_train_name", None)
        df_pl_name = kwargs.get("df_pl_name", None)
        image_folder = kwargs.get("image_folder", None)
        encoder_name = kwargs.get("model_name", None)
        test_mode = kwargs.get("test_mode", None)
        type = kwargs.get("type", None)
        height = kwargs.get("height", None)
        width = kwargs.get("width", None)

        if type == "train":
            df_train = pd.read_csv(os.path.join(path, df_train_name))
            if df_pl_name is not None:
                df_pl = pd.read_csv(os.path.join(path, df_pl_name))
                df_train = df_train.append(df_pl)
                print(
                    f"Pseudo-labels named {df_pl_name} {len(df_pl) / 4} added to train df"
                )

            if test_mode:
                df_train = df_train[:150]

            df_train["label"] = df_train["Image_Label"].apply(
                lambda x: x.split("_")[1])
            df_train["im_id"] = df_train["Image_Label"].apply(
                lambda x: x.split("_")[0])

            id_mask_count = (
                df_train.loc[~df_train["EncodedPixels"].isnull(),
                             "Image_Label"].apply(lambda x: x.split("_")[0]).
                value_counts().reset_index().rename(columns={
                    "index": "img_id",
                    "Image_Label": "count"
                }).sort_values(["count", "img_id"]))
            assert len(id_mask_count["img_id"].values) == len(
                id_mask_count["img_id"].unique())
            train_ids, valid_ids = train_test_split(
                id_mask_count["img_id"].values,
                random_state=42,
                stratify=id_mask_count["count"],
                test_size=0.1,
            )

        df_test = pd.read_csv(os.path.join(path, "sample_submission.csv"))
        df_test["label"] = df_test["Image_Label"].apply(
            lambda x: x.split("_")[1])
        df_test["im_id"] = df_test["Image_Label"].apply(
            lambda x: x.split("_")[0])
        test_ids = (df_test["Image_Label"].apply(
            lambda x: x.split("_")[0]).drop_duplicates().values)

        preprocess_fn = get_preprocessing_fn(encoder_name,
                                             pretrained="imagenet")

        if type != "test":
            train_dataset = CloudDataset(
                df=df_train,
                path=path,
                img_ids=train_ids,
                image_folder=image_folder,
                transforms=get_transforms("train"),
                preprocessing_fn=preprocess_fn,
                height=height,
                width=width,
            )

            valid_dataset = CloudDataset(
                df=df_train,
                path=path,
                img_ids=valid_ids,
                image_folder=image_folder,
                transforms=get_transforms("valid"),
                preprocessing_fn=preprocess_fn,
                height=height,
                width=width,
            )

        test_dataset = CloudDataset(
            df=df_test,
            path=path,
            img_ids=test_ids,
            image_folder="test_images",
            transforms=get_transforms("valid"),
            preprocessing_fn=preprocess_fn,
            height=height,
            width=width,
        )

        datasets = collections.OrderedDict()
        if type == "train":
            datasets["train"] = train_dataset
            datasets["valid"] = valid_dataset
        elif type == "postprocess":
            datasets["infer"] = valid_dataset
        elif type == "test":
            datasets["infer"] = test_dataset

        return datasets