from AdversariallyLearnedInference import Generator_x, Generator_z, Discriminator_x, Discriminator_z, Discriminator_x_z, AdversariallyLearnedInferenceTrainerV2
from homura.trainers import SupervisedTrainer
from homura.optim import Adam
from homura.reporters import TQDMReporter, TensorboardReporter
from homura.vision.data import VisionSet
from torch.nn import functional as F
from torch.nn import ConvTranspose2d
from torchvision.datasets import CIFAR10
from torchvision import transforms as tf
import argparse
from torch.cuda import device_count
from homura.metrics import accuracy
from utils import SemiVisionSet
from managpu import GpuManager
GpuManager().set_by_memory(1)


def main(args):
    cifar = SemiVisionSet(CIFAR10,
                          args.dataset,
                          10, [
                              tf.ToTensor(),
                              tf.Normalize((0.4914, 0.4822, 0.4465),
                                           (0.2023, 0.1994, 0.2010))
                          ],
                          semi_size=40000)
    train_loader, test_loader, num_classes = cifar(args.batch_size,
                                                   num_workers=4,
                                                   return_num_classes=True,
                                                   use_prefetcher=True)
    model_dict = {
Example #2
0
# -*- coding: UTF-8 -*-

import sys
sys.path.append("../../..")

from managpu import GpuManager
my_gpu = GpuManager()
my_gpu.set_by_memory(1)

import os
import time
import argparse

from config import proj_cfg

import torch

import torcherry as tc
from torcherry.utils.metric import MetricAccuracy, MetricLoss
from torcherry.utils.checkpoint import CheckBestValAcc
from torcherry.utils.util import set_env_seed

from model.resnet_cifar10 import resnet20_cifar10, resnet32_cifar10
from model.resnet_tn_cifar10 import *

Models = dict(
    resnet20_cifar10=resnet20_cifar10,
    resnet32_cifar10=resnet32_cifar10,
    TRResNet20_CIFAR10=TRResNet20_CIFAR10,
    TRResNet32_CIFAR10=TRResNet32_CIFAR10,
    BTTResNet20_CIFAR10=BTTResNet20_CIFAR10,