def main(argv=None):

    image_dir = '/mnt/nas/data/denoise/LRDE/image/'
    mask_dir = '/mnt/nas/data/denoise/LRDE/mask/'

    overlap = 30. / 100.  # 30. / 100. -> 317,750
    imgh = 256
    imgw = 256

    image_save_dir = '/data/denoise/LRDE/image_patches'
    mask_save_dir = '/data/denoise/LRDE/mask_patches'
    os.makedirs(image_save_dir, exist_ok=True)
    os.makedirs(mask_save_dir, exist_ok=True)

    image_pathes = os.listdir(image_dir)
    for image_path in image_pathes:
        if not check_is_image(image_path):
            print('not image', image_path)
            continue

        image_name = image_path.split('.')[0]

        mask = cv2.imread(mask_dir + image_path, cv2.IMREAD_GRAYSCALE)
        image = cv2.imread(image_dir + image_path)

        print('processing the image:', image_path)

        image_patches, _ = get_image_patch(image,
                                           imgh,
                                           imgw,
                                           overlap=overlap,
                                           is_mask=False)
        mask_patches, poslist = get_image_patch(mask,
                                                imgh,
                                                imgw,
                                                overlap=overlap,
                                                is_mask=True)

        print('get patches: %d' % len(image_patches))
        for idx in range(len(image_patches)):
            img_color = image_patches[idx]
            img_gray = cv2.cvtColor(img_color, cv2.COLOR_BGR2GRAY)

            mask_gray = mask_patches[idx]

            img_color_tmp = img_color
            mask_gray_tmp = mask_gray
            cv2.imwrite('%s/%s_i%dh0.png' % (image_save_dir, image_name, idx),
                        img_color_tmp)
            cv2.imwrite('%s/%s_i%dh0.png' % (mask_save_dir, image_name, idx),
                        mask_gray_tmp)

            # horizontal axis
            img_color_tmp = np.flipud(img_color)
            mask_gray_tmp = np.flipud(mask_gray)
            cv2.imwrite('%s/%s_i%dh1.png' % (image_save_dir, image_name, idx),
                        img_color_tmp)
            cv2.imwrite('%s/%s_i%dh1.png' % (mask_save_dir, image_name, idx),
                        mask_gray_tmp)
def unet_train(epochs, gpu, base_model_name, encoder_weights, generator_lr,
               discriminator_lr, lambda_bce, threshold, batch_size,
               image_train_dir, mask_train_dir, original_dir, fold_num,
               fold_total):
    # make save directory
    weight_path = (
        './step1_label%d_' % fold_num) + base_model_name + '_' + str(
            int(lambda_bce)) + '_' + str(generator_lr) + '_' + str(threshold)
    image_save_path = weight_path + '/images'
    os.makedirs(weight_path, exist_ok=True)
    os.makedirs(image_save_path, exist_ok=True)

    # rgb , preprocess input
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])

    train_data_set = Dataset_Return_Four(image_train_dir,
                                         mask_train_dir,
                                         os.path.join(original_dir, 'image'),
                                         base_model_name,
                                         encoder_weights,
                                         threshold=threshold,
                                         is_train=True,
                                         fold_num=fold_num,
                                         fold_total=fold_total)

    train_loader = DataLoader(train_data_set,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True)

    device = torch.device("cuda:%s" % gpu)

    image_test_list = []
    test_image_pathes = os.listdir(os.path.join(original_dir, 'image'))
    for test_image_path in test_image_pathes:
        if not check_is_image(test_image_path):
            print(test_image_path, 'not image')
            continue
        image_test_list.append(
            (os.path.join(original_dir, 'image', test_image_path),
             os.path.join(original_dir, 'mask', test_image_path)))
    image_test_list = image_test_list[fold_num::fold_total]
    print('test len:', len(image_test_list))

    preprocess_input = get_preprocessing_fn(base_model_name,
                                            pretrained=encoder_weights)

    models = []
    optimizers = []
    for channel in range(4):
        models.append(
            smp.Unet(base_model_name,
                     encoder_weights=encoder_weights,
                     in_channels=3))
        models[channel].to(device)
        optimizers.append(
            optim.Adam(models[channel].parameters(),
                       lr=generator_lr,
                       betas=(0.5, 0.999)))

    discriminator = Discriminator(in_channels=4)
    discriminator.to(device)
    optimizer_discriminator = optim.Adam(discriminator.parameters(),
                                         lr=discriminator_lr,
                                         betas=(0.5, 0.999))

    criterion = nn.BCEWithLogitsLoss()

    channel_dict = {0: 'blue', 1: 'green', 2: 'red', 3: 'gray'}
    value = int(256 * 0.5)
    best_fmeasure = 0.0
    lambda_gp = 10.0
    epoch_start_time = time.time()
    for epoch in range(epochs):

        # train
        for channel in range(4):
            models[channel].train()
            models[channel].requires_grad_(False)

        for idx, (images, masks) in enumerate(train_loader):
            # for sample image
            test_masks_pred_list = []
            test_masks_list = []
            test_images_list = []
            for channel in range(4):
                images[channel] = images[channel].to(device)
                masks[channel] = masks[channel].to(device)

                models[channel].requires_grad_(True)
                masks_pred = models[channel](images[channel])

                # discriminator
                discriminator.requires_grad_(True)
                # Fake
                fake_AB = torch.cat((images[channel], masks_pred), 1).detach()
                pred_fake = discriminator(fake_AB)

                # Real
                real_AB = torch.cat((images[channel], masks[channel]), 1)
                pred_real = discriminator(real_AB)

                gradient_penalty = compute_gradient_penalty(
                    discriminator, real_AB, fake_AB, device)
                discriminator_loss = -torch.mean(pred_real) + torch.mean(
                    pred_fake) + lambda_gp * gradient_penalty

                optimizer_discriminator.zero_grad()
                discriminator_loss.backward()
                optimizer_discriminator.step()

                discriminator.requires_grad_(False)
                # end discriminator

                # generator
                fake_AB = torch.cat((images[channel], masks_pred), 1)
                pred_fake = discriminator(fake_AB)
                generator_loss = -torch.mean(pred_fake)
                bce_loss = criterion(masks_pred, masks[channel])
                total_loss = generator_loss + bce_loss * lambda_bce

                optimizers[channel].zero_grad()
                total_loss.backward()
                optimizers[channel].step()

                models[channel].requires_grad_(False)
                # end generator

                if idx % 2000 == 0:
                    print(
                        'channel %s train step[%d/%d] discriminator loss: %.5f, total loss: %.5f, generator loss: %.5f, bce loss: %.5f, time: %.2f'
                        % (channel_dict[channel], idx, len(train_loader),
                           discriminator_loss.item(), total_loss.item(),
                           generator_loss.item(), bce_loss.item(),
                           time.time() - epoch_start_time))

                    # for sample images
                    rand_idx_start = randrange(masks[channel].size(0) - 2)
                    rand_idx_end = rand_idx_start + 2
                    test_masks_pred = torch.sigmoid(
                        masks_pred[rand_idx_start:rand_idx_end]).detach().cpu(
                        )
                    test_masks_pred = test_masks_pred.permute(
                        0, 2, 3, 1).numpy().astype(np.float32)
                    test_masks_pred = np.squeeze(test_masks_pred, axis=-1)
                    test_masks_pred_list.extend(test_masks_pred)

                    test_masks = masks[
                        channel][rand_idx_start:rand_idx_end].permute(
                            0, 2, 3, 1).cpu().numpy().astype(np.float32)
                    test_masks = np.squeeze(test_masks, axis=-1)
                    test_masks_list.extend(test_masks)

                    test_images = images[channel][
                        rand_idx_start:rand_idx_end].permute(0, 2, 3,
                                                             1).cpu().numpy()
                    test_images = test_images * imagenet_std + imagenet_mean
                    test_images = np.maximum(test_images, 0.0)
                    test_images = np.minimum(test_images, 1.0)
                    test_images_list.extend(test_images)

            if idx % 2000 == 0:
                sample_images(epoch, idx, test_images_list, test_masks_list,
                              test_masks_pred_list, image_save_path)
            # break

        # eval
        for channel in range(4):
            models[channel].eval()

        total_fmeasure = 0.0
        total_image_number = 0
        for eval_idx, (image_test, mask_test) in enumerate(image_test_list):
            image = cv2.imread(image_test)
            h, w, _ = image.shape
            image_name = image_test.split('/')[-1].split('.')[0]
            # print('eval the image:', image_name)

            gt_mask = cv2.imread(mask_test, cv2.IMREAD_GRAYSCALE)
            gt_mask = np.expand_dims(gt_mask, axis=-1)
            image_patches, poslist = get_image_patch(image,
                                                     256,
                                                     256,
                                                     overlap=0.5,
                                                     is_mask=False)

            # random_number = randrange(10)
            for channel in range(4):
                color_patches = []
                for patch in image_patches:
                    tmp = patch.astype(np.float32)
                    if channel != 3:
                        color_patches.append(
                            preprocess_input(tmp[:, :, channel:channel + 1]))
                    else:
                        color_patches.append(
                            preprocess_input(
                                np.expand_dims(cv2.cvtColor(
                                    tmp, cv2.COLOR_BGR2GRAY),
                                               axis=-1)))

                step = 0
                preds = []
                with torch.no_grad():
                    while step < len(image_patches):
                        ps = step
                        pe = step + batch_size
                        if pe >= len(image_patches):
                            pe = len(image_patches)

                        target = torch.from_numpy(
                            np.array(color_patches[ps:pe])).permute(
                                0, 3, 1, 2).float()
                        pred = torch.sigmoid(models[channel](
                            target.to(device))).cpu()
                        preds.extend(pred)
                        step += batch_size

                # handling overlap
                out_img = np.ones((h, w, 1)) * 255
                for i in range(len(image_patches)):
                    patch = preds[i].permute(1, 2, 0).numpy() * 255

                    start_h, start_w, end_h, end_w, h_shift, w_shift = poslist[
                        i]
                    h_cut = end_h - start_h
                    w_cut = end_w - start_w

                    tmp = np.minimum(
                        out_img[start_h:end_h, start_w:end_w],
                        patch[h_shift:h_shift + h_cut,
                              w_shift:w_shift + w_cut])
                    out_img[start_h:end_h, start_w:end_w] = tmp

                out_img = out_img.astype(np.uint8)
                out_img[out_img > value] = 255
                out_img[out_img <= value] = 0

                # if random_number == 0:
                #     cv2.imwrite('%s/%d_%d_%s.png' % (image_save_path, epoch, channel, image_name), out_img)

                # f_measure
                # background 1, text 0
                gt_mask[gt_mask > 0] = 1
                out_img[out_img > 0] = 1

                # true positive
                tp = np.zeros(gt_mask.shape, np.uint8)
                tp[(out_img == 0) & (gt_mask == 0)] = 1
                numtp = tp.sum()

                # false positive
                fp = np.zeros(gt_mask.shape, np.uint8)
                fp[(out_img == 0) & (gt_mask == 1)] = 1
                numfp = fp.sum()

                # false negative
                fn = np.zeros(gt_mask.shape, np.uint8)
                fn[(out_img == 1) & (gt_mask == 0)] = 1
                numfn = fn.sum()

                precision = numtp / float(numtp + numfp)
                recall = numtp / float(numtp + numfn)
                fmeasure = 100. * (2. * recall * precision) / (
                    recall + precision)  # percent
                total_fmeasure += fmeasure
            total_image_number += 4
            # break
        total_fmeasure /= total_image_number

        if best_fmeasure < total_fmeasure:
            best_fmeasure = total_fmeasure

        print('epoch[%d/%d] fmeasure: %.4f, best_fmeasure: %.4f, time: %.2f' %
              (epoch + 1, epochs, total_fmeasure, best_fmeasure,
               time.time() - epoch_start_time))
        print()
    for channel in range(4):
        torch.save(
            models[channel].state_dict(), weight_path +
            '/unet_%d_%d_%.4f.pth' % (channel, epoch + 1, total_fmeasure))
    torch.save(discriminator.state_dict(),
               weight_path + '/dis_%d_%.4f.pth' % (epoch + 1, total_fmeasure))
def main(argv=None):

    image_dir = '/mnt/nas/data/denoise/Label_data/image'
    mask_dir = '/mnt/nas/data/denoise/Label_data/mask'

    overlap = 30. / 100.  # 30. / 100. -> 65,247
    imgh = 256
    imgw = 256
    scale_list = [
        0.75, 1.00, 1.25, 1.50
    ]  # sample patches with the scale factor and resize patches to 256 * 256 // 192, 256, 384
    resize_size = (imgh, imgw)

    image_save_dir = '/data/denoise/Label_patch/image_patches'
    mask_save_dir = '/data/denoise/Label_patch/mask_patches'
    os.makedirs(image_save_dir, exist_ok=True)
    os.makedirs(mask_save_dir, exist_ok=True)

    image_pathes = os.listdir(image_dir)
    for image_path in image_pathes:
        if not check_is_image(image_path):
            print('not image', image_path)
            continue

        image_name = image_path.split('.')[0]
        image = cv2.imread(os.path.join(image_dir, image_path))

        # find and read mask file
        if not os.path.isfile(os.path.join(mask_dir, image_path)):
            print(image_path, 'no mask')
            exit(1)

        mask = cv2.imread(os.path.join(mask_dir, image_path),
                          cv2.IMREAD_GRAYSCALE)
        mask[mask <= 128] = 0
        mask[mask > 128] = 255

        if image.shape[:2] != mask.shape[:2]:
            print(image_path, 'size mismatch')
            exit(1)

        print('processing the image:', image_path)

        scale_cnt = 0
        for scale in scale_list:
            # (patches, 256, 256, 3)
            crpW = int(scale * imgw)
            crpH = int(scale * imgh)

            image_patches, poslist = get_image_patch(image, crpH, crpW,
                                                     overlap, False)
            mask_patches, poslist = get_image_patch(mask, crpH, crpW, overlap,
                                                    True)
            print('get patches: %d' % len(image_patches))

            for idx in range(len(image_patches)):
                img_color = image_patches[idx]
                img_color = cv2.resize(img_color,
                                       dsize=resize_size,
                                       interpolation=cv2.INTER_NEAREST)
                img_gray = cv2.cvtColor(img_color, cv2.COLOR_BGR2GRAY)

                mask_gray = mask_patches[idx]
                mask_gray = cv2.resize(mask_gray,
                                       dsize=resize_size,
                                       interpolation=cv2.INTER_NEAREST)

                cv2.imwrite(
                    '%s/%s_s%di%d.png' %
                    (image_save_dir, image_name, scale_cnt, idx), img_color)
                cv2.imwrite(
                    '%s/%s_s%di%d.png' %
                    (mask_save_dir, image_name, scale_cnt, idx), mask_gray)

            scale_cnt += 1
    image = cv2.imread(os.path.join(root_image_path, image_path))

    # find and read mask file
    if not os.path.isfile(os.path.join(root_mask_path, image_path)):
        print(image_path, 'no mask')
        exit(1)

    mask = cv2.imread(os.path.join(root_mask_path, image_path),
                      cv2.IMREAD_GRAYSCALE)
    mask[mask <= 128] = 0
    mask[mask > 128] = 255

    h, w, _ = image.shape
    image_patches, poslist = get_image_patch(image,
                                             crop_h,
                                             crop_w,
                                             overlap=predict_overlap_ratio,
                                             is_mask=False)

    merge_img = np.ones((h, w, 3))
    out_imgs = []

    for channel in range(4):
        color_patches = []
        for patch in image_patches:
            tmp = patch.astype(np.float32)
            if channel != 3:
                color_patches.append(
                    preprocess_input(tmp[:, :, channel:channel + 1]))
            else:
                color_patches.append(
Beispiel #5
0
        save_csv_file.writerow(csv_tmp)
        save_csv_file.writerow([])
        prev_dibco_year = dibco_year

    # prepare ground truth mask for f-measure
    gt_mask = cv2.imread(test_mask, cv2.IMREAD_GRAYSCALE)
    gt_mask[gt_mask > 0] = 1
    # end ground truth mask

    print('processing the image:', img_name)
    h, w, _ = image.shape

    # start step1
    image_patches, poslist = get_image_patch(image,
                                             256,
                                             256,
                                             overlap=0.1,
                                             is_mask=False)
    merge_img = np.ones((h, w, 3))
    out_imgs = []

    for channel in range(4):
        color_patches = []
        for patch in image_patches:
            tmp = patch.astype(np.float32)
            if channel != 3:
                color_patches.append(
                    preprocess_input(tmp[:, :, channel:channel + 1]))
            else:
                color_patches.append(
                    preprocess_input(
def unet_train(epochs, gpu, base_model_name, encoder_weights, generator_lr, discriminator_lr, lambda_bce,
                batch_size, image_train_dir, mask_train_dir, image_test_dir, original_dir, fold_num, fold_total):
    image_train_dir = image_train_dir.replace('%d', str(fold_num))
    mask_train_dir = mask_train_dir.replace('%d', str(fold_num))
    image_test_dir = image_test_dir.replace('%d', str(fold_num))
    
    weight_path = ('./step2_label%d_' % fold_num) + base_model_name + '_' + str(int(lambda_bce)) + '_' + str(generator_lr)
    image_save_path = weight_path + '/images'
    os.makedirs(weight_path, exist_ok=True)
    os.makedirs(image_save_path, exist_ok=True)

    # rgb , preprocess input
    imagenet_mean = np.array( [0.485, 0.456, 0.406] )
    imagenet_std = np.array( [0.229, 0.224, 0.225] )

    # patch data loader
    patch_train_data_set = Dataset_Return_One(image_train_dir, mask_train_dir, os.path.join(original_dir, 'image'),
                                                base_model_name, encoder_weights, is_train=True, fold_num=0, fold_total=5)

    patch_train_loader = DataLoader(patch_train_data_set, batch_size=batch_size, num_workers=4, shuffle=True)

    device = torch.device("cuda:%s" % gpu)

    # step2 patch unet
    patch_model = smp.Unet(base_model_name, encoder_weights=encoder_weights, in_channels=3)
    patch_model.to(device)
    optimizer_patch_generator = optim.Adam(patch_model.parameters(), lr=generator_lr, betas=(0.5, 0.999))

    discriminator = Discriminator(in_channels=4)
    discriminator.to(device)
    optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=discriminator_lr, betas=(0.5, 0.999))

    criterion = nn.BCEWithLogitsLoss()

    image_test_pathes = os.listdir(image_test_dir)
    image_test_list = []
    for image_test_path in image_test_pathes:
        if not check_is_image(image_test_path):
            print(image_test_path, 'not image')
            continue
        image_test_list.append( (os.path.join(image_test_dir, image_test_path), 
                                    os.path.join(original_dir, 'mask', image_test_path)) )
    print('test len:', len(image_test_list))

    preprocess_input = get_preprocessing_fn(base_model_name, pretrained=encoder_weights)

    lambda_gp = 10.0
    threshold_value = int(256 * 0.5)

    patch_best_fmeasure = 0.0
    epoch_start_time = time.time()
    for epoch in range(epochs):

        # train
        patch_model.train()
        for idx, (images, masks) in enumerate(patch_train_loader):
            # train discriminator with patch
            images = images.to(device)
            masks = masks.to(device)

            masks_pred = patch_model(images)

            # discriminator
            discriminator.requires_grad_(True)
            # Fake
            fake_AB = torch.cat((images, masks_pred), 1).detach()
            pred_fake = discriminator(fake_AB)

            # Real
            real_AB = torch.cat((images, masks), 1)
            pred_real = discriminator(real_AB)

            gradient_penalty = compute_gradient_penalty(discriminator, real_AB, fake_AB, device)
            discriminator_loss = -torch.mean(pred_real) + torch.mean(pred_fake) + lambda_gp * gradient_penalty

            optimizer_discriminator.zero_grad()
            discriminator_loss.backward()
            optimizer_discriminator.step()

            if idx % 5 == 0:
                discriminator.requires_grad_(False)

                # generator
                fake_AB = torch.cat((images, masks_pred), 1)
                pred_fake = discriminator(fake_AB)
                generator_loss = -torch.mean(pred_fake)
                bce_loss = criterion(masks_pred, masks)
                total_loss = generator_loss + bce_loss * lambda_bce
                
                optimizer_patch_generator.zero_grad()
                total_loss.backward()
                optimizer_patch_generator.step()

            if idx % 2000 == 0:
                print('train step[%d/%d] patch discriminator loss: %.5f, total loss: %.5f, generator loss: %.5f, bce loss: %.5f, time: %.2f' % 
                            (idx, len(patch_train_loader), discriminator_loss.item(), total_loss.item(), generator_loss.item(), bce_loss.item(), time.time() - epoch_start_time))

                rand_idx_start = randrange(masks.size(0) - 3)
                rand_idx_end = rand_idx_start + 3
                test_masks_pred = torch.sigmoid(masks_pred[rand_idx_start:rand_idx_end]).detach().cpu()
                test_masks_pred = test_masks_pred.permute(0, 2, 3, 1).numpy().astype(np.float32)
                test_masks_pred = np.squeeze(test_masks_pred, axis=-1)

                test_masks = masks[rand_idx_start:rand_idx_end].permute(0, 2, 3, 1).cpu().numpy().astype(np.float32)
                test_masks = np.squeeze(test_masks, axis=-1)

                test_images = images[rand_idx_start:rand_idx_end].permute(0, 2, 3, 1).cpu().numpy()
                test_images = test_images * imagenet_std + imagenet_mean
                test_images = np.maximum(test_images, 0.0)
                test_images = np.minimum(test_images, 1.0)
                sample_images(epoch, idx, test_images, test_masks, test_masks_pred, image_save_path)
            break
        
        # eval
        patch_model.eval()
        print('eval patch')
        total_fmeasure = 0.0
        total_image_number = 0
        # random_number = randrange(len(patch_image_test_list))
        for eval_idx, (image_test, mask_test) in enumerate(image_test_list):
            image = cv2.imread(image_test)
            h, w, _ = image.shape
            image_name = image_test.split('/')[-1].split('.')[0]

            gt_mask = cv2.imread(mask_test, cv2.IMREAD_GRAYSCALE)
            gt_mask = np.expand_dims(gt_mask, axis=-1)

            image_patches, poslist = get_image_patch(image, 256, 256, overlap=0.5, is_mask=False)
            color_patches = []
            for patch in image_patches:
                color_patches.append(preprocess_input(patch.astype(np.float32), input_space="BGR"))

            step = 0
            preds = []
            with torch.no_grad():
                while step < len(image_patches):
                    ps = step
                    pe = step + batch_size
                    if pe >= len(image_patches):
                        pe = len(image_patches)

                    images_global = torch.from_numpy(np.array(color_patches[ps:pe])).permute(0, 3, 1, 2).float().to(device)
                    preds.extend( torch.sigmoid(patch_model(images_global)).cpu() )
                    step += batch_size

            # handling overlap
            out_img = np.ones((h, w, 1)) * 255
            for i in range(len(image_patches)):
                patch = preds[i].permute(1, 2, 0).numpy() * 255

                start_h, start_w, end_h, end_w, h_shift, w_shift = poslist[i]
                h_cut = end_h - start_h
                w_cut = end_w - start_w

                tmp = np.minimum(out_img[start_h:end_h, start_w:end_w], patch[h_shift:h_shift+h_cut, w_shift:w_shift+w_cut])
                out_img[start_h:end_h, start_w:end_w] = tmp

            out_img = out_img.astype(np.uint8)
            out_img[out_img > threshold_value] = 255
            out_img[out_img <= threshold_value] = 0

            # if random_number == eval_idx:
            #     cv2.imwrite('%s/patch_%d_%s.png' % (image_save_path, epoch, image_name), out_img)

            # f_measure
            # background 1, text 0
            gt_mask[gt_mask > 0] = 1
            out_img[out_img > 0] = 1

            # true positive
            tp = np.zeros(gt_mask.shape, np.uint8)
            tp[(out_img==0) & (gt_mask==0)] = 1
            numtp = tp.sum()

            # false positive
            fp = np.zeros(gt_mask.shape, np.uint8)
            fp[(out_img==0) & (gt_mask==1)] = 1
            numfp = fp.sum()

            # false negative
            fn = np.zeros(gt_mask.shape, np.uint8)
            fn[(out_img==1) & (gt_mask==0)] = 1
            numfn = fn.sum()

            precision = (numtp) / float(numtp + numfp)
            recall = (numtp) / float(numtp + numfn)
            fmeasure = 100. * (2. * recall * precision) / (recall + precision) # percent

            total_fmeasure += fmeasure
            total_image_number += 1
            break
        total_fmeasure /= total_image_number

        if patch_best_fmeasure < total_fmeasure:
            patch_best_fmeasure = total_fmeasure
        print('epoch[%d/%d] patch fmeasure: %.4f, best_fmeasure: %.4f, time: %.2f' 
                    % (epoch + 1, epochs, total_fmeasure, patch_best_fmeasure, time.time() - epoch_start_time))
        print()

    torch.save(patch_model.state_dict(), weight_path + '/unet_patch_%d_%.4f.pth' % (epoch + 1, total_fmeasure))
    torch.save(discriminator.state_dict(), weight_path + '/dis_%d_%.4f.pth' % (epoch + 1, total_fmeasure))
Beispiel #7
0
        mask = cv2.imread(os.path.join(mask_train_dir, image_name + '.png'),
                          cv2.IMREAD_GRAYSCALE)
    elif os.path.isfile(os.path.join(mask_train_dir, image_name + '.bmp')):
        mask = cv2.imread(os.path.join(mask_train_dir, image_name + '.bmp'),
                          cv2.IMREAD_GRAYSCALE)
    else:
        print(img, 'no mask')
        exit(1)

    mask[mask < 190] = 0
    mask[mask >= 190] = 255

    h, w, _ = image.shape
    image_patches, poslist = get_image_patch(image,
                                             crop_h,
                                             crop_w,
                                             overlap=predict_overlap_ratio,
                                             is_mask=False)

    merge_img = np.ones((h, w, 3))
    out_imgs = []

    for channel in range(4):
        color_patches = []
        for patch in image_patches:
            tmp = patch.astype(np.float32)
            if channel != 3:
                color_patches.append(
                    preprocess_input(tmp[:, :, channel:channel + 1]))
            else:
                color_patches.append(