Exemplo n.º 1
0
    def __init__(self,
                 model: nn.Module,
                 name: str,
                 augmentation,
                 labeled_meta_data: pd.DataFrame,
                 unlabeled_meta_data: pd.DataFrame,
                 n_augmentations=1,
                 output_type='logits',
                 data_key: str = "data",
                 target_key: str = 'target',
                 parse_item_cb: callable or None = None,
                 root: str or None = None,
                 batch_size: int = 1,
                 num_workers: int = 0,
                 shuffle: bool = False,
                 pin_memory: bool = False,
                 collate_fn: callable = default_collate,
                 transform: callable or None = None,
                 sampler: torch.utils.data.sampler.Sampler or None = None,
                 batch_sampler=None,
                 drop_last: bool = False,
                 timeout: int = 0,
                 detach: bool = False):
        self._label_sampler = ItemLoader(meta_data=labeled_meta_data,
                                         parse_item_cb=parse_item_cb,
                                         root=root,
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         shuffle=shuffle,
                                         pin_memory=pin_memory,
                                         collate_fn=collate_fn,
                                         transform=transform,
                                         sampler=sampler,
                                         batch_sampler=batch_sampler,
                                         drop_last=drop_last,
                                         timeout=timeout)

        self._unlabel_sampler = ItemLoader(meta_data=unlabeled_meta_data,
                                           parse_item_cb=parse_item_cb,
                                           root=root,
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           shuffle=shuffle,
                                           pin_memory=pin_memory,
                                           collate_fn=collate_fn,
                                           transform=transform,
                                           sampler=sampler,
                                           batch_sampler=batch_sampler,
                                           drop_last=drop_last,
                                           timeout=timeout)

        self._name = name
        self._model: nn.Module = model
        self._n_augmentations = n_augmentations
        self._augmentation = augmentation
        self._data_key = data_key
        self._target_key = target_key
        self._output_type = output_type
        self._detach = detach
        self._len = max(len(self._label_sampler), len(self._unlabel_sampler))
Exemplo n.º 2
0
def estimate_mean_std(config, metadata, parse_item_cb, num_threads=8, bs=16):
    mean_std_loader = ItemLoader(
        meta_data=metadata,
        transform=train_test_transforms(config)['train'],
        parse_item_cb=parse_item_cb,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=False)

    mean = None
    std = None
    for i in tqdm(range(len(mean_std_loader)),
                  desc='Calculating mean and standard deviation'):
        for batch in mean_std_loader.sample():
            if mean is None:
                mean = torch.zeros(batch['data'].size(1))
                std = torch.zeros(batch['data'].size(1))
            # for channel in range(batch['data'].size(1)):
            #     mean[channel] += batch['data'][:, channel, :, :].mean().item()
            #     std[channel] += batch['data'][:, channel, :, :].std().item()
            mean += batch['data'].mean().item()
            std += batch['data'].std().item()

    mean /= len(mean_std_loader)
    std /= len(mean_std_loader)

    return mean, std
Exemplo n.º 3
0
def semixup_data_provider(model, alpha, n_classes, train_labeled_data, train_unlabeled_data, val_labeled_data,
                          transforms, parse_item, bs, num_workers, item_loaders=dict(), root="", augmentation=None,
                          data_rearrange=None):
    """
    Default setting of data provider for Semixup
    """
    # item_loaders["labeled_train"] = MixUpSampler(meta_data=train_labeled_data, name='l_mixup', alpha=alpha, model=model,
    #                                              transform=transforms['train'], parse_item_cb=parse_item, batch_size=bs,
    #                                              data_rearrange=data_rearrange,
    #                                              num_workers=num_workers, root=root, shuffle=True)
    item_loaders["labeled_train"] = ItemLoader(meta_data=train_labeled_data, name='l_norm', transform=transforms['train'],
                                              parse_item_cb=parse_item, batch_size=bs, num_workers=num_workers,
                                              root=root, shuffle=True)

    item_loaders["unlabeled_train"] = SemixupSampler(meta_data=train_unlabeled_data, name='u_mixup', alpha=alpha,
                                                     model=model, min_lambda=0.55,
                                                     transform=transforms['train'], parse_item_cb=parse_item, batch_size=bs,
                                                     data_rearrange=data_rearrange,
                                                     num_workers=num_workers, augmentation=augmentation, root=root,
                                                     shuffle=True)

    item_loaders["labeled_eval"] = ItemLoader(meta_data=val_labeled_data, name='l_norm', transform=transforms['eval'],
                                              parse_item_cb=parse_item, batch_size=bs, num_workers=num_workers,
                                              root=root, shuffle=False)

    return DataProvider(item_loaders)
Exemplo n.º 4
0
def ict_data_provider(model,
                      alpha,
                      n_classes,
                      train_labeled_data,
                      train_unlabeled_data,
                      val_labeled_data,
                      val_unlabeled_data,
                      transforms,
                      parse_item,
                      bs,
                      num_threads,
                      item_loaders=dict(),
                      root=""):
    """
    Default setting of data provider for ICT
    """
    item_loaders["labeled_train"] = ItemLoader(meta_data=train_labeled_data,
                                               name='l_norm',
                                               transform=transforms[1],
                                               parse_item_cb=parse_item,
                                               batch_size=bs,
                                               num_workers=num_threads,
                                               root=root,
                                               shuffle=True)

    item_loaders["unlabeled_train"] = MixUpSampler(
        meta_data=train_unlabeled_data,
        name='u_mixup',
        alpha=alpha,
        model=model,
        transform=transforms[0],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        root=root,
        shuffle=True)

    item_loaders["labeled_eval"] = ItemLoader(meta_data=val_labeled_data,
                                              name='l_norm',
                                              transform=transforms[1],
                                              parse_item_cb=parse_item,
                                              batch_size=bs,
                                              num_workers=num_threads,
                                              root=root,
                                              shuffle=False)

    item_loaders["unlabeled_eval"] = MixUpSampler(meta_data=val_unlabeled_data,
                                                  name='u_mixup',
                                                  alpha=alpha,
                                                  model=model,
                                                  transform=transforms[1],
                                                  parse_item_cb=parse_item,
                                                  batch_size=bs,
                                                  num_workers=num_threads,
                                                  root=root,
                                                  shuffle=False)

    return DataProvider(item_loaders)
Exemplo n.º 5
0
def test_loader_samples_batches(batch_size, n_samples, metadata_fname_target_5_classes,
                                ones_image_parser, img_target_transformer):
    iterm_loader = ItemLoader(meta_data=metadata_fname_target_5_classes, root='/tmp/',
                              batch_size=batch_size, parse_item_cb=ones_image_parser,
                              transform=img_target_transformer, shuffle=True)

    samples = iterm_loader.sample(n_samples)

    assert len(samples) == n_samples
    assert samples[0]['img'].size(0) == batch_size
    assert samples[0]['target'].size(0) == batch_size
Exemplo n.º 6
0
def create_data_provider(args, config, parser, metadata, mean, std):
    """
    Setup dataloader and augmentations
    :param args: General arguments
    :param config: Experiment parameters
    :param parser: Function for loading images
    :param metadata: Image paths and subject IDs
    :param mean: Dataset mean
    :param std: Dataset std
    :return: The compiled dataloader
    """
    # Compile ItemLoaders
    item_loaders = dict()
    for stage in ['train', 'val']:
        item_loaders[f'bfpn_{stage}'] = ItemLoader(
            meta_data=metadata[stage],
            transform=train_test_transforms(
                config,
                mean,
                std,
                crop_size=tuple(config['training']['crop_size']))[stage],
            parse_item_cb=parser,
            batch_size=config['training']['bs'],
            num_workers=args.num_threads,
            shuffle=True if stage == "train" else False)

    return DataProvider(item_loaders)
Exemplo n.º 7
0
def main(cfg):
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    data_dir = os.path.join(os.environ['PWD'], cfg.data_dir)

    train_ds, classes = get_cifar10(data_folder=data_dir, train=True)
    eval_ds, _ = get_cifar10(data_folder=data_dir, train=False)
    n_channels = 3

    criterion = torch.nn.CrossEntropyLoss()

    model = ResNet(in_channels=n_channels, n_features=64, drop_rate=0.3).to(device).half()
    optimizer = torch.optim.SGD(params=model.parameters(), lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.wd,
                                nesterov=True)

    # Tensorboard visualization
    log_dir = cfg.log_dir
    comment = cfg.comment
    summary_writer = SummaryWriter(log_dir=log_dir, comment=comment)

    item_loaders = dict()
    for stage, df in zip(['train', 'eval'], [train_ds, eval_ds]):
        item_loaders[f'loader_{stage}'] = ItemLoader(meta_data=df,
                                                     transform=my_transforms()[stage],
                                                     parse_item_cb=parse_item,
                                                     batch_size=cfg.bs, num_workers=cfg.num_workers,
                                                     shuffle=True if stage == "train" else False)

    data_provider = DataProvider(item_loaders)

    train_cbs = (CosineAnnealingWarmRestartsWithWarmup(optimizer=optimizer, warmup_epochs=(0, 10, 20),
                                                       warmup_lrs=(0, 0.1, 0.01), T_O=5, T_mult=2, eta_min=0),
                 RunningAverageMeter(name="loss"),
                 AccuracyMeter(name="acc"))

    val_cbs = (RunningAverageMeter(name="loss"),
               AccuracyMeter(name="acc"),
               ScalarMeterLogger(writer=summary_writer),
               ModelSaver(metric_names='loss', save_dir=cfg.snapshots, conditions='min', model=model),
               ModelSaver(metric_names='acc', save_dir=cfg.snapshots, conditions='max', model=model))

    session = dict()
    session['mymodel'] = Session(data_provider=data_provider,
                                 train_loader_names=cfg.sampling.train.data_provider.mymodel.keys(),
                                 val_loader_names=cfg.sampling.eval.data_provider.mymodel.keys(),
                                 module=model, loss=criterion, optimizer=optimizer,
                                 train_callbacks=train_cbs,
                                 val_callbacks=val_cbs)

    strategy = Strategy(data_provider=data_provider,
                        data_sampling_config=cfg.sampling,
                        strategy_config=cfg.strategy,
                        sessions=session,
                        n_epochs=cfg.n_epochs,
                        device=device)

    strategy.run()
Exemplo n.º 8
0
def mixmatch_ema_data_provider(model,
                               augmentation,
                               labeled_meta_data,
                               unlabeled_meta_data,
                               val_labeled_data,
                               n_augmentations,
                               parse_item,
                               bs,
                               transforms,
                               root="",
                               num_threads=4):
    itemloader_dict = {}
    itemloader_dict['all_train'] = MixMatchSampler(
        model=model,
        name="train_mixmatch",
        augmentation=augmentation,
        labeled_meta_data=labeled_meta_data,
        unlabeled_meta_data=unlabeled_meta_data,
        n_augmentations=n_augmentations,
        data_key='data',
        target_key='target',
        parse_item_cb=parse_item,
        batch_size=bs,
        transform=transforms[0],
        num_workers=num_threads,
        shuffle=True)

    itemloader_dict['labeled_eval_st'] = ItemLoader(root=root,
                                                    meta_data=val_labeled_data,
                                                    name='l_eval',
                                                    transform=transforms[1],
                                                    parse_item_cb=parse_item,
                                                    batch_size=bs,
                                                    num_workers=num_threads,
                                                    shuffle=False)

    itemloader_dict['labeled_eval_te'] = ItemLoader(root=root,
                                                    meta_data=val_labeled_data,
                                                    name='l_eval',
                                                    transform=transforms[1],
                                                    parse_item_cb=parse_item,
                                                    batch_size=bs,
                                                    num_workers=num_threads,
                                                    shuffle=False)

    return DataProvider(itemloader_dict)
Exemplo n.º 9
0
def create_data_provider(args, config, parser, metadata, mean, std):
    # Compile ItemLoaders
    item_loaders = dict()
    for stage in ['train', 'val']:
        item_loaders[f'bfpn_{stage}'] = ItemLoader(meta_data=metadata[stage],
                                                   transform=train_test_transforms(config, mean, std)[stage],
                                                   parse_item_cb=parser,
                                                   batch_size=args.bs, num_workers=args.num_threads,
                                                   shuffle=True if stage == "train" else False)

    return DataProvider(item_loaders)
Exemplo n.º 10
0
def init_data_provider(args, df_train, df_val, item_loaders, test_ds):
    """
    function to initialize data provider for the autoencoder
    Parameters
    ----------
    args: Namespace
        arguments for the whole network parsed using argparse
    df_train: DataFrame
        training data as pandas DataFrame
    df_val: DataFrame
        validation data as pandas DataFrame
    item_loaders: dict
        empty dictionary to be populated by data samplers
    test_ds: DataFrame
        test data for visualization as pandas DataFrame


    Returns
    -------
    DataProvider
        DataProvider object constructed from all the data samplers
    """
    for stage, df in zip(['train', 'eval'], [df_train, df_val]):
        item_loaders[f'mnist_{stage}'] = ItemLoader(
            meta_data=df,
            transform=init_mnist_transforms()[0],
            parse_item_cb=parse_item_ae,
            batch_size=args.bs,
            num_workers=args.num_threads,
            shuffle=True if stage == 'train' else False)

    item_loaders['mnist_viz'] = ItemLoader(
        meta_data=test_ds,
        transform=init_mnist_transforms()[0],
        parse_item_cb=parse_item_ae,
        batch_size=args.bs,
        num_workers=args.num_threads,
        shuffle=True if stage == 'train' else False)
    return DataProvider(item_loaders)
Exemplo n.º 11
0
def test_loader_drop_last(batch_size, n_samples, metadata_fname_target_5_classes,
                          ones_image_parser, img_target_transformer, drop_last):
    iterm_loader = ItemLoader(meta_data=metadata_fname_target_5_classes, root='/tmp/',
                              batch_size=batch_size, parse_item_cb=ones_image_parser,
                              transform=img_target_transformer, shuffle=True, drop_last=drop_last)

    if drop_last:
        assert len(iterm_loader) == metadata_fname_target_5_classes.shape[0] // batch_size
    else:
        if metadata_fname_target_5_classes.shape[0] % batch_size != 0:
            assert len(iterm_loader) == metadata_fname_target_5_classes.shape[0] // batch_size + 1
        else:
            assert len(iterm_loader) == metadata_fname_target_5_classes.shape[0] // batch_size
Exemplo n.º 12
0
    g_network = Generator(nc=1, nz=args.latent_size,
                          ngf=args.g_net_features).to(device)
    g_optim = optim.Adam(g_network.parameters(),
                         lr=args.g_lr,
                         weight_decay=args.g_wd,
                         betas=(args.beta1, 0.999))
    g_crit = SSGeneratorLoss(d_network=d_network, d_loss=BCELoss()).to(device)

    item_loaders = dict()
    train_labeled_data, val_labeled_data, train_unlabeled_data, val_unlabeled_data = next(
        splitter)

    item_loaders["real_labeled_train"] = ItemLoader(
        meta_data=train_labeled_data,
        transform=init_mnist_transforms()[1],
        parse_item_cb=parse_item_mnist_ssgan,
        batch_size=args.bs,
        num_workers=args.num_threads,
        shuffle=True)

    item_loaders["real_unlabeled_train"] = ItemLoader(
        meta_data=train_unlabeled_data,
        transform=init_mnist_transforms()[1],
        parse_item_cb=parse_item_mnist_ssgan,
        batch_size=args.bs,
        num_workers=args.num_threads,
        shuffle=True)

    item_loaders["real_labeled_val"] = ItemLoader(
        meta_data=val_labeled_data,
        transform=init_mnist_transforms()[1],
Exemplo n.º 13
0
def main(cfg):
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    data_dir = os.path.join(os.environ['PWD'], cfg.data_dir)

    train_ds, classes = get_mnist(data_folder=data_dir, train=True)
    n_classes = len(classes)
    n_channels = 1

    criterion = torch.nn.CrossEntropyLoss()

    # Tensorboard visualization
    log_dir = cfg.log_dir
    comment = cfg.comment
    summary_writer = SummaryWriter(log_dir=log_dir, comment=comment)

    splitter = FoldSplit(train_ds, n_folds=5, target_col="target")

    for fold_id, (df_train, df_val) in enumerate(splitter):
        item_loaders = dict()

        for stage, df in zip(['train', 'eval'], [df_train, df_val]):
            item_loaders[f'loader_{stage}'] = ItemLoader(
                meta_data=df,
                transform=my_transforms()[stage],
                parse_item_cb=parse_item,
                batch_size=cfg.bs,
                num_workers=cfg.num_threads,
                shuffle=True if stage == "train" else False)

        model = SimpleConvNet(bw=cfg.bw,
                              drop_rate=cfg.dropout,
                              n_classes=n_classes).to(device)
        optimizer = torch.optim.Adam(params=model.parameters(),
                                     lr=cfg.lr,
                                     weight_decay=cfg.wd)
        data_provider = DataProvider(item_loaders)

        train_cbs = (RunningAverageMeter(name="loss"),
                     AccuracyMeter(name="acc"))

        val_cbs = (RunningAverageMeter(name="loss"), AccuracyMeter(name="acc"),
                   ScalarMeterLogger(writer=summary_writer),
                   ModelSaver(metric_names='loss',
                              save_dir=cfg.snapshots,
                              conditions='min',
                              model=model),
                   ModelSaver(metric_names='acc',
                              save_dir=cfg.snapshots,
                              conditions='max',
                              model=model))

        session = dict()
        session['mymodel'] = Session(
            data_provider=data_provider,
            train_loader_names=cfg.sampling.train.data_provider.mymodel.keys(),
            val_loader_names=cfg.sampling.eval.data_provider.mymodel.keys(),
            module=model,
            loss=criterion,
            optimizer=optimizer,
            train_callbacks=train_cbs,
            val_callbacks=val_cbs)

        strategy = Strategy(data_provider=data_provider,
                            data_sampling_config=cfg.sampling,
                            strategy_config=cfg.strategy,
                            sessions=session,
                            n_epochs=cfg.n_epochs,
                            device=device)

        strategy.run()
Exemplo n.º 14
0
                else:
                    pas_str = list_pas

                print("Processing {} with {} by best {}...".format(pas_str, comment, md))

                cm_norm_viz = ConfusionMatrixVisualizer(writer=writer, tag="CM_" + md + "_" + pas_str, normalize=True,
                                                        labels=["KL" + str(i) for i in range(5)],
                                                        parse_class=parse_class)
                cm_viz = ConfusionMatrixVisualizer(writer=writer, tag="CM_raw_" + md + "_" + pas_str, normalize=False,
                                                   labels=["KL" + str(i) for i in range(5)], parse_class=parse_class)

                ds_most_filtered = filter_most_by_pa(ds_most, df_most_ex, list_pas)

                loader = ItemLoader(root=args.root,
                                    meta_data=ds_most_filtered,
                                    transform=init_transform_wo_aug(),
                                    parse_item_cb=parse_item,
                                    batch_size=args.bs, num_workers=args.num_threads,
                                    shuffle=False, drop_last=False)

                kappa_meter.on_epoch_begin(0)
                acc_meter.on_epoch_begin(0)
                mse_meter.on_epoch_begin(0)

                cm_viz.on_epoch_begin(0)
                cm_norm_viz.on_epoch_begin(0)
                progress_bar = tqdm(range(len(loader)), total=len(loader), desc="Eval::")

                if save_detail_preds:
                    bi_preds_probs_all = []
                    bi_targets_all = []
Exemplo n.º 15
0
def mt_data_provider(st_model,
                     te_model,
                     train_labeled_data,
                     train_unlabeled_data,
                     val_labeled_data,
                     val_unlabeled_data,
                     transforms,
                     parse_item,
                     bs,
                     num_threads,
                     item_loaders=dict(),
                     n_augmentations=1,
                     output_type='logits',
                     root=""):
    """
    Default setting of data provider for Mean-Teacher

    """

    # Train
    item_loaders["labeled_train_st"] = AugmentedGroupSampler(
        root=root,
        name='l_st',
        meta_data=train_labeled_data,
        model=st_model,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=True)

    item_loaders["unlabeled_train_st"] = AugmentedGroupSampler(
        root=root,
        name='u_st',
        model=st_model,
        meta_data=train_unlabeled_data,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=True)

    item_loaders["labeled_train_te"] = AugmentedGroupSampler(
        root=root,
        name='l_te',
        meta_data=train_labeled_data,
        model=te_model,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        detach=True,
        shuffle=True)

    item_loaders["unlabeled_train_te"] = AugmentedGroupSampler(
        root=root,
        name='u_te',
        model=te_model,
        meta_data=train_unlabeled_data,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        detach=True,
        shuffle=True)

    # Eval

    item_loaders["labeled_eval_st"] = AugmentedGroupSampler(
        root=root,
        name='l_st',
        meta_data=val_labeled_data,
        model=st_model,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=False)

    item_loaders["unlabeled_eval_st"] = AugmentedGroupSampler(
        root=root,
        name='u_st',
        model=st_model,
        meta_data=val_unlabeled_data,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=False)

    item_loaders["labeled_eval_te"] = ItemLoader(root=root,
                                                 meta_data=val_labeled_data,
                                                 name='l_te_eval',
                                                 transform=transforms[1],
                                                 parse_item_cb=parse_item,
                                                 batch_size=bs,
                                                 num_workers=num_threads,
                                                 shuffle=False)

    return DataProvider(item_loaders)
Exemplo n.º 16
0
def pimodel_data_provider(model,
                          train_labeled_data,
                          train_unlabeled_data,
                          val_labeled_data,
                          val_unlabeled_data,
                          transforms,
                          parse_item,
                          bs,
                          num_threads,
                          item_loaders=dict(),
                          root="",
                          n_augmentations=1,
                          output_type='logits'):
    """
    Default setting of data provider for Pi-Model

    """
    item_loaders["labeled_train"] = ItemLoader(root=root,
                                               meta_data=train_labeled_data,
                                               name='l',
                                               transform=transforms[0],
                                               parse_item_cb=parse_item,
                                               batch_size=bs,
                                               num_workers=num_threads,
                                               shuffle=True)

    # item_loaders["labeled_train"] = AugmentedGroupSampler(root=root, model=model, name='l', output_type=output_type,
    #                                                       meta_data=train_labeled_data,
    #                                                       n_augmentations=n_augmentations,
    #                                                       augmentation=transforms[2],
    #                                                       transform=transforms[1],
    #                                                       parse_item_cb=parse_item,
    #                                                       batch_size=bs, num_workers=num_workers,
    #                                                       shuffle=True)

    item_loaders["unlabeled_train"] = AugmentedGroupSampler(
        root=root,
        model=model,
        name='u',
        output_type=output_type,
        meta_data=train_unlabeled_data,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[0],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=True)

    item_loaders["labeled_eval"] = ItemLoader(root=root,
                                              meta_data=val_labeled_data,
                                              name='l',
                                              transform=transforms[1],
                                              parse_item_cb=parse_item,
                                              batch_size=bs,
                                              num_workers=num_threads,
                                              shuffle=False)

    # item_loaders["labeled_eval"] = AugmentedGroupSampler(root=root, model=model, name='l', output_type=output_type,
    #                                                      meta_data=val_labeled_data,
    #                                                      n_augmentations=n_augmentations,
    #                                                      augmentation=transforms[2],
    #                                                      transform=transforms[1],
    #                                                      parse_item_cb=parse_item,
    #                                                      batch_size=bs, num_workers=num_workers,
    #                                                      shuffle=False)

    # item_loaders["unlabeled_eval"] = AugmentedGroupSampler(root=root, model=model, name='u', output_type=output_type,
    #                                                        meta_data=val_unlabeled_data,
    #                                                        n_augmentations=n_augmentations,
    #                                                        augmentation=transforms[2],
    #                                                        transform=transforms[1],
    #                                                        parse_item_cb=parse_item,
    #                                                        batch_size=bs, num_workers=num_workers,
    #                                                        shuffle=False)

    return DataProvider(item_loaders)
Exemplo n.º 17
0
class MixMatchEMASampler(object):
    def __init__(self,
                 st_model: nn.Module,
                 te_model: nn.Module,
                 name: str,
                 augmentation,
                 labeled_meta_data: pd.DataFrame,
                 unlabeled_meta_data: pd.DataFrame,
                 n_augmentations=1,
                 output_type='logits',
                 data_key: str = "data",
                 target_key: str = 'target',
                 parse_item_cb: callable or None = None,
                 root: str or None = None,
                 batch_size: int = 1,
                 num_workers: int = 0,
                 shuffle: bool = False,
                 pin_memory: bool = False,
                 collate_fn: callable = default_collate,
                 transform: callable or None = None,
                 sampler: torch.utils.data.sampler.Sampler or None = None,
                 batch_sampler=None,
                 drop_last: bool = False,
                 timeout: int = 0,
                 detach: bool = False):
        self._label_sampler = ItemLoader(meta_data=labeled_meta_data,
                                         parse_item_cb=parse_item_cb,
                                         root=root,
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         shuffle=shuffle,
                                         pin_memory=pin_memory,
                                         collate_fn=collate_fn,
                                         transform=transform,
                                         sampler=sampler,
                                         batch_sampler=batch_sampler,
                                         drop_last=drop_last,
                                         timeout=timeout)

        self._unlabel_sampler = ItemLoader(meta_data=unlabeled_meta_data,
                                           parse_item_cb=parse_item_cb,
                                           root=root,
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           shuffle=shuffle,
                                           pin_memory=pin_memory,
                                           collate_fn=collate_fn,
                                           transform=transform,
                                           sampler=sampler,
                                           batch_sampler=batch_sampler,
                                           drop_last=drop_last,
                                           timeout=timeout)

        self._name = name
        self._st_model: nn.Module = st_model
        self._te_model: nn.Module = te_model
        self._n_augmentations = n_augmentations
        self._augmentation = augmentation
        self._data_key = data_key
        self._target_key = target_key
        self._output_type = output_type
        self._detach = detach
        self._len = max(len(self._label_sampler), len(self._unlabel_sampler))

    def __len__(self):
        return self._len

    def _crop_if_needed(self, df1, df2):
        assert len(df1) == len(df2)
        for i in range(len(df1)):
            if len(df1[i]['data']) != len(df2[i]['data']):
                min_len = min(len(df1[i]['data']), len(df2[i]['data']))
                df1[i][self._data_key] = df1[i][self._data_key][:min_len, :]
                df2[i][self._data_key] = df2[i][self._data_key][:min_len, :]
                df1[i][self._target_key] = df1[i][self._target_key][:min_len]
                df2[i][self._target_key] = df2[i][self._target_key][:min_len]
        return df1, df2

    def sharpen(self, x, T=0.5):
        assert len(x.shape) == 2

        _x = torch.pow(x, 1 / T)
        s = torch.sum(_x, dim=-1, keepdim=True)
        _x = _x / s
        return _x

    def _create_union_data(self, r1, r2):
        assert len(r1) == len(r2)
        r = []

        for i in range(len(r1)):
            union_rows = dict()
            union_rows[self._data_key] = torch.cat(
                [r1[i][self._data_key], r2[i][self._data_key]], dim=0)
            union_rows["probs"] = torch.cat([r1[i]["probs"], r2[i]["probs"]],
                                            dim=0)
            union_rows['name'] = r1[i]['name']
            r.append(union_rows)
        return r

    def _mixup(self, x1, y1, x2, y2, alpha=0.75):
        l = np.random.beta(alpha, alpha)
        l = max(l, 1 - l)
        x = l * x1 + (1 - l) * x2
        y = l * y1 + (1 - l) * y2
        return x, y

    def sample(self, k=1):
        samples = []
        labeled_sampled_rows = self._label_sampler.sample(k)
        unlabeled_sampled_rows = self._unlabel_sampler.sample(k)

        labeled_sampled_rows, unlabeled_sampled_rows = self._crop_if_needed(
            labeled_sampled_rows, unlabeled_sampled_rows)

        for i in range(k):
            # Unlabeled data
            unlabeled_sampled_rows[i][
                self._data_key] = unlabeled_sampled_rows[i][self._data_key].to(
                    next(self._model.parameters()).device)

            u_imgs = unlabeled_sampled_rows[i][self._data_key]

            list_imgs = []
            for b in range(u_imgs.shape[0]):
                for j in range(self._n_augmentations):
                    img = u_imgs[b, :, :, :]
                    if img.shape[0] == 1:
                        img = img[0, :, :]
                    else:
                        img = img.permute(1, 2, 0)

                    img_cpu = to_cpu(img)
                    aug_img = self._augmentation(img_cpu)
                    list_imgs.append(aug_img)

            batch_imgs = torch.cat(list_imgs, dim=0)
            batch_imgs = batch_imgs.to(next(self._model.parameters()).device)
            if self._output_type == 'logits':
                out = self._model(batch_imgs)
            elif self._output_type == 'features':
                out = self._model.get_features(batch_imgs)

            preds = F.softmax(out, dim=1)
            preds = preds.view(u_imgs.shape[0], -1, preds.shape[-1])

            mean_preds = torch.mean(preds, dim=1)
            guessing_labels = self.sharpen(mean_preds).detach()

            unlabeled_sampled_rows[i]["probs"] = guessing_labels

            # Labeled data
            labeled_sampled_rows[i][self._data_key] = labeled_sampled_rows[i][
                self._data_key].to(next(self._model.parameters()).device)
            target_l = labeled_sampled_rows[i][self._target_key]
            onehot_l = torch.zeros(guessing_labels.shape)
            onehot_l.scatter_(1, target_l.type(torch.int64).unsqueeze(-1), 1.0)
            labeled_sampled_rows[i]["probs"] = onehot_l.to(
                next(self._model.parameters()).device)

        union_rows = self._create_union_data(labeled_sampled_rows,
                                             unlabeled_sampled_rows)

        for i in range(k):
            ridx = np.random.permutation(
                union_rows[i][self._data_key].shape[0])
            u = unlabeled_sampled_rows[i]
            x = labeled_sampled_rows[i]

            x_mix, target_mix = self._mixup(
                x[self._data_key], x["probs"],
                union_rows[i][self._data_key][ridx[i]],
                union_rows[i]["probs"][ridx[i]])
            u_mix, pred_mix = self._mixup(
                u[self._data_key], u["probs"],
                union_rows[i][self._data_key][ridx[k + i]],
                union_rows[i]["probs"][ridx[k + i]])

            samples.append({
                'name': self._name,
                'x_mix': x_mix,
                'target_mix_x': target_mix,
                'u_mix': u_mix,
                'target_mix_u': pred_mix,
                'target_x': x[self._target_key]
            })
        return samples
Exemplo n.º 18
0
        sampling_config = yaml.load(f)

    for fold_id, df in enumerate(splitter):
        df_train = df[0]
        df_val = df[1]

        print("Fold {} on {} labeled samples...".format(
            fold_id, len(df_train.index)))
        item_loaders = dict()

        # Data provider
        for stage, df in zip(['train', 'eval'], [df_train, df_val]):
            item_loaders[f'data_{stage}'] = ItemLoader(
                root=args.root,
                meta_data=df,
                transform=init_transforms()[stage],
                parse_item_cb=parse_item,
                batch_size=args.bs,
                num_workers=args.num_threads,
                shuffle=True if stage == "train" else False)
        data_provider = DataProvider(item_loaders)

        # Visualizers
        summary_writer = SummaryWriter(logdir=logdir,
                                       comment=comment + "_fold" +
                                       str(fold_id + 1))
        model_dir = os.path.join(summary_writer.logdir, args.model_dir)
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)

        # Model
        model = make_model(model_name=args.model_name,
Exemplo n.º 19
0
def worker_process(gpu, ngpus, sampling_config, strategy_config, args):
    args.gpu = gpu  # this line of code is not redundant
    if args.distributed:
        lr_m = float(args.batch_size * args.world_size) / 256.
    else:
        lr_m = 1.0
    criterion = torch.nn.CrossEntropyLoss().to(gpu)
    train_ds, classes = get_mnist(data_folder=args.save_data, train=True)
    test_ds, _ = get_mnist(data_folder=args.save_data, train=False)
    model = SimpleConvNet(bw=args.bw,
                          drop=args.dropout,
                          n_cls=len(classes),
                          n_channels=args.n_channels).to(gpu)
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=args.lr * lr_m,
                                 weight_decay=args.wd)

    args, model, optimizer = convert_according_to_args(args=args,
                                                       gpu=gpu,
                                                       ngpus=ngpus,
                                                       network=model,
                                                       optim=optimizer)

    item_loaders = dict()
    for stage, df in zip(['train', 'eval'], [train_ds, test_ds]):
        if args.distributed:
            item_loaders[f'mnist_{stage}'] = DistributedItemLoader(
                meta_data=df,
                transform=init_mnist_cifar_transforms(1, stage),
                parse_item_cb=parse_item_mnist,
                args=args)
        else:
            item_loaders[f'mnist_{stage}'] = ItemLoader(
                meta_data=df,
                transform=init_mnist_cifar_transforms(1, stage),
                parse_item_cb=parse_item_mnist,
                batch_size=args.batch_size,
                num_workers=args.workers,
                shuffle=True if stage == "train" else False)
    data_provider = DataProvider(item_loaders)
    if args.gpu == 0:
        log_dir = args.log_dir
        comment = args.comment
        summary_writer = SummaryWriter(log_dir=log_dir,
                                       comment='_' + comment + 'gpu_' +
                                       str(args.gpu))
        train_cbs = (RunningAverageMeter(prefix="train", name="loss"),
                     AccuracyMeter(prefix="train", name="acc"))

        val_cbs = (RunningAverageMeter(prefix="eval", name="loss"),
                   AccuracyMeter(prefix="eval", name="acc"),
                   ScalarMeterLogger(writer=summary_writer),
                   ModelSaver(metric_names='eval/loss',
                              save_dir=args.snapshots,
                              conditions='min',
                              model=model))
    else:
        train_cbs = ()
        val_cbs = ()

    strategy = Strategy(data_provider=data_provider,
                        train_loader_names=tuple(
                            sampling_config['train']['data_provider'].keys()),
                        val_loader_names=tuple(
                            sampling_config['eval']['data_provider'].keys()),
                        data_sampling_config=sampling_config,
                        loss=criterion,
                        model=model,
                        n_epochs=args.n_epochs,
                        optimizer=optimizer,
                        train_callbacks=train_cbs,
                        val_callbacks=val_cbs,
                        device=torch.device('cuda:{}'.format(args.gpu)),
                        distributed=args.distributed,
                        use_apex=args.use_apex)

    strategy.run()