def setUpClass(self):
     # hyper-parameters
     self.batch_size = 32
     self.weight_decay = 0.0001
     self.momentum = 0.9
     self.learning_rate = 0.01
     # mnist dataset
     self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
         .set_attrs(batch_size=self.batch_size, shuffle=True)
Beispiel #2
0
 def setUpClass(self):
     # hyper-parameters
     self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100"))
     self.weight_decay = 0.0001
     self.momentum = 0.9
     self.learning_rate = 0.1
     # mnist dataset
     self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
         .set_attrs(batch_size=self.batch_size, shuffle=True)
     self.train_loader.num_workers = 4
def main():
    batch_size = 64
    learning_rate = 0.1
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 5
    train_loader = MNIST(train=True, transform=trans.Resize(28)).set_attrs(
        batch_size=batch_size, shuffle=True)

    val_loader = MNIST(train=True,
                       transform=trans.Resize(28)).set_attrs(batch_size=1,
                                                             shuffle=False)

    model = Model()
    optimizer = nn.SGD(model.parameters(), learning_rate, momentum,
                       weight_decay)
    for epoch in range(epochs):
        train(model, train_loader, optimizer, epoch)
        test(model, val_loader, epoch)
 def test_dataset(self):
     return
     self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
         .set_attrs(batch_size=300, shuffle=True)
     self.train_loader.num_workers = 1
     import time
     for batch_idx, (data, target) in tqdm(enumerate(self.train_loader)):
         # time.sleep(5)
         # print("break")
         # break
         # self.train_loader.display_worker_status()
         if batch_idx > 30:
             break
         pass
     for batch_idx, (data, target) in tqdm(enumerate(self.train_loader)):
         # time.sleep(5)
         # print("break")
         # break
         # self.train_loader.display_worker_status()
         if batch_idx > 300:
             break
         pass
Beispiel #5
0
# 计算结果:(A-B)^2
adversarial_loss = nn.MSELoss()

generator = Generator()
discriminator = Discriminator()

# 导入MNIST数据集
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
transform = transform.Compose([
    transform.Resize(opt.img_size),
    transform.Gray(),
    transform.ImageNormalize(mean=[0.5], std=[0.5]),
])
dataloader = MNIST(train=True,
                   transform=transform).set_attrs(batch_size=opt.batch_size,
                                                  shuffle=True)

optimizer_G = nn.Adam(generator.parameters(),
                      lr=opt.lr,
                      betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(),
                      lr=opt.lr,
                      betas=(opt.b1, opt.b2))

from PIL import Image


def save_image(img, path, nrow=10, padding=5):
    N, C, W, H = img.shape
    if (N % nrow != 0):
Beispiel #6
0
bce_loss = nn.BCELoss()
xe_loss = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss()

# Initialize generator and discriminator
generator = Generator_CNN(latent_dim, n_c, x_shape)
encoder = Encoder_CNN(latent_dim, n_c)
discriminator = Discriminator_CNN(wass_metric=wass_metric)

# Configure data loader
transform = transform.Compose([
    transform.Resize(size=img_size),
    transform.Gray(),
])
dataloader = MNIST(train=True,
                   transform=transform).set_attrs(batch_size=batch_size,
                                                  shuffle=True)
testdata = MNIST(train=False,
                 transform=transform).set_attrs(batch_size=batch_size,
                                                shuffle=True)
(test_imgs, test_labels) = next(iter(testdata))

ge_chain = generator.parameters()
for p in encoder.parameters():
    ge_chain.append(p)
#TODO: weight_decay=decay
optimizer_GE = jt.optim.Adam(ge_chain, lr=lr, betas=(b1, b2), weight_decay=0.0)
optimizer_D = jt.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

ge_l = []
d_l = []
def val2():
    dataloader = MNIST(train=False).set_attrs(batch_size=16)
    for i, (imgs, labels) in enumerate(dataloader):
        assert (imgs.shape[0] == 16)
        if i == 5:
            break
Beispiel #8
0
# 训练轮数
train_epoch = 50 if task=="MNIST" else 50
# 训练图像标准大小
img_size = 112
# Adam优化器参数
betas = (0.5,0.999)
# 数据集图像通道数,MNIST为1,CelebA为3
dim = 1 if task=="MNIST" else 3

if task=="MNIST":
    transform = transform.Compose([
        transform.Resize(size=img_size),
        transform.Gray(),
        transform.ImageNormalize(mean=[0.5], std=[0.5]),
    ])
    train_loader = MNIST(train=True, transform=transform).set_attrs(batch_size=batch_size, shuffle=True)
    eval_loader = MNIST(train=False, transform = transform).set_attrs(batch_size=batch_size, shuffle=True)
elif task=="CelebA":
    transform = transform.Compose([
        transform.Resize(size=img_size),
        transform.ImageNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    train_dir = './data/celebA_train'
    train_loader = ImageFolder(train_dir).set_attrs(transform=transform, batch_size=batch_size, shuffle=True)
    eval_dir = './data/celebA_eval'
    eval_loader = ImageFolder(eval_dir).set_attrs(transform=transform, batch_size=batch_size, shuffle=True)

G = generator (dim)
D = discriminator (dim)
G_optim = jt.nn.Adam(G.parameters(), lr, betas=betas)
D_optim = jt.nn.Adam(D.parameters(), lr, betas=betas)