Ejemplo n.º 1
0
def get_baseline_experiment(experiment_name):
    generability_baseline_model = HAN(20, 10, 300, 2, nb_layers, .25).eval()
    generability_baseline_experiment = Experiment(
        experiment_name,
        generability_baseline_model,
        monitor_metric="val_fscore_macro",
        monitor_mode="max",
        loss_function='cross_entropy',
        task="classification",
        epoch_metrics=[FBeta(average='macro')],
        device=0)
    generability_baseline_experiment.load_checkpoint('best')

    return generability_baseline_experiment
Ejemplo n.º 2
0
def get_proposed_hmc_experiment(experiment_name):
    generability_proposed_model = MLHAN_MultiLabel(20, 10, 300, [3, 2],
                                                   nb_layers, .25).eval()
    generability_proposed_experiment = Experiment(
        experiment_name,
        generability_proposed_model,
        monitor_metric="val_fscore_macro",
        monitor_mode="max",
        loss_function=MultiLevelMultiLabelLoss(),
        epoch_metrics=[
            FBetaLowerLevelMultiLabel(average='macro'),
            FBetaUpperLevelMultiLabel(average='macro')
        ],
        device=0)
    generability_proposed_experiment.load_checkpoint('best')

    return generability_proposed_experiment
def main(args):
    raw_dataset = RegressionDatasetFolder(os.path.join(
        args.root_dir, 'Images/1024_with_jedi'),
                                          input_only_transform=None,
                                          transform=Compose([ToTensor()]))
    mean, std = compute_mean_std(raw_dataset)
    print(mean)
    print(std)
    pos_weights = compute_pos_weight(raw_dataset)
    print(pos_weights)
    test_dataset = RegressionDatasetFolder(
        os.path.join(args.root_dir, 'Images/1024_with_jedi'),
        input_only_transform=Compose([Normalize(mean, std)]),
        transform=Compose(
            [Lambda(lambda img: pad_resize(img, 1024, 1024)),
             ToTensor()]),
        in_memory=True)

    valid_dataset = RegressionDatasetFolder(
        os.path.join(args.root_dir, 'Images/1024_with_jedi'),
        input_only_transform=Compose([Normalize(mean, std)]),
        transform=Compose([ToTensor()]),
        include_fname=True)

    train_split, valid_split, test_split, train_weights = get_splits(
        valid_dataset)
    valid_loader = DataLoader(Subset(test_dataset, valid_split),
                              batch_size=8,
                              num_workers=8,
                              pin_memory=False)

    # module = deeplabv3_efficientnet(n=5)
    module = fcn_resnet50(dropout=0.8)
    # module = deeplabv3_resnet50()

    optim = torch.optim.Adam(module.parameters(), lr=5e-4, weight_decay=2e-3)
    exp = Experiment(directory=os.path.join(args.root_dir, 'moar'),
                     module=module,
                     device=torch.device(args.device),
                     optimizer=optim,
                     loss_function=LovaszSoftmax(),
                     metrics=[miou, PixelWiseF1(None)],
                     monitor_metric='val_miou',
                     monitor_mode='max')

    lr_schedulers = [
        ReduceLROnPlateau(monitor='val_miou',
                          mode='max',
                          factor=0.2,
                          patience=3,
                          threshold=1e-1,
                          threshold_mode='abs')
    ]
    callbacks = [
        EarlyStopping(monitor='val_miou',
                      min_delta=1e-1,
                      patience=8,
                      verbose=True,
                      mode='max')
    ]

    for i, (crop_size, batch_size) in enumerate(zip([512], [5])):
        train_loader = get_loader_for_crop_batch(crop_size, batch_size,
                                                 train_split, mean, std,
                                                 train_weights, args.root_dir)

        exp.train(train_loader=train_loader,
                  valid_loader=valid_loader,
                  epochs=(1 + i) * 30,
                  lr_schedulers=lr_schedulers,
                  callbacks=callbacks)

    raw_dataset.print_filenames()

    pure_dataset = RegressionDatasetFolder(os.path.join(
        args.root_dir, 'Images/1024_with_jedi'),
                                           transform=Compose([ToTensor()]),
                                           include_fname=True)

    test_loader = DataLoader(Subset(test_dataset, test_split),
                             batch_size=8,
                             num_workers=8,
                             pin_memory=False)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=1,
                              num_workers=8,
                              pin_memory=False)
    pure_loader = DataLoader(pure_dataset,
                             batch_size=1,
                             num_workers=8,
                             pin_memory=False)

    exp.test(test_loader)

    # for checkpoint in [11, 15, 16, 17, 21]:
    #     print("Testing checkpoint {}".format(checkpoint))
    #     exp.load_checkpoint(checkpoint)
    #     test_model_on_checkpoint(exp.model, test_loader)

    exp.load_checkpoint(11)
    module = exp.model.model
    module.eval()

    generate_output_folders(args.root_dir)

    splits = [(train_split, 'train'), (valid_split, 'valid'),
              (test_split, 'test')]

    results_csv = [[
        'Name', 'Type', 'Split', 'iou_nothing', 'iou_bark', 'iou_node',
        'iou_mean', 'f1_nothing', 'f1_bark', 'f1_node', 'f1_mean',
        'Output Bark %', 'Output Node %', 'Target Bark %', 'Target Node %'
    ]]

    with torch.no_grad():
        for image_number, (batch, pure_batch) in enumerate(
                zip(valid_loader, pure_loader)):
            input = pure_batch[0]
            target = pure_batch[1]
            fname = pure_batch[2][0]
            wood_type = pure_batch[3][0]

            del pure_batch

            outputs = module(batch[0].to(torch.device(args.device)))
            outputs = remove_small_zones(outputs)

            del batch

            names = ['Input', 'Target', 'Generated image']

            try:
                class_accs = iou(outputs, target.to(torch.device(args.device)))
                f1s = PixelWiseF1('all')(outputs, target) * 100

                acc = class_accs.mean()
                f1 = f1s.mean()
            except ValueError as e:
                print('Error on file {}'.format(fname))
                print(outputs.shape)
                print(target.shape)
                raise e

            outputs = torch.argmax(outputs, dim=1)

            imgs = [input, target, outputs]
            imgs = [img.detach().cpu().squeeze().numpy() for img in imgs]

            fig, axs = plt.subplots(1, 3)

            class_names = ['Nothing', 'Bark', 'Node']

            for i, ax in enumerate(axs.flatten()):
                img = imgs[i]

                raw = (len(img.shape) == 3)

                if raw:  # Raw input
                    img = img.transpose(1, 2, 0)

                values = np.unique(img.ravel())

                plotted_img = ax.imshow(img, vmax=2)
                ax.set_title(names[i])
                ax.axis('off')

                if not raw:  # Predicted image
                    patches = [
                        mpatches.Patch(
                            color=plotted_img.cmap(plotted_img.norm(value)),
                            label='{} zone'.format(class_names[value]))
                        for value in values
                    ]

            suptitle = 'Mean iou : {:.3f}\n'.format(acc)

            for split_idxs, split_name in splits:
                if image_number in split_idxs:
                    split = split_name

            running_csv_stats = [fname, wood_type, split]

            class_names = ['Nothing', 'Bark', 'Node']

            for c, c_acc in zip(class_names, class_accs):
                suptitle += '{} : {:.3f};  '.format('iou_' + c, c_acc)
                running_csv_stats.append('{:.3f}'.format(c_acc))

            running_csv_stats.append('{:.3f}'.format(acc))
            suptitle += '\nMean f1 : {:.3f}\n'.format(f1)

            for c, c_f1 in zip(class_names, f1s):
                suptitle += '{} : {:.3f};  '.format('f1_' + c, c_f1)
                running_csv_stats.append('{:.3f}'.format(c_f1))

            running_csv_stats.append('{:.3f}'.format(f1))

            for class_idx in [1, 2]:
                class_percent = (outputs == class_idx).float().mean().cpu()
                running_csv_stats.append('{:.5f}'.format(class_percent * 100))

            for class_idx in [1, 2]:
                class_percent = (target == class_idx).float().mean().cpu()
                running_csv_stats.append('{:.5f}'.format(class_percent * 100))

            fig.legend(handles=patches,
                       title='Classes',
                       bbox_to_anchor=(0.4, -0.2, 0.5, 0.5))
            plt.suptitle(suptitle)
            plt.tight_layout()
            # plt.show()
            plt.savefig(os.path.join(
                args.root_dir,
                'Images/results/moar/combined_images/{}/{}/{}').format(
                    wood_type, split, fname),
                        format='png',
                        dpi=900)
            plt.close()

            outputs = outputs.squeeze().cpu().numpy()
            dual_outputs = np.zeros((outputs.shape[0], outputs.shape[1]),
                                    dtype=np.uint8)
            dual_outputs[outputs == 1] = 127
            dual_outputs[outputs == 2] = 255

            dual = Image.fromarray(dual_outputs, mode='L')
            dual.save(
                os.path.join(args.root_dir,
                             'Images/results/moar/outputs/{}/{}/{}').format(
                                 wood_type, split, fname))

            results_csv.append(running_csv_stats)

    csv_file = os.path.join(args.root_dir, 'Images', 'results', 'moar',
                            'final_stats.csv')

    with open(csv_file, 'w') as f:
        csv_writer = csv.writer(f, delimiter='\t')
        csv_writer.writerows(results_csv)