Exemplo n.º 1
0
def estimate_mean_std(config, metadata, parse_item_cb, num_threads=8, bs=16):
    mean_std_loader = ItemLoader(
        meta_data=metadata,
        transform=train_test_transforms(config)['train'],
        parse_item_cb=parse_item_cb,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=False)

    mean = None
    std = None
    for i in tqdm(range(len(mean_std_loader)),
                  desc='Calculating mean and standard deviation'):
        for batch in mean_std_loader.sample():
            if mean is None:
                mean = torch.zeros(batch['data'].size(1))
                std = torch.zeros(batch['data'].size(1))
            # for channel in range(batch['data'].size(1)):
            #     mean[channel] += batch['data'][:, channel, :, :].mean().item()
            #     std[channel] += batch['data'][:, channel, :, :].std().item()
            mean += batch['data'].mean().item()
            std += batch['data'].std().item()

    mean /= len(mean_std_loader)
    std /= len(mean_std_loader)

    return mean, std
Exemplo n.º 2
0
def test_loader_samples_batches(batch_size, n_samples, metadata_fname_target_5_classes,
                                ones_image_parser, img_target_transformer):
    iterm_loader = ItemLoader(meta_data=metadata_fname_target_5_classes, root='/tmp/',
                              batch_size=batch_size, parse_item_cb=ones_image_parser,
                              transform=img_target_transformer, shuffle=True)

    samples = iterm_loader.sample(n_samples)

    assert len(samples) == n_samples
    assert samples[0]['img'].size(0) == batch_size
    assert samples[0]['target'].size(0) == batch_size
Exemplo n.º 3
0
                                    shuffle=False, drop_last=False)

                kappa_meter.on_epoch_begin(0)
                acc_meter.on_epoch_begin(0)
                mse_meter.on_epoch_begin(0)

                cm_viz.on_epoch_begin(0)
                cm_norm_viz.on_epoch_begin(0)
                progress_bar = tqdm(range(len(loader)), total=len(loader), desc="Eval::")

                if save_detail_preds:
                    bi_preds_probs_all = []
                    bi_targets_all = []

                for i in progress_bar:
                    sample = loader.sample(1)[0]
                    _output = d_model(sample['data'].to(next(d_model.parameters()).device))
                    sample['target'] = sample['target'].type(torch.int32)

                    output = _output

                    if save_servere_wrong_predictions:
                        preds_logits = to_cpu(output)

                        targets_cpu = to_cpu(sample['target'])

                        preds_probs = softmax(preds_logits, axis=-1)
                        preds = np.argmax(preds_probs, axis=-1)
                        bi_probs_cpu = preds_probs[:, 2] + preds_probs[:, 3] + preds_probs[:, 4]
                        bi_target_cpu = np.zeros_like(targets_cpu).tolist()
                        for i in range(targets_cpu.shape[0]):
Exemplo n.º 4
0
class MixMatchEMASampler(object):
    def __init__(self,
                 st_model: nn.Module,
                 te_model: nn.Module,
                 name: str,
                 augmentation,
                 labeled_meta_data: pd.DataFrame,
                 unlabeled_meta_data: pd.DataFrame,
                 n_augmentations=1,
                 output_type='logits',
                 data_key: str = "data",
                 target_key: str = 'target',
                 parse_item_cb: callable or None = None,
                 root: str or None = None,
                 batch_size: int = 1,
                 num_workers: int = 0,
                 shuffle: bool = False,
                 pin_memory: bool = False,
                 collate_fn: callable = default_collate,
                 transform: callable or None = None,
                 sampler: torch.utils.data.sampler.Sampler or None = None,
                 batch_sampler=None,
                 drop_last: bool = False,
                 timeout: int = 0,
                 detach: bool = False):
        self._label_sampler = ItemLoader(meta_data=labeled_meta_data,
                                         parse_item_cb=parse_item_cb,
                                         root=root,
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         shuffle=shuffle,
                                         pin_memory=pin_memory,
                                         collate_fn=collate_fn,
                                         transform=transform,
                                         sampler=sampler,
                                         batch_sampler=batch_sampler,
                                         drop_last=drop_last,
                                         timeout=timeout)

        self._unlabel_sampler = ItemLoader(meta_data=unlabeled_meta_data,
                                           parse_item_cb=parse_item_cb,
                                           root=root,
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           shuffle=shuffle,
                                           pin_memory=pin_memory,
                                           collate_fn=collate_fn,
                                           transform=transform,
                                           sampler=sampler,
                                           batch_sampler=batch_sampler,
                                           drop_last=drop_last,
                                           timeout=timeout)

        self._name = name
        self._st_model: nn.Module = st_model
        self._te_model: nn.Module = te_model
        self._n_augmentations = n_augmentations
        self._augmentation = augmentation
        self._data_key = data_key
        self._target_key = target_key
        self._output_type = output_type
        self._detach = detach
        self._len = max(len(self._label_sampler), len(self._unlabel_sampler))

    def __len__(self):
        return self._len

    def _crop_if_needed(self, df1, df2):
        assert len(df1) == len(df2)
        for i in range(len(df1)):
            if len(df1[i]['data']) != len(df2[i]['data']):
                min_len = min(len(df1[i]['data']), len(df2[i]['data']))
                df1[i][self._data_key] = df1[i][self._data_key][:min_len, :]
                df2[i][self._data_key] = df2[i][self._data_key][:min_len, :]
                df1[i][self._target_key] = df1[i][self._target_key][:min_len]
                df2[i][self._target_key] = df2[i][self._target_key][:min_len]
        return df1, df2

    def sharpen(self, x, T=0.5):
        assert len(x.shape) == 2

        _x = torch.pow(x, 1 / T)
        s = torch.sum(_x, dim=-1, keepdim=True)
        _x = _x / s
        return _x

    def _create_union_data(self, r1, r2):
        assert len(r1) == len(r2)
        r = []

        for i in range(len(r1)):
            union_rows = dict()
            union_rows[self._data_key] = torch.cat(
                [r1[i][self._data_key], r2[i][self._data_key]], dim=0)
            union_rows["probs"] = torch.cat([r1[i]["probs"], r2[i]["probs"]],
                                            dim=0)
            union_rows['name'] = r1[i]['name']
            r.append(union_rows)
        return r

    def _mixup(self, x1, y1, x2, y2, alpha=0.75):
        l = np.random.beta(alpha, alpha)
        l = max(l, 1 - l)
        x = l * x1 + (1 - l) * x2
        y = l * y1 + (1 - l) * y2
        return x, y

    def sample(self, k=1):
        samples = []
        labeled_sampled_rows = self._label_sampler.sample(k)
        unlabeled_sampled_rows = self._unlabel_sampler.sample(k)

        labeled_sampled_rows, unlabeled_sampled_rows = self._crop_if_needed(
            labeled_sampled_rows, unlabeled_sampled_rows)

        for i in range(k):
            # Unlabeled data
            unlabeled_sampled_rows[i][
                self._data_key] = unlabeled_sampled_rows[i][self._data_key].to(
                    next(self._model.parameters()).device)

            u_imgs = unlabeled_sampled_rows[i][self._data_key]

            list_imgs = []
            for b in range(u_imgs.shape[0]):
                for j in range(self._n_augmentations):
                    img = u_imgs[b, :, :, :]
                    if img.shape[0] == 1:
                        img = img[0, :, :]
                    else:
                        img = img.permute(1, 2, 0)

                    img_cpu = to_cpu(img)
                    aug_img = self._augmentation(img_cpu)
                    list_imgs.append(aug_img)

            batch_imgs = torch.cat(list_imgs, dim=0)
            batch_imgs = batch_imgs.to(next(self._model.parameters()).device)
            if self._output_type == 'logits':
                out = self._model(batch_imgs)
            elif self._output_type == 'features':
                out = self._model.get_features(batch_imgs)

            preds = F.softmax(out, dim=1)
            preds = preds.view(u_imgs.shape[0], -1, preds.shape[-1])

            mean_preds = torch.mean(preds, dim=1)
            guessing_labels = self.sharpen(mean_preds).detach()

            unlabeled_sampled_rows[i]["probs"] = guessing_labels

            # Labeled data
            labeled_sampled_rows[i][self._data_key] = labeled_sampled_rows[i][
                self._data_key].to(next(self._model.parameters()).device)
            target_l = labeled_sampled_rows[i][self._target_key]
            onehot_l = torch.zeros(guessing_labels.shape)
            onehot_l.scatter_(1, target_l.type(torch.int64).unsqueeze(-1), 1.0)
            labeled_sampled_rows[i]["probs"] = onehot_l.to(
                next(self._model.parameters()).device)

        union_rows = self._create_union_data(labeled_sampled_rows,
                                             unlabeled_sampled_rows)

        for i in range(k):
            ridx = np.random.permutation(
                union_rows[i][self._data_key].shape[0])
            u = unlabeled_sampled_rows[i]
            x = labeled_sampled_rows[i]

            x_mix, target_mix = self._mixup(
                x[self._data_key], x["probs"],
                union_rows[i][self._data_key][ridx[i]],
                union_rows[i]["probs"][ridx[i]])
            u_mix, pred_mix = self._mixup(
                u[self._data_key], u["probs"],
                union_rows[i][self._data_key][ridx[k + i]],
                union_rows[i]["probs"][ridx[k + i]])

            samples.append({
                'name': self._name,
                'x_mix': x_mix,
                'target_mix_x': target_mix,
                'u_mix': u_mix,
                'target_mix_u': pred_mix,
                'target_x': x[self._target_key]
            })
        return samples