Ejemplo n.º 1
0
parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")
parser.add_argument("--nz", type=int, default=10, help="dimensionality of the latent space")
parser.add_argument("--imageSize", type=int, default=28, help="size of each image dimension")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between image sampling")
opt = parser.parse_args()
print(opt)
os.makedirs(opt.experiment, exist_ok=True)

## device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## load dataset
os.makedirs(os.path.join(opt.dataroot, opt.dataset), exist_ok=True)
trans = transforms.Compose([transforms.ToTensor()])
dataset = load_dataset(opt.dataroot, opt.dataset, opt.imageSize, trans, train=False)
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)
opt.dataSize = len(dataset)

## parameters
opt.n_classes = len(dataset.classes)

## load pretrained model
enc = Encoder(opt.imageSize**2, opt.nz, opt.n_classes).to(device)
dec = Decoder(opt.imageSize**2, opt.nz, opt.n_classes).to(device)
enc.load_state_dict(torch.load(opt.enc_pth))
dec.load_state_dict(torch.load(opt.dec_pth))
print("Pretrained models have been loaded.")

# opt.seed = 42
# torch.manual_seed(opt.seed)
# np.random.seed(opt.seed)

## cudnn
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

## device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## dataset
train_dataset = load_dataset(opt.dataroot,
                             opt.dataset,
                             opt.imageSize,
                             trans=None,
                             train=True)
train_dataloader = DataLoader(train_dataset,
                              batch_size=opt.batchSize,
                              shuffle=True)


## model init
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
## random seed
# opt.seed = 42
# torch.manual_seed(opt.seed)
# np.random.seed(opt.seed)

## cudnn
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

## device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## dataset
test_dataset = load_dataset(opt.dataroot, opt.dataset, opt.size, trans=None, train=False)
test_dataloader = DataLoader(test_dataset, batch_size=opt.batchSize, shuffle=False)
opt.dataSize = test_dataset.__len__()

## model
gen = NetG(opt).to(device)
disc = NetD(opt).to(device)
assert gen.load_state_dict(torch.load(opt.gen_pth))
assert disc.load_state_dict(torch.load(opt.disc_pth))
print("Pretrained models have been loaded.")

## record results
writer = SummaryWriter("../runs{0}".format(opt.experiment[1:]), comment=opt.experiment[1:])

## def
def splitImage(img, size):