def visualize_predictions(image_logger, max_samples, metric_fn, logits, gt):
    num_samples = min(len(gt), max_samples)
    metrics = to_numpy(
        metric_fn(from_numpy(logits), from_numpy(gt), average=False))
    order = np.argsort(metrics)
    gt = gt[order][:num_samples]
    logits = logits[order][:num_samples]
    metrics = metrics[order][:num_samples]
    probs = 1 / (1 + np.exp(-logits.squeeze()))

    samples_per_row = 16
    num_rows = int(np.ceil(num_samples / samples_per_row)) * 2
    plt.figure(figsize=(6, 1 * num_rows))

    for i in range(num_samples):
        plt.subplot(num_rows, samples_per_row,
                    (i // samples_per_row) * samples_per_row + i + 1)
        plt.title(f'{metrics[i]:.1f}')
        plt.imshow(probs[i], vmin=0, vmax=1)
        plt.xticks([])
        plt.yticks([])
        plt.subplot(num_rows, samples_per_row,
                    (i // samples_per_row + 1) * samples_per_row + i + 1)
        plt.imshow(gt[i])
        plt.xticks([])
        plt.yticks([])
    plt.gcf().tight_layout()
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    image_logger(plt.gcf())
def find_outliers(checkpoint_path,
                  num_folds=5,
                  fold_ids=[0, 1, 2, 3],
                  batch_size=1,
                  limit=None,
                  tta=False):
    model = load_checkpoint(checkpoint_path)
    model = as_cuda(model)
    torch.set_grad_enabled(False)
    model.eval()

    mask_db = get_mask_db('data/train.csv')
    all_image_ids, all_fold_ids = get_area_stratified_split(mask_db, num_folds)
    image_ids = all_image_ids[np.isin(all_fold_ids, fold_ids)]

    losses = []
    generator = get_validation_generator(num_folds, fold_ids, batch_size,
                                         limit)
    for inputs, gt in tqdm(generator, total=len(generator)):
        inputs, gt = from_numpy(inputs), from_numpy(gt)
        outputs = model(inputs)
        if tta:
            flipped_outputs = model(inputs.flip(dims=(3, )))
            outputs = (outputs + flipped_outputs.flip(dims=(3, ))) / 2
        batch_losses = to_numpy(lovasz_hinge_loss(outputs, gt, average=False))
        losses.extend(batch_losses)

    image_ids = image_ids[:len(losses)]
    import pdb
    pdb.set_trace()
    stats = pd.DataFrame(data={'image_id': image_ids, 'loss': losses})
    print(stats.sort_values('loss'))
 def on_validation_end(self, logs, outputs, gt):
     losses = to_numpy(self.loss_fn(from_numpy(outputs), from_numpy(gt)))
     losses = losses.mean(axis=0)
     losses = uniform_filter(losses, size=6, mode='nearest')
     x, y = np.meshgrid(np.arange(losses.shape[0]),
                        np.arange(losses.shape[1]))
     fig = plt.figure()
     ax = fig.gca(projection='3d')
     ax.plot_surface(x,
                     y,
                     losses,
                     linewidth=0,
                     antialiased=True,
                     cmap=plt.get_cmap('viridis'),
                     edgecolor='none')
     ax.view_init(60, 35)
     self.image_logger(fig)
Example #4
0
 def on_validation_end(self, logs, outputs, gt):
     values = to_numpy(self.metric_fn(from_numpy(outputs), from_numpy(gt), average=False))
     plt.hist(values, bins=20)
     plt.title(self.metric_fn.__name__)
     self.image_logger(plt.gcf())
Example #5
0
def predict(checkpoint_path, batch_size=1, limit=None, tta=False):
    blacklist = [
        '0035c56490',
        '005855cd72',
        '00b6d3a31f',
        '00c473f654',
        '00f12566b9',
        '01a40c3405',
        '01d6be9b57',
        '0279cf7419',
        '02c9416d86',
        '0300ad09f3',
        '0350a9e3bd',
        '06b664b866',
        '06fa0b053b',
        '0794c37f5a',
        '09ce1453ea',
        '0a7c09181a',
        '0a90963914',
        '0bda89116e',
        '0be00ec340',
        '0da877df19',
        '0f13aea58f',
        '0fcf26daaf',
        '0fd7ee2ea9',
        '1081dc0cb9',
        '108fe2c3d8',
        '10947d16b1',
        '10f1d4a32c',
        '10f59e0caf',
        '11a0fc8072',
        '12de44ce0c',
        '131fc7cbf2',
        '132cea196e',
        '1409ddabbf',
        '1471bc7b6d',
        '147a36429e',
        '153558a258',
        '1557139238',
        '162f55f72b',
        '162f738a03',
        '1817464894',
        '186e591d03',
        '1915af8856',
        '1aaffdb790',
        '1bb08d1d48',
        '1bd9b1a1c8',
        '1c16231286',
        '1c165d61d6',
        '1c21b0e9bb',
        '1c920e3604',
        '1dac95d5b7',
        '1df8119d15',
        '1df8d91491',
        '1e10c8b214',
        '1e95bad8bb',
        '21beba74c7',
        '21c45388e8',
        '228e88f048',
        '22a9fdff9e',
        '22ca26f94a',
        '23062a0de4',
        '242f1fac62',
        '2454a76962',
        '245daa2004',
        '24a18f95ed',
        '24f57fec17',
        '2567b0c9de',
        '2662f5581a',
        '270dd2b8e6',
        '27a79442e2',
        '27b98d2fb7',
        '2849772dd6',
        '29d32c7561',
        '29fdaa3c4e',
        '2a0583533f',
        '2aa30fc1d4',
        '2cf9b202bb',
        '2e40cf4b88',
        '2f85d3f736',
        '2f8b6beb17',
        '2fac24e793',
        '30e5c4d227',
        '31c6a2860f',
        '31d4b415e4',
        '328d1497af',
        '3336643b56',
        '33cb556649',
        '3474bba21d',
        '34f546117e',
        '353e010b7b',
        '35b730549d',
        '36c0313b9e',
        '37062e1138',
        '3746667eb4',
        '3780e66035',
        '37ab36b8b8',
        '38bda63402',
        '38cc48d708',
        '39aa1f3a2b',
        '3c3dedc0cc',
        '3c8cf08665',
        '3ce8a8bdf4',
        '3cf20ce659',
        '3d05dc9c29',
        '3d21e1dc58',
        '3d9c671ce1',
        '3e7fa000c7',
        '3e871346f4',
        '3e991d73fa',
        '3fe3373b67',
        '4062ebb8f8',
        '40e3262fb1',
        '4150e85a49',
        '42e3dc40ba',
        '4303cf97d3',
        '432f07b09e',
        '447b0d530d',
        '44df56fa12',
        '457c69bcc8',
        '45e3277f3f',
        '4622cee4b9',
        '46a93d2bd4',
        '4775e574dd',
        '48132be0ac',
        '485e869caf',
        '48a6b2914d',
        '49174ccdd9',
        '49ba3412bd',
        '49f33f1832',
        '4a96727bdf',
        '4b478173aa',
        '4b64c96dbe',
        '4c6c90518d',
        '4d1825be4e',
        '4d3eeda971',
        '4e70e4d96c',
        '4f18b39baa',
        '4f29c4cd04',
        '4fb2778316',
        '503f08307e',
        '50ac343646',
        '5119fdbfde',
        '51806f45ce',
        '5241080fcb',
        '52cfb25e88',
        '52dbe09f4c',
        '52f369a734',
        '536decbbed',
        '54fbd50faa',
        '55b9cb8acd',
        '571e7f5a50',
        '578cae22ea',
        '5791c70c05',
        '5850ff1f52',
        '5895535b14',
        '58b6687314',
        '58d550a35d',
        '58ee6464fd',
        '5a06c45958',
        '5af890206e',
        '5b7ebdc259',
        '5c0acc3f31',
        '5d049a35b3',
        '5d36a9659e',
        '5d4578efd7',
        '5e52f098d9',
        '5ef842ff3e',
        '5f3f6d6ca6',
        '60611407d3',
        '60ed5847a1',
        '61a0a17dcd',
        '629f639f2c',
        '62ab055b81',
        '63660a3693',
        '63a371159d',
        '648ac9c05a',
        '65556f12d3',
        '6654fbf093',
        '66b8883dd0',
        '67d21b9e92',
        '68a1443ac6',
        '68a469ef7e',
        '68ff47165b',
        '69c7de5d26',
        '6a7d046783',
        '6ae646336d',
        '6afae09aa9',
        '6c6f886709',
        '6d0ab73c29',
        '6d23b68142',
        '6d3209c6d5',
        '6d75fcd108',
        '6e28a340f7',
        '6e470db51c',
        '6e5e055e6f',
        '6e67571b91',
        '6f0d7bdd29',
        '6fc195d3de',
        '706321cca9',
        '70b96ce692',
        '717bd5d6b1',
        '71b7cc2fdc',
        '71e937b285',
        '7304dcacee',
        '7370b31cd0',
        '75d8153a52',
        '75f0269699',
        '7671c2d961',
        '76b26d2e39',
        '76de89d0eb',
        '7785879425',
        '77b35c01c7',
        '78e2b8371a',
        '792cef86c2',
        '7940a73429',
        '7959b9bd53',
        '7968eec1f1',
        '79949a4117',
        '79bb9dc447',
        '7a60a99114',
        '7ae7a91efc',
        '7bf1f7ff2c',
        '7ceede7d55',
        '7e94eb71be',
        '7f6d43202d',
        '800b9f8b72',
        '800d311316',
        '8016056c46',
        '808c63ed8f',
        '8135ca6dde',
        '81ad4fed2b',
        '82db07ac64',
        '837543dddb',
        '84d6b146a8',
        '8507aeb1cb',
        '85ef988548',
        '85fe1cb502',
        '868bb336b9',
        '86bb8ddfe4',
        '86e7716f79',
        '8869f3399c',
        '88e2efec9e',
        '8920afbcd2',
        '897d7e821c',
        '8a39b1ce9f',
        '8a6401e9d5',
        '8c1cdc6be2',
        '8c4081751a',
        '8cb2bbf1a8',
        '8e48845dd4',
        '8fc26c0caa',
        '90519d19c7',
        '906c1ff22e',
        '90c214f1e5',
        '90efc61382',
        '91791bd48c',
        '923aca4789',
        '9290c42ee0',
        '92b1a7eec7',
        '934fcb3879',
        '93c3bd5f5c',
        '93ff5d63de',
        '94c1d3a759',
        '9511d7e887',
        '953d0eb2ab',
        '96798b90f9',
        '967f41605c',
        '96d2ddd94f',
        '988251c854',
        '98a174a0d4',
        '98a9095250',
        '98b6e72be0',
        '995c49297f',
        '998e781159',
        '9aad618687',
        '9b66bcf23c',
        '9b86b8a5fe',
        '9b952b1af0',
        '9c14cb5581',
        '9c58e81c46',
        '9cf74c6432',
        '9d96a157ca',
        '9f3cc74e77',
        '9f5029183b',
        '9f757eb55b',
        'a01b8c4af5',
        'a16611557b',
        'a16bbd70da',
        'a17c4dc16f',
        'a21e1cf8d4',
        'a4554beb3c',
        'a4763177ef',
        'a4af4ec79e',
        'a89e1ab744',
        'a8a86d7d11',
        'aa10246dda',
        'ac42154416',
        'acdbb294e4',
        'ad5d638041',
        'ad91e70d5c',
        'ad972c7127',
        'ade51acc52',
        'b0fda65353',
        'b180ac0113',
        'b3459251e5',
        'b3be6d0b23',
        'b3c8932dc8',
        'b5f5cb0885',
        'b72181a9d5',
        'b77b60349e',
        'b935f60197',
        'bc09c0039c',
        'be02cfe5da',
        'be56c6732d',
        'be8fb89ea9',
        'bf0f9e16e7',
        'bf66dafb31',
        'c1466f194c',
        'c19559637d',
        'c1dbd8af8d',
        'c1f15b6967',
        'c20570220b',
        'c215f25359',
        'c25ec5b066',
        'c28f0753df',
        'c2cf683cdc',
        'c329b6d198',
        'c3bafa1d78',
        'c4291c0396',
        'c4a43ba621',
        'c59e59474e',
        'c5dad43641',
        'c64b87ba5a',
        'c7bfd8548e',
        'c8409479b3',
        'c877496b7a',
        'c89bda5c02',
        'c9526a0744',
        'c96f52ba7a',
        'c9867c4064',
        'ca4d59fef3',
        'ca9f801c0a',
        'cae26f3214',
        'cc2fc654f5',
        'cc5b03f643',
        'ccc645a996',
        'cd22ea7236',
        'ce96a07785',
        'cee5f80b1c',
        'd075ac58be',
        'd0c4763c9b',
        'd2066b2414',
        'd27db3dcc0',
        'd2e4b1e381',
        'd2f25a78a6',
        'd3acb2d561',
        'd4a84662ee',
        'd4f2ed2ce2',
        'd795565f86',
        'd8e7f62b4e',
        'dba4475686',
        'ddaa8f2cc9',
        'ddd3e5ca3f',
        'de6e6ed26a',
        'de96056281',
        'df0027e3ab',
        'dfab2c098a',
        'e0257b4c20',
        'e0ad4ecf12',
        'e18e04a6ff',
        'e1a0800dc7',
        'e224e91c50',
        'e25d11adc9',
        'e576adfb4f',
        'e634f95937',
        'e699288e54',
        'e7688312ec',
        'e7afd37c7f',
        'e84940ac6a',
        'e86a7ecaa5',
        'e88de6b8fc',
        'e8957ff25d',
        'e902cfed5a',
        'e925b2a8ff',
        'ea1488cd35',
        'ea3b072e78',
        'ea5cc18a2d',
        'ea6dfe38ae',
        'eac66ccf10',
        'eae2a8b2ef',
        'eaf44cece7',
        'ebeeab3a36',
        'ec712f4129',
        'ed8fe4f4e0',
        'ee54465e3d',
        'ee9f7913a1',
        'eeb676e37b',
        'f2a038cbd1',
        'f2fd5c81ac',
        'f38d7d303f',
        'f3b6e11340',
        'f3bf3050d0',
        'f43de2bb2a',
        'f560c14550',
        'f5b747d45b',
        'f5d2d4d974',
        'f5ddab9cab',
        'f5f5901c55',
        'f63f03c926',
        'f70a411f05',
        'f7ed9b7ab7',
        'f8088ca24a',
        'f8a05eac67',
        'f8afabcb4c',
        'f8f9e5ba3b',
        'f96265d40e',
        'f99a792e1c',
        'f9c42b8d50',
        'fae8a1af69',
        'fbac994408',
        'fbc8a6a4f9',
        'fbf6226410',
        'fbf7b17153',
        'fd1f3e5caa',
        'fd2642606b',
        'fd5cc36d93',
        'fdad2f99d8',
        'fdda40625e',
        'ff06e0f167',
        'ff608418c1',
        'ffb3e66bbe',
    ]
    model = load_checkpoint(checkpoint_path)
    model = as_cuda(model)
    torch.set_grad_enabled(False)
    model.eval()

    records = []
    ids = list(
        map(lambda path: path.split('/')[-1].split('.')[0],
            get_images_in('data/test/images')))[:limit]
    test_generator = get_test_generator(batch_size, limit)
    for inputs, _ in tqdm(test_generator, total=len(test_generator)):
        inputs = from_numpy(inputs)
        outputs = model(inputs)
        if tta:
            flipped_outputs = model(inputs.flip(dims=(3, )))
            outputs = (torch.sigmoid(outputs) +
                       torch.sigmoid(flipped_outputs.flip(dims=(3, )))) / 2
        # TODO AS: Ignoring reflected regions, since they are cut on submission
        masks = to_numpy(outputs[:, 0, :, :].round().long())[:, 13:-14, 13:-14]
        for mask in masks:
            _id = ids.pop(0)
            if mask.sum() <= 3 or _id in blacklist:
                records.append((_id, None))
            else:
                records.append((_id, encode_rle(mask)))

    image_ids, encoded_pixels = zip(*records)
    df = pd.DataFrame({'id': image_ids, 'rle_mask': encoded_pixels})
    df.to_csv('./data/submissions/__latest.csv', index=False)
Example #6
0
def fit_model(model,
              train_generator,
              validation_generator,
              optimizer,
              loss_fn,
              num_epochs,
              logger,
              callbacks=[],
              metrics=[]):

    for epoch in tqdm(range(num_epochs)):
        num_batches = len(train_generator)
        logs = {}
        logs['train_loss'] = 0
        for func in metrics:
            logs[f'train_{func.__name__}'] = 0
        model.train()
        torch.set_grad_enabled(True)
        for callback in callbacks:
            callback.on_train_begin()
        for inputs, gt in tqdm(train_generator, total=num_batches):
            inputs, gt = from_numpy(inputs), from_numpy(gt)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, gt)
            loss.backward()
            optimizer.step()
            logs['train_loss'] += loss.data[0]
            for func in metrics:
                logs[f'train_{func.__name__}'] += func(outputs.detach(), gt)
            for callback in callbacks:
                callback.on_train_batch_end(loss.data[0])

        logs['train_loss'] /= num_batches
        for func in metrics:
            logs[f'train_{func.__name__}'] /= num_batches

        logs['val_loss'] = 0
        for func in metrics:
            logs[f'val_{func.__name__}'] = 0
        all_outputs = []
        all_gt = []
        num_batches = len(validation_generator)
        model.eval()
        torch.set_grad_enabled(False)
        for inputs, gt in tqdm(validation_generator, total=num_batches):
            all_gt.append(gt)
            inputs, gt = from_numpy(inputs), from_numpy(gt)
            outputs = model(inputs)
            # TODO AS: Extract as cmd opt
            flipped_outputs = torch.sigmoid(
                model(inputs.flip(dims=(3, ))).flip(dims=(3, )))
            outputs = torch.sigmoid(outputs)
            outputs = (outputs + flipped_outputs) / 2
            outputs = torch.log(outputs / (1 - outputs))
            logs['val_loss'] += loss_fn(outputs, gt).data[0]
            for func in metrics:
                logs[f'val_{func.__name__}'] += func(outputs.detach(), gt)

            if isinstance(outputs, tuple):
                all_outputs.append(list(map(to_numpy, outputs)))
            else:
                all_outputs.append(to_numpy(outputs))
        logs['val_loss'] /= num_batches
        for func in metrics:
            logs[f'val_{func.__name__}'] /= num_batches

        if isinstance(all_outputs[0], tuple):
            all_outputs = list(map(np.concatenate, zip(*all_outputs)))
        else:
            all_outputs = np.concatenate(all_outputs)

        all_gt = np.concatenate(all_gt)
        for callback in callbacks:
            callback.on_validation_end(logs, all_outputs, all_gt)

        epoch_rows = [['epoch', epoch]]
        for name, value in logs.items():
            epoch_rows.append([name, f'{value:.3f}'])

        logger(tabulate(epoch_rows))