from AdversariallyLearnedInference import Generator_x, Generator_z, Discriminator_x, Discriminator_z, Discriminator_x_z, AdversariallyLearnedInferenceTrainerV2 from homura.trainers import SupervisedTrainer from homura.optim import Adam from homura.reporters import TQDMReporter, TensorboardReporter from homura.vision.data import VisionSet from torch.nn import functional as F from torch.nn import ConvTranspose2d from torchvision.datasets import CIFAR10 from torchvision import transforms as tf import argparse from torch.cuda import device_count from homura.metrics import accuracy from utils import SemiVisionSet from managpu import GpuManager GpuManager().set_by_memory(1) def main(args): cifar = SemiVisionSet(CIFAR10, args.dataset, 10, [ tf.ToTensor(), tf.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ], semi_size=40000) train_loader, test_loader, num_classes = cifar(args.batch_size, num_workers=4, return_num_classes=True, use_prefetcher=True) model_dict = {
# -*- coding: UTF-8 -*- import sys sys.path.append("../../..") from managpu import GpuManager my_gpu = GpuManager() my_gpu.set_by_memory(1) import os import time import argparse from config import proj_cfg import torch import torcherry as tc from torcherry.utils.metric import MetricAccuracy, MetricLoss from torcherry.utils.checkpoint import CheckBestValAcc from torcherry.utils.util import set_env_seed from model.resnet_cifar10 import resnet20_cifar10, resnet32_cifar10 from model.resnet_tn_cifar10 import * Models = dict( resnet20_cifar10=resnet20_cifar10, resnet32_cifar10=resnet32_cifar10, TRResNet20_CIFAR10=TRResNet20_CIFAR10, TRResNet32_CIFAR10=TRResNet32_CIFAR10, BTTResNet20_CIFAR10=BTTResNet20_CIFAR10,