def __init__(self, model: nn.Module, name: str, augmentation, labeled_meta_data: pd.DataFrame, unlabeled_meta_data: pd.DataFrame, n_augmentations=1, output_type='logits', data_key: str = "data", target_key: str = 'target', parse_item_cb: callable or None = None, root: str or None = None, batch_size: int = 1, num_workers: int = 0, shuffle: bool = False, pin_memory: bool = False, collate_fn: callable = default_collate, transform: callable or None = None, sampler: torch.utils.data.sampler.Sampler or None = None, batch_sampler=None, drop_last: bool = False, timeout: int = 0, detach: bool = False): self._label_sampler = ItemLoader(meta_data=labeled_meta_data, parse_item_cb=parse_item_cb, root=root, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, pin_memory=pin_memory, collate_fn=collate_fn, transform=transform, sampler=sampler, batch_sampler=batch_sampler, drop_last=drop_last, timeout=timeout) self._unlabel_sampler = ItemLoader(meta_data=unlabeled_meta_data, parse_item_cb=parse_item_cb, root=root, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, pin_memory=pin_memory, collate_fn=collate_fn, transform=transform, sampler=sampler, batch_sampler=batch_sampler, drop_last=drop_last, timeout=timeout) self._name = name self._model: nn.Module = model self._n_augmentations = n_augmentations self._augmentation = augmentation self._data_key = data_key self._target_key = target_key self._output_type = output_type self._detach = detach self._len = max(len(self._label_sampler), len(self._unlabel_sampler))
def estimate_mean_std(config, metadata, parse_item_cb, num_threads=8, bs=16): mean_std_loader = ItemLoader( meta_data=metadata, transform=train_test_transforms(config)['train'], parse_item_cb=parse_item_cb, batch_size=bs, num_workers=num_threads, shuffle=False) mean = None std = None for i in tqdm(range(len(mean_std_loader)), desc='Calculating mean and standard deviation'): for batch in mean_std_loader.sample(): if mean is None: mean = torch.zeros(batch['data'].size(1)) std = torch.zeros(batch['data'].size(1)) # for channel in range(batch['data'].size(1)): # mean[channel] += batch['data'][:, channel, :, :].mean().item() # std[channel] += batch['data'][:, channel, :, :].std().item() mean += batch['data'].mean().item() std += batch['data'].std().item() mean /= len(mean_std_loader) std /= len(mean_std_loader) return mean, std
def semixup_data_provider(model, alpha, n_classes, train_labeled_data, train_unlabeled_data, val_labeled_data, transforms, parse_item, bs, num_workers, item_loaders=dict(), root="", augmentation=None, data_rearrange=None): """ Default setting of data provider for Semixup """ # item_loaders["labeled_train"] = MixUpSampler(meta_data=train_labeled_data, name='l_mixup', alpha=alpha, model=model, # transform=transforms['train'], parse_item_cb=parse_item, batch_size=bs, # data_rearrange=data_rearrange, # num_workers=num_workers, root=root, shuffle=True) item_loaders["labeled_train"] = ItemLoader(meta_data=train_labeled_data, name='l_norm', transform=transforms['train'], parse_item_cb=parse_item, batch_size=bs, num_workers=num_workers, root=root, shuffle=True) item_loaders["unlabeled_train"] = SemixupSampler(meta_data=train_unlabeled_data, name='u_mixup', alpha=alpha, model=model, min_lambda=0.55, transform=transforms['train'], parse_item_cb=parse_item, batch_size=bs, data_rearrange=data_rearrange, num_workers=num_workers, augmentation=augmentation, root=root, shuffle=True) item_loaders["labeled_eval"] = ItemLoader(meta_data=val_labeled_data, name='l_norm', transform=transforms['eval'], parse_item_cb=parse_item, batch_size=bs, num_workers=num_workers, root=root, shuffle=False) return DataProvider(item_loaders)
def ict_data_provider(model, alpha, n_classes, train_labeled_data, train_unlabeled_data, val_labeled_data, val_unlabeled_data, transforms, parse_item, bs, num_threads, item_loaders=dict(), root=""): """ Default setting of data provider for ICT """ item_loaders["labeled_train"] = ItemLoader(meta_data=train_labeled_data, name='l_norm', transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, root=root, shuffle=True) item_loaders["unlabeled_train"] = MixUpSampler( meta_data=train_unlabeled_data, name='u_mixup', alpha=alpha, model=model, transform=transforms[0], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, root=root, shuffle=True) item_loaders["labeled_eval"] = ItemLoader(meta_data=val_labeled_data, name='l_norm', transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, root=root, shuffle=False) item_loaders["unlabeled_eval"] = MixUpSampler(meta_data=val_unlabeled_data, name='u_mixup', alpha=alpha, model=model, transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, root=root, shuffle=False) return DataProvider(item_loaders)
def test_loader_samples_batches(batch_size, n_samples, metadata_fname_target_5_classes, ones_image_parser, img_target_transformer): iterm_loader = ItemLoader(meta_data=metadata_fname_target_5_classes, root='/tmp/', batch_size=batch_size, parse_item_cb=ones_image_parser, transform=img_target_transformer, shuffle=True) samples = iterm_loader.sample(n_samples) assert len(samples) == n_samples assert samples[0]['img'].size(0) == batch_size assert samples[0]['target'].size(0) == batch_size
def create_data_provider(args, config, parser, metadata, mean, std): """ Setup dataloader and augmentations :param args: General arguments :param config: Experiment parameters :param parser: Function for loading images :param metadata: Image paths and subject IDs :param mean: Dataset mean :param std: Dataset std :return: The compiled dataloader """ # Compile ItemLoaders item_loaders = dict() for stage in ['train', 'val']: item_loaders[f'bfpn_{stage}'] = ItemLoader( meta_data=metadata[stage], transform=train_test_transforms( config, mean, std, crop_size=tuple(config['training']['crop_size']))[stage], parse_item_cb=parser, batch_size=config['training']['bs'], num_workers=args.num_threads, shuffle=True if stage == "train" else False) return DataProvider(item_loaders)
def main(cfg): torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) random.seed(cfg.seed) data_dir = os.path.join(os.environ['PWD'], cfg.data_dir) train_ds, classes = get_cifar10(data_folder=data_dir, train=True) eval_ds, _ = get_cifar10(data_folder=data_dir, train=False) n_channels = 3 criterion = torch.nn.CrossEntropyLoss() model = ResNet(in_channels=n_channels, n_features=64, drop_rate=0.3).to(device).half() optimizer = torch.optim.SGD(params=model.parameters(), lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.wd, nesterov=True) # Tensorboard visualization log_dir = cfg.log_dir comment = cfg.comment summary_writer = SummaryWriter(log_dir=log_dir, comment=comment) item_loaders = dict() for stage, df in zip(['train', 'eval'], [train_ds, eval_ds]): item_loaders[f'loader_{stage}'] = ItemLoader(meta_data=df, transform=my_transforms()[stage], parse_item_cb=parse_item, batch_size=cfg.bs, num_workers=cfg.num_workers, shuffle=True if stage == "train" else False) data_provider = DataProvider(item_loaders) train_cbs = (CosineAnnealingWarmRestartsWithWarmup(optimizer=optimizer, warmup_epochs=(0, 10, 20), warmup_lrs=(0, 0.1, 0.01), T_O=5, T_mult=2, eta_min=0), RunningAverageMeter(name="loss"), AccuracyMeter(name="acc")) val_cbs = (RunningAverageMeter(name="loss"), AccuracyMeter(name="acc"), ScalarMeterLogger(writer=summary_writer), ModelSaver(metric_names='loss', save_dir=cfg.snapshots, conditions='min', model=model), ModelSaver(metric_names='acc', save_dir=cfg.snapshots, conditions='max', model=model)) session = dict() session['mymodel'] = Session(data_provider=data_provider, train_loader_names=cfg.sampling.train.data_provider.mymodel.keys(), val_loader_names=cfg.sampling.eval.data_provider.mymodel.keys(), module=model, loss=criterion, optimizer=optimizer, train_callbacks=train_cbs, val_callbacks=val_cbs) strategy = Strategy(data_provider=data_provider, data_sampling_config=cfg.sampling, strategy_config=cfg.strategy, sessions=session, n_epochs=cfg.n_epochs, device=device) strategy.run()
def mixmatch_ema_data_provider(model, augmentation, labeled_meta_data, unlabeled_meta_data, val_labeled_data, n_augmentations, parse_item, bs, transforms, root="", num_threads=4): itemloader_dict = {} itemloader_dict['all_train'] = MixMatchSampler( model=model, name="train_mixmatch", augmentation=augmentation, labeled_meta_data=labeled_meta_data, unlabeled_meta_data=unlabeled_meta_data, n_augmentations=n_augmentations, data_key='data', target_key='target', parse_item_cb=parse_item, batch_size=bs, transform=transforms[0], num_workers=num_threads, shuffle=True) itemloader_dict['labeled_eval_st'] = ItemLoader(root=root, meta_data=val_labeled_data, name='l_eval', transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, shuffle=False) itemloader_dict['labeled_eval_te'] = ItemLoader(root=root, meta_data=val_labeled_data, name='l_eval', transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, shuffle=False) return DataProvider(itemloader_dict)
def create_data_provider(args, config, parser, metadata, mean, std): # Compile ItemLoaders item_loaders = dict() for stage in ['train', 'val']: item_loaders[f'bfpn_{stage}'] = ItemLoader(meta_data=metadata[stage], transform=train_test_transforms(config, mean, std)[stage], parse_item_cb=parser, batch_size=args.bs, num_workers=args.num_threads, shuffle=True if stage == "train" else False) return DataProvider(item_loaders)
def init_data_provider(args, df_train, df_val, item_loaders, test_ds): """ function to initialize data provider for the autoencoder Parameters ---------- args: Namespace arguments for the whole network parsed using argparse df_train: DataFrame training data as pandas DataFrame df_val: DataFrame validation data as pandas DataFrame item_loaders: dict empty dictionary to be populated by data samplers test_ds: DataFrame test data for visualization as pandas DataFrame Returns ------- DataProvider DataProvider object constructed from all the data samplers """ for stage, df in zip(['train', 'eval'], [df_train, df_val]): item_loaders[f'mnist_{stage}'] = ItemLoader( meta_data=df, transform=init_mnist_transforms()[0], parse_item_cb=parse_item_ae, batch_size=args.bs, num_workers=args.num_threads, shuffle=True if stage == 'train' else False) item_loaders['mnist_viz'] = ItemLoader( meta_data=test_ds, transform=init_mnist_transforms()[0], parse_item_cb=parse_item_ae, batch_size=args.bs, num_workers=args.num_threads, shuffle=True if stage == 'train' else False) return DataProvider(item_loaders)
def test_loader_drop_last(batch_size, n_samples, metadata_fname_target_5_classes, ones_image_parser, img_target_transformer, drop_last): iterm_loader = ItemLoader(meta_data=metadata_fname_target_5_classes, root='/tmp/', batch_size=batch_size, parse_item_cb=ones_image_parser, transform=img_target_transformer, shuffle=True, drop_last=drop_last) if drop_last: assert len(iterm_loader) == metadata_fname_target_5_classes.shape[0] // batch_size else: if metadata_fname_target_5_classes.shape[0] % batch_size != 0: assert len(iterm_loader) == metadata_fname_target_5_classes.shape[0] // batch_size + 1 else: assert len(iterm_loader) == metadata_fname_target_5_classes.shape[0] // batch_size
g_network = Generator(nc=1, nz=args.latent_size, ngf=args.g_net_features).to(device) g_optim = optim.Adam(g_network.parameters(), lr=args.g_lr, weight_decay=args.g_wd, betas=(args.beta1, 0.999)) g_crit = SSGeneratorLoss(d_network=d_network, d_loss=BCELoss()).to(device) item_loaders = dict() train_labeled_data, val_labeled_data, train_unlabeled_data, val_unlabeled_data = next( splitter) item_loaders["real_labeled_train"] = ItemLoader( meta_data=train_labeled_data, transform=init_mnist_transforms()[1], parse_item_cb=parse_item_mnist_ssgan, batch_size=args.bs, num_workers=args.num_threads, shuffle=True) item_loaders["real_unlabeled_train"] = ItemLoader( meta_data=train_unlabeled_data, transform=init_mnist_transforms()[1], parse_item_cb=parse_item_mnist_ssgan, batch_size=args.bs, num_workers=args.num_threads, shuffle=True) item_loaders["real_labeled_val"] = ItemLoader( meta_data=val_labeled_data, transform=init_mnist_transforms()[1],
def main(cfg): torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) random.seed(cfg.seed) data_dir = os.path.join(os.environ['PWD'], cfg.data_dir) train_ds, classes = get_mnist(data_folder=data_dir, train=True) n_classes = len(classes) n_channels = 1 criterion = torch.nn.CrossEntropyLoss() # Tensorboard visualization log_dir = cfg.log_dir comment = cfg.comment summary_writer = SummaryWriter(log_dir=log_dir, comment=comment) splitter = FoldSplit(train_ds, n_folds=5, target_col="target") for fold_id, (df_train, df_val) in enumerate(splitter): item_loaders = dict() for stage, df in zip(['train', 'eval'], [df_train, df_val]): item_loaders[f'loader_{stage}'] = ItemLoader( meta_data=df, transform=my_transforms()[stage], parse_item_cb=parse_item, batch_size=cfg.bs, num_workers=cfg.num_threads, shuffle=True if stage == "train" else False) model = SimpleConvNet(bw=cfg.bw, drop_rate=cfg.dropout, n_classes=n_classes).to(device) optimizer = torch.optim.Adam(params=model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) data_provider = DataProvider(item_loaders) train_cbs = (RunningAverageMeter(name="loss"), AccuracyMeter(name="acc")) val_cbs = (RunningAverageMeter(name="loss"), AccuracyMeter(name="acc"), ScalarMeterLogger(writer=summary_writer), ModelSaver(metric_names='loss', save_dir=cfg.snapshots, conditions='min', model=model), ModelSaver(metric_names='acc', save_dir=cfg.snapshots, conditions='max', model=model)) session = dict() session['mymodel'] = Session( data_provider=data_provider, train_loader_names=cfg.sampling.train.data_provider.mymodel.keys(), val_loader_names=cfg.sampling.eval.data_provider.mymodel.keys(), module=model, loss=criterion, optimizer=optimizer, train_callbacks=train_cbs, val_callbacks=val_cbs) strategy = Strategy(data_provider=data_provider, data_sampling_config=cfg.sampling, strategy_config=cfg.strategy, sessions=session, n_epochs=cfg.n_epochs, device=device) strategy.run()
else: pas_str = list_pas print("Processing {} with {} by best {}...".format(pas_str, comment, md)) cm_norm_viz = ConfusionMatrixVisualizer(writer=writer, tag="CM_" + md + "_" + pas_str, normalize=True, labels=["KL" + str(i) for i in range(5)], parse_class=parse_class) cm_viz = ConfusionMatrixVisualizer(writer=writer, tag="CM_raw_" + md + "_" + pas_str, normalize=False, labels=["KL" + str(i) for i in range(5)], parse_class=parse_class) ds_most_filtered = filter_most_by_pa(ds_most, df_most_ex, list_pas) loader = ItemLoader(root=args.root, meta_data=ds_most_filtered, transform=init_transform_wo_aug(), parse_item_cb=parse_item, batch_size=args.bs, num_workers=args.num_threads, shuffle=False, drop_last=False) kappa_meter.on_epoch_begin(0) acc_meter.on_epoch_begin(0) mse_meter.on_epoch_begin(0) cm_viz.on_epoch_begin(0) cm_norm_viz.on_epoch_begin(0) progress_bar = tqdm(range(len(loader)), total=len(loader), desc="Eval::") if save_detail_preds: bi_preds_probs_all = [] bi_targets_all = []
def mt_data_provider(st_model, te_model, train_labeled_data, train_unlabeled_data, val_labeled_data, val_unlabeled_data, transforms, parse_item, bs, num_threads, item_loaders=dict(), n_augmentations=1, output_type='logits', root=""): """ Default setting of data provider for Mean-Teacher """ # Train item_loaders["labeled_train_st"] = AugmentedGroupSampler( root=root, name='l_st', meta_data=train_labeled_data, model=st_model, n_augmentations=n_augmentations, augmentation=transforms[2], transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, shuffle=True) item_loaders["unlabeled_train_st"] = AugmentedGroupSampler( root=root, name='u_st', model=st_model, meta_data=train_unlabeled_data, n_augmentations=n_augmentations, augmentation=transforms[2], transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, shuffle=True) item_loaders["labeled_train_te"] = AugmentedGroupSampler( root=root, name='l_te', meta_data=train_labeled_data, model=te_model, n_augmentations=n_augmentations, augmentation=transforms[2], transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, detach=True, shuffle=True) item_loaders["unlabeled_train_te"] = AugmentedGroupSampler( root=root, name='u_te', model=te_model, meta_data=train_unlabeled_data, n_augmentations=n_augmentations, augmentation=transforms[2], transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, detach=True, shuffle=True) # Eval item_loaders["labeled_eval_st"] = AugmentedGroupSampler( root=root, name='l_st', meta_data=val_labeled_data, model=st_model, n_augmentations=n_augmentations, augmentation=transforms[2], transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, shuffle=False) item_loaders["unlabeled_eval_st"] = AugmentedGroupSampler( root=root, name='u_st', model=st_model, meta_data=val_unlabeled_data, n_augmentations=n_augmentations, augmentation=transforms[2], transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, shuffle=False) item_loaders["labeled_eval_te"] = ItemLoader(root=root, meta_data=val_labeled_data, name='l_te_eval', transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, shuffle=False) return DataProvider(item_loaders)
def pimodel_data_provider(model, train_labeled_data, train_unlabeled_data, val_labeled_data, val_unlabeled_data, transforms, parse_item, bs, num_threads, item_loaders=dict(), root="", n_augmentations=1, output_type='logits'): """ Default setting of data provider for Pi-Model """ item_loaders["labeled_train"] = ItemLoader(root=root, meta_data=train_labeled_data, name='l', transform=transforms[0], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, shuffle=True) # item_loaders["labeled_train"] = AugmentedGroupSampler(root=root, model=model, name='l', output_type=output_type, # meta_data=train_labeled_data, # n_augmentations=n_augmentations, # augmentation=transforms[2], # transform=transforms[1], # parse_item_cb=parse_item, # batch_size=bs, num_workers=num_workers, # shuffle=True) item_loaders["unlabeled_train"] = AugmentedGroupSampler( root=root, model=model, name='u', output_type=output_type, meta_data=train_unlabeled_data, n_augmentations=n_augmentations, augmentation=transforms[2], transform=transforms[0], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, shuffle=True) item_loaders["labeled_eval"] = ItemLoader(root=root, meta_data=val_labeled_data, name='l', transform=transforms[1], parse_item_cb=parse_item, batch_size=bs, num_workers=num_threads, shuffle=False) # item_loaders["labeled_eval"] = AugmentedGroupSampler(root=root, model=model, name='l', output_type=output_type, # meta_data=val_labeled_data, # n_augmentations=n_augmentations, # augmentation=transforms[2], # transform=transforms[1], # parse_item_cb=parse_item, # batch_size=bs, num_workers=num_workers, # shuffle=False) # item_loaders["unlabeled_eval"] = AugmentedGroupSampler(root=root, model=model, name='u', output_type=output_type, # meta_data=val_unlabeled_data, # n_augmentations=n_augmentations, # augmentation=transforms[2], # transform=transforms[1], # parse_item_cb=parse_item, # batch_size=bs, num_workers=num_workers, # shuffle=False) return DataProvider(item_loaders)
class MixMatchEMASampler(object): def __init__(self, st_model: nn.Module, te_model: nn.Module, name: str, augmentation, labeled_meta_data: pd.DataFrame, unlabeled_meta_data: pd.DataFrame, n_augmentations=1, output_type='logits', data_key: str = "data", target_key: str = 'target', parse_item_cb: callable or None = None, root: str or None = None, batch_size: int = 1, num_workers: int = 0, shuffle: bool = False, pin_memory: bool = False, collate_fn: callable = default_collate, transform: callable or None = None, sampler: torch.utils.data.sampler.Sampler or None = None, batch_sampler=None, drop_last: bool = False, timeout: int = 0, detach: bool = False): self._label_sampler = ItemLoader(meta_data=labeled_meta_data, parse_item_cb=parse_item_cb, root=root, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, pin_memory=pin_memory, collate_fn=collate_fn, transform=transform, sampler=sampler, batch_sampler=batch_sampler, drop_last=drop_last, timeout=timeout) self._unlabel_sampler = ItemLoader(meta_data=unlabeled_meta_data, parse_item_cb=parse_item_cb, root=root, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, pin_memory=pin_memory, collate_fn=collate_fn, transform=transform, sampler=sampler, batch_sampler=batch_sampler, drop_last=drop_last, timeout=timeout) self._name = name self._st_model: nn.Module = st_model self._te_model: nn.Module = te_model self._n_augmentations = n_augmentations self._augmentation = augmentation self._data_key = data_key self._target_key = target_key self._output_type = output_type self._detach = detach self._len = max(len(self._label_sampler), len(self._unlabel_sampler)) def __len__(self): return self._len def _crop_if_needed(self, df1, df2): assert len(df1) == len(df2) for i in range(len(df1)): if len(df1[i]['data']) != len(df2[i]['data']): min_len = min(len(df1[i]['data']), len(df2[i]['data'])) df1[i][self._data_key] = df1[i][self._data_key][:min_len, :] df2[i][self._data_key] = df2[i][self._data_key][:min_len, :] df1[i][self._target_key] = df1[i][self._target_key][:min_len] df2[i][self._target_key] = df2[i][self._target_key][:min_len] return df1, df2 def sharpen(self, x, T=0.5): assert len(x.shape) == 2 _x = torch.pow(x, 1 / T) s = torch.sum(_x, dim=-1, keepdim=True) _x = _x / s return _x def _create_union_data(self, r1, r2): assert len(r1) == len(r2) r = [] for i in range(len(r1)): union_rows = dict() union_rows[self._data_key] = torch.cat( [r1[i][self._data_key], r2[i][self._data_key]], dim=0) union_rows["probs"] = torch.cat([r1[i]["probs"], r2[i]["probs"]], dim=0) union_rows['name'] = r1[i]['name'] r.append(union_rows) return r def _mixup(self, x1, y1, x2, y2, alpha=0.75): l = np.random.beta(alpha, alpha) l = max(l, 1 - l) x = l * x1 + (1 - l) * x2 y = l * y1 + (1 - l) * y2 return x, y def sample(self, k=1): samples = [] labeled_sampled_rows = self._label_sampler.sample(k) unlabeled_sampled_rows = self._unlabel_sampler.sample(k) labeled_sampled_rows, unlabeled_sampled_rows = self._crop_if_needed( labeled_sampled_rows, unlabeled_sampled_rows) for i in range(k): # Unlabeled data unlabeled_sampled_rows[i][ self._data_key] = unlabeled_sampled_rows[i][self._data_key].to( next(self._model.parameters()).device) u_imgs = unlabeled_sampled_rows[i][self._data_key] list_imgs = [] for b in range(u_imgs.shape[0]): for j in range(self._n_augmentations): img = u_imgs[b, :, :, :] if img.shape[0] == 1: img = img[0, :, :] else: img = img.permute(1, 2, 0) img_cpu = to_cpu(img) aug_img = self._augmentation(img_cpu) list_imgs.append(aug_img) batch_imgs = torch.cat(list_imgs, dim=0) batch_imgs = batch_imgs.to(next(self._model.parameters()).device) if self._output_type == 'logits': out = self._model(batch_imgs) elif self._output_type == 'features': out = self._model.get_features(batch_imgs) preds = F.softmax(out, dim=1) preds = preds.view(u_imgs.shape[0], -1, preds.shape[-1]) mean_preds = torch.mean(preds, dim=1) guessing_labels = self.sharpen(mean_preds).detach() unlabeled_sampled_rows[i]["probs"] = guessing_labels # Labeled data labeled_sampled_rows[i][self._data_key] = labeled_sampled_rows[i][ self._data_key].to(next(self._model.parameters()).device) target_l = labeled_sampled_rows[i][self._target_key] onehot_l = torch.zeros(guessing_labels.shape) onehot_l.scatter_(1, target_l.type(torch.int64).unsqueeze(-1), 1.0) labeled_sampled_rows[i]["probs"] = onehot_l.to( next(self._model.parameters()).device) union_rows = self._create_union_data(labeled_sampled_rows, unlabeled_sampled_rows) for i in range(k): ridx = np.random.permutation( union_rows[i][self._data_key].shape[0]) u = unlabeled_sampled_rows[i] x = labeled_sampled_rows[i] x_mix, target_mix = self._mixup( x[self._data_key], x["probs"], union_rows[i][self._data_key][ridx[i]], union_rows[i]["probs"][ridx[i]]) u_mix, pred_mix = self._mixup( u[self._data_key], u["probs"], union_rows[i][self._data_key][ridx[k + i]], union_rows[i]["probs"][ridx[k + i]]) samples.append({ 'name': self._name, 'x_mix': x_mix, 'target_mix_x': target_mix, 'u_mix': u_mix, 'target_mix_u': pred_mix, 'target_x': x[self._target_key] }) return samples
sampling_config = yaml.load(f) for fold_id, df in enumerate(splitter): df_train = df[0] df_val = df[1] print("Fold {} on {} labeled samples...".format( fold_id, len(df_train.index))) item_loaders = dict() # Data provider for stage, df in zip(['train', 'eval'], [df_train, df_val]): item_loaders[f'data_{stage}'] = ItemLoader( root=args.root, meta_data=df, transform=init_transforms()[stage], parse_item_cb=parse_item, batch_size=args.bs, num_workers=args.num_threads, shuffle=True if stage == "train" else False) data_provider = DataProvider(item_loaders) # Visualizers summary_writer = SummaryWriter(logdir=logdir, comment=comment + "_fold" + str(fold_id + 1)) model_dir = os.path.join(summary_writer.logdir, args.model_dir) if not os.path.exists(model_dir): os.mkdir(model_dir) # Model model = make_model(model_name=args.model_name,
def worker_process(gpu, ngpus, sampling_config, strategy_config, args): args.gpu = gpu # this line of code is not redundant if args.distributed: lr_m = float(args.batch_size * args.world_size) / 256. else: lr_m = 1.0 criterion = torch.nn.CrossEntropyLoss().to(gpu) train_ds, classes = get_mnist(data_folder=args.save_data, train=True) test_ds, _ = get_mnist(data_folder=args.save_data, train=False) model = SimpleConvNet(bw=args.bw, drop=args.dropout, n_cls=len(classes), n_channels=args.n_channels).to(gpu) optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr * lr_m, weight_decay=args.wd) args, model, optimizer = convert_according_to_args(args=args, gpu=gpu, ngpus=ngpus, network=model, optim=optimizer) item_loaders = dict() for stage, df in zip(['train', 'eval'], [train_ds, test_ds]): if args.distributed: item_loaders[f'mnist_{stage}'] = DistributedItemLoader( meta_data=df, transform=init_mnist_cifar_transforms(1, stage), parse_item_cb=parse_item_mnist, args=args) else: item_loaders[f'mnist_{stage}'] = ItemLoader( meta_data=df, transform=init_mnist_cifar_transforms(1, stage), parse_item_cb=parse_item_mnist, batch_size=args.batch_size, num_workers=args.workers, shuffle=True if stage == "train" else False) data_provider = DataProvider(item_loaders) if args.gpu == 0: log_dir = args.log_dir comment = args.comment summary_writer = SummaryWriter(log_dir=log_dir, comment='_' + comment + 'gpu_' + str(args.gpu)) train_cbs = (RunningAverageMeter(prefix="train", name="loss"), AccuracyMeter(prefix="train", name="acc")) val_cbs = (RunningAverageMeter(prefix="eval", name="loss"), AccuracyMeter(prefix="eval", name="acc"), ScalarMeterLogger(writer=summary_writer), ModelSaver(metric_names='eval/loss', save_dir=args.snapshots, conditions='min', model=model)) else: train_cbs = () val_cbs = () strategy = Strategy(data_provider=data_provider, train_loader_names=tuple( sampling_config['train']['data_provider'].keys()), val_loader_names=tuple( sampling_config['eval']['data_provider'].keys()), data_sampling_config=sampling_config, loss=criterion, model=model, n_epochs=args.n_epochs, optimizer=optimizer, train_callbacks=train_cbs, val_callbacks=val_cbs, device=torch.device('cuda:{}'.format(args.gpu)), distributed=args.distributed, use_apex=args.use_apex) strategy.run()