예제 #1
0
def make_torchvision_dataset( dataset, dataset_dir, is_training_set = True, ds_transforms = None, *args, **kwargs ):
  if not isinstance( dataset_dir, ( str, Path, ) ):
    raise TypeError( '"dataset_dir" must be of type "str" or "pathlib.Path"' )
  dataset_dir = str( dataset_dir )

  dataset_title = dataset.casefold()
  # TODO: Add more torchvision-catered datasets:
  if dataset_title in TORCHVISION_CATERED_DATASETS['LSUN Bedrooms'][0]:
    config = get_current_configuration( 'config', raise_exception = False )
    if config is not None:
      if config.res_dataset > 256:
        message = f'WARNING: config.res_dataset currently set to {config.res_dataset},' + \
                  f' but recommended to set --res_dataset to 256 or below when using LSUN Bedrooms!'
        raise RuntimeWarning( message )
    classes_categories = TORCHVISION_CATERED_DATASETS['LSUN Bedrooms'][1]
    classes = classes_categories[0] if is_training_set else classes_categories[1]
    dataset = datasets.LSUN(
      root = dataset_dir, classes = classes, transform = ds_transforms
    )
  elif dataset_title in TORCHVISION_CATERED_DATASETS['CIFAR-10'][0]:
    config = get_current_configuration( 'config', raise_exception = False )
    if config is not None:
      if config.res_dataset > 32:
        message = f'WARNING: config.res_dataset currently set to {config.res_dataset},' + \
                  f' but recommended to set --res_dataset to 32 or below when using CIFAR-10!'
        raise RuntimeWarning( message )
    classes_categories = TORCHVISION_CATERED_DATASETS['CIFAR-10'][1]
    classes = classes_categories[0] if is_training_set else classes_categories[1]
    dataset = datasets.CIFAR10(
      root = dataset_dir, train = classes, transform = ds_transforms, download = True
    )
  else:
    # Custom Datasets:
    imgs_dirname = TRAINING_SET_DIRNAME if is_training_set else VALID_SET_DIRNAME
    if get_dataset_img_extension( dataset_dir + '/' + imgs_dirname ) in IMG_EXTENSIONS:
      dataset = ImageFolderSingleClass( root = dataset_dir, category = imgs_dirname, transform = ds_transforms )
    else:
      if 'loader' not in kwargs or kwargs[ 'loader' ] is None:
        raise ValueError( 'Generic DatasetFolderSingleClass requires input `loader:callable` auxiliary argument.' )
      raise RuntimeWarning( '\nUsing generic DatasetFolderSingleClass...\n' + \
                            ' WARNING: The input "loader" argument for this class that acts on your data must output an object that' + \
                                       ' can be converted into a PIL Image object (e.g. an object of type `np.uint8` or `torch.uint8`).\n' )
      loader = kwargs.pop( 'loader' )
      dataset = DatasetFolderSingleClass( root = dataset_dir, category = imgs_dirname,
                                          loader = loader, transform = ds_transforms, **kwargs )
  return dataset
예제 #2
0
              help = 'percent of images generated from 2 latent vectors z instead of 1; 0. means no mixing regularization' )

            parser.add_argument( '--beta_trunc_trick', type = float, default = .995, \
              help = 'decay coefficient for computing EWMA of disentangled latent vector w during training,' + \
                     ' which will be used for the truncation trick in disentangled latent space W during evaluation' )
            parser.add_argument( '--psi_trunc_trick', type = float, default = .7, \
              help = 'multiplicative coefficient in range [0,1) for the truncation trick in disentangled latent space W during evaluation;' + \
                     ' the smaller the value, the more the truncation' )
            parser.add_argument( '--cutoff_trunc_trick', type = int, default = 4, \
              help = 'final generator block ( = log2(res) - 1 ) to use the computed psi-truncated disentangled latent vector w during evaluation' )

    # ---------------------------------------------------------------------------- #

    config = parser.parse_args()

    data_config = get_current_configuration('data_config',
                                            raise_exception=False)

    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # Post-processing:
    # ----------------

    config.dev = torch.device(config.dev)

    if config.pin_memory and config.dev != torch.device('cuda'):
        config = None
        raise ValueError(
            '--pin_memory should be set to `False` if not using CUDA.')

    if config.random_seed == -1:
        np.random.seed(None)
        if config.n_gpu > 1:
예제 #3
0
from utils.data_utils import prepare_dataset, prepare_dataloader
from resnetgan.learner import GANLearner
from progan.learner import ProGANLearner
from stylegan.learner import StyleGANLearner

from pathlib import Path

# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #

SAVE_MODEL_PATH = './models/gan_model.tar'

# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #

if __name__ == '__main__':

    config = get_current_configuration('config')
    data_config = get_current_configuration('data_config')

    # Construct DataLoader(s) according to config and data_config:
    # ------------------------------------------------------------
    train_ds, valid_ds = prepare_dataset(data_config)
    train_dl, valid_dl, z_valid_dl = \
      prepare_dataloader( config, data_config, train_ds, valid_ds )

    # Instantiate GAN Learner:
    # ------------------------
    if config.model == 'ResNet GAN':
        learner = GANLearner(config)
    elif config.model == 'ProGAN':
        learner = ProGANLearner(config)
    elif config.model == 'StyleGAN':