def test_imagenet(self):
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=256,
                                                   shuffle=False)

        random.seed(0)
        tc_data = []
        for i, data in enumerate(train_loader):
            tc_data.append(data)
            print("get", data[0].shape)
            if i == check_num_batch: break

        from jittor.dataset.dataset import ImageFolder
        import jittor.transform as transform

        dataset = ImageFolder(traindir).set_attrs(batch_size=256,
                                                  shuffle=False)

        dataset.set_attrs(transform=transform.Compose([
            transform.RandomCropAndResize(224),
            transform.RandomHorizontalFlip(),
            transform.ImageNormalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        ]))

        random.seed(0)

        for i, (images, labels) in enumerate(dataset):
            print("compare", i)
            assert np.allclose(images.numpy(), tc_data[i][0].numpy())
            assert np.allclose(labels.numpy(), tc_data[i][1].numpy())
            if i == check_num_batch: break
 def get_dataset():
     dataset = ImageFolder(traindir).set_attrs(batch_size=256,
                                               shuffle=False)
     dataset.set_attrs(transform=transform.Compose([
         transform.Resize(224),
         transform.ImageNormalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
     ]),
                       num_workers=0)
     return dataset
Esempio n. 3
0
if task=="MNIST":
    transform = transform.Compose([
        transform.Resize(size=img_size),
        transform.Gray(),
        transform.ImageNormalize(mean=[0.5], std=[0.5]),
    ])
    train_loader = MNIST(train=True, transform=transform).set_attrs(batch_size=batch_size, shuffle=True)
    eval_loader = MNIST(train=False, transform = transform).set_attrs(batch_size=batch_size, shuffle=True)
elif task=="CelebA":
    transform = transform.Compose([
        transform.Resize(size=img_size),
        transform.ImageNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    train_dir = './data/celebA_train'
    train_loader = ImageFolder(train_dir).set_attrs(transform=transform, batch_size=batch_size, shuffle=True)
    eval_dir = './data/celebA_eval'
    eval_loader = ImageFolder(eval_dir).set_attrs(transform=transform, batch_size=batch_size, shuffle=True)

G = generator (dim)
D = discriminator (dim)
G_optim = jt.nn.Adam(G.parameters(), lr, betas=betas)
D_optim = jt.nn.Adam(D.parameters(), lr, betas=betas)

def train(epoch):
    for batch_idx, (x_, target) in enumerate(train_loader):
        mini_batch = x_.shape[0]

        # train discriminator
        D_result = D(x_)
        D_real_loss = ls_loss(D_result, True)