Esempio n. 1
0
def load_configs_initialize_training():
    parser = ArgumentParser(add_help=True)
    parser.add_argument("--entity",
                        type=str,
                        default=None,
                        help="entity for wandb logging")
    parser.add_argument("--project",
                        type=str,
                        default=None,
                        help="project name for wandb logging")

    parser.add_argument("-cfg",
                        "--cfg_file",
                        type=str,
                        default="./src/configs/CIFAR10/ContraGAN.yaml")
    parser.add_argument("-data", "--data_dir", type=str, default=None)
    parser.add_argument("-save", "--save_dir", type=str, default="./")
    parser.add_argument("-ckpt", "--ckpt_dir", type=str, default=None)
    parser.add_argument("-best",
                        "--load_best",
                        action="store_true",
                        help="load the best performed checkpoint")

    parser.add_argument("--seed",
                        type=int,
                        default=-1,
                        help="seed for generating random numbers")
    parser.add_argument("-DDP",
                        "--distributed_data_parallel",
                        action="store_true")
    parser.add_argument(
        "--backend",
        type=str,
        default="nccl",
        help="cuda backend for DDP training \in ['nccl', 'gloo']")
    parser.add_argument("-tn",
                        "--total_nodes",
                        default=1,
                        type=int,
                        help="total number of nodes for training")
    parser.add_argument("-cn",
                        "--current_node",
                        default=0,
                        type=int,
                        help="rank of the current node")
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("-sync_bn",
                        "--synchronized_bn",
                        action="store_true",
                        help="turn on synchronized batchnorm")
    parser.add_argument("-mpc",
                        "--mixed_precision",
                        action="store_true",
                        help="turn on mixed precision training")

    parser.add_argument("--truncation_factor",
                        type=float,
                        default=-1.0,
                        help="truncation factor for applying truncation trick \
                        (-1.0 means not applying truncation trick)")
    parser.add_argument("--truncation_cutoff",
                        type=float,
                        default=None,
                        help="truncation cutoff for stylegan \
                        (apply truncation for only w[:truncation_cutoff]")
    parser.add_argument(
        "-batch_stat",
        "--batch_statistics",
        action="store_true",
        help="use the statistics of a batch when evaluating GAN \
                        (if false, use the moving average updated statistics)")
    parser.add_argument("-std_stat",
                        "--standing_statistics",
                        action="store_true",
                        help="apply standing statistics for evaluation")
    parser.add_argument(
        "-std_max",
        "--standing_max_batch",
        type=int,
        default=-1,
        help="maximum batch_size for calculating standing statistics \
                        (-1.0 menas not applying standing statistics trick for evaluation)"
    )
    parser.add_argument("-std_step",
                        "--standing_step",
                        type=int,
                        default=-1,
                        help="# of steps for standing statistics \
                        (-1.0 menas not applying standing statistics trick for evaluation)"
                        )
    parser.add_argument(
        "--freezeD",
        type=int,
        default=-1,
        help="# of freezed blocks in the discriminator for transfer learning")

    # parser arguments to apply langevin sampling for GAN evaluation
    # In the arguments regarding 'decay', -1 means not applying the decay trick by default
    parser.add_argument(
        "-lgv",
        "--langevin_sampling",
        action="store_true",
        help=
        "apply langevin sampling to generate images from a Energy-Based Model")
    parser.add_argument(
        "-lgv_rate",
        "--langevin_rate",
        type=float,
        default=-1,
        help="an initial update rate for langevin sampling (\epsilon)")
    parser.add_argument(
        "-lgv_std",
        "--langevin_noise_std",
        type=float,
        default=-1,
        help=
        "standard deviation of a gaussian noise used in langevin sampling (std of n_i)"
    )
    parser.add_argument(
        "-lgv_decay",
        "--langevin_decay",
        type=float,
        default=-1,
        help="decay strength for langevin_rate and langevin_noise_std")
    parser.add_argument(
        "-lgv_decay_steps",
        "--langevin_decay_steps",
        type=int,
        default=-1,
        help=
        "langevin_rate and langevin_noise_std decrease every 'langevin_decay_steps'"
    )
    parser.add_argument("-lgv_steps",
                        "--langevin_steps",
                        type=int,
                        default=-1,
                        help="total steps of langevin sampling")

    parser.add_argument("-t", "--train", action="store_true")
    parser.add_argument("-hdf5",
                        "--load_train_hdf5",
                        action="store_true",
                        help="load train images from a hdf5 file for fast I/O")
    parser.add_argument(
        "-l",
        "--load_data_in_memory",
        action="store_true",
        help="put the whole train dataset on the main memory for fast I/O")
    parser.add_argument(
        "-metrics",
        "--eval_metrics",
        nargs='+',
        default=['fid'],
        help=
        "evaluation metrics to use during training, a subset list of ['fid', 'is', 'prdc'] or none"
    )
    parser.add_argument("--num_eval",
                        type=int,
                        default=1,
                        help="number of runs for final evaluation.")
    parser.add_argument(
        "--resize_fn",
        type=str,
        default="legacy",
        help=
        "which mode to use PIL.bicubic resizing for calculating clean metrics\
                        in ['legacy', 'clean']")
    parser.add_argument("-sr",
                        "--save_real_images",
                        action="store_true",
                        help="save images sampled from the reference dataset")
    parser.add_argument("-sf",
                        "--save_fake_images",
                        action="store_true",
                        help="save fake images generated by the GAN.")
    parser.add_argument("-sf_num",
                        "--save_fake_images_num",
                        type=int,
                        default=1,
                        help="number of fake images to save")
    parser.add_argument("-v",
                        "--vis_fake_images",
                        action="store_true",
                        help="visualize image canvas")
    parser.add_argument("-knn",
                        "--k_nearest_neighbor",
                        action="store_true",
                        help="conduct k-nearest neighbor analysis")
    parser.add_argument("-itp",
                        "--interpolation",
                        action="store_true",
                        help="conduct interpolation analysis")
    parser.add_argument("-fa",
                        "--frequency_analysis",
                        action="store_true",
                        help="conduct frequency analysis")
    parser.add_argument("-tsne",
                        "--tsne_analysis",
                        action="store_true",
                        help="conduct tsne analysis")
    parser.add_argument("-ifid",
                        "--intra_class_fid",
                        action="store_true",
                        help="calculate intra-class fid")
    parser.add_argument('--GAN_train',
                        action='store_true',
                        help="whether to calculate CAS (Recall)")
    parser.add_argument('--GAN_test',
                        action='store_true',
                        help="whether to calculate CAS (Precision)")
    parser.add_argument('-resume_ct',
                        '--resume_classifier_train',
                        action='store_true',
                        help="whether to resume classifier traning for CAS")
    parser.add_argument("-sefa",
                        "--semantic_factorization",
                        action="store_true",
                        help="perform semantic (closed-form) factorization")
    parser.add_argument("-sefa_axis",
                        "--num_semantic_axis",
                        type=int,
                        default=-1,
                        help="number of semantic axis for sefa")
    parser.add_argument(
        "-sefa_max",
        "--maximum_variations",
        type=float,
        default=-1,
        help="iterpolate between z and z + maximum_variations*eigen-vector")
    parser.add_argument(
        "-empty_cache",
        "--empty_cache",
        action="store_true",
        help=
        "empty cuda caches after training step of generator and discriminator, \
                        slightly reduces memory usage but slows training speed. (not recommended for normal use)"
    )

    parser.add_argument("--print_every",
                        type=int,
                        default=100,
                        help="logging interval")
    parser.add_argument("-every",
                        "--save_every",
                        type=int,
                        default=2000,
                        help="save interval")
    parser.add_argument('--eval_backbone',
                        type=str,
                        default='Inception_V3',
                        help="[SwAV, Inception_V3]")
    parser.add_argument(
        "-ref",
        "--ref_dataset",
        type=str,
        default="train",
        help="reference dataset for evaluation[train/valid/test]")
    parser.add_argument(
        "--is_ref_dataset",
        action="store_true",
        help="whether to calculate a inception score of the ref dataset.")
    args = parser.parse_args()
    run_cfgs = vars(args)

    if not args.train and \
            "none" in args.eval_metrics and \
            not args.save_real_images and \
            not args.save_fake_images and \
            not args.vis_fake_images and \
            not args.k_nearest_neighbor and \
            not args.interpolation and \
            not args.frequency_analysis and \
            not args.tsne_analysis and \
            not args.intra_class_fid and \
            not args.GAN_train and \
            not args.GAN_test and \
            not args.semantic_factorization:
        parser.print_help(sys.stderr)
        sys.exit(1)

    gpus_per_node, rank = torch.cuda.device_count(), torch.cuda.current_device(
    )

    cfgs = config.Configurations(args.cfg_file)
    cfgs.update_cfgs(run_cfgs, super="RUN")
    cfgs.OPTIMIZATION.world_size = gpus_per_node * cfgs.RUN.total_nodes
    cfgs.check_compatability()

    run_name = log.make_run_name(
        RUN_NAME_FORMAT,
        data_name=cfgs.DATA.name,
        framework=cfgs.RUN.cfg_file.split("/")[-1][:-5],
        phase="train")

    crop_long_edge = False if cfgs.DATA.name in cfgs.MISC.no_proc_data else True
    resize_size = None if cfgs.DATA.name in cfgs.MISC.no_proc_data else cfgs.DATA.img_size
    if cfgs.RUN.load_train_hdf5:
        hdf5_path, crop_long_edge, resize_size = hdf5.make_hdf5(
            name=cfgs.DATA.name,
            img_size=cfgs.DATA.img_size,
            crop_long_edge=crop_long_edge,
            resize_size=resize_size,
            data_dir=cfgs.RUN.data_dir,
            DATA=cfgs.DATA,
            RUN=cfgs.RUN)
    else:
        hdf5_path = None
    cfgs.PRE.crop_long_edge, cfgs.PRE.resize_size = crop_long_edge, resize_size

    misc.prepare_folder(names=cfgs.MISC.base_folders,
                        save_dir=cfgs.RUN.save_dir)
    misc.download_data_if_possible(data_name=cfgs.DATA.name,
                                   data_dir=cfgs.RUN.data_dir)

    if cfgs.RUN.seed == -1:
        cfgs.RUN.seed = random.randint(1, 4096)
        cfgs.RUN.fix_seed = False
    else:
        cfgs.RUN.fix_seed = True

    if cfgs.OPTIMIZATION.world_size == 1:
        print(
            "You have chosen a specific GPU. This will completely disable data parallelism."
        )
    return cfgs, gpus_per_node, run_name, hdf5_path, rank
Esempio n. 2
0
def main():
    parser = ArgumentParser(add_help=False)
    parser.add_argument('-c', '--config_path', type=str, default='./src/configs/CIFAR10/ContraGAN.json')
    parser.add_argument('--checkpoint_folder', type=str, default=None)
    parser.add_argument('-current', '--load_current', action='store_true', help='whether you load the current or best checkpoint')
    parser.add_argument('--log_output_path', type=str, default=None)

    parser.add_argument('-DDP', '--distributed_data_parallel', action='store_true')
    parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
    parser.add_argument('-nr', '--nr', default=0, type=int, help='ranking within the nodes')

    parser.add_argument('--seed', type=int, default=-1, help='seed for generating random numbers')
    parser.add_argument('--num_workers', type=int, default=8, help='')
    parser.add_argument('-sync_bn', '--synchronized_bn', action='store_true', help='whether turn on synchronized batchnorm')
    parser.add_argument('-mpc', '--mixed_precision', action='store_true', help='whether turn on mixed precision training')
    parser.add_argument('-LARS', '--LARS_optimizer', action='store_true', help='whether turn on LARS optimizer')
    parser.add_argument('-rm_API', '--disable_debugging_API', action='store_true', help='whether disable pytorch autograd debugging mode')

    parser.add_argument('--reduce_train_dataset', type=float, default=1.0, help='control the number of train dataset')
    parser.add_argument('-stat_otf', '--bn_stat_OnTheFly', action='store_true', help='when evaluating, use the statistics of a batch')
    parser.add_argument('-std_stat', '--standing_statistics', action='store_true')
    parser.add_argument('--standing_step', type=int, default=-1, help='# of steps for accumulation batchnorm')
    parser.add_argument('--freeze_layers', type=int, default=-1, help='# of layers for freezing discriminator')

    parser.add_argument('-l', '--load_all_data_in_memory', action='store_true')
    parser.add_argument('-t', '--train', action='store_true')
    parser.add_argument('-e', '--eval', action='store_true')
    parser.add_argument('-s', '--save_images', action='store_true')
    parser.add_argument('-iv', '--image_visualization', action='store_true', help='select whether conduct image visualization')
    parser.add_argument('-knn', '--k_nearest_neighbor', action='store_true', help='select whether conduct k-nearest neighbor analysis')
    parser.add_argument('-itp', '--interpolation', action='store_true', help='whether conduct interpolation analysis')
    parser.add_argument('-fa', '--frequency_analysis', action='store_true', help='whether conduct frequency analysis')
    parser.add_argument('-tsne', '--tsne_analysis', action='store_true', help='whether conduct tsne analysis')
    parser.add_argument('--nrow', type=int, default=10, help='number of rows to plot image canvas')
    parser.add_argument('--ncol', type=int, default=8, help='number of cols to plot image canvas')

    parser.add_argument('--print_every', type=int, default=100, help='control log interval')
    parser.add_argument('--save_every', type=int, default=2000, help='control evaluation and save interval')
    parser.add_argument('--eval_type', type=str, default='test', help='[train/valid/test]')
    args = parser.parse_args()

    if not args.train and \
            not args.eval and \
            not args.save_images and \
            not args.image_visualization and \
            not args.k_nearest_neighbor and \
            not args.interpolation and \
            not args.frequency_analysis and \
            not args.tsne_analysis:
        parser.print_help(sys.stderr)
        sys.exit(1)

    if args.config_path is not None:
        with open(args.config_path) as f:
            model_configs = json.load(f)
        train_configs = vars(args)
    else:
        raise NotImplementedError

    hdf5_path_train = make_hdf5(model_configs['data_processing'], train_configs, mode="train") \
        if train_configs['load_all_data_in_memory'] else None

    if train_configs['seed'] == -1:
        train_configs['seed'] = random.randint(1,4096)
        cudnn.benchmark, cudnn.deterministic = True, False
    else:
        cudnn.benchmark, cudnn.deterministic = False, True

    fix_all_seed(train_configs['seed'])
    gpus_per_node, rank = torch.cuda.device_count(), torch.cuda.current_device()
    world_size = gpus_per_node*train_configs['nodes']
    if world_size == 1:
        warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.')

    run_name = make_run_name(RUN_NAME_FORMAT, framework=train_configs['config_path'].split('/')[-1][:-5], phase='train')
    if train_configs['disable_debugging_API']: torch.autograd.set_detect_anomaly(False)
    check_flags(train_configs, model_configs, world_size)

    if train_configs['distributed_data_parallel'] and world_size > 1:
        print("Train the models through DistributedDataParallel (DDP) mode.")
        mp.spawn(prepare_train_eval, nprocs=gpus_per_node, args=(gpus_per_node, world_size, run_name,
                                                                 train_configs, model_configs, hdf5_path_train))
    else:
        prepare_train_eval(rank, gpus_per_node, world_size, run_name, train_configs, model_configs, hdf5_path_train=hdf5_path_train)
Esempio n. 3
0
def main():
    parser = ArgumentParser(add_help=False)
    parser.add_argument('-c', '--config_path', type=str, default='./src/configs/CIFAR10/ContraGAN.json')
    parser.add_argument('--checkpoint_folder', type=str, default=None)
    parser.add_argument('-current', '--load_current', action='store_true', help='whether you load the current or best checkpoint')
    parser.add_argument('--log_output_path', type=str, default=None)

    parser.add_argument('--seed', type=int, default=-1, help='seed for generating random numbers')
    parser.add_argument('-DDP', '--distributed_data_parallel', action='store_true')
    parser.add_argument('--num_workers', type=int, default=8, help='')
    parser.add_argument('-sync_bn', '--synchronized_bn', action='store_true', help='whether turn on synchronized batchnorm')
    parser.add_argument('-mpc', '--mixed_precision', action='store_true', help='whether turn on mixed precision training')
    parser.add_argument('-rm_API', '--disable_debugging_API', action='store_true', help='whether disable pytorch autograd debugging mode')

    parser.add_argument('--reduce_train_dataset', type=float, default=1.0, help='control the number of train dataset')
    parser.add_argument('-std_stat', '--standing_statistics', action='store_true')
    parser.add_argument('--standing_step', type=int, default=-1, help='# of steps for accumulation batchnorm')
    parser.add_argument('--freeze_layers', type=int, default=-1, help='# of layers for freezing discriminator')

    parser.add_argument('-l', '--load_all_data_in_memory', action='store_true')
    parser.add_argument('-t', '--train', action='store_true')
    parser.add_argument('-e', '--eval', action='store_true')
    parser.add_argument('-s', '--save_images', action='store_true')
    parser.add_argument('-iv', '--image_visualization', action='store_true', help='select whether conduct image visualization')
    parser.add_argument('-knn', '--k_nearest_neighbor', action='store_true', help='select whether conduct k-nearest neighbor analysis')
    parser.add_argument('-itp', '--interpolation', action='store_true', help='whether conduct interpolation analysis')
    parser.add_argument('-fa', '--frequency_analysis', action='store_true', help='whether conduct frequency analysis')
    parser.add_argument('-tsne', '--tsne_analysis', action='store_true', help='whether conduct tsne analysis')
    parser.add_argument('--nrow', type=int, default=10, help='number of rows to plot image canvas')
    parser.add_argument('--ncol', type=int, default=8, help='number of cols to plot image canvas')

    parser.add_argument('--print_every', type=int, default=100, help='control log interval')
    parser.add_argument('--save_every', type=int, default=2000, help='control evaluation and save interval')
    parser.add_argument('--eval_type', type=str, default='test', help='[train/valid/test]')
    args = parser.parse_args()

    if not args.train and \
            not args.eval and \
            not args.save_images and \
            not args.image_visualization and \
            not args.k_nearest_neighbor and \
            not args.interpolation and \
            not args.frequency_analysis and \
            not args.tsne_analysis:
        parser.print_help(sys.stderr)
        sys.exit(1)

    if args.config_path is not None:
        with open(args.config_path) as f:
            model_config = json.load(f)
        train_config = vars(args)
    else:
        raise NotImplementedError

    if model_config['data_processing']['dataset_name'] == 'cifar10':
        assert train_config['eval_type'] in ['train', 'test'], "Cifar10 does not contain dataset for validation."
    elif model_config['data_processing']['dataset_name'] in ['imagenet', 'tiny_imagenet', 'custom']:
        assert train_config['eval_type'] == 'train' or train_config['eval_type'] == 'valid', \
            "StudioGAN dose not support the evalutation protocol that uses the test dataset on imagenet, tiny imagenet, and custom datasets"

    if train_config['distributed_data_parallel']:
        msg = "StudioGAN does not support image visualization, k_nearest_neighbor, interpolation, and frequency_analysis with DDP. " +\
            "Please change DDP with a single GPU training or DataParallel instead."
        assert train_config['image_visualization'] + train_config['k_nearest_neighbor'] + \
            train_config['interpolation'] + train_config['frequency_analysis'] + train_config['tsne_analysis'] == 0, msg

    hdf5_path_train = make_hdf5(model_config['data_processing'], train_config, mode="train") \
        if train_config['load_all_data_in_memory'] else None

    if train_config['seed'] == -1:
        cudnn.benchmark, cudnn.deterministic = True, False
    else:
        fix_all_seed(train_config['seed'])
        cudnn.benchmark, cudnn.deterministic = False, True

    world_size, rank = torch.cuda.device_count(), torch.cuda.current_device()
    if world_size == 1: warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.')

    if train_config['disable_debugging_API']: torch.autograd.set_detect_anomaly(False)
    check_flag_0(model_config['train']['optimization']['batch_size'], world_size, train_config['freeze_layers'], train_config['checkpoint_folder'],
                 model_config['train']['model']['architecture'], model_config['data_processing']['img_size'])

    run_name = make_run_name(RUN_NAME_FORMAT, framework=train_config['config_path'].split('/')[-1][:-5], phase='train')

    if train_config['distributed_data_parallel'] and world_size > 1:
        print("Train the models through DistributedDataParallel (DDP) mode.")
        mp.spawn(prepare_train_eval, nprocs=world_size, args=(world_size, run_name, train_config, model_config, hdf5_path_train))
    else:
        prepare_train_eval(rank, world_size, run_name, train_config, model_config, hdf5_path_train=hdf5_path_train)
def load_frameowrk(
        seed, disable_debugging_API, num_workers, config_path,
        checkpoint_folder, reduce_train_dataset, standing_statistics,
        standing_step, freeze_layers, load_current, eval_type, dataset_name,
        num_classes, img_size, data_path, architecture, conditional_strategy,
        hypersphere_dim, nonlinear_embed, normalize_embed, g_spectral_norm,
        d_spectral_norm, activation_fn, attention,
        attention_after_nth_gen_block, attention_after_nth_dis_block, z_dim,
        shared_dim, g_conv_dim, d_conv_dim, G_depth, D_depth, optimizer,
        batch_size, d_lr, g_lr, momentum, nesterov, alpha, beta1, beta2,
        total_step, adv_loss, cr, g_init, d_init, random_flip_preprocessing,
        prior, truncated_factor, ema, ema_decay, ema_start, synchronized_bn,
        mixed_precision, hdf5_path_train, train_config, model_config, **_):
    if seed == 0:
        cudnn.benchmark = True
        cudnn.deterministic = False
    else:
        fix_all_seed(seed)
        cudnn.benchmark = False
        cudnn.deterministic = True

    if disable_debugging_API:
        torch.autograd.set_detect_anomaly(False)

    n_gpus = torch.cuda.device_count()
    default_device = torch.cuda.current_device()

    check_flag_0(batch_size, n_gpus, standing_statistics, ema, freeze_layers,
                 checkpoint_folder)
    assert batch_size % n_gpus == 0, "batch_size should be divided by the number of gpus "

    if n_gpus == 1:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None
    standing_step = standing_step if standing_statistics is True else batch_size

    run_name = make_run_name(RUN_NAME_FORMAT,
                             framework=config_path.split('/')[-1][:-5],
                             phase='train')

    logger = make_logger(run_name, None)
    writer = SummaryWriter(log_dir=join('./logs', run_name))
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(train_config)
    logger.info(model_config)

    logger.info('Loading train datasets...')
    train_dataset = LoadDataset(dataset_name,
                                data_path,
                                train=True,
                                download=True,
                                resize_size=img_size,
                                hdf5_path=hdf5_path_train,
                                random_flip=random_flip_preprocessing)
    if reduce_train_dataset < 1.0:
        num_train = int(reduce_train_dataset * len(train_dataset))
        train_dataset, _ = torch.utils.data.random_split(
            train_dataset,
            [num_train, len(train_dataset) - num_train])
    logger.info('Train dataset size : {dataset_size}'.format(
        dataset_size=len(train_dataset)))

    logger.info('Loading {mode} datasets...'.format(mode=eval_type))
    eval_mode = True if eval_type == 'train' else False
    eval_dataset = LoadDataset(dataset_name,
                               data_path,
                               train=eval_mode,
                               download=True,
                               resize_size=img_size,
                               hdf5_path=None,
                               random_flip=False)
    logger.info('Eval dataset size : {dataset_size}'.format(
        dataset_size=len(eval_dataset)))

    logger.info('Building model...')
    if architecture == "dcgan":
        assert img_size == 32, "Sry, StudioGAN does not support dcgan models for generation of images larger than 32 resolution."
    module = __import__(
        'models.{architecture}'.format(architecture=architecture),
        fromlist=['something'])
    logger.info('Modules are located on models.{architecture}'.format(
        architecture=architecture))
    Gen = module.Generator(z_dim, shared_dim, img_size, g_conv_dim,
                           g_spectral_norm, attention,
                           attention_after_nth_gen_block, activation_fn,
                           conditional_strategy, num_classes, g_init, G_depth,
                           mixed_precision).to(default_device)

    Dis = module.Discriminator(img_size, d_conv_dim, d_spectral_norm,
                               attention, attention_after_nth_dis_block,
                               activation_fn, conditional_strategy,
                               hypersphere_dim, num_classes, nonlinear_embed,
                               normalize_embed, d_init, D_depth,
                               mixed_precision).to(default_device)

    if ema:
        print('Preparing EMA for G with decay of {}'.format(ema_decay))
        Gen_copy = module.Generator(
            z_dim,
            shared_dim,
            img_size,
            g_conv_dim,
            g_spectral_norm,
            attention,
            attention_after_nth_gen_block,
            activation_fn,
            conditional_strategy,
            num_classes,
            initialize=False,
            G_depth=G_depth,
            mixed_precision=mixed_precision).to(default_device)
        Gen_ema = ema_(Gen, Gen_copy, ema_decay, ema_start)
    else:
        Gen_copy, Gen_ema = None, None

    logger.info(count_parameters(Gen))
    logger.info(Gen)

    logger.info(count_parameters(Dis))
    logger.info(Dis)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=num_workers,
                                  drop_last=True)
    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 drop_last=False)

    G_loss = {
        'vanilla': loss_dcgan_gen,
        'least_square': loss_lsgan_gen,
        'hinge': loss_hinge_gen,
        'wasserstein': loss_wgan_gen
    }
    D_loss = {
        'vanilla': loss_dcgan_dis,
        'least_square': loss_lsgan_dis,
        'hinge': loss_hinge_dis,
        'wasserstein': loss_wgan_dis
    }

    if optimizer == "SGD":
        G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Gen.parameters()),
                                      g_lr,
                                      momentum=momentum,
                                      nesterov=nesterov)
        D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Dis.parameters()),
                                      d_lr,
                                      momentum=momentum,
                                      nesterov=nesterov)
    elif optimizer == "RMSprop":
        G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Gen.parameters()),
                                          g_lr,
                                          momentum=momentum,
                                          alpha=alpha)
        D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Dis.parameters()),
                                          d_lr,
                                          momentum=momentum,
                                          alpha=alpha)
    elif optimizer == "Adam":
        G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Gen.parameters()),
                                       g_lr, [beta1, beta2],
                                       eps=1e-6)
        D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Dis.parameters()),
                                       d_lr, [beta1, beta2],
                                       eps=1e-6)
    elif optimizer == "AdaBelief":
        G_optimizer = AdaBelief(filter(lambda p: p.requires_grad,
                                       Gen.parameters()),
                                g_lr, [beta1, beta2],
                                eps=1e-12,
                                rectify=False)
        D_optimizer = AdaBelief(filter(lambda p: p.requires_grad,
                                       Dis.parameters()),
                                d_lr, [beta1, beta2],
                                eps=1e-12,
                                rectify=False)
    else:
        raise NotImplementedError

    if checkpoint_folder is not None:
        when = "current" if load_current is True else "best"
        if not exists(abspath(checkpoint_folder)):
            raise NotADirectoryError
        checkpoint_dir = make_checkpoint_dir(checkpoint_folder, run_name)
        g_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=G-{when}-weights-step*.pth".format(when=when)))[0]
        d_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=D-{when}-weights-step*.pth".format(when=when)))[0]
        Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint(
            Gen, G_optimizer, g_checkpoint_dir)
        Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\
            load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True)
        logger = make_logger(run_name, None)
        if ema:
            g_ema_checkpoint_dir = glob.glob(
                join(checkpoint_dir,
                     "model=G_ema-{when}-weights-step*.pth".format(
                         when=when)))[0]
            Gen_copy = load_checkpoint(Gen_copy,
                                       None,
                                       g_ema_checkpoint_dir,
                                       ema=True)
            Gen_ema.source, Gen_ema.target = Gen, Gen_copy

        writer = SummaryWriter(log_dir=join('./logs', run_name))
        if train_config['train']:
            assert seed == trained_seed, "seed for sampling random numbers should be same!"
        logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
        logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir))
        if freeze_layers > -1:
            prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None
    else:
        checkpoint_dir = make_checkpoint_dir(checkpoint_folder, run_name)

    if n_gpus > 1:
        Gen = DataParallel(Gen, output_device=default_device)
        Dis = DataParallel(Dis, output_device=default_device)
        if ema:
            Gen_copy = DataParallel(Gen_copy, output_device=default_device)

        if synchronized_bn:
            Gen = convert_model(Gen).to(default_device)
            Dis = convert_model(Dis).to(default_device)
            if ema:
                Gen_copy = convert_model(Gen_copy).to(default_device)

    if train_config['eval']:
        inception_model = InceptionV3().to(default_device)
        if n_gpus > 1:
            inception_model = DataParallel(inception_model,
                                           output_device=default_device)
        mu, sigma = prepare_inception_moments(dataloader=eval_dataloader,
                                              generator=Gen,
                                              eval_mode=eval_type,
                                              inception_model=inception_model,
                                              splits=1,
                                              run_name=run_name,
                                              logger=logger,
                                              device=default_device)
    else:
        mu, sigma, inception_model = None, None, None

    train_eval = Train_Eval(
        run_name=run_name,
        best_step=best_step,
        dataset_name=dataset_name,
        eval_type=eval_type,
        logger=logger,
        writer=writer,
        n_gpus=n_gpus,
        gen_model=Gen,
        dis_model=Dis,
        inception_model=inception_model,
        Gen_copy=Gen_copy,
        Gen_ema=Gen_ema,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        freeze_layers=freeze_layers,
        conditional_strategy=conditional_strategy,
        pos_collected_numerator=model_config['model']
        ['pos_collected_numerator'],
        z_dim=z_dim,
        num_classes=num_classes,
        hypersphere_dim=hypersphere_dim,
        d_spectral_norm=d_spectral_norm,
        g_spectral_norm=g_spectral_norm,
        G_optimizer=G_optimizer,
        D_optimizer=D_optimizer,
        batch_size=batch_size,
        g_steps_per_iter=model_config['optimization']['g_steps_per_iter'],
        d_steps_per_iter=model_config['optimization']['d_steps_per_iter'],
        accumulation_steps=model_config['optimization']['accumulation_steps'],
        total_step=total_step,
        G_loss=G_loss[adv_loss],
        D_loss=D_loss[adv_loss],
        contrastive_lambda=model_config['loss_function']['contrastive_lambda'],
        margin=model_config['loss_function']['margin'],
        tempering_type=model_config['loss_function']['tempering_type'],
        tempering_step=model_config['loss_function']['tempering_step'],
        start_temperature=model_config['loss_function']['start_temperature'],
        end_temperature=model_config['loss_function']['end_temperature'],
        weight_clipping_for_dis=model_config['loss_function']
        ['weight_clipping_for_dis'],
        weight_clipping_bound=model_config['loss_function']
        ['weight_clipping_bound'],
        gradient_penalty_for_dis=model_config['loss_function']
        ['gradient_penalty_for_dis'],
        gradient_penalty_lambda=model_config['loss_function']
        ['gradient_penalty_lambda'],
        deep_regret_analysis_for_dis=model_config['loss_function']
        ['deep_regret_analysis_for_dis'],
        regret_penalty_lambda=model_config['loss_function']
        ['regret_penalty_lambda'],
        cr=cr,
        cr_lambda=model_config['loss_function']['cr_lambda'],
        bcr=model_config['loss_function']['bcr'],
        real_lambda=model_config['loss_function']['real_lambda'],
        fake_lambda=model_config['loss_function']['fake_lambda'],
        zcr=model_config['loss_function']['zcr'],
        gen_lambda=model_config['loss_function']['gen_lambda'],
        dis_lambda=model_config['loss_function']['dis_lambda'],
        sigma_noise=model_config['loss_function']['sigma_noise'],
        diff_aug=model_config['training_and_sampling_setting']['diff_aug'],
        ada=model_config['training_and_sampling_setting']['ada'],
        prev_ada_p=prev_ada_p,
        ada_target=model_config['training_and_sampling_setting']['ada_target'],
        ada_length=model_config['training_and_sampling_setting']['ada_length'],
        prior=prior,
        truncated_factor=truncated_factor,
        ema=ema,
        latent_op=model_config['training_and_sampling_setting']['latent_op'],
        latent_op_rate=model_config['training_and_sampling_setting']
        ['latent_op_rate'],
        latent_op_step=model_config['training_and_sampling_setting']
        ['latent_op_step'],
        latent_op_step4eval=model_config['training_and_sampling_setting']
        ['latent_op_step4eval'],
        latent_op_alpha=model_config['training_and_sampling_setting']
        ['latent_op_alpha'],
        latent_op_beta=model_config['training_and_sampling_setting']
        ['latent_op_beta'],
        latent_norm_reg_weight=model_config['training_and_sampling_setting']
        ['latent_norm_reg_weight'],
        default_device=default_device,
        print_every=train_config['print_every'],
        save_every=train_config['save_every'],
        checkpoint_dir=checkpoint_dir,
        evaluate=train_config['eval'],
        mu=mu,
        sigma=sigma,
        best_fid=best_fid,
        best_fid_checkpoint_path=best_fid_checkpoint_path,
        mixed_precision=mixed_precision,
        train_config=train_config,
        model_config=model_config,
    )

    if train_config['train']:
        step = train_eval.train(current_step=step, total_step=total_step)

    if train_config['eval']:
        is_save = train_eval.evaluation(
            step=step,
            standing_statistics=standing_statistics,
            standing_step=standing_step)

    if train_config['save_images']:
        train_eval.save_images(is_generate=True,
                               png=True,
                               npz=True,
                               standing_statistics=standing_statistics,
                               standing_step=standing_step)

    if train_config['image_visualization']:
        train_eval.run_image_visualization(
            nrow=train_config['nrow'],
            ncol=train_config['ncol'],
            standing_statistics=standing_statistics,
            standing_step=standing_step)

    if train_config['k_nearest_neighbor']:
        train_eval.run_nearest_neighbor(
            nrow=train_config['nrow'],
            ncol=train_config['ncol'],
            standing_statistics=standing_statistics,
            standing_step=standing_step)

    if train_config['interpolation']:
        assert architecture in [
            "big_resnet", "biggan_deep"
        ], "Not supported except for biggan and biggan_deep."
        train_eval.run_linear_interpolation(
            nrow=train_config['nrow'],
            ncol=train_config['ncol'],
            fix_z=True,
            fix_y=False,
            standing_statistics=standing_statistics,
            standing_step=standing_step)
        train_eval.run_linear_interpolation(
            nrow=train_config['nrow'],
            ncol=train_config['ncol'],
            fix_z=False,
            fix_y=True,
            standing_statistics=standing_statistics,
            standing_step=standing_step)

    if train_config['frequency_analysis']:
        train_eval.run_frequency_analysis(
            num_images=len(train_dataset) // num_classes,
            standing_statistics=standing_statistics,
            standing_step=standing_step)
Esempio n. 5
0
def prepare_train_eval(cfgs, hdf5_path_train, **_):
    if cfgs.seed == -1:
        cudnn.benchmark, cudnn.deterministic = True, False
    else:
        fix_all_seed(cfgs.seed)
        cudnn.benchmark, cudnn.deterministic = False, True

    n_gpus, default_device = torch.cuda.device_count(), torch.cuda.current_device()
    if n_gpus ==1: warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.')

    if cfgs.disable_debugging_API: torch.autograd.set_detect_anomaly(False)
    check_flag_0(cfgs.batch_size, n_gpus, cfgs.freeze_layers, cfgs.checkpoint_folder, cfgs.architecture, cfgs.img_size)
    run_name = make_run_name(RUN_NAME_FORMAT, framework=cfgs.config_path.split('/')[3][:-5], phase='train')
    prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path, mu, sigma, inception_model = None, 0, 0, None, None, None, None, None

    logger = make_logger(run_name, None)
    writer = SummaryWriter(log_dir=join('./logs', run_name))
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(cfgs.train_configs)
    logger.info(cfgs.model_configs)


    ##### load dataset #####
    logger.info('Loading train datasets...')
    train_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=True, download=True, resize_size=cfgs.img_size,
                                hdf5_path=hdf5_path_train, random_flip=cfgs.random_flip_preprocessing)
    if cfgs.reduce_train_dataset < 1.0:
        num_train = int(cfgs.reduce_train_dataset*len(train_dataset))
        train_dataset, _ = torch.utils.data.random_split(train_dataset, [num_train, len(train_dataset) - num_train])
    logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset)))

    logger.info('Loading {mode} datasets...'.format(mode=cfgs.eval_type))
    eval_mode = True if cfgs.eval_type == 'train' else False
    eval_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=eval_mode, download=True, resize_size=cfgs.img_size,
                               hdf5_path=None, random_flip=False)
    logger.info('Eval dataset size : {dataset_size}'.format(dataset_size=len(eval_dataset)))

    train_dataloader = DataLoader(train_dataset, batch_size=cfgs.batch_size, shuffle=True, pin_memory=True, num_workers=cfgs.num_workers, drop_last=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=cfgs.batch_size, shuffle=True, pin_memory=True, num_workers=cfgs.num_workers, drop_last=False)


    ##### build model #####
    logger.info('Building model...')
    module = __import__('models.{architecture}'.format(architecture=cfgs.architecture), fromlist=['something'])
    logger.info('Modules are located on models.{architecture}'.format(architecture=cfgs.architecture))
    Gen = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention,
                           cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes,
                           cfgs.g_init, cfgs.G_depth, cfgs.mixed_precision).to(default_device)

    Dis = module.Discriminator(cfgs.img_size, cfgs.d_conv_dim, cfgs.d_spectral_norm, cfgs.attention, cfgs.attention_after_nth_dis_block,
                               cfgs.activation_fn, cfgs.conditional_strategy, cfgs.hypersphere_dim, cfgs.num_classes, cfgs.nonlinear_embed,
                               cfgs.normalize_embed, cfgs.d_init, cfgs.D_depth, cfgs.mixed_precision).to(default_device)

    if cfgs.ema:
        print('Preparing EMA for G with decay of {}'.format(cfgs.ema_decay))
        Gen_copy = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention,
                                    cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes,
                                    initialize=False, G_depth=cfgs.G_depth, mixed_precision=cfgs.mixed_precision).to(default_device)
        Gen_ema = ema(Gen, Gen_copy, cfgs.ema_decay, cfgs.ema_start)
    else:
        Gen_copy, Gen_ema = None, None

    logger.info(count_parameters(Gen))
    logger.info(Gen)

    logger.info(count_parameters(Dis))
    logger.info(Dis)


    ### define loss functions and optimizers
    G_loss = {'vanilla': loss_dcgan_gen, 'least_square': loss_lsgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen}
    D_loss = {'vanilla': loss_dcgan_dis, 'least_square': loss_lsgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis}

    if cfgs.optimizer == "SGD":
        G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov)
        D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov)
    elif cfgs.optimizer == "RMSprop":
        G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, alpha=cfgs.alpha)
        D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, alpha=cfgs.alpha)
    elif cfgs.optimizer == "Adam":
        G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6)
        D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6)
    else:
        raise NotImplementedError


    ##### load checkpoints if needed #####
    if cfgs.checkpoint_folder is None:
        checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name)
    else:
        when = "current" if cfgs.load_current is True else "best"
        if not exists(abspath(cfgs.checkpoint_folder)):
            raise NotADirectoryError
        checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name)
        g_checkpoint_dir = glob.glob(join(checkpoint_dir,"model=G-{when}-weights-step*.pth".format(when=when)))[0]
        d_checkpoint_dir = glob.glob(join(checkpoint_dir,"model=D-{when}-weights-step*.pth".format(when=when)))[0]
        Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint(Gen, G_optimizer, g_checkpoint_dir)
        Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\
            load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True)
        logger = make_logger(run_name, None)
        if cfgs.ema:
            g_ema_checkpoint_dir = glob.glob(join(checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format(when=when)))[0]
            Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=True)
            Gen_ema.source, Gen_ema.target = Gen, Gen_copy

        writer = SummaryWriter(log_dir=join('./logs', run_name))
        if cfgs.train_configs['train']:
            assert cfgs.seed == trained_seed, "seed for sampling random numbers should be same!"
        logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
        logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir))
        if cfgs.freeze_layers > -1 :
            prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None


    ##### wrap models with DP and convert BN to Sync BN #####
    if n_gpus > 1:
        Gen = DataParallel(Gen, output_device=default_device)
        Dis = DataParallel(Dis, output_device=default_device)
        if cfgs.ema:
            Gen_copy = DataParallel(Gen_copy, output_device=default_device)

        if cfgs.synchronized_bn:
            Gen = convert_model(Gen).to(default_device)
            Dis = convert_model(Dis).to(default_device)
            if cfgs.ema:
                Gen_copy = convert_model(Gen_copy).to(default_device)


    ##### load the inception network and prepare first/secend moments for calculating FID #####
    if cfgs.eval:
        inception_model = InceptionV3().to(default_device)
        if n_gpus > 1: inception_model = DataParallel(inception_model, output_device=default_device)

        mu, sigma = prepare_inception_moments(dataloader=eval_dataloader,
                                              generator=Gen,
                                              eval_mode=cfgs.eval_type,
                                              inception_model=inception_model,
                                              splits=1,
                                              run_name=run_name,
                                              logger=logger,
                                              device=default_device)


    worker = make_worker(
        cfgs=cfgs,
        run_name=run_name,
        best_step=best_step,
        logger=logger,
        writer=writer,
        n_gpus=n_gpus,
        gen_model=Gen,
        dis_model=Dis,
        inception_model=inception_model,
        Gen_copy=Gen_copy,
        Gen_ema=Gen_ema,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        G_optimizer=G_optimizer,
        D_optimizer=D_optimizer,
        G_loss=G_loss[cfgs.adv_loss],
        D_loss=D_loss[cfgs.adv_loss],
        prev_ada_p=prev_ada_p,
        default_device=default_device,
        checkpoint_dir=checkpoint_dir,
        mu=mu,
        sigma=sigma,
        best_fid=best_fid,
        best_fid_checkpoint_path=best_fid_checkpoint_path,
    )

    if cfgs.train_configs['train']:
        step = worker.train(current_step=step, total_step=cfgs.total_step)

    if cfgs.eval:
        is_save = worker.evaluation(step=step, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)

    if cfgs.save_images:
        worker.save_images(is_generate=True, png=True, npz=True, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)

    if cfgs.image_visualization:
        worker.run_image_visualization(nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)

    if cfgs.k_nearest_neighbor:
        worker.run_nearest_neighbor(nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)

    if cfgs.interpolation:
        assert cfgs.architecture in ["big_resnet", "biggan_deep"], "Not supported except for biggan and biggan_deep."
        worker.run_linear_interpolation(nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=True, fix_y=False,
                                            standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)
        worker.run_linear_interpolation(nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=False, fix_y=True,
                                            standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)

    if cfgs.frequency_analysis:
        worker.run_frequency_analysis(num_images=len(train_dataset)//cfgs.num_classes,
                                          standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)
Esempio n. 6
0
def train_framework(
        seed, num_workers, config_path, reduce_train_dataset, load_current,
        type4eval_dataset, dataset_name, num_classes, img_size, data_path,
        architecture, conditional_strategy, hypersphere_dim, nonlinear_embed,
        normalize_embed, g_spectral_norm, d_spectral_norm, activation_fn,
        attention, attention_after_nth_gen_block,
        attention_after_nth_dis_block, z_dim, shared_dim, g_conv_dim,
        d_conv_dim, G_depth, D_depth, optimizer, batch_size, d_lr, g_lr,
        momentum, nesterov, alpha, beta1, beta2, total_step, adv_loss,
        consistency_reg, g_init, d_init, random_flip_preprocessing, prior,
        truncated_factor, latent_op, ema, ema_decay, ema_start,
        synchronized_bn, hdf5_path_train, train_config, model_config, **_):
    fix_all_seed(seed)
    cudnn.benchmark = False  # Not good Generator for undetermined input size
    cudnn.deterministic = True
    n_gpus = torch.cuda.device_count()
    default_device = torch.cuda.current_device()
    second_device = default_device if n_gpus == 1 else default_device + 1
    assert batch_size % n_gpus == 0, "batch_size should be divided by the number of gpus "

    if n_gpus == 1:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    start_step, best_step, best_fid, best_fid_checkpoint_path = 0, 0, None, None
    run_name = make_run_name(RUN_NAME_FORMAT,
                             framework=config_path.split('/')[3][:-5],
                             phase='train')

    logger = make_logger(run_name, None)
    writer = SummaryWriter(log_dir=join('./logs', run_name))
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(train_config)
    logger.info(model_config)

    logger.info('Loading train datasets...')
    train_dataset = LoadDataset(dataset_name,
                                data_path,
                                train=True,
                                download=True,
                                resize_size=img_size,
                                hdf5_path=hdf5_path_train,
                                consistency_reg=consistency_reg,
                                random_flip=random_flip_preprocessing)
    if reduce_train_dataset < 1.0:
        num_train = int(reduce_train_dataset * len(train_dataset))
        train_dataset, _ = torch.utils.data.random_split(
            train_dataset,
            [num_train, len(train_dataset) - num_train])
    logger.info('Train dataset size : {dataset_size}'.format(
        dataset_size=len(train_dataset)))

    logger.info('Loading {mode} datasets...'.format(mode=type4eval_dataset))
    eval_mode = True if type4eval_dataset == 'train' else False
    eval_dataset = LoadDataset(dataset_name,
                               data_path,
                               train=eval_mode,
                               download=True,
                               resize_size=img_size,
                               hdf5_path=None,
                               random_flip=False)
    logger.info('Eval dataset size : {dataset_size}'.format(
        dataset_size=len(eval_dataset)))

    logger.info('Building model...')
    if architecture == "dcgan":
        assert img_size == 32, "Sry, StudioGAN does not support dcgan models for generation of images larger than 32 resolution."
    module = __import__(
        'models.{architecture}'.format(architecture=architecture),
        fromlist=['something'])
    logger.info('Modules are located on models.{architecture}'.format(
        architecture=architecture))
    Gen = module.Generator(z_dim, shared_dim, img_size, g_conv_dim,
                           g_spectral_norm, attention,
                           attention_after_nth_gen_block, activation_fn,
                           conditional_strategy, num_classes, synchronized_bn,
                           g_init, G_depth).to(default_device)

    Dis = module.Discriminator(img_size, d_conv_dim, d_spectral_norm,
                               attention, attention_after_nth_dis_block,
                               activation_fn, conditional_strategy,
                               hypersphere_dim, num_classes, nonlinear_embed,
                               normalize_embed, synchronized_bn, d_init,
                               D_depth).to(default_device)

    if ema:
        print('Preparing EMA for G with decay of {}'.format(ema_decay))
        Gen_copy = module.Generator(z_dim,
                                    shared_dim,
                                    img_size,
                                    g_conv_dim,
                                    g_spectral_norm,
                                    attention,
                                    attention_after_nth_gen_block,
                                    activation_fn,
                                    conditional_strategy,
                                    num_classes,
                                    synchronized_bn=False,
                                    initialize=False,
                                    G_depth=G_depth).to(default_device)
        Gen_ema = ema_(Gen, Gen_copy, ema_decay, ema_start)
    else:
        Gen_copy, Gen_ema = None, None

    if n_gpus > 1:
        Gen = DataParallel(Gen, output_device=second_device)
        Dis = DataParallel(Dis, output_device=second_device)
        if ema:
            Gen_copy = DataParallel(Gen_copy, output_device=second_device)
        if synchronized_bn:
            patch_replication_callback(Gen)
            patch_replication_callback(Dis)

    logger.info(count_parameters(Gen))
    logger.info(Gen)

    logger.info(count_parameters(Dis))
    logger.info(Dis)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=num_workers,
                                  drop_last=True)
    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 drop_last=False)

    G_loss = {
        'vanilla': loss_dcgan_gen,
        'hinge': loss_hinge_gen,
        'wasserstein': loss_wgan_gen
    }
    D_loss = {
        'vanilla': loss_dcgan_dis,
        'hinge': loss_hinge_dis,
        'wasserstein': loss_wgan_dis
    }

    if optimizer == "SGD":
        G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Gen.parameters()),
                                      g_lr,
                                      momentum=momentum,
                                      nesterov=nesterov)
        D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Dis.parameters()),
                                      d_lr,
                                      momentum=momentum,
                                      nesterov=nesterov)
    elif optimizer == "RMSprop":
        G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Gen.parameters()),
                                          g_lr,
                                          momentum=momentum,
                                          alpha=alpha)
        D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Dis.parameters()),
                                          d_lr,
                                          momentum=momentum,
                                          alpha=alpha)
    elif optimizer == "Adam":
        G_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, Gen.parameters()), g_lr,
            [beta1, beta2])
        D_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, Dis.parameters()), d_lr,
            [beta1, beta2])
    elif optimizer == "AdamP":
        G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Gen.parameters()),
                                       g_lr,
                                       betas=(beta1, beta2))
        D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Dis.parameters()),
                                       d_lr,
                                       betas=(beta1, beta2))
    else:
        raise NotImplementedError

    checkpoint_dir = make_checkpoint_dir(train_config['checkpoint_folder'],
                                         run_name)

    if train_config['checkpoint_folder'] is not None:
        when = "current" if load_current is True else "best"
        g_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=G-{when}-weights-step*.pth".format(when=when)))[0]
        d_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=D-{when}-weights-step*.pth".format(when=when)))[0]
        Gen, G_optimizer, trained_seed, run_name, start_step, best_step = load_checkpoint(
            Gen, G_optimizer, g_checkpoint_dir)
        Dis, D_optimizer, trained_seed, run_name, start_step, best_step, best_fid, best_fid_checkpoint_path = load_checkpoint(
            Dis, D_optimizer, d_checkpoint_dir, metric=True)
        logger = make_logger(run_name, None)
        if ema:
            g_ema_checkpoint_dir = glob.glob(
                join(checkpoint_dir,
                     "model=G_ema-{when}-weights-step*.pth".format(
                         when=when)))[0]
            Gen_copy = load_checkpoint(Gen_copy,
                                       None,
                                       g_ema_checkpoint_dir,
                                       ema=True)
            Gen_ema.source, Gen_ema.target = Gen, Gen_copy

        writer = SummaryWriter(log_dir=join('./logs', run_name))
        assert seed == trained_seed, "seed for sampling random numbers should be same!"
        logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
        logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir))

    if train_config['eval']:
        inception_model = InceptionV3().to(default_device)
        inception_model = DataParallel(inception_model,
                                       output_device=second_device)
        mu, sigma, is_score, is_std = prepare_inception_moments_eval_dataset(
            dataloader=eval_dataloader,
            generator=Gen,
            eval_mode=type4eval_dataset,
            inception_model=inception_model,
            splits=10,
            run_name=run_name,
            logger=logger,
            device=second_device)
    else:
        mu, sigma, inception_model = None, None, None

    logger.info('Start training...')
    trainer = Trainer(
        run_name=run_name,
        best_step=best_step,
        dataset_name=dataset_name,
        type4eval_dataset=type4eval_dataset,
        logger=logger,
        writer=writer,
        n_gpus=n_gpus,
        gen_model=Gen,
        dis_model=Dis,
        inception_model=inception_model,
        Gen_copy=Gen_copy,
        Gen_ema=Gen_ema,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        conditional_strategy=conditional_strategy,
        z_dim=z_dim,
        num_classes=num_classes,
        hypersphere_dim=hypersphere_dim,
        d_spectral_norm=d_spectral_norm,
        g_spectral_norm=g_spectral_norm,
        G_optimizer=G_optimizer,
        D_optimizer=D_optimizer,
        batch_size=batch_size,
        g_steps_per_iter=model_config['optimization']['g_steps_per_iter'],
        d_steps_per_iter=model_config['optimization']['d_steps_per_iter'],
        accumulation_steps=model_config['optimization']['accumulation_steps'],
        total_step=total_step,
        G_loss=G_loss[adv_loss],
        D_loss=D_loss[adv_loss],
        contrastive_lambda=model_config['loss_function']['contrastive_lambda'],
        tempering_type=model_config['loss_function']['tempering_type'],
        tempering_step=model_config['loss_function']['tempering_step'],
        start_temperature=model_config['loss_function']['start_temperature'],
        end_temperature=model_config['loss_function']['end_temperature'],
        gradient_penalty_for_dis=model_config['loss_function']
        ['gradient_penalty_for_dis'],
        gradient_penelty_lambda=model_config['loss_function']
        ['gradient_penelty_lambda'],
        weight_clipping_for_dis=model_config['loss_function']
        ['weight_clipping_for_dis'],
        weight_clipping_bound=model_config['loss_function']
        ['weight_clipping_bound'],
        consistency_reg=consistency_reg,
        consistency_lambda=model_config['loss_function']['consistency_lambda'],
        diff_aug=model_config['training_and_sampling_setting']['diff_aug'],
        prior=prior,
        truncated_factor=truncated_factor,
        ema=ema,
        latent_op=latent_op,
        latent_op_rate=model_config['training_and_sampling_setting']
        ['latent_op_rate'],
        latent_op_step=model_config['training_and_sampling_setting']
        ['latent_op_step'],
        latent_op_step4eval=model_config['training_and_sampling_setting']
        ['latent_op_step4eval'],
        latent_op_alpha=model_config['training_and_sampling_setting']
        ['latent_op_alpha'],
        latent_op_beta=model_config['training_and_sampling_setting']
        ['latent_op_beta'],
        latent_norm_reg_weight=model_config['training_and_sampling_setting']
        ['latent_norm_reg_weight'],
        default_device=default_device,
        second_device=second_device,
        print_every=train_config['print_every'],
        save_every=train_config['save_every'],
        checkpoint_dir=checkpoint_dir,
        evaluate=train_config['eval'],
        mu=mu,
        sigma=sigma,
        best_fid=best_fid,
        best_fid_checkpoint_path=best_fid_checkpoint_path,
        train_config=train_config,
        model_config=model_config,
    )

    if conditional_strategy == 'ContraGAN' and train_config['train']:
        trainer.run_ours(current_step=start_step, total_step=total_step)
    elif train_config['train']:
        trainer.run(current_step=start_step, total_step=total_step)
    elif train_config['eval']:
        is_save = trainer.evaluation(step=start_step)

    if train_config['k_nearest_neighbor'] > 0:
        trainer.K_Nearest_Neighbor(
            train_config['criterion_4_k_nearest_neighbor'],
            train_config['number_of_nearest_samples'],
            random.randrange(num_classes))
Esempio n. 7
0
def pretrain(data_dir, train_path, val_path, dictionary_path,
             dataset_limit, vocabulary_size, batch_size, max_len, epochs, clip_grads, device,
             layers_count, hidden_size, heads_count, d_ff, dropout_prob,
             log_output, checkpoint_dir, print_every, save_every, config, run_name=None, **_):

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    train_path = train_path if data_dir is None else join(data_dir, train_path)
    val_path = val_path if data_dir is None else join(data_dir, val_path)
    dictionary_path = dictionary_path if data_dir is None else join(data_dir, dictionary_path)

    run_name = run_name if run_name is not None else make_run_name(RUN_NAME_FORMAT, phase='pretrain', config=config)
    logger = make_logger(run_name, log_output)
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(config)

    logger.info('Constructing dictionaries...')
    dictionary = IndexDictionary.load(dictionary_path=dictionary_path,
                                      vocabulary_size=vocabulary_size)
    vocabulary_size = len(dictionary)
    logger.info(f'dictionary vocabulary : {vocabulary_size} tokens')

    logger.info('Loading datasets...')
    train_dataset = PairedDataset(data_path=train_path, dictionary=dictionary, dataset_limit=dataset_limit)
    val_dataset = PairedDataset(data_path=val_path, dictionary=dictionary, dataset_limit=dataset_limit)
    logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset)))

    logger.info('Building model...')
    model = build_model(layers_count, hidden_size, heads_count, d_ff, dropout_prob, max_len, vocabulary_size)

    logger.info(model)
    logger.info('{parameters_count} parameters'.format(
        parameters_count=sum([p.nelement() for p in model.parameters()])))

    loss_model = MLMNSPLossModel(model)
    if torch.cuda.device_count() > 1:
        loss_model = DataParallel(loss_model, output_device=1)

    metric_functions = [mlm_accuracy, nsp_accuracy]

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        collate_fn=pretraining_collate_function)

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        collate_fn=pretraining_collate_function)

    optimizer = NoamOptimizer(model.parameters(),
                              d_model=hidden_size, factor=2, warmup_steps=10000, betas=(0.9, 0.999), weight_decay=0.01)

    checkpoint_dir = make_checkpoint_dir(checkpoint_dir, run_name, config)

    logger.info('Start training...')
    trainer = Trainer(
        loss_model=loss_model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        metric_functions=metric_functions,
        optimizer=optimizer,
        clip_grads=clip_grads,
        logger=logger,
        checkpoint_dir=checkpoint_dir,
        print_every=print_every,
        save_every=save_every,
        device=device
    )

    trainer.run(epochs=epochs)
    return trainer
Esempio n. 8
0
def finetune(pretrained_checkpoint,
             data_dir, train_path, val_path, dictionary_path,
             vocabulary_size, batch_size, max_len, epochs, lr, clip_grads, device,
             layers_count, hidden_size, heads_count, d_ff, dropout_prob,
             log_output, checkpoint_dir, print_every, save_every, config, run_name=None, **_):

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    train_path = train_path if data_dir is None else join(data_dir, train_path)
    val_path = val_path if data_dir is None else join(data_dir, val_path)
    dictionary_path = dictionary_path if data_dir is None else join(data_dir, dictionary_path)

    run_name = run_name if run_name is not None else make_run_name(RUN_NAME_FORMAT, phase='finetune', config=config)
    logger = make_logger(run_name, log_output)
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(config)

    logger.info('Constructing dictionaries...')
    dictionary = IndexDictionary.load(dictionary_path=dictionary_path,
                                      vocabulary_size=vocabulary_size)
    vocabulary_size = len(dictionary)
    logger.info(f'dictionary vocabulary : {vocabulary_size} tokens')

    logger.info('Loading datasets...')
    train_dataset = SST2IndexedDataset(data_path=train_path, dictionary=dictionary)
    val_dataset = SST2IndexedDataset(data_path=val_path, dictionary=dictionary)
    logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset)))

    logger.info('Building model...')
    pretrained_model = build_model(layers_count, hidden_size, heads_count, d_ff, dropout_prob, max_len, vocabulary_size)
    pretrained_model.load_state_dict(torch.load(pretrained_checkpoint, map_location='cpu')['state_dict'])

    model = FineTuneModel(pretrained_model, hidden_size, num_classes=2)

    logger.info(model)
    logger.info('{parameters_count} parameters'.format(
        parameters_count=sum([p.nelement() for p in model.parameters()])))

    loss_model = ClassificationLossModel(model)
    metric_functions = [classification_accuracy]

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        collate_fn=classification_collate_function)

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        collate_fn=classification_collate_function)

    optimizer = Adam(model.parameters(), lr=lr)

    checkpoint_dir = make_checkpoint_dir(checkpoint_dir, run_name, config)

    logger.info('Start training...')
    trainer = Trainer(
        loss_model=loss_model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        metric_functions=metric_functions,
        optimizer=optimizer,
        clip_grads=clip_grads,
        logger=logger,
        checkpoint_dir=checkpoint_dir,
        print_every=print_every,
        save_every=save_every,
        device=device
    )

    trainer.run(epochs=epochs)
    return trainer
def train_framework(dataset_name, architecture, num_classes, img_size, data_path, eval_dataset, hdf5_path_train, hdf5_path_valid, train_rate, auxiliary_classifier,
                    projection_discriminator, contrastive_training, hyper_dim, nonlinear_embed, normalize_embed, g_spectral_norm, d_spectral_norm, attention, reduce_class,
                    at_after_th_gen_block, at_after_th_dis_block, leaky_relu, g_init, d_init, latent_op, consistency_reg, make_positive_aug, synchronized_bn, ema,
                    ema_decay, ema_start, adv_loss, z_dim, shared_dim, g_conv_dim, d_conv_dim, batch_size, total_step, truncated_factor, prior, d_lr, g_lr,
                    beta1, beta2, batch4metrics, config, **_):

    fix_all_seed(config['seed'])
    cudnn.benchmark = True # Not good Generator for undetermined input size
    cudnn.deterministic = False
    n_gpus = torch.cuda.device_count()
    default_device = torch.cuda.current_device()
    second_device = default_device if n_gpus == 1 else default_device+1
    assert batch_size % n_gpus == 0, "batch_size should be divided by the number of gpus "

    if n_gpus == 1:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    start_step = 0
    best_val_fid, best_checkpoint_fid_path, best_val_is, best_checkpoint_is_path = None, None, None, None
    run_name = make_run_name(RUN_NAME_FORMAT,
                             framework=config['config_path'].split('/')[3][:-5],
                             phase='train',
                             config=config)

    logger = make_logger(run_name, None)
    writer = SummaryWriter(log_dir=join('./logs', run_name))
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(config)

    logger.info('Loading train datasets...')
    train_dataset = LoadDataset(dataset_name, data_path, train=True, download=True, resize_size=img_size, hdf5_path=hdf5_path_train,
                                consistency_reg=consistency_reg, make_positive_aug=make_positive_aug)
    if train_rate < 1.0:
        num_train = int(train_rate*len(train_dataset))
        train_dataset, _ = torch.utils.data.random_split(train_dataset, [num_train, len(train_dataset) - num_train])

    logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset)))

    logger.info('Loading valid datasets...')
    valid_dataset = LoadDataset(dataset_name, data_path, train=False, download=True, resize_size=img_size, hdf5_path=hdf5_path_valid)
    logger.info('Valid dataset size : {dataset_size}'.format(dataset_size=len(valid_dataset)))

    logger.info('Building model...')
    module = __import__('models.{architecture}'.format(architecture=architecture),fromlist=['something'])
    logger.info('Modules are located on models.{architecture}'.format(architecture=architecture))
    num_classes = int(reduce_class*num_classes)
    Gen = module.Generator(z_dim, shared_dim, g_conv_dim, g_spectral_norm, attention, at_after_th_gen_block, leaky_relu, auxiliary_classifier,
                           projection_discriminator, num_classes, contrastive_training, synchronized_bn, g_init).to(default_device)

    Dis = module.Discriminator(d_conv_dim, d_spectral_norm, attention, at_after_th_dis_block, leaky_relu, auxiliary_classifier, 
                               projection_discriminator, hyper_dim, num_classes, contrastive_training, nonlinear_embed, normalize_embed,
                               synchronized_bn, d_init).to(default_device)

    if ema:
        print('Preparing EMA for G with decay of {}'.format(ema_decay))
        Gen_copy = module.Generator(z_dim, shared_dim, g_conv_dim, g_spectral_norm, attention, at_after_th_gen_block, leaky_relu, auxiliary_classifier,
                                    projection_discriminator, num_classes, contrastive_training, synchronized_bn=False, initialize=False).to(default_device)
        Gen_ema = ema_(Gen, Gen_copy, ema_decay, ema_start)
    else:
        Gen_copy, Gen_ema = None, None

    if n_gpus > 1:
        Gen = DataParallel(Gen, output_device=second_device)
        Dis = DataParallel(Dis, output_device=second_device)
        if ema:
            Gen_copy = DataParallel(Gen_copy, output_device=second_device)
        if config['synchronized_bn']:
            patch_replication_callback(Gen)
            patch_replication_callback(Dis)

    logger.info(count_parameters(Gen))
    logger.info(Gen)

    logger.info(count_parameters(Dis))
    logger.info(Dis)
    if reduce_class != 1.0:
        assert dataset_name == "TINY_ILSVRC2012" or "ILSVRC2012", "reduce_class mode can not be applied on the CIFAR10 dataset"
        n_train = int(reduce_class*len(train_dataset))
        n_valid = int(reduce_class*len(valid_dataset))
        train_weights = [1.0]*n_train + [0.0]*(len(train_dataset) - n_train)
        valid_weights = [1.0]*n_valid + [0.0]*(len(valid_dataset) - n_valid)
        train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_weights, len(train_weights))
        valid_sampler = torch.utils.data.sampler.WeightedRandomSampler(valid_weights, len(valid_weights))
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=batch_size, sampler=train_sampler, shuffle=False,
                                      pin_memory=True, num_workers=config['num_workers'], drop_last=True)

        evaluation_dataloader = DataLoader(valid_dataset,
                                           sampler=valid_sampler, batch_size=batch4metrics, shuffle=False,
                                           pin_memory=True, num_workers=config['num_workers'], drop_last=False)
    else:       
        train_dataloader = DataLoader(train_dataset,
                                    batch_size=batch_size, shuffle=True, pin_memory=True,
                                    num_workers=config['num_workers'], drop_last=True)

        evaluation_dataloader = DataLoader(valid_dataset,
                                        batch_size=batch4metrics, shuffle=True, pin_memory=True,
                                        num_workers=config['num_workers'], drop_last=False)

    G_loss = {'vanilla': loss_dcgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen}
    D_loss = {'vanilla': loss_dcgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis}

    G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, [beta1, beta2])
    D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, [beta1, beta2])

    checkpoint_dir = make_checkpoint_dir(config['checkpoint_folder'], run_name, config)

    if config['checkpoint_folder'] is not None:
        logger = make_logger(run_name, config['log_output_path'])
        g_checkpoint_dir = glob.glob(os.path.join(checkpoint_dir,"model=G-step=" + str(config['step']) + "*.pth"))[0]
        d_checkpoint_dir = glob.glob(os.path.join(checkpoint_dir,"model=D-step=" + str(config['step']) + "*.pth"))[0]
        Gen, G_optimizer, seed, run_name, start_step = load_checkpoint(Gen, G_optimizer, g_checkpoint_dir)
        Dis, D_optimizer, seed, run_name, start_step, best_val_fid, best_checkpoint_fid_path,\
        best_val_is, best_checkpoint_is_path = load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True)
        if ema:
            g_ema_checkpoint_dir = glob.glob(os.path.join(checkpoint_dir, "model=G_ema-step=" + str(config['step']) + "*.pth"))[0]
            Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=ema)
            Gen_ema.source, Gen_ema.target = Gen, Gen_copy

        writer = SummaryWriter(log_dir=join('./logs', run_name))
        assert config['seed'] == seed, "seed for sampling random numbers should be same!"
        logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
        logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir))

    if config['eval']:
        inception_model = InceptionV3().to(default_device)
        inception_model = DataParallel(inception_model, output_device=second_device)
        mu, sigma, is_score, is_std = prepare_inception_moments_eval_dataset(dataloader=evaluation_dataloader,
                                                                            inception_model=inception_model,
                                                                            reduce_class=reduce_class,
                                                                            splits=10,
                                                                            logger=logger,
                                                                            device=second_device,
                                                                            eval_dataset=eval_dataset)
    else:
        mu, sigma, inception_model = None, None, None

    logger.info('Start training...')
    trainer = Trainer(
        run_name=run_name,
        logger=logger,
        writer=writer,
        n_gpus=n_gpus,
        gen_model=Gen,
        dis_model=Dis,
        inception_model=inception_model,
        Gen_copy=Gen_copy,
        Gen_ema=Gen_ema,
        train_dataloader=train_dataloader,
        evaluation_dataloader=evaluation_dataloader,
        G_loss=G_loss[adv_loss],
        D_loss=D_loss[adv_loss],
        auxiliary_classifier=auxiliary_classifier,
        contrastive_training=contrastive_training,
        contrastive_lambda=config['contrastive_lambda'],
        softmax_posterior=config['softmax_posterior'],
        contrastive_softmax=config['contrastive_softmax'],
        hyper_dim=config['hyper_dim'],
        tempering=config['tempering'],
        discrete_tempering=config['discrete_tempering'],
        tempering_times=config['tempering_times'],
        start_temperature=config['start_temperature'],
        end_temperature=config['end_temperature'],
        gradient_penalty_for_dis=config['gradient_penalty_for_dis'],
        lambda4lp=config['lambda4lp'],
        lambda4gp=config['lambda4gp'],
        weight_clipping_for_dis=config['weight_clipping_for_dis'],
        weight_clipping_bound=config['weight_clipping_bound'],
        latent_op=latent_op,
        latent_op_rate=config['latent_op_rate'],
        latent_op_step=config['latent_op_step'],
        latent_op_step4eval=config['latent_op_step4eval'],
        latent_op_alpha=config['latent_op_alpha'],
        latent_op_beta=config['latent_op_beta'],
        latent_norm_reg_weight=config['latent_norm_reg_weight'],
        consistency_reg=consistency_reg,
        consistency_lambda=config['consistency_lambda'],
        make_positive_aug=make_positive_aug,
        G_optimizer=G_optimizer,
        D_optimizer=D_optimizer,
        default_device=default_device,
        second_device=second_device,
        batch_size=batch_size,
        z_dim=z_dim,
        num_classes=num_classes,
        truncated_factor=truncated_factor,
        prior=prior,
        g_steps_per_iter=config['g_steps_per_iter'],
        d_steps_per_iter=config['d_steps_per_iter'],
        accumulation_steps=config['accumulation_steps'],
        lambda4ortho=config['lambda4ortho'],
        print_every=config['print_every'],
        save_every=config['save_every'],
        checkpoint_dir=checkpoint_dir,
        evaluate=config['eval'],
        mu=mu,
        sigma=sigma,
        best_val_fid=best_val_fid,
        best_checkpoint_fid_path=best_checkpoint_fid_path,
        best_val_is=best_val_is,
        best_checkpoint_is_path=best_checkpoint_is_path,
        config=config,
    )

    if contrastive_training:
        trainer.run_ours(current_step=start_step, total_step=total_step)
    else:
        trainer.run(current_step=start_step, total_step=total_step)