Beispiel #1
0
                for i in to_cpu(decoded_target_cls, use_numpy=True).tolist()
            ]
            self._predicts += [
                self._labels[i]
                for i in to_cpu(decoded_pred_cls, use_numpy=True).tolist()
            ]


if __name__ == "__main__":
    args = init_args()
    device = auto_detect_device()
    log_dir = args.log_dir
    comment = "ssgan"

    # Data provider
    train_ds, classes = get_mnist(data_folder=args.save_data, train=True)
    n_folds = 5
    splitter = SSFoldSplit(train_ds,
                           n_ss_folds=3,
                           n_folds=n_folds,
                           target_col="target",
                           random_state=args.seed,
                           labeled_train_size_per_class=100,
                           unlabeled_train_size_per_class=200,
                           equal_target=True,
                           equal_unlabeled_target=False,
                           shuffle=True)

    summary_writer = SummaryWriter(log_dir=log_dir, comment=comment)

    # Initializing Discriminator
Beispiel #2
0
def worker_process(gpu, ngpus, sampling_config, strategy_config, args):
    args.gpu = gpu  # this line of code is not redundant
    if args.distributed:
        lr_m = float(args.batch_size * args.world_size) / 256.
    else:
        lr_m = 1.0
    criterion = torch.nn.CrossEntropyLoss().to(gpu)
    train_ds, classes = get_mnist(data_folder=args.save_data, train=True)
    test_ds, _ = get_mnist(data_folder=args.save_data, train=False)
    model = SimpleConvNet(bw=args.bw,
                          drop=args.dropout,
                          n_cls=len(classes),
                          n_channels=args.n_channels).to(gpu)
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=args.lr * lr_m,
                                 weight_decay=args.wd)

    args, model, optimizer = convert_according_to_args(args=args,
                                                       gpu=gpu,
                                                       ngpus=ngpus,
                                                       network=model,
                                                       optim=optimizer)

    item_loaders = dict()
    for stage, df in zip(['train', 'eval'], [train_ds, test_ds]):
        if args.distributed:
            item_loaders[f'mnist_{stage}'] = DistributedItemLoader(
                meta_data=df,
                transform=init_mnist_cifar_transforms(1, stage),
                parse_item_cb=parse_item_mnist,
                args=args)
        else:
            item_loaders[f'mnist_{stage}'] = ItemLoader(
                meta_data=df,
                transform=init_mnist_cifar_transforms(1, stage),
                parse_item_cb=parse_item_mnist,
                batch_size=args.batch_size,
                num_workers=args.workers,
                shuffle=True if stage == "train" else False)
    data_provider = DataProvider(item_loaders)
    if args.gpu == 0:
        log_dir = args.log_dir
        comment = args.comment
        summary_writer = SummaryWriter(log_dir=log_dir,
                                       comment='_' + comment + 'gpu_' +
                                       str(args.gpu))
        train_cbs = (RunningAverageMeter(prefix="train", name="loss"),
                     AccuracyMeter(prefix="train", name="acc"))

        val_cbs = (RunningAverageMeter(prefix="eval", name="loss"),
                   AccuracyMeter(prefix="eval", name="acc"),
                   ScalarMeterLogger(writer=summary_writer),
                   ModelSaver(metric_names='eval/loss',
                              save_dir=args.snapshots,
                              conditions='min',
                              model=model))
    else:
        train_cbs = ()
        val_cbs = ()

    strategy = Strategy(data_provider=data_provider,
                        train_loader_names=tuple(
                            sampling_config['train']['data_provider'].keys()),
                        val_loader_names=tuple(
                            sampling_config['eval']['data_provider'].keys()),
                        data_sampling_config=sampling_config,
                        loss=criterion,
                        model=model,
                        n_epochs=args.n_epochs,
                        optimizer=optimizer,
                        train_callbacks=train_cbs,
                        val_callbacks=val_cbs,
                        device=torch.device('cuda:{}'.format(args.gpu)),
                        distributed=args.distributed,
                        use_apex=args.use_apex)

    strategy.run()