for i in to_cpu(decoded_target_cls, use_numpy=True).tolist() ] self._predicts += [ self._labels[i] for i in to_cpu(decoded_pred_cls, use_numpy=True).tolist() ] if __name__ == "__main__": args = init_args() device = auto_detect_device() log_dir = args.log_dir comment = "ssgan" # Data provider train_ds, classes = get_mnist(data_folder=args.save_data, train=True) n_folds = 5 splitter = SSFoldSplit(train_ds, n_ss_folds=3, n_folds=n_folds, target_col="target", random_state=args.seed, labeled_train_size_per_class=100, unlabeled_train_size_per_class=200, equal_target=True, equal_unlabeled_target=False, shuffle=True) summary_writer = SummaryWriter(log_dir=log_dir, comment=comment) # Initializing Discriminator
def worker_process(gpu, ngpus, sampling_config, strategy_config, args): args.gpu = gpu # this line of code is not redundant if args.distributed: lr_m = float(args.batch_size * args.world_size) / 256. else: lr_m = 1.0 criterion = torch.nn.CrossEntropyLoss().to(gpu) train_ds, classes = get_mnist(data_folder=args.save_data, train=True) test_ds, _ = get_mnist(data_folder=args.save_data, train=False) model = SimpleConvNet(bw=args.bw, drop=args.dropout, n_cls=len(classes), n_channels=args.n_channels).to(gpu) optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr * lr_m, weight_decay=args.wd) args, model, optimizer = convert_according_to_args(args=args, gpu=gpu, ngpus=ngpus, network=model, optim=optimizer) item_loaders = dict() for stage, df in zip(['train', 'eval'], [train_ds, test_ds]): if args.distributed: item_loaders[f'mnist_{stage}'] = DistributedItemLoader( meta_data=df, transform=init_mnist_cifar_transforms(1, stage), parse_item_cb=parse_item_mnist, args=args) else: item_loaders[f'mnist_{stage}'] = ItemLoader( meta_data=df, transform=init_mnist_cifar_transforms(1, stage), parse_item_cb=parse_item_mnist, batch_size=args.batch_size, num_workers=args.workers, shuffle=True if stage == "train" else False) data_provider = DataProvider(item_loaders) if args.gpu == 0: log_dir = args.log_dir comment = args.comment summary_writer = SummaryWriter(log_dir=log_dir, comment='_' + comment + 'gpu_' + str(args.gpu)) train_cbs = (RunningAverageMeter(prefix="train", name="loss"), AccuracyMeter(prefix="train", name="acc")) val_cbs = (RunningAverageMeter(prefix="eval", name="loss"), AccuracyMeter(prefix="eval", name="acc"), ScalarMeterLogger(writer=summary_writer), ModelSaver(metric_names='eval/loss', save_dir=args.snapshots, conditions='min', model=model)) else: train_cbs = () val_cbs = () strategy = Strategy(data_provider=data_provider, train_loader_names=tuple( sampling_config['train']['data_provider'].keys()), val_loader_names=tuple( sampling_config['eval']['data_provider'].keys()), data_sampling_config=sampling_config, loss=criterion, model=model, n_epochs=args.n_epochs, optimizer=optimizer, train_callbacks=train_cbs, val_callbacks=val_cbs, device=torch.device('cuda:{}'.format(args.gpu)), distributed=args.distributed, use_apex=args.use_apex) strategy.run()