def get_data(hps, sess): if hps.image_size == -1: hps.image_size = {'edges2shoes': 32, 'mnist': 32, 'cifar10': 32, 'imagenet-oord': 64, 'imagenet': 256, 'celeba': 256, 'lsun_realnvp': 64, 'lsun': 256}[hps.problem] if hps.n_test == -1: hps.n_test = {'edges2shoes': 200, 'mnist': 10000, 'cifar10': 10000, 'imagenet-oord': 50000, 'imagenet': 50000, 'celeba': 3000, 'lsun_realnvp': 300*hvd.size(), 'lsun': 300*hvd.size()}[hps.problem] hps.n_y = {'edges2shoes': 10, 'mnist': 10, 'cifar10': 10, 'imagenet-oord': 1000, 'imagenet': 1000, 'celeba': 1, 'lsun_realnvp': 1, 'lsun': 1}[hps.problem] if hps.data_dir == "": hps.data_dir = {'edges2shoes': 'edges2shoes', 'mnist': None, 'cifar10': None, 'imagenet-oord': '/mnt/host/imagenet-oord-tfr', 'imagenet': '/mnt/host/imagenet-tfr', 'celeba': '/mnt/host/celeba-reshard-tfr', 'lsun_realnvp': '/mnt/host/lsun_realnvp', 'lsun': '/mnt/host/lsun'}[hps.problem] if hps.problem == 'lsun_realnvp': hps.rnd_crop = True else: hps.rnd_crop = False if hps.category: hps.data_dir += ('/%s' % hps.category) # Use anchor_size to rescale batch size based on image_size s = hps.anchor_size hps.local_batch_train = hps.n_batch_train * \ s * s // (hps.image_size * hps.image_size) hps.local_batch_test = {64: 50, 32: 25, 16: 10, 8: 5, 4: 2, 2: 2, 1: 1}[ hps.local_batch_train] # round down to closest divisor of 50 hps.local_batch_init = hps.n_batch_init * \ s * s // (hps.image_size * hps.image_size) print("Rank {} Batch sizes Train {} Test {} Init {}".format( hvd.rank(), hps.local_batch_train, hps.local_batch_test, hps.local_batch_init)) if hps.problem in ['imagenet-oord', 'imagenet', 'celeba', 'lsun_realnvp', 'lsun']: hps.direct_iterator = True import data_loaders.get_data as v train_iterator, test_iterator, data_init = \ v.get_data(sess, hps.data_dir, hvd.size(), hvd.rank(), hps.pmap, hps.fmap, hps.local_batch_train, hps.local_batch_test, hps.local_batch_init, hps.image_size, hps.rnd_crop) elif hps.problem in ['mnist', 'cifar10']: hps.direct_iterator = False import data_loaders.get_mnist_cifar_joint as v train_iterator_A, test_iterator_A, data_init_A, train_iterator_B, test_iterator_B, data_init_B = \ v.get_data(hps.problem, hvd.size(), hvd.rank(), hps.dal, hps.local_batch_train, hps.local_batch_test, hps.local_batch_init, hps.image_size) elif hps.problem in ['edges2shoes']: hps.direct_iterator = False import data_loaders.get_edges_shoes_joint as v train_iterator_A, test_iterator_A, data_init_A, train_iterator_B, test_iterator_B, data_init_B = \ v.get_data(hps.problem, hvd.size(), hvd.rank(), hps.dal, hps.local_batch_train, hps.local_batch_test, hps.local_batch_init, hps.image_size) else: raise Exception() return train_iterator_A, test_iterator_A, data_init_A, train_iterator_B, test_iterator_B, data_init_B
def get_data(hps, sess): if hps.image_size == -1: hps.image_size = {'mnist': 32, 'cifar10': 32, 'imagenet-oord': 64, 'imagenet': 256, 'celeba': 256, 'lsun_realnvp': 64, 'lsun': 256}[hps.problem] if hps.n_test == -1: hps.n_test = {'mnist': 10000, 'cifar10': 10000, 'imagenet-oord': 50000, 'imagenet': 50000, 'celeba': 3000, 'lsun_realnvp': 300*hvd.size(), 'lsun': 300*hvd.size()}[hps.problem] hps.n_y = {'mnist': 10, 'cifar10': 10, 'imagenet-oord': 1000, 'imagenet': 1000, 'celeba': 1, 'lsun_realnvp': 1, 'lsun': 1}[hps.problem] if hps.data_dir == "": hps.data_dir = {'mnist': None, 'cifar10': None, 'imagenet-oord': '/mnt/host/imagenet-oord-tfr', 'imagenet': '/mnt/host/imagenet-tfr', 'celeba': '/mnt/host/celeba-reshard-tfr', 'lsun_realnvp': '/mnt/host/lsun_realnvp', 'lsun': '/mnt/host/lsun'}[hps.problem] if hps.problem == 'lsun_realnvp': hps.rnd_crop = True else: hps.rnd_crop = False if hps.category: hps.data_dir += ('/%s' % hps.category) # Use anchor_size to rescale batch size based on image_size s = hps.anchor_size hps.local_batch_train = hps.n_batch_train * \ s * s // (hps.image_size * hps.image_size) hps.local_batch_test = {64: 50, 32: 25, 16: 10, 8: 5, 4: 2, 2: 2, 1: 1}[ hps.local_batch_train] # round down to closest divisor of 50 hps.local_batch_init = hps.n_batch_init * \ s * s // (hps.image_size * hps.image_size) print("Rank {} Batch sizes Train {} Test {} Init {}".format( hvd.rank(), hps.local_batch_train, hps.local_batch_test, hps.local_batch_init)) if hps.problem in ['imagenet-oord', 'imagenet', 'celeba', 'lsun_realnvp', 'lsun']: hps.direct_iterator = True import data_loaders.get_data as v train_iterator, test_iterator, data_init = \ v.get_data(sess, hps.data_dir, hvd.size(), hvd.rank(), hps.pmap, hps.fmap, hps.local_batch_train, hps.local_batch_test, hps.local_batch_init, hps.image_size, hps.rnd_crop) elif hps.problem in ['mnist', 'cifar10']: hps.direct_iterator = False import data_loaders.get_mnist_cifar as v train_iterator, test_iterator, data_init = \ v.get_data(hps.problem, hvd.size(), hvd.rank(), hps.dal, hps.local_batch_train, hps.local_batch_test, hps.local_batch_init, hps.image_size) else: raise Exception() return train_iterator, test_iterator, data_init