def test_extract_ImageNet100_CMC(self): """ Usage: proj_root=moco-exp python template_lib/modelarts/scripts/copy_tool.py \ -s s3://bucket-7001/ZhouPeng/codes/$proj_root -d /cache/$proj_root -t copytree cd /cache/$proj_root export CUDA_VISIBLE_DEVICES=0 export TIME_STR=0 export PYTHONPATH=./ python -c "from template_lib.proj.imagenet.tests.test_imagenet import Testing_PrepareImageNet;\ Testing_PrepareImageNet().test_extract_ImageNet100_CMC()" :return: """ if 'CUDA_VISIBLE_DEVICES' not in os.environ: os.environ['CUDA_VISIBLE_DEVICES'] = '0' if 'TIME_STR' not in os.environ: os.environ['TIME_STR'] = '0' if utils.is_debugging() else '0' from template_lib.v2.config_cfgnode.argparser import \ (get_command_and_outdir, setup_outdir_and_yaml, get_append_cmd_str, start_cmd_run) from template_lib.v2.config_cfgnode import update_parser_defaults_from_yaml, global_cfg from template_lib.modelarts import modelarts_utils from distutils.dir_util import copy_tree command, outdir = get_command_and_outdir(self, func_name=sys._getframe().f_code.co_name, file=__file__) argv_str = f""" --tl_config_file template_lib/proj/imagenet/tests/configs/PrepareImageNet.yaml --tl_command {command} --tl_outdir {outdir} """ args, cfg = setup_outdir_and_yaml(argv_str, return_cfg=True) modelarts_utils.setup_tl_outdir_obs(global_cfg) modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True) modelarts_utils.prepare_dataset(global_cfg.get('modelarts_download', {}), global_cfg=global_cfg) train_dir = f'{cfg.data_dir}/train' val_dir = f'{cfg.data_dir}/val' save_train_dir = f'{cfg.saved_dir}/train' save_val_dir = f'{cfg.saved_dir}/val' os.makedirs(save_train_dir, exist_ok=True) os.makedirs(save_val_dir, exist_ok=True) with open(cfg.class_list_file, 'r') as f: class_list = f.readlines() for class_subdir in tqdm.tqdm(class_list): class_subdir, _ = class_subdir.strip().split() train_class_dir = f'{train_dir}/{class_subdir}' save_train_class_dir = f'{save_train_dir}/{class_subdir}' copy_tree(train_class_dir, save_train_class_dir) val_class_dir = f'{val_dir}/{class_subdir}' save_val_class_dir = f'{save_val_dir}/{class_subdir}' copy_tree(val_class_dir, save_val_class_dir) modelarts_utils.prepare_dataset(global_cfg.get('modelarts_upload', {}), global_cfg=global_cfg, download=False) modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True) pass
def test_extract_ImageNet_1000x50(self): """ Usage: proj_root=moco-exp python template_lib/modelarts/scripts/copy_tool.py \ -s s3://bucket-7001/ZhouPeng/codes/$proj_root -d /cache/$proj_root -t copytree cd /cache/$proj_root export CUDA_VISIBLE_DEVICES=0 export TIME_STR=0 export PYTHONPATH=./ python -c "from template_lib.proj.imagenet.tests.test_imagenet import Testing_PrepareImageNet;\ Testing_PrepareImageNet().test_extract_ImageNet_1000x50()" :return: """ if 'CUDA_VISIBLE_DEVICES' not in os.environ: os.environ['CUDA_VISIBLE_DEVICES'] = '0' if 'TIME_STR' not in os.environ: os.environ['TIME_STR'] = '0' if utils.is_debugging() else '0' from template_lib.v2.config_cfgnode.argparser import \ (get_command_and_outdir, setup_outdir_and_yaml, get_append_cmd_str, start_cmd_run) from template_lib.v2.config_cfgnode import update_parser_defaults_from_yaml, global_cfg from template_lib.modelarts import modelarts_utils command, outdir = get_command_and_outdir(self, func_name=sys._getframe().f_code.co_name, file=__file__) argv_str = f""" --tl_config_file template_lib/proj/imagenet/tests/configs/PrepareImageNet.yaml --tl_command {command} --tl_outdir {outdir} """ args, cfg = setup_outdir_and_yaml(argv_str, return_cfg=True) global_cfg.merge_from_dict(cfg) global_cfg.merge_from_dict(vars(args)) modelarts_utils.setup_tl_outdir_obs(global_cfg) modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True) modelarts_utils.prepare_dataset(global_cfg.get('modelarts_download', {}), global_cfg=global_cfg) train_dir = f'{cfg.data_dir}/train' counter_cls = 0 for rootdir, subdir, files in os.walk(train_dir): if len(subdir) == 0: counter_cls += 1 extracted_files = sorted(files)[:cfg.num_per_class] for file in tqdm.tqdm(extracted_files, desc=f'class: {counter_cls}'): img_path = os.path.join(rootdir, file) img_rel_path = os.path.relpath(img_path, cfg.data_dir) saved_img_path = f'{cfg.saved_dir}/{os.path.dirname(img_rel_path)}' os.makedirs(saved_img_path, exist_ok=True) shutil.copy(img_path, saved_img_path) pass modelarts_utils.prepare_dataset(global_cfg.get('modelarts_upload', {}), global_cfg=global_cfg, download=False) modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True) pass
def main(): parser = build_parser() args, _ = parser.parse_known_args() is_main_process = args.local_rank == 0 update_parser_defaults_from_yaml(parser, is_main_process=is_main_process) if is_main_process: modelarts_utils.setup_tl_outdir_obs(global_cfg) modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True) modelarts_utils.prepare_dataset(global_cfg.get('modelarts_download', {}), global_cfg=global_cfg) args = parser.parse_args() setup_runtime(seed=args.seed) distributed = ddp_utils.is_distributed() if distributed: dist_utils.init_dist(args.launcher, backend='nccl') # important: use different random seed for different process torch.manual_seed(args.seed + dist.get_rank()) # dataset dataset = torch_data_utils.ImageListDataset(meta_file=global_cfg.image_list_file, ) if distributed: sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) else: sampler = None train_loader = data_utils.DataLoader( dataset, batch_size=1, shuffle=False, sampler=sampler, num_workers=args.num_workers, pin_memory=False) # test data_iter = iter(train_loader) data = next(data_iter) if is_main_process: modelarts_utils.prepare_dataset(global_cfg.get('modelarts_upload', {}), global_cfg=global_cfg, download=False) modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True) if distributed: dist.barrier() pass
def create_optimizers(self, opt): G_params = list(self.netG.parameters()) if opt.use_vae: G_params += list(self.netE.parameters()) if opt.isTrain: D_params = list(self.netD.parameters()) beta1, beta2 = opt.beta1, opt.beta2 if opt.no_TTUR: G_lr, D_lr = opt.lr, opt.lr else: G_lr, D_lr = opt.lr / 2, opt.lr * 2 optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2), **global_cfg.get('G_optim_cfg', {})) optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2), **global_cfg.get('D_optim_cfg', {})) return optimizer_G, optimizer_D
def main(): update_parser_defaults_from_yaml(parser) args = parser.parse_args() global_cfg.merge_from_dict(vars(args)) modelarts_utils.setup_tl_outdir_obs(global_cfg) modelarts_utils.modelarts_sync_results_dir(global_cfg, join=True) modelarts_utils.prepare_dataset(global_cfg.get('modelarts_download', {}), global_cfg=global_cfg) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 or args.multiprocessing_distributed ngpus_per_node = torch.cuda.device_count() if args.multiprocessing_distributed: # Since we have ngpus_per_node processes per node, the total world_size # needs to be adjusted accordingly args.world_size = ngpus_per_node * args.world_size # Use torch.multiprocessing.spawn to launch distributed processes: the # main_worker process function mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: # Simply call main_worker function main_worker(args.gpu, ngpus_per_node, args)
def run(config): logger = logging.getLogger('tl') # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None # FP16? if config['G_fp16']: print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? GD = model.G_D(G, D) logger.info(G) logger.info(D) print('Number of params in G: {} D: {}'.format( * [sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]])) # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') utils.load_weights( G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # If parallel, parallelize the GD module if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) print('Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) # Write metadata utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders = utils.get_data_loaders(**{ **config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'] }) # Prepare inception metrics: FID and IS # get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid']) if global_cfg.get('use_official_eval', True): # get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid']) get_inception_metrics = inception_utils.prepare_FID_IS(global_cfg) else: get_inception_metrics = inception_utils.prepare_inception_metrics( global_cfg.inception_file, config['parallel'], config['no_fid']) # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Prepare a fixed z & y to see individual sample evolution throghout training fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() # Loaders are loaded, prepare the training function if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config) # Else, assume debugging and use the dummy train fn else: train = train_fns.dummy_training_function() # Prepare Sample function for use with inception metrics if global_cfg.get('use_official_eval', True): return_y = False else: return_y = True sample = functools.partial( utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config, return_y=return_y) print('Beginning training at epoch %d...' % state_dict['epoch']) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): # Which progressbar to use? TQDM or my own? if config['pbar'] == 'mine': pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. G.train() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) metrics, default_dict = train(x, y) train_log.log(itr=int(state_dict['itr']), **metrics) summary_defaultdict2txtfig(default_dict, prefix='train', step=state_dict['itr'], textlogger=global_textlogger) # Every sv_log_interval, log singular values if (config['sv_log_interval'] > 0) and ( not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{ **utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D') }) # If using my progbar, print metrics. if config['pbar'] == 'mine': print(', '.join( ['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') # Save weights and copies as configured at specified interval if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) # Test every specified interval if not (state_dict['itr'] % config['test_every']) or state_dict['itr']==1 or \ state_dict['itr'] % (global_cfg.get('test_every_epoch', float('inf')) * len(loaders[0])) == 0: if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics, experiment_name, test_log) # Increment epoch counter at end of epoch state_dict['epoch'] += 1
def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128, G_kernel_size=3, G_attn='64', n_classes=1000, num_G_SVs=1, num_G_SV_itrs=1, G_shared=True, shared_dim=0, hier=False, cross_replica=False, mybn=False, G_activation=nn.ReLU(inplace=False), G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, G_init='ortho', skip_init=False, no_optim=False, G_param='SN', norm_style='bn', **kwargs): super(Generator, self).__init__() # Channel width mulitplier self.ch = G_ch # Dimensionality of the latent space self.dim_z = dim_z # The initial spatial dimensions self.bottom_width = bottom_width # Resolution of the output self.resolution = resolution # Kernel size? self.kernel_size = G_kernel_size # Attention? self.attention = G_attn # number of classes, for use in categorical conditional generation self.n_classes = n_classes # Use shared embeddings? self.G_shared = G_shared # Dimensionality of the shared embedding? Unused if not using G_shared self.shared_dim = shared_dim if shared_dim > 0 else dim_z # Hierarchical latent space? self.hier = hier # Cross replica batchnorm? self.cross_replica = cross_replica # Use my batchnorm? self.mybn = mybn # nonlinearity for residual blocks self.activation = G_activation # Initialization style self.init = G_init # Parameterization style self.G_param = G_param # Normalization style self.norm_style = norm_style # Epsilon for BatchNorm? self.BN_eps = BN_eps # Epsilon for Spectral Norm? self.SN_eps = SN_eps # fp16? self.fp16 = G_fp16 # Architecture dict self.arch = G_arch(self.ch, self.attention)[resolution] # If using hierarchical latents, adjust z if self.hier: # Number of places z slots into self.num_slots = len(self.arch['in_channels']) + 1 self.z_chunk_size = (self.dim_z // self.num_slots) # Recalculate latent dimensionality for even splitting into chunks self.dim_z = self.z_chunk_size * self.num_slots else: self.num_slots = 1 self.z_chunk_size = 0 # Which convs, batchnorms, and linear layers to use if self.G_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) else: self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) self.which_linear = nn.Linear # We use a non-spectral-normed embedding here regardless; # For some reason applying SN to G's embedding seems to randomly cripple G self.which_embedding = nn.Embedding bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared else self.which_embedding) self.which_bn = functools.partial( layers.ccbn, which_linear=bn_linear, cross_replica=self.cross_replica, mybn=self.mybn, input_size=(self.shared_dim + self.z_chunk_size if self.G_shared else self.n_classes), norm_style=self.norm_style, eps=self.BN_eps) # Prepare model # If not using shared embeddings, self.shared is just a passthrough self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared else layers.identity()) # First linear layer self.linear = self.which_linear( self.dim_z // self.num_slots, self.arch['in_channels'][0] * (self.bottom_width**2)) # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) # while the inner loop is over a given block self.blocks = [] for index in range(len(self.arch['out_channels'])): self.blocks += [[ layers.GBlock( in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, which_bn=self.which_bn, activation=self.activation, upsample=(functools.partial(F.interpolate, scale_factor=2) if self.arch['upsample'][index] else None)) ]] # If attention on this block, attach it to the end if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [ layers.Attention(self.arch['out_channels'][index], self.which_conv) ] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList( [nn.ModuleList(block) for block in self.blocks]) # output layer: batchnorm-relu-conv. # Consider using a non-spectral conv here self.output_layer = nn.Sequential( layers.bn(self.arch['out_channels'][-1], cross_replica=self.cross_replica, mybn=self.mybn), self.activation, self.which_conv(self.arch['out_channels'][-1], 3)) # Initialize weights. Optionally skip init for testing. if not skip_init: self.init_weights() # Set up optimizer # If this is an EMA copy, no need for an optim, so just return now if no_optim: return self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps if G_mixed_precision: print('Using fp16 adam in G...') import utils self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) else: weight_decay = global_cfg.Generator.get('weight_decay') optim_type = global_cfg.get('optim_type', 'adam') if optim_type.lower() == 'adam': self.optim = optim.Adam(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=weight_decay, eps=self.adam_eps) elif optim_type.lower() == 'adams': from .adams import AdamS self.optim = AdamS(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=weight_decay, eps=self.adam_eps, amsgrad=False) else: assert 0
def __init__(self, D_ch=64, D_wide=True, resolution=128, D_kernel_size=3, D_attn='64', n_classes=1000, num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8, SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False, D_init='ortho', skip_init=False, D_param='SN', **kwargs): super(Discriminator, self).__init__() # Width multiplier self.ch = D_ch # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? self.D_wide = D_wide # Resolution self.resolution = resolution # Kernel size self.kernel_size = D_kernel_size # Attention? self.attention = D_attn # Number of classes self.n_classes = n_classes # Activation self.activation = D_activation # Initialization style self.init = D_init # Parameterization style self.D_param = D_param # Epsilon for Spectral Norm? self.SN_eps = SN_eps # Fp16? self.fp16 = D_fp16 # Architecture self.arch = D_arch(self.ch, self.attention)[resolution] self.weight_decay = global_cfg.Discriminator.weight_decay # Which convs, batchnorms, and linear layers to use # No option to turn off SN in D right now if self.D_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_embedding = functools.partial(layers.SNEmbedding, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) # Prepare model # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) self.blocks = [] for index in range(len(self.arch['out_channels'])): self.blocks += [[ layers.DBlock( in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, wide=self.D_wide, activation=self.activation, preactivation=(index > 0), downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None)) ]] # If attention on this block, attach it to the end if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [ layers.Attention(self.arch['out_channels'][index], self.which_conv) ] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList( [nn.ModuleList(block) for block in self.blocks]) # Linear output layer. The output dimension is typically 1, but may be # larger if we're e.g. turning this into a VAE with an inference output output_dim = global_cfg.Discriminator.output_dim if 'output_dim' in global_cfg.Discriminator \ else self.n_classes + 2 self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) # Embedding for projection discrimination # self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) # Initialize weights if not skip_init: self.init_weights() # Set up optimizer self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps if D_mixed_precision: print('Using fp16 adam in D...') import utils self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=self.weight_decay, eps=self.adam_eps) else: optim_type = global_cfg.get('optim_type', 'adam') if optim_type.lower() == 'adam': self.optim = optim.Adam(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=self.weight_decay, eps=self.adam_eps) elif optim_type.lower() == 'adams': from .adams import AdamS self.optim = AdamS(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=self.weight_decay, eps=self.adam_eps, amsgrad=False) else: assert 0 # LR scheduling, left here for forward compatibility # self.lr_sched = {'itr' : 0}# if self.progressive else {} # self.j = 0 pass