Ejemplo n.º 1
0
  def test_extract_ImageNet100_CMC(self):
    """
    Usage:
        proj_root=moco-exp
        python template_lib/modelarts/scripts/copy_tool.py \
          -s s3://bucket-7001/ZhouPeng/codes/$proj_root -d /cache/$proj_root -t copytree
        cd /cache/$proj_root

        export CUDA_VISIBLE_DEVICES=0
        export TIME_STR=0
        export PYTHONPATH=./
        python -c "from template_lib.proj.imagenet.tests.test_imagenet import Testing_PrepareImageNet;\
          Testing_PrepareImageNet().test_extract_ImageNet100_CMC()"

    :return:
    """
    if 'CUDA_VISIBLE_DEVICES' not in os.environ:
      os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    if 'TIME_STR' not in os.environ:
      os.environ['TIME_STR'] = '0' if utils.is_debugging() else '0'
    from template_lib.v2.config_cfgnode.argparser import \
      (get_command_and_outdir, setup_outdir_and_yaml, get_append_cmd_str, start_cmd_run)
    from template_lib.v2.config_cfgnode import update_parser_defaults_from_yaml, global_cfg
    from template_lib.modelarts import modelarts_utils
    from distutils.dir_util import copy_tree

    command, outdir = get_command_and_outdir(self, func_name=sys._getframe().f_code.co_name, file=__file__)
    argv_str = f"""
                --tl_config_file template_lib/proj/imagenet/tests/configs/PrepareImageNet.yaml
                --tl_command {command}
                --tl_outdir {outdir}
                """
    args, cfg = setup_outdir_and_yaml(argv_str, return_cfg=True)

    modelarts_utils.setup_tl_outdir_obs(global_cfg)
    modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True)
    modelarts_utils.prepare_dataset(global_cfg.get('modelarts_download', {}), global_cfg=global_cfg)

    train_dir = f'{cfg.data_dir}/train'
    val_dir = f'{cfg.data_dir}/val'
    save_train_dir = f'{cfg.saved_dir}/train'
    save_val_dir = f'{cfg.saved_dir}/val'
    os.makedirs(save_train_dir, exist_ok=True)
    os.makedirs(save_val_dir, exist_ok=True)

    with open(cfg.class_list_file, 'r') as f:
      class_list = f.readlines()
    for class_subdir in tqdm.tqdm(class_list):
      class_subdir, _ = class_subdir.strip().split()
      train_class_dir = f'{train_dir}/{class_subdir}'
      save_train_class_dir = f'{save_train_dir}/{class_subdir}'
      copy_tree(train_class_dir, save_train_class_dir)

      val_class_dir = f'{val_dir}/{class_subdir}'
      save_val_class_dir = f'{save_val_dir}/{class_subdir}'
      copy_tree(val_class_dir, save_val_class_dir)

    modelarts_utils.prepare_dataset(global_cfg.get('modelarts_upload', {}), global_cfg=global_cfg, download=False)
    modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True)
    pass
Ejemplo n.º 2
0
  def test_extract_ImageNet_1000x50(self):
    """
    Usage:
        proj_root=moco-exp
        python template_lib/modelarts/scripts/copy_tool.py \
          -s s3://bucket-7001/ZhouPeng/codes/$proj_root -d /cache/$proj_root -t copytree
        cd /cache/$proj_root

        export CUDA_VISIBLE_DEVICES=0
        export TIME_STR=0
        export PYTHONPATH=./
        python -c "from template_lib.proj.imagenet.tests.test_imagenet import Testing_PrepareImageNet;\
          Testing_PrepareImageNet().test_extract_ImageNet_1000x50()"

    :return:
    """
    if 'CUDA_VISIBLE_DEVICES' not in os.environ:
      os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    if 'TIME_STR' not in os.environ:
      os.environ['TIME_STR'] = '0' if utils.is_debugging() else '0'
    from template_lib.v2.config_cfgnode.argparser import \
      (get_command_and_outdir, setup_outdir_and_yaml, get_append_cmd_str, start_cmd_run)
    from template_lib.v2.config_cfgnode import update_parser_defaults_from_yaml, global_cfg
    from template_lib.modelarts import modelarts_utils

    command, outdir = get_command_and_outdir(self, func_name=sys._getframe().f_code.co_name, file=__file__)
    argv_str = f"""
                --tl_config_file template_lib/proj/imagenet/tests/configs/PrepareImageNet.yaml
                --tl_command {command}
                --tl_outdir {outdir}
                """
    args, cfg = setup_outdir_and_yaml(argv_str, return_cfg=True)
    global_cfg.merge_from_dict(cfg)
    global_cfg.merge_from_dict(vars(args))

    modelarts_utils.setup_tl_outdir_obs(global_cfg)
    modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True)
    modelarts_utils.prepare_dataset(global_cfg.get('modelarts_download', {}), global_cfg=global_cfg)

    train_dir = f'{cfg.data_dir}/train'
    counter_cls = 0
    for rootdir, subdir, files in os.walk(train_dir):
      if len(subdir) == 0:
        counter_cls += 1
        extracted_files = sorted(files)[:cfg.num_per_class]
        for file in tqdm.tqdm(extracted_files, desc=f'class: {counter_cls}'):
          img_path = os.path.join(rootdir, file)
          img_rel_path = os.path.relpath(img_path, cfg.data_dir)
          saved_img_path = f'{cfg.saved_dir}/{os.path.dirname(img_rel_path)}'
          os.makedirs(saved_img_path, exist_ok=True)
          shutil.copy(img_path, saved_img_path)
      pass

    modelarts_utils.prepare_dataset(global_cfg.get('modelarts_upload', {}), global_cfg=global_cfg, download=False)
    modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True)
    pass
Ejemplo n.º 3
0
def main():
  parser = build_parser()
  args, _ = parser.parse_known_args()
  is_main_process = args.local_rank == 0

  update_parser_defaults_from_yaml(parser, is_main_process=is_main_process)

  if is_main_process:
    modelarts_utils.setup_tl_outdir_obs(global_cfg)
    modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True)
    modelarts_utils.prepare_dataset(global_cfg.get('modelarts_download', {}), global_cfg=global_cfg)

  args = parser.parse_args()

  setup_runtime(seed=args.seed)

  distributed = ddp_utils.is_distributed()
  if distributed:
      dist_utils.init_dist(args.launcher, backend='nccl')
      # important: use different random seed for different process
      torch.manual_seed(args.seed + dist.get_rank())

  # dataset
  dataset = torch_data_utils.ImageListDataset(meta_file=global_cfg.image_list_file, )
  if distributed:
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False)
  else:
    sampler = None

  train_loader = data_utils.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=sampler,
    num_workers=args.num_workers,
    pin_memory=False)

  # test
  data_iter = iter(train_loader)
  data = next(data_iter)

  if is_main_process:
    modelarts_utils.prepare_dataset(global_cfg.get('modelarts_upload', {}), global_cfg=global_cfg, download=False)
    modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True)
  if distributed:
    dist.barrier()
  pass
Ejemplo n.º 4
0
    def create_optimizers(self, opt):
        G_params = list(self.netG.parameters())
        if opt.use_vae:
            G_params += list(self.netE.parameters())
        if opt.isTrain:
            D_params = list(self.netD.parameters())

        beta1, beta2 = opt.beta1, opt.beta2
        if opt.no_TTUR:
            G_lr, D_lr = opt.lr, opt.lr
        else:
            G_lr, D_lr = opt.lr / 2, opt.lr * 2

        optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2), **global_cfg.get('G_optim_cfg', {}))
        optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2), **global_cfg.get('D_optim_cfg', {}))

        return optimizer_G, optimizer_D
Ejemplo n.º 5
0
def main():
    update_parser_defaults_from_yaml(parser)
    args = parser.parse_args()
    global_cfg.merge_from_dict(vars(args))
    modelarts_utils.setup_tl_outdir_obs(global_cfg)
    modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True)
    modelarts_utils.prepare_dataset(global_cfg.get('modelarts_download', {}),
                                    global_cfg=global_cfg)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()
    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker,
                 nprocs=ngpus_per_node,
                 args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args)
Ejemplo n.º 6
0
def run(config):
    logger = logging.getLogger('tl')
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(G, D)
    logger.info(G)
    logger.info(D)
    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    # Prepare inception metrics: FID and IS
    # get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid'])
    if global_cfg.get('use_official_eval', True):
        # get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid'])
        get_inception_metrics = inception_utils.prepare_FID_IS(global_cfg)
    else:
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            global_cfg.inception_file, config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()

    # Prepare Sample function for use with inception metrics
    if global_cfg.get('use_official_eval', True):
        return_y = False
    else:
        return_y = True
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config,
        return_y=return_y)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics, default_dict = train(x, y)
            train_log.log(itr=int(state_dict['itr']), **metrics)
            summary_defaultdict2txtfig(default_dict,
                                       prefix='train',
                                       step=state_dict['itr'],
                                       textlogger=global_textlogger)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

            # Test every specified interval
            if not (state_dict['itr'] % config['test_every']) or state_dict['itr']==1 or \
                  state_dict['itr'] % (global_cfg.get('test_every_epoch', float('inf')) * len(loaders[0])) == 0:
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                               get_inception_metrics, experiment_name,
                               test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Ejemplo n.º 7
0
    def __init__(self,
                 G_ch=64,
                 dim_z=128,
                 bottom_width=4,
                 resolution=128,
                 G_kernel_size=3,
                 G_attn='64',
                 n_classes=1000,
                 num_G_SVs=1,
                 num_G_SV_itrs=1,
                 G_shared=True,
                 shared_dim=0,
                 hier=False,
                 cross_replica=False,
                 mybn=False,
                 G_activation=nn.ReLU(inplace=False),
                 G_lr=5e-5,
                 G_B1=0.0,
                 G_B2=0.999,
                 adam_eps=1e-8,
                 BN_eps=1e-5,
                 SN_eps=1e-12,
                 G_mixed_precision=False,
                 G_fp16=False,
                 G_init='ortho',
                 skip_init=False,
                 no_optim=False,
                 G_param='SN',
                 norm_style='bn',
                 **kwargs):
        super(Generator, self).__init__()
        # Channel width mulitplier
        self.ch = G_ch
        # Dimensionality of the latent space
        self.dim_z = dim_z
        # The initial spatial dimensions
        self.bottom_width = bottom_width
        # Resolution of the output
        self.resolution = resolution
        # Kernel size?
        self.kernel_size = G_kernel_size
        # Attention?
        self.attention = G_attn
        # number of classes, for use in categorical conditional generation
        self.n_classes = n_classes
        # Use shared embeddings?
        self.G_shared = G_shared
        # Dimensionality of the shared embedding? Unused if not using G_shared
        self.shared_dim = shared_dim if shared_dim > 0 else dim_z
        # Hierarchical latent space?
        self.hier = hier
        # Cross replica batchnorm?
        self.cross_replica = cross_replica
        # Use my batchnorm?
        self.mybn = mybn
        # nonlinearity for residual blocks
        self.activation = G_activation
        # Initialization style
        self.init = G_init
        # Parameterization style
        self.G_param = G_param
        # Normalization style
        self.norm_style = norm_style
        # Epsilon for BatchNorm?
        self.BN_eps = BN_eps
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # fp16?
        self.fp16 = G_fp16
        # Architecture dict
        self.arch = G_arch(self.ch, self.attention)[resolution]

        # If using hierarchical latents, adjust z
        if self.hier:
            # Number of places z slots into
            self.num_slots = len(self.arch['in_channels']) + 1
            self.z_chunk_size = (self.dim_z // self.num_slots)
            # Recalculate latent dimensionality for even splitting into chunks
            self.dim_z = self.z_chunk_size * self.num_slots
        else:
            self.num_slots = 1
            self.z_chunk_size = 0

        # Which convs, batchnorms, and linear layers to use
        if self.G_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                                kernel_size=3,
                                                padding=1,
                                                num_svs=num_G_SVs,
                                                num_itrs=num_G_SV_itrs,
                                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                                  num_svs=num_G_SVs,
                                                  num_itrs=num_G_SV_itrs,
                                                  eps=self.SN_eps)
        else:
            self.which_conv = functools.partial(nn.Conv2d,
                                                kernel_size=3,
                                                padding=1)
            self.which_linear = nn.Linear

        # We use a non-spectral-normed embedding here regardless;
        # For some reason applying SN to G's embedding seems to randomly cripple G
        self.which_embedding = nn.Embedding
        bn_linear = (functools.partial(self.which_linear, bias=False)
                     if self.G_shared else self.which_embedding)
        self.which_bn = functools.partial(
            layers.ccbn,
            which_linear=bn_linear,
            cross_replica=self.cross_replica,
            mybn=self.mybn,
            input_size=(self.shared_dim + self.z_chunk_size
                        if self.G_shared else self.n_classes),
            norm_style=self.norm_style,
            eps=self.BN_eps)

        # Prepare model
        # If not using shared embeddings, self.shared is just a passthrough
        self.shared = (self.which_embedding(n_classes, self.shared_dim)
                       if G_shared else layers.identity())
        # First linear layer
        self.linear = self.which_linear(
            self.dim_z // self.num_slots,
            self.arch['in_channels'][0] * (self.bottom_width**2))

        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        # while the inner loop is over a given block
        self.blocks = []
        for index in range(len(self.arch['out_channels'])):
            self.blocks += [[
                layers.GBlock(
                    in_channels=self.arch['in_channels'][index],
                    out_channels=self.arch['out_channels'][index],
                    which_conv=self.which_conv,
                    which_bn=self.which_bn,
                    activation=self.activation,
                    upsample=(functools.partial(F.interpolate, scale_factor=2)
                              if self.arch['upsample'][index] else None))
            ]]

            # If attention on this block, attach it to the end
            if self.arch['attention'][self.arch['resolution'][index]]:
                print('Adding attention layer in G at resolution %d' %
                      self.arch['resolution'][index])
                self.blocks[-1] += [
                    layers.Attention(self.arch['out_channels'][index],
                                     self.which_conv)
                ]

        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        # output layer: batchnorm-relu-conv.
        # Consider using a non-spectral conv here
        self.output_layer = nn.Sequential(
            layers.bn(self.arch['out_channels'][-1],
                      cross_replica=self.cross_replica,
                      mybn=self.mybn), self.activation,
            self.which_conv(self.arch['out_channels'][-1], 3))

        # Initialize weights. Optionally skip init for testing.
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        # If this is an EMA copy, no need for an optim, so just return now
        if no_optim:
            return
        self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
        if G_mixed_precision:
            print('Using fp16 adam in G...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(),
                                      lr=self.lr,
                                      betas=(self.B1, self.B2),
                                      weight_decay=0,
                                      eps=self.adam_eps)
        else:
            weight_decay = global_cfg.Generator.get('weight_decay')
            optim_type = global_cfg.get('optim_type', 'adam')
            if optim_type.lower() == 'adam':
                self.optim = optim.Adam(params=self.parameters(),
                                        lr=self.lr,
                                        betas=(self.B1, self.B2),
                                        weight_decay=weight_decay,
                                        eps=self.adam_eps)
            elif optim_type.lower() == 'adams':
                from .adams import AdamS
                self.optim = AdamS(params=self.parameters(),
                                   lr=self.lr,
                                   betas=(self.B1, self.B2),
                                   weight_decay=weight_decay,
                                   eps=self.adam_eps,
                                   amsgrad=False)
            else:
                assert 0
Ejemplo n.º 8
0
    def __init__(self,
                 D_ch=64,
                 D_wide=True,
                 resolution=128,
                 D_kernel_size=3,
                 D_attn='64',
                 n_classes=1000,
                 num_D_SVs=1,
                 num_D_SV_itrs=1,
                 D_activation=nn.ReLU(inplace=False),
                 D_lr=2e-4,
                 D_B1=0.0,
                 D_B2=0.999,
                 adam_eps=1e-8,
                 SN_eps=1e-12,
                 output_dim=1,
                 D_mixed_precision=False,
                 D_fp16=False,
                 D_init='ortho',
                 skip_init=False,
                 D_param='SN',
                 **kwargs):
        super(Discriminator, self).__init__()
        # Width multiplier
        self.ch = D_ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = D_wide
        # Resolution
        self.resolution = resolution
        # Kernel size
        self.kernel_size = D_kernel_size
        # Attention?
        self.attention = D_attn
        # Number of classes
        self.n_classes = n_classes
        # Activation
        self.activation = D_activation
        # Initialization style
        self.init = D_init
        # Parameterization style
        self.D_param = D_param
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # Fp16?
        self.fp16 = D_fp16
        # Architecture
        self.arch = D_arch(self.ch, self.attention)[resolution]

        self.weight_decay = global_cfg.Discriminator.weight_decay

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now
        if self.D_param == 'SN':
            self.which_conv = functools.partial(layers.SNConv2d,
                                                kernel_size=3,
                                                padding=1,
                                                num_svs=num_D_SVs,
                                                num_itrs=num_D_SV_itrs,
                                                eps=self.SN_eps)
            self.which_linear = functools.partial(layers.SNLinear,
                                                  num_svs=num_D_SVs,
                                                  num_itrs=num_D_SV_itrs,
                                                  eps=self.SN_eps)
            self.which_embedding = functools.partial(layers.SNEmbedding,
                                                     num_svs=num_D_SVs,
                                                     num_itrs=num_D_SV_itrs,
                                                     eps=self.SN_eps)
        # Prepare model
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        self.blocks = []
        for index in range(len(self.arch['out_channels'])):
            self.blocks += [[
                layers.DBlock(
                    in_channels=self.arch['in_channels'][index],
                    out_channels=self.arch['out_channels'][index],
                    which_conv=self.which_conv,
                    wide=self.D_wide,
                    activation=self.activation,
                    preactivation=(index > 0),
                    downsample=(nn.AvgPool2d(2)
                                if self.arch['downsample'][index] else None))
            ]]
            # If attention on this block, attach it to the end
            if self.arch['attention'][self.arch['resolution'][index]]:
                print('Adding attention layer in D at resolution %d' %
                      self.arch['resolution'][index])
                self.blocks[-1] += [
                    layers.Attention(self.arch['out_channels'][index],
                                     self.which_conv)
                ]
        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])
        # Linear output layer. The output dimension is typically 1, but may be
        # larger if we're e.g. turning this into a VAE with an inference output
        output_dim = global_cfg.Discriminator.output_dim if 'output_dim' in global_cfg.Discriminator \
              else self.n_classes + 2
        self.linear = self.which_linear(self.arch['out_channels'][-1],
                                        output_dim)

        # Embedding for projection discrimination
        # self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])

        # Initialize weights
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
        if D_mixed_precision:
            print('Using fp16 adam in D...')
            import utils
            self.optim = utils.Adam16(params=self.parameters(),
                                      lr=self.lr,
                                      betas=(self.B1, self.B2),
                                      weight_decay=self.weight_decay,
                                      eps=self.adam_eps)
        else:
            optim_type = global_cfg.get('optim_type', 'adam')
            if optim_type.lower() == 'adam':
                self.optim = optim.Adam(params=self.parameters(),
                                        lr=self.lr,
                                        betas=(self.B1, self.B2),
                                        weight_decay=self.weight_decay,
                                        eps=self.adam_eps)
            elif optim_type.lower() == 'adams':
                from .adams import AdamS
                self.optim = AdamS(params=self.parameters(),
                                   lr=self.lr,
                                   betas=(self.B1, self.B2),
                                   weight_decay=self.weight_decay,
                                   eps=self.adam_eps,
                                   amsgrad=False)
            else:
                assert 0
        # LR scheduling, left here for forward compatibility
        # self.lr_sched = {'itr' : 0}# if self.progressive else {}
        # self.j = 0
        pass