def test_imagenet(self): train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=False) random.seed(0) tc_data = [] for i, data in enumerate(train_loader): tc_data.append(data) print("get", data[0].shape) if i == check_num_batch: break from jittor.dataset.dataset import ImageFolder import jittor.transform as transform dataset = ImageFolder(traindir).set_attrs(batch_size=256, shuffle=False) dataset.set_attrs(transform=transform.Compose([ transform.RandomCropAndResize(224), transform.RandomHorizontalFlip(), transform.ImageNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])) random.seed(0) for i, (images, labels) in enumerate(dataset): print("compare", i) assert np.allclose(images.numpy(), tc_data[i][0].numpy()) assert np.allclose(labels.numpy(), tc_data[i][1].numpy()) if i == check_num_batch: break
def get_dataset(): dataset = ImageFolder(traindir).set_attrs(batch_size=256, shuffle=False) dataset.set_attrs(transform=transform.Compose([ transform.Resize(224), transform.ImageNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), num_workers=0) return dataset
if task=="MNIST": transform = transform.Compose([ transform.Resize(size=img_size), transform.Gray(), transform.ImageNormalize(mean=[0.5], std=[0.5]), ]) train_loader = MNIST(train=True, transform=transform).set_attrs(batch_size=batch_size, shuffle=True) eval_loader = MNIST(train=False, transform = transform).set_attrs(batch_size=batch_size, shuffle=True) elif task=="CelebA": transform = transform.Compose([ transform.Resize(size=img_size), transform.ImageNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) train_dir = './data/celebA_train' train_loader = ImageFolder(train_dir).set_attrs(transform=transform, batch_size=batch_size, shuffle=True) eval_dir = './data/celebA_eval' eval_loader = ImageFolder(eval_dir).set_attrs(transform=transform, batch_size=batch_size, shuffle=True) G = generator (dim) D = discriminator (dim) G_optim = jt.nn.Adam(G.parameters(), lr, betas=betas) D_optim = jt.nn.Adam(D.parameters(), lr, betas=betas) def train(epoch): for batch_idx, (x_, target) in enumerate(train_loader): mini_batch = x_.shape[0] # train discriminator D_result = D(x_) D_real_loss = ls_loss(D_result, True)