コード例 #1
0
def load_StudioGAN_ckpts(ckpt_dir, load_best, Gen, Dis, g_optimizer, d_optimizer, run_name, apply_g_ema, Gen_ema, ema,
                         is_train, RUN, logger, global_rank, device, cfg_file):
    when = "best" if load_best is True else "current"
    Gen_ckpt_path = glob.glob(join(ckpt_dir, "model=G-{when}-weights-step*.pth".format(when=when)))[0]
    Dis_ckpt_path = glob.glob(join(ckpt_dir, "model=D-{when}-weights-step*.pth".format(when=when)))[0]
    prev_run_name = torch.load(Dis_ckpt_path, map_location=lambda storage, loc: storage)["run_name"]
    is_freezeD = True if RUN.freezeD > -1 else False

    load_ckpt(model=Gen,
              optimizer=g_optimizer,
              ckpt_path=Gen_ckpt_path,
              load_model=True,
              load_opt=False if prev_run_name in blacklist or is_freezeD or not is_train else True,
              load_misc=False,
              is_freezeD=is_freezeD)

    seed, prev_run_name, step, epoch, topk, aa_p, best_step, best_fid, best_ckpt_path =\
        load_ckpt(model=Dis,
                  optimizer=d_optimizer,
                  ckpt_path=Dis_ckpt_path,
                  load_model=True,
                  load_opt=False if prev_run_name in blacklist or is_freezeD or not is_train else True,
                  load_misc=True,
                  is_freezeD=is_freezeD)

    if not is_train:
        prev_run_name = cfg_file[cfg_file.rindex("/")+1:cfg_file.index(".yaml")]+prev_run_name[prev_run_name.index("-train"):]

    if apply_g_ema:
        Gen_ema_ckpt_path = glob.glob(join(ckpt_dir, "model=G_ema-{when}-weights-step*.pth".format(when=when)))[0]
        load_ckpt(model=Gen_ema,
                  optimizer=None,
                  ckpt_path=Gen_ema_ckpt_path,
                  load_model=True,
                  load_opt=False,
                  load_misc=False,
                  is_freezeD=is_freezeD)

        ema.source, ema.target = Gen, Gen_ema

    if is_train and RUN.seed != seed:
        RUN.seed = seed + global_rank
        misc.fix_seed(RUN.seed)

    if device == 0:
        if not is_freezeD:
            logger = log.make_logger(RUN.save_dir, prev_run_name, None)

        logger.info("Generator checkpoint is {}".format(Gen_ckpt_path))
        if apply_g_ema:
            logger.info("EMA_Generator checkpoint is {}".format(Gen_ema_ckpt_path))
        logger.info("Discriminator checkpoint is {}".format(Dis_ckpt_path))

    if is_freezeD:
        prev_run_name, step, epoch, topk, aa_p, best_step, best_fid, best_ckpt_path =\
            run_name, 0, 0, "initialize", None, 0, None, None
    return prev_run_name, step, epoch, topk, aa_p, best_step, best_fid, best_ckpt_path, logger
コード例 #2
0
ファイル: loader.py プロジェクト: jlim13/PyTorch-StudioGAN
def prepare_train_eval(rank, world_size, run_name, train_config, model_config,
                       hdf5_path_train):
    cfgs = dict2clsattr(train_config, model_config)
    prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path, mu, sigma, inception_model = None, 0, 0, None, None, None, None, None
    if cfgs.distributed_data_parallel:
        print("Use GPU: {} for training.".format(rank))
        setup(rank, world_size)
        torch.cuda.set_device(rank)

    writer = SummaryWriter(
        log_dir=join('./logs', run_name)) if rank == 0 else None
    if rank == 0:
        logger = make_logger(run_name, None)
        logger.info('Run name : {run_name}'.format(run_name=run_name))
        logger.info(train_config)
        logger.info(model_config)
    else:
        logger = None

    ##### load dataset #####
    if rank == 0: logger.info('Load train datasets...')
    train_dataset = LoadDataset(cfgs.dataset_name,
                                cfgs.data_path,
                                train=True,
                                download=True,
                                resize_size=cfgs.img_size,
                                hdf5_path=hdf5_path_train,
                                random_flip=cfgs.random_flip_preprocessing)
    if cfgs.reduce_train_dataset < 1.0:
        num_train = int(cfgs.reduce_train_dataset * len(train_dataset))
        train_dataset, _ = torch.utils.data.random_split(
            train_dataset,
            [num_train, len(train_dataset) - num_train])
    if rank == 0:
        logger.info('Train dataset size : {dataset_size}'.format(
            dataset_size=len(train_dataset)))

    if rank == 0:
        logger.info('Load {mode} datasets...'.format(mode=cfgs.eval_type))
    eval_mode = True if cfgs.eval_type == 'train' else False
    eval_dataset = LoadDataset(cfgs.dataset_name,
                               cfgs.data_path,
                               train=eval_mode,
                               download=True,
                               resize_size=cfgs.img_size,
                               hdf5_path=None,
                               random_flip=False)
    if rank == 0:
        logger.info('Eval dataset size : {dataset_size}'.format(
            dataset_size=len(eval_dataset)))

    if cfgs.distributed_data_parallel:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        cfgs.batch_size = cfgs.batch_size // world_size
    else:
        train_sampler = None

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=cfgs.batch_size,
                                  shuffle=(train_sampler is None),
                                  pin_memory=True,
                                  num_workers=cfgs.num_workers,
                                  sampler=train_sampler,
                                  drop_last=True)
    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=cfgs.batch_size,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=cfgs.num_workers,
                                 drop_last=False)

    ##### build model #####
    if rank == 0: logger.info('Build model...')
    module = __import__(
        'models.{architecture}'.format(architecture=cfgs.architecture),
        fromlist=['something'])
    if rank == 0:
        logger.info('Modules are located on models.{architecture}.'.format(
            architecture=cfgs.architecture))
    Gen = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size,
                           cfgs.g_conv_dim, cfgs.g_spectral_norm,
                           cfgs.attention, cfgs.attention_after_nth_gen_block,
                           cfgs.activation_fn, cfgs.conditional_strategy,
                           cfgs.num_classes, cfgs.g_init, cfgs.G_depth,
                           cfgs.mixed_precision).to(rank)

    Dis = module.Discriminator(
        cfgs.img_size, cfgs.d_conv_dim, cfgs.d_spectral_norm, cfgs.attention,
        cfgs.attention_after_nth_dis_block, cfgs.activation_fn,
        cfgs.conditional_strategy, cfgs.hypersphere_dim, cfgs.num_classes,
        cfgs.nonlinear_embed, cfgs.normalize_embed, cfgs.d_init, cfgs.D_depth,
        cfgs.mixed_precision).to(rank)

    if cfgs.ema:
        if rank == 0:
            logger.info('Prepare EMA for G with decay of {}.'.format(
                cfgs.ema_decay))
        Gen_copy = module.Generator(
            cfgs.z_dim,
            cfgs.shared_dim,
            cfgs.img_size,
            cfgs.g_conv_dim,
            cfgs.g_spectral_norm,
            cfgs.attention,
            cfgs.attention_after_nth_gen_block,
            cfgs.activation_fn,
            cfgs.conditional_strategy,
            cfgs.num_classes,
            initialize=False,
            G_depth=cfgs.G_depth,
            mixed_precision=cfgs.mixed_precision).to(rank)
        Gen_ema = ema(Gen, Gen_copy, cfgs.ema_decay, cfgs.ema_start)
    else:
        Gen_copy, Gen_ema = None, None

    if rank == 0: logger.info(count_parameters(Gen))
    if rank == 0: logger.info(Gen)

    if rank == 0: logger.info(count_parameters(Dis))
    if rank == 0: logger.info(Dis)

    ### define loss functions and optimizers
    G_loss = {
        'vanilla': loss_dcgan_gen,
        'least_square': loss_lsgan_gen,
        'hinge': loss_hinge_gen,
        'wasserstein': loss_wgan_gen
    }
    D_loss = {
        'vanilla': loss_dcgan_dis,
        'least_square': loss_lsgan_dis,
        'hinge': loss_hinge_dis,
        'wasserstein': loss_wgan_dis
    }

    if cfgs.optimizer == "SGD":
        G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Gen.parameters()),
                                      cfgs.g_lr,
                                      momentum=cfgs.momentum,
                                      nesterov=cfgs.nesterov)
        D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Dis.parameters()),
                                      cfgs.d_lr,
                                      momentum=cfgs.momentum,
                                      nesterov=cfgs.nesterov)
    elif cfgs.optimizer == "RMSprop":
        G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Gen.parameters()),
                                          cfgs.g_lr,
                                          momentum=cfgs.momentum,
                                          alpha=cfgs.alpha)
        D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Dis.parameters()),
                                          cfgs.d_lr,
                                          momentum=cfgs.momentum,
                                          alpha=cfgs.alpha)
    elif cfgs.optimizer == "Adam":
        G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Gen.parameters()),
                                       cfgs.g_lr, [cfgs.beta1, cfgs.beta2],
                                       eps=1e-6)
        D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Dis.parameters()),
                                       cfgs.d_lr, [cfgs.beta1, cfgs.beta2],
                                       eps=1e-6)
    else:
        raise NotImplementedError

    if cfgs.LARS_optimizer:
        G_optimizer = LARS(optimizer=G_optimizer, eps=1e-8, trust_coef=0.001)
        D_optimizer = LARS(optimizer=D_optimizer, eps=1e-8, trust_coef=0.001)

    ##### load checkpoints if needed #####
    if cfgs.checkpoint_folder is None:
        checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name)
    else:
        when = "current" if cfgs.load_current is True else "best"
        if not exists(abspath(cfgs.checkpoint_folder)):
            raise NotADirectoryError
        checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name)
        g_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=G-{when}-weights-step*.pth".format(when=when)))[0]
        d_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=D-{when}-weights-step*.pth".format(when=when)))[0]
        Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint(
            Gen, G_optimizer, g_checkpoint_dir)
        Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\
            load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True)
        if rank == 0: logger = make_logger(run_name, None)
        if cfgs.ema:
            g_ema_checkpoint_dir = glob.glob(
                join(checkpoint_dir,
                     "model=G_ema-{when}-weights-step*.pth".format(
                         when=when)))[0]
            Gen_copy = load_checkpoint(Gen_copy,
                                       None,
                                       g_ema_checkpoint_dir,
                                       ema=True)
            Gen_ema.source, Gen_ema.target = Gen, Gen_copy

        writer = SummaryWriter(
            log_dir=join('./logs', run_name)) if rank == 0 else None
        if cfgs.train_configs['train']:
            assert cfgs.seed == trained_seed, "Seed for sampling random numbers should be same!"

        if rank == 0:
            logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
        if rank == 0:
            logger.info(
                'Discriminator checkpoint is {}'.format(d_checkpoint_dir))
        if cfgs.freeze_layers > -1:
            prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None

    ##### wrap models with DP and convert BN to Sync BN #####
    if world_size > 1:
        if cfgs.distributed_data_parallel:
            if cfgs.synchronized_bn:
                process_group = torch.distributed.new_group(
                    [w for w in range(world_size)])
                Gen = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    Gen, process_group)
                Dis = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    Dis, process_group)
                if cfgs.ema:
                    Gen_copy = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        Gen_copy, process_group)

            Gen = DDP(Gen,
                      device_ids=[rank],
                      broadcast_buffers=False,
                      find_unused_parameters=True)
            Dis = DDP(Dis,
                      device_ids=[rank],
                      broadcast_buffers=False,
                      find_unused_parameters=True)
            if cfgs.ema:
                Gen_copy = DDP(Gen_copy,
                               device_ids=[rank],
                               broadcast_buffers=False,
                               find_unused_parameters=True)
        else:
            Gen = DataParallel(Gen, output_device=rank)
            Dis = DataParallel(Dis, output_device=rank)
            if cfgs.ema:
                Gen_copy = DataParallel(Gen_copy, output_device=rank)

            if cfgs.synchronized_bn:
                Gen = convert_model(Gen).to(rank)
                Dis = convert_model(Dis).to(rank)
                if cfgs.ema:
                    Gen_copy = convert_model(Gen_copy).to(rank)

    ##### load the inception network and prepare first/secend moments for calculating FID #####
    if cfgs.eval:
        inception_model = InceptionV3().to(rank)
        if world_size > 1 and cfgs.distributed_data_parallel:
            toggle_grad(inception_model, on=True)
            inception_model = DDP(inception_model,
                                  device_ids=[rank],
                                  broadcast_buffers=False,
                                  find_unused_parameters=True)
        elif world_size > 1 and cfgs.distributed_data_parallel is False:
            inception_model = DataParallel(inception_model, output_device=rank)
        else:
            pass

        mu, sigma = prepare_inception_moments(dataloader=eval_dataloader,
                                              generator=Gen,
                                              eval_mode=cfgs.eval_type,
                                              inception_model=inception_model,
                                              splits=1,
                                              run_name=run_name,
                                              logger=logger,
                                              device=rank)

    worker = make_worker(
        cfgs=cfgs,
        run_name=run_name,
        best_step=best_step,
        logger=logger,
        writer=writer,
        n_gpus=world_size,
        gen_model=Gen,
        dis_model=Dis,
        inception_model=inception_model,
        Gen_copy=Gen_copy,
        Gen_ema=Gen_ema,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        G_optimizer=G_optimizer,
        D_optimizer=D_optimizer,
        G_loss=G_loss[cfgs.adv_loss],
        D_loss=D_loss[cfgs.adv_loss],
        prev_ada_p=prev_ada_p,
        rank=rank,
        checkpoint_dir=checkpoint_dir,
        mu=mu,
        sigma=sigma,
        best_fid=best_fid,
        best_fid_checkpoint_path=best_fid_checkpoint_path,
    )

    if cfgs.train_configs['train']:
        step = worker.train(current_step=step, total_step=cfgs.total_step)

    if cfgs.eval:
        is_save = worker.evaluation(
            step=step,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)

    if cfgs.save_images:
        worker.save_images(is_generate=True,
                           png=True,
                           npz=True,
                           standing_statistics=cfgs.standing_statistics,
                           standing_step=cfgs.standing_step)

    if cfgs.image_visualization:
        worker.run_image_visualization(
            nrow=cfgs.nrow,
            ncol=cfgs.ncol,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)

    if cfgs.k_nearest_neighbor:
        worker.run_nearest_neighbor(
            nrow=cfgs.nrow,
            ncol=cfgs.ncol,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)

    if cfgs.interpolation:
        assert cfgs.architecture in [
            "big_resnet", "biggan_deep"
        ], "StudioGAN does not support interpolation analysis except for biggan and biggan_deep."
        worker.run_linear_interpolation(
            nrow=cfgs.nrow,
            ncol=cfgs.ncol,
            fix_z=True,
            fix_y=False,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)
        worker.run_linear_interpolation(
            nrow=cfgs.nrow,
            ncol=cfgs.ncol,
            fix_z=False,
            fix_y=True,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)

    if cfgs.frequency_analysis:
        worker.run_frequency_analysis(
            num_images=len(train_dataset) // cfgs.num_classes,
            standing_statistics=cfgs.standing_statistics,
            standing_step=cfgs.standing_step)

    if cfgs.tsne_analysis:
        worker.run_tsne(dataloader=eval_dataloader,
                        standing_statistics=cfgs.standing_statistics,
                        standing_step=cfgs.standing_step)
コード例 #3
0
def load_frameowrk(
        seed, disable_debugging_API, num_workers, config_path,
        checkpoint_folder, reduce_train_dataset, standing_statistics,
        standing_step, freeze_layers, load_current, eval_type, dataset_name,
        num_classes, img_size, data_path, architecture, conditional_strategy,
        hypersphere_dim, nonlinear_embed, normalize_embed, g_spectral_norm,
        d_spectral_norm, activation_fn, attention,
        attention_after_nth_gen_block, attention_after_nth_dis_block, z_dim,
        shared_dim, g_conv_dim, d_conv_dim, G_depth, D_depth, optimizer,
        batch_size, d_lr, g_lr, momentum, nesterov, alpha, beta1, beta2,
        total_step, adv_loss, cr, g_init, d_init, random_flip_preprocessing,
        prior, truncated_factor, ema, ema_decay, ema_start, synchronized_bn,
        mixed_precision, hdf5_path_train, train_config, model_config, **_):
    if seed == 0:
        cudnn.benchmark = True
        cudnn.deterministic = False
    else:
        fix_all_seed(seed)
        cudnn.benchmark = False
        cudnn.deterministic = True

    if disable_debugging_API:
        torch.autograd.set_detect_anomaly(False)

    n_gpus = torch.cuda.device_count()
    default_device = torch.cuda.current_device()

    check_flag_0(batch_size, n_gpus, standing_statistics, ema, freeze_layers,
                 checkpoint_folder)
    assert batch_size % n_gpus == 0, "batch_size should be divided by the number of gpus "

    if n_gpus == 1:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None
    standing_step = standing_step if standing_statistics is True else batch_size

    run_name = make_run_name(RUN_NAME_FORMAT,
                             framework=config_path.split('/')[-1][:-5],
                             phase='train')

    logger = make_logger(run_name, None)
    writer = SummaryWriter(log_dir=join('./logs', run_name))
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(train_config)
    logger.info(model_config)

    logger.info('Loading train datasets...')
    train_dataset = LoadDataset(dataset_name,
                                data_path,
                                train=True,
                                download=True,
                                resize_size=img_size,
                                hdf5_path=hdf5_path_train,
                                random_flip=random_flip_preprocessing)
    if reduce_train_dataset < 1.0:
        num_train = int(reduce_train_dataset * len(train_dataset))
        train_dataset, _ = torch.utils.data.random_split(
            train_dataset,
            [num_train, len(train_dataset) - num_train])
    logger.info('Train dataset size : {dataset_size}'.format(
        dataset_size=len(train_dataset)))

    logger.info('Loading {mode} datasets...'.format(mode=eval_type))
    eval_mode = True if eval_type == 'train' else False
    eval_dataset = LoadDataset(dataset_name,
                               data_path,
                               train=eval_mode,
                               download=True,
                               resize_size=img_size,
                               hdf5_path=None,
                               random_flip=False)
    logger.info('Eval dataset size : {dataset_size}'.format(
        dataset_size=len(eval_dataset)))

    logger.info('Building model...')
    if architecture == "dcgan":
        assert img_size == 32, "Sry, StudioGAN does not support dcgan models for generation of images larger than 32 resolution."
    module = __import__(
        'models.{architecture}'.format(architecture=architecture),
        fromlist=['something'])
    logger.info('Modules are located on models.{architecture}'.format(
        architecture=architecture))
    Gen = module.Generator(z_dim, shared_dim, img_size, g_conv_dim,
                           g_spectral_norm, attention,
                           attention_after_nth_gen_block, activation_fn,
                           conditional_strategy, num_classes, g_init, G_depth,
                           mixed_precision).to(default_device)

    Dis = module.Discriminator(img_size, d_conv_dim, d_spectral_norm,
                               attention, attention_after_nth_dis_block,
                               activation_fn, conditional_strategy,
                               hypersphere_dim, num_classes, nonlinear_embed,
                               normalize_embed, d_init, D_depth,
                               mixed_precision).to(default_device)

    if ema:
        print('Preparing EMA for G with decay of {}'.format(ema_decay))
        Gen_copy = module.Generator(
            z_dim,
            shared_dim,
            img_size,
            g_conv_dim,
            g_spectral_norm,
            attention,
            attention_after_nth_gen_block,
            activation_fn,
            conditional_strategy,
            num_classes,
            initialize=False,
            G_depth=G_depth,
            mixed_precision=mixed_precision).to(default_device)
        Gen_ema = ema_(Gen, Gen_copy, ema_decay, ema_start)
    else:
        Gen_copy, Gen_ema = None, None

    logger.info(count_parameters(Gen))
    logger.info(Gen)

    logger.info(count_parameters(Dis))
    logger.info(Dis)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=num_workers,
                                  drop_last=True)
    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 drop_last=False)

    G_loss = {
        'vanilla': loss_dcgan_gen,
        'least_square': loss_lsgan_gen,
        'hinge': loss_hinge_gen,
        'wasserstein': loss_wgan_gen
    }
    D_loss = {
        'vanilla': loss_dcgan_dis,
        'least_square': loss_lsgan_dis,
        'hinge': loss_hinge_dis,
        'wasserstein': loss_wgan_dis
    }

    if optimizer == "SGD":
        G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Gen.parameters()),
                                      g_lr,
                                      momentum=momentum,
                                      nesterov=nesterov)
        D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Dis.parameters()),
                                      d_lr,
                                      momentum=momentum,
                                      nesterov=nesterov)
    elif optimizer == "RMSprop":
        G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Gen.parameters()),
                                          g_lr,
                                          momentum=momentum,
                                          alpha=alpha)
        D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Dis.parameters()),
                                          d_lr,
                                          momentum=momentum,
                                          alpha=alpha)
    elif optimizer == "Adam":
        G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Gen.parameters()),
                                       g_lr, [beta1, beta2],
                                       eps=1e-6)
        D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Dis.parameters()),
                                       d_lr, [beta1, beta2],
                                       eps=1e-6)
    elif optimizer == "AdaBelief":
        G_optimizer = AdaBelief(filter(lambda p: p.requires_grad,
                                       Gen.parameters()),
                                g_lr, [beta1, beta2],
                                eps=1e-12,
                                rectify=False)
        D_optimizer = AdaBelief(filter(lambda p: p.requires_grad,
                                       Dis.parameters()),
                                d_lr, [beta1, beta2],
                                eps=1e-12,
                                rectify=False)
    else:
        raise NotImplementedError

    if checkpoint_folder is not None:
        when = "current" if load_current is True else "best"
        if not exists(abspath(checkpoint_folder)):
            raise NotADirectoryError
        checkpoint_dir = make_checkpoint_dir(checkpoint_folder, run_name)
        g_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=G-{when}-weights-step*.pth".format(when=when)))[0]
        d_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=D-{when}-weights-step*.pth".format(when=when)))[0]
        Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint(
            Gen, G_optimizer, g_checkpoint_dir)
        Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\
            load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True)
        logger = make_logger(run_name, None)
        if ema:
            g_ema_checkpoint_dir = glob.glob(
                join(checkpoint_dir,
                     "model=G_ema-{when}-weights-step*.pth".format(
                         when=when)))[0]
            Gen_copy = load_checkpoint(Gen_copy,
                                       None,
                                       g_ema_checkpoint_dir,
                                       ema=True)
            Gen_ema.source, Gen_ema.target = Gen, Gen_copy

        writer = SummaryWriter(log_dir=join('./logs', run_name))
        if train_config['train']:
            assert seed == trained_seed, "seed for sampling random numbers should be same!"
        logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
        logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir))
        if freeze_layers > -1:
            prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None
    else:
        checkpoint_dir = make_checkpoint_dir(checkpoint_folder, run_name)

    if n_gpus > 1:
        Gen = DataParallel(Gen, output_device=default_device)
        Dis = DataParallel(Dis, output_device=default_device)
        if ema:
            Gen_copy = DataParallel(Gen_copy, output_device=default_device)

        if synchronized_bn:
            Gen = convert_model(Gen).to(default_device)
            Dis = convert_model(Dis).to(default_device)
            if ema:
                Gen_copy = convert_model(Gen_copy).to(default_device)

    if train_config['eval']:
        inception_model = InceptionV3().to(default_device)
        if n_gpus > 1:
            inception_model = DataParallel(inception_model,
                                           output_device=default_device)
        mu, sigma = prepare_inception_moments(dataloader=eval_dataloader,
                                              generator=Gen,
                                              eval_mode=eval_type,
                                              inception_model=inception_model,
                                              splits=1,
                                              run_name=run_name,
                                              logger=logger,
                                              device=default_device)
    else:
        mu, sigma, inception_model = None, None, None

    train_eval = Train_Eval(
        run_name=run_name,
        best_step=best_step,
        dataset_name=dataset_name,
        eval_type=eval_type,
        logger=logger,
        writer=writer,
        n_gpus=n_gpus,
        gen_model=Gen,
        dis_model=Dis,
        inception_model=inception_model,
        Gen_copy=Gen_copy,
        Gen_ema=Gen_ema,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        freeze_layers=freeze_layers,
        conditional_strategy=conditional_strategy,
        pos_collected_numerator=model_config['model']
        ['pos_collected_numerator'],
        z_dim=z_dim,
        num_classes=num_classes,
        hypersphere_dim=hypersphere_dim,
        d_spectral_norm=d_spectral_norm,
        g_spectral_norm=g_spectral_norm,
        G_optimizer=G_optimizer,
        D_optimizer=D_optimizer,
        batch_size=batch_size,
        g_steps_per_iter=model_config['optimization']['g_steps_per_iter'],
        d_steps_per_iter=model_config['optimization']['d_steps_per_iter'],
        accumulation_steps=model_config['optimization']['accumulation_steps'],
        total_step=total_step,
        G_loss=G_loss[adv_loss],
        D_loss=D_loss[adv_loss],
        contrastive_lambda=model_config['loss_function']['contrastive_lambda'],
        margin=model_config['loss_function']['margin'],
        tempering_type=model_config['loss_function']['tempering_type'],
        tempering_step=model_config['loss_function']['tempering_step'],
        start_temperature=model_config['loss_function']['start_temperature'],
        end_temperature=model_config['loss_function']['end_temperature'],
        weight_clipping_for_dis=model_config['loss_function']
        ['weight_clipping_for_dis'],
        weight_clipping_bound=model_config['loss_function']
        ['weight_clipping_bound'],
        gradient_penalty_for_dis=model_config['loss_function']
        ['gradient_penalty_for_dis'],
        gradient_penalty_lambda=model_config['loss_function']
        ['gradient_penalty_lambda'],
        deep_regret_analysis_for_dis=model_config['loss_function']
        ['deep_regret_analysis_for_dis'],
        regret_penalty_lambda=model_config['loss_function']
        ['regret_penalty_lambda'],
        cr=cr,
        cr_lambda=model_config['loss_function']['cr_lambda'],
        bcr=model_config['loss_function']['bcr'],
        real_lambda=model_config['loss_function']['real_lambda'],
        fake_lambda=model_config['loss_function']['fake_lambda'],
        zcr=model_config['loss_function']['zcr'],
        gen_lambda=model_config['loss_function']['gen_lambda'],
        dis_lambda=model_config['loss_function']['dis_lambda'],
        sigma_noise=model_config['loss_function']['sigma_noise'],
        diff_aug=model_config['training_and_sampling_setting']['diff_aug'],
        ada=model_config['training_and_sampling_setting']['ada'],
        prev_ada_p=prev_ada_p,
        ada_target=model_config['training_and_sampling_setting']['ada_target'],
        ada_length=model_config['training_and_sampling_setting']['ada_length'],
        prior=prior,
        truncated_factor=truncated_factor,
        ema=ema,
        latent_op=model_config['training_and_sampling_setting']['latent_op'],
        latent_op_rate=model_config['training_and_sampling_setting']
        ['latent_op_rate'],
        latent_op_step=model_config['training_and_sampling_setting']
        ['latent_op_step'],
        latent_op_step4eval=model_config['training_and_sampling_setting']
        ['latent_op_step4eval'],
        latent_op_alpha=model_config['training_and_sampling_setting']
        ['latent_op_alpha'],
        latent_op_beta=model_config['training_and_sampling_setting']
        ['latent_op_beta'],
        latent_norm_reg_weight=model_config['training_and_sampling_setting']
        ['latent_norm_reg_weight'],
        default_device=default_device,
        print_every=train_config['print_every'],
        save_every=train_config['save_every'],
        checkpoint_dir=checkpoint_dir,
        evaluate=train_config['eval'],
        mu=mu,
        sigma=sigma,
        best_fid=best_fid,
        best_fid_checkpoint_path=best_fid_checkpoint_path,
        mixed_precision=mixed_precision,
        train_config=train_config,
        model_config=model_config,
    )

    if train_config['train']:
        step = train_eval.train(current_step=step, total_step=total_step)

    if train_config['eval']:
        is_save = train_eval.evaluation(
            step=step,
            standing_statistics=standing_statistics,
            standing_step=standing_step)

    if train_config['save_images']:
        train_eval.save_images(is_generate=True,
                               png=True,
                               npz=True,
                               standing_statistics=standing_statistics,
                               standing_step=standing_step)

    if train_config['image_visualization']:
        train_eval.run_image_visualization(
            nrow=train_config['nrow'],
            ncol=train_config['ncol'],
            standing_statistics=standing_statistics,
            standing_step=standing_step)

    if train_config['k_nearest_neighbor']:
        train_eval.run_nearest_neighbor(
            nrow=train_config['nrow'],
            ncol=train_config['ncol'],
            standing_statistics=standing_statistics,
            standing_step=standing_step)

    if train_config['interpolation']:
        assert architecture in [
            "big_resnet", "biggan_deep"
        ], "Not supported except for biggan and biggan_deep."
        train_eval.run_linear_interpolation(
            nrow=train_config['nrow'],
            ncol=train_config['ncol'],
            fix_z=True,
            fix_y=False,
            standing_statistics=standing_statistics,
            standing_step=standing_step)
        train_eval.run_linear_interpolation(
            nrow=train_config['nrow'],
            ncol=train_config['ncol'],
            fix_z=False,
            fix_y=True,
            standing_statistics=standing_statistics,
            standing_step=standing_step)

    if train_config['frequency_analysis']:
        train_eval.run_frequency_analysis(
            num_images=len(train_dataset) // num_classes,
            standing_statistics=standing_statistics,
            standing_step=standing_step)
コード例 #4
0
def prepare_train_eval(cfgs, hdf5_path_train, **_):
    if cfgs.seed == -1:
        cudnn.benchmark, cudnn.deterministic = True, False
    else:
        fix_all_seed(cfgs.seed)
        cudnn.benchmark, cudnn.deterministic = False, True

    n_gpus, default_device = torch.cuda.device_count(), torch.cuda.current_device()
    if n_gpus ==1: warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.')

    if cfgs.disable_debugging_API: torch.autograd.set_detect_anomaly(False)
    check_flag_0(cfgs.batch_size, n_gpus, cfgs.freeze_layers, cfgs.checkpoint_folder, cfgs.architecture, cfgs.img_size)
    run_name = make_run_name(RUN_NAME_FORMAT, framework=cfgs.config_path.split('/')[3][:-5], phase='train')
    prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path, mu, sigma, inception_model = None, 0, 0, None, None, None, None, None

    logger = make_logger(run_name, None)
    writer = SummaryWriter(log_dir=join('./logs', run_name))
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(cfgs.train_configs)
    logger.info(cfgs.model_configs)


    ##### load dataset #####
    logger.info('Loading train datasets...')
    train_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=True, download=True, resize_size=cfgs.img_size,
                                hdf5_path=hdf5_path_train, random_flip=cfgs.random_flip_preprocessing)
    if cfgs.reduce_train_dataset < 1.0:
        num_train = int(cfgs.reduce_train_dataset*len(train_dataset))
        train_dataset, _ = torch.utils.data.random_split(train_dataset, [num_train, len(train_dataset) - num_train])
    logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset)))

    logger.info('Loading {mode} datasets...'.format(mode=cfgs.eval_type))
    eval_mode = True if cfgs.eval_type == 'train' else False
    eval_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=eval_mode, download=True, resize_size=cfgs.img_size,
                               hdf5_path=None, random_flip=False)
    logger.info('Eval dataset size : {dataset_size}'.format(dataset_size=len(eval_dataset)))

    train_dataloader = DataLoader(train_dataset, batch_size=cfgs.batch_size, shuffle=True, pin_memory=True, num_workers=cfgs.num_workers, drop_last=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=cfgs.batch_size, shuffle=True, pin_memory=True, num_workers=cfgs.num_workers, drop_last=False)


    ##### build model #####
    logger.info('Building model...')
    module = __import__('models.{architecture}'.format(architecture=cfgs.architecture), fromlist=['something'])
    logger.info('Modules are located on models.{architecture}'.format(architecture=cfgs.architecture))
    Gen = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention,
                           cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes,
                           cfgs.g_init, cfgs.G_depth, cfgs.mixed_precision).to(default_device)

    Dis = module.Discriminator(cfgs.img_size, cfgs.d_conv_dim, cfgs.d_spectral_norm, cfgs.attention, cfgs.attention_after_nth_dis_block,
                               cfgs.activation_fn, cfgs.conditional_strategy, cfgs.hypersphere_dim, cfgs.num_classes, cfgs.nonlinear_embed,
                               cfgs.normalize_embed, cfgs.d_init, cfgs.D_depth, cfgs.mixed_precision).to(default_device)

    if cfgs.ema:
        print('Preparing EMA for G with decay of {}'.format(cfgs.ema_decay))
        Gen_copy = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention,
                                    cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes,
                                    initialize=False, G_depth=cfgs.G_depth, mixed_precision=cfgs.mixed_precision).to(default_device)
        Gen_ema = ema(Gen, Gen_copy, cfgs.ema_decay, cfgs.ema_start)
    else:
        Gen_copy, Gen_ema = None, None

    logger.info(count_parameters(Gen))
    logger.info(Gen)

    logger.info(count_parameters(Dis))
    logger.info(Dis)


    ### define loss functions and optimizers
    G_loss = {'vanilla': loss_dcgan_gen, 'least_square': loss_lsgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen}
    D_loss = {'vanilla': loss_dcgan_dis, 'least_square': loss_lsgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis}

    if cfgs.optimizer == "SGD":
        G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov)
        D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov)
    elif cfgs.optimizer == "RMSprop":
        G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, alpha=cfgs.alpha)
        D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, alpha=cfgs.alpha)
    elif cfgs.optimizer == "Adam":
        G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6)
        D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6)
    else:
        raise NotImplementedError


    ##### load checkpoints if needed #####
    if cfgs.checkpoint_folder is None:
        checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name)
    else:
        when = "current" if cfgs.load_current is True else "best"
        if not exists(abspath(cfgs.checkpoint_folder)):
            raise NotADirectoryError
        checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name)
        g_checkpoint_dir = glob.glob(join(checkpoint_dir,"model=G-{when}-weights-step*.pth".format(when=when)))[0]
        d_checkpoint_dir = glob.glob(join(checkpoint_dir,"model=D-{when}-weights-step*.pth".format(when=when)))[0]
        Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint(Gen, G_optimizer, g_checkpoint_dir)
        Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\
            load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True)
        logger = make_logger(run_name, None)
        if cfgs.ema:
            g_ema_checkpoint_dir = glob.glob(join(checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format(when=when)))[0]
            Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=True)
            Gen_ema.source, Gen_ema.target = Gen, Gen_copy

        writer = SummaryWriter(log_dir=join('./logs', run_name))
        if cfgs.train_configs['train']:
            assert cfgs.seed == trained_seed, "seed for sampling random numbers should be same!"
        logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
        logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir))
        if cfgs.freeze_layers > -1 :
            prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None


    ##### wrap models with DP and convert BN to Sync BN #####
    if n_gpus > 1:
        Gen = DataParallel(Gen, output_device=default_device)
        Dis = DataParallel(Dis, output_device=default_device)
        if cfgs.ema:
            Gen_copy = DataParallel(Gen_copy, output_device=default_device)

        if cfgs.synchronized_bn:
            Gen = convert_model(Gen).to(default_device)
            Dis = convert_model(Dis).to(default_device)
            if cfgs.ema:
                Gen_copy = convert_model(Gen_copy).to(default_device)


    ##### load the inception network and prepare first/secend moments for calculating FID #####
    if cfgs.eval:
        inception_model = InceptionV3().to(default_device)
        if n_gpus > 1: inception_model = DataParallel(inception_model, output_device=default_device)

        mu, sigma = prepare_inception_moments(dataloader=eval_dataloader,
                                              generator=Gen,
                                              eval_mode=cfgs.eval_type,
                                              inception_model=inception_model,
                                              splits=1,
                                              run_name=run_name,
                                              logger=logger,
                                              device=default_device)


    worker = make_worker(
        cfgs=cfgs,
        run_name=run_name,
        best_step=best_step,
        logger=logger,
        writer=writer,
        n_gpus=n_gpus,
        gen_model=Gen,
        dis_model=Dis,
        inception_model=inception_model,
        Gen_copy=Gen_copy,
        Gen_ema=Gen_ema,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        G_optimizer=G_optimizer,
        D_optimizer=D_optimizer,
        G_loss=G_loss[cfgs.adv_loss],
        D_loss=D_loss[cfgs.adv_loss],
        prev_ada_p=prev_ada_p,
        default_device=default_device,
        checkpoint_dir=checkpoint_dir,
        mu=mu,
        sigma=sigma,
        best_fid=best_fid,
        best_fid_checkpoint_path=best_fid_checkpoint_path,
    )

    if cfgs.train_configs['train']:
        step = worker.train(current_step=step, total_step=cfgs.total_step)

    if cfgs.eval:
        is_save = worker.evaluation(step=step, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)

    if cfgs.save_images:
        worker.save_images(is_generate=True, png=True, npz=True, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)

    if cfgs.image_visualization:
        worker.run_image_visualization(nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)

    if cfgs.k_nearest_neighbor:
        worker.run_nearest_neighbor(nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)

    if cfgs.interpolation:
        assert cfgs.architecture in ["big_resnet", "biggan_deep"], "Not supported except for biggan and biggan_deep."
        worker.run_linear_interpolation(nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=True, fix_y=False,
                                            standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)
        worker.run_linear_interpolation(nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=False, fix_y=True,
                                            standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)

    if cfgs.frequency_analysis:
        worker.run_frequency_analysis(num_images=len(train_dataset)//cfgs.num_classes,
                                          standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)
コード例 #5
0
ファイル: train.py プロジェクト: MLDL/PyTorch-StudioGAN
def train_framework(
        seed, num_workers, config_path, reduce_train_dataset, load_current,
        type4eval_dataset, dataset_name, num_classes, img_size, data_path,
        architecture, conditional_strategy, hypersphere_dim, nonlinear_embed,
        normalize_embed, g_spectral_norm, d_spectral_norm, activation_fn,
        attention, attention_after_nth_gen_block,
        attention_after_nth_dis_block, z_dim, shared_dim, g_conv_dim,
        d_conv_dim, G_depth, D_depth, optimizer, batch_size, d_lr, g_lr,
        momentum, nesterov, alpha, beta1, beta2, total_step, adv_loss,
        consistency_reg, g_init, d_init, random_flip_preprocessing, prior,
        truncated_factor, latent_op, ema, ema_decay, ema_start,
        synchronized_bn, hdf5_path_train, train_config, model_config, **_):
    fix_all_seed(seed)
    cudnn.benchmark = False  # Not good Generator for undetermined input size
    cudnn.deterministic = True
    n_gpus = torch.cuda.device_count()
    default_device = torch.cuda.current_device()
    second_device = default_device if n_gpus == 1 else default_device + 1
    assert batch_size % n_gpus == 0, "batch_size should be divided by the number of gpus "

    if n_gpus == 1:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    start_step, best_step, best_fid, best_fid_checkpoint_path = 0, 0, None, None
    run_name = make_run_name(RUN_NAME_FORMAT,
                             framework=config_path.split('/')[3][:-5],
                             phase='train')

    logger = make_logger(run_name, None)
    writer = SummaryWriter(log_dir=join('./logs', run_name))
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(train_config)
    logger.info(model_config)

    logger.info('Loading train datasets...')
    train_dataset = LoadDataset(dataset_name,
                                data_path,
                                train=True,
                                download=True,
                                resize_size=img_size,
                                hdf5_path=hdf5_path_train,
                                consistency_reg=consistency_reg,
                                random_flip=random_flip_preprocessing)
    if reduce_train_dataset < 1.0:
        num_train = int(reduce_train_dataset * len(train_dataset))
        train_dataset, _ = torch.utils.data.random_split(
            train_dataset,
            [num_train, len(train_dataset) - num_train])
    logger.info('Train dataset size : {dataset_size}'.format(
        dataset_size=len(train_dataset)))

    logger.info('Loading {mode} datasets...'.format(mode=type4eval_dataset))
    eval_mode = True if type4eval_dataset == 'train' else False
    eval_dataset = LoadDataset(dataset_name,
                               data_path,
                               train=eval_mode,
                               download=True,
                               resize_size=img_size,
                               hdf5_path=None,
                               random_flip=False)
    logger.info('Eval dataset size : {dataset_size}'.format(
        dataset_size=len(eval_dataset)))

    logger.info('Building model...')
    if architecture == "dcgan":
        assert img_size == 32, "Sry, StudioGAN does not support dcgan models for generation of images larger than 32 resolution."
    module = __import__(
        'models.{architecture}'.format(architecture=architecture),
        fromlist=['something'])
    logger.info('Modules are located on models.{architecture}'.format(
        architecture=architecture))
    Gen = module.Generator(z_dim, shared_dim, img_size, g_conv_dim,
                           g_spectral_norm, attention,
                           attention_after_nth_gen_block, activation_fn,
                           conditional_strategy, num_classes, synchronized_bn,
                           g_init, G_depth).to(default_device)

    Dis = module.Discriminator(img_size, d_conv_dim, d_spectral_norm,
                               attention, attention_after_nth_dis_block,
                               activation_fn, conditional_strategy,
                               hypersphere_dim, num_classes, nonlinear_embed,
                               normalize_embed, synchronized_bn, d_init,
                               D_depth).to(default_device)

    if ema:
        print('Preparing EMA for G with decay of {}'.format(ema_decay))
        Gen_copy = module.Generator(z_dim,
                                    shared_dim,
                                    img_size,
                                    g_conv_dim,
                                    g_spectral_norm,
                                    attention,
                                    attention_after_nth_gen_block,
                                    activation_fn,
                                    conditional_strategy,
                                    num_classes,
                                    synchronized_bn=False,
                                    initialize=False,
                                    G_depth=G_depth).to(default_device)
        Gen_ema = ema_(Gen, Gen_copy, ema_decay, ema_start)
    else:
        Gen_copy, Gen_ema = None, None

    if n_gpus > 1:
        Gen = DataParallel(Gen, output_device=second_device)
        Dis = DataParallel(Dis, output_device=second_device)
        if ema:
            Gen_copy = DataParallel(Gen_copy, output_device=second_device)
        if synchronized_bn:
            patch_replication_callback(Gen)
            patch_replication_callback(Dis)

    logger.info(count_parameters(Gen))
    logger.info(Gen)

    logger.info(count_parameters(Dis))
    logger.info(Dis)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=num_workers,
                                  drop_last=True)
    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 drop_last=False)

    G_loss = {
        'vanilla': loss_dcgan_gen,
        'hinge': loss_hinge_gen,
        'wasserstein': loss_wgan_gen
    }
    D_loss = {
        'vanilla': loss_dcgan_dis,
        'hinge': loss_hinge_dis,
        'wasserstein': loss_wgan_dis
    }

    if optimizer == "SGD":
        G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Gen.parameters()),
                                      g_lr,
                                      momentum=momentum,
                                      nesterov=nesterov)
        D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                             Dis.parameters()),
                                      d_lr,
                                      momentum=momentum,
                                      nesterov=nesterov)
    elif optimizer == "RMSprop":
        G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Gen.parameters()),
                                          g_lr,
                                          momentum=momentum,
                                          alpha=alpha)
        D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                 Dis.parameters()),
                                          d_lr,
                                          momentum=momentum,
                                          alpha=alpha)
    elif optimizer == "Adam":
        G_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, Gen.parameters()), g_lr,
            [beta1, beta2])
        D_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, Dis.parameters()), d_lr,
            [beta1, beta2])
    elif optimizer == "AdamP":
        G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Gen.parameters()),
                                       g_lr,
                                       betas=(beta1, beta2))
        D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              Dis.parameters()),
                                       d_lr,
                                       betas=(beta1, beta2))
    else:
        raise NotImplementedError

    checkpoint_dir = make_checkpoint_dir(train_config['checkpoint_folder'],
                                         run_name)

    if train_config['checkpoint_folder'] is not None:
        when = "current" if load_current is True else "best"
        g_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=G-{when}-weights-step*.pth".format(when=when)))[0]
        d_checkpoint_dir = glob.glob(
            join(checkpoint_dir,
                 "model=D-{when}-weights-step*.pth".format(when=when)))[0]
        Gen, G_optimizer, trained_seed, run_name, start_step, best_step = load_checkpoint(
            Gen, G_optimizer, g_checkpoint_dir)
        Dis, D_optimizer, trained_seed, run_name, start_step, best_step, best_fid, best_fid_checkpoint_path = load_checkpoint(
            Dis, D_optimizer, d_checkpoint_dir, metric=True)
        logger = make_logger(run_name, None)
        if ema:
            g_ema_checkpoint_dir = glob.glob(
                join(checkpoint_dir,
                     "model=G_ema-{when}-weights-step*.pth".format(
                         when=when)))[0]
            Gen_copy = load_checkpoint(Gen_copy,
                                       None,
                                       g_ema_checkpoint_dir,
                                       ema=True)
            Gen_ema.source, Gen_ema.target = Gen, Gen_copy

        writer = SummaryWriter(log_dir=join('./logs', run_name))
        assert seed == trained_seed, "seed for sampling random numbers should be same!"
        logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
        logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir))

    if train_config['eval']:
        inception_model = InceptionV3().to(default_device)
        inception_model = DataParallel(inception_model,
                                       output_device=second_device)
        mu, sigma, is_score, is_std = prepare_inception_moments_eval_dataset(
            dataloader=eval_dataloader,
            generator=Gen,
            eval_mode=type4eval_dataset,
            inception_model=inception_model,
            splits=10,
            run_name=run_name,
            logger=logger,
            device=second_device)
    else:
        mu, sigma, inception_model = None, None, None

    logger.info('Start training...')
    trainer = Trainer(
        run_name=run_name,
        best_step=best_step,
        dataset_name=dataset_name,
        type4eval_dataset=type4eval_dataset,
        logger=logger,
        writer=writer,
        n_gpus=n_gpus,
        gen_model=Gen,
        dis_model=Dis,
        inception_model=inception_model,
        Gen_copy=Gen_copy,
        Gen_ema=Gen_ema,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        conditional_strategy=conditional_strategy,
        z_dim=z_dim,
        num_classes=num_classes,
        hypersphere_dim=hypersphere_dim,
        d_spectral_norm=d_spectral_norm,
        g_spectral_norm=g_spectral_norm,
        G_optimizer=G_optimizer,
        D_optimizer=D_optimizer,
        batch_size=batch_size,
        g_steps_per_iter=model_config['optimization']['g_steps_per_iter'],
        d_steps_per_iter=model_config['optimization']['d_steps_per_iter'],
        accumulation_steps=model_config['optimization']['accumulation_steps'],
        total_step=total_step,
        G_loss=G_loss[adv_loss],
        D_loss=D_loss[adv_loss],
        contrastive_lambda=model_config['loss_function']['contrastive_lambda'],
        tempering_type=model_config['loss_function']['tempering_type'],
        tempering_step=model_config['loss_function']['tempering_step'],
        start_temperature=model_config['loss_function']['start_temperature'],
        end_temperature=model_config['loss_function']['end_temperature'],
        gradient_penalty_for_dis=model_config['loss_function']
        ['gradient_penalty_for_dis'],
        gradient_penelty_lambda=model_config['loss_function']
        ['gradient_penelty_lambda'],
        weight_clipping_for_dis=model_config['loss_function']
        ['weight_clipping_for_dis'],
        weight_clipping_bound=model_config['loss_function']
        ['weight_clipping_bound'],
        consistency_reg=consistency_reg,
        consistency_lambda=model_config['loss_function']['consistency_lambda'],
        diff_aug=model_config['training_and_sampling_setting']['diff_aug'],
        prior=prior,
        truncated_factor=truncated_factor,
        ema=ema,
        latent_op=latent_op,
        latent_op_rate=model_config['training_and_sampling_setting']
        ['latent_op_rate'],
        latent_op_step=model_config['training_and_sampling_setting']
        ['latent_op_step'],
        latent_op_step4eval=model_config['training_and_sampling_setting']
        ['latent_op_step4eval'],
        latent_op_alpha=model_config['training_and_sampling_setting']
        ['latent_op_alpha'],
        latent_op_beta=model_config['training_and_sampling_setting']
        ['latent_op_beta'],
        latent_norm_reg_weight=model_config['training_and_sampling_setting']
        ['latent_norm_reg_weight'],
        default_device=default_device,
        second_device=second_device,
        print_every=train_config['print_every'],
        save_every=train_config['save_every'],
        checkpoint_dir=checkpoint_dir,
        evaluate=train_config['eval'],
        mu=mu,
        sigma=sigma,
        best_fid=best_fid,
        best_fid_checkpoint_path=best_fid_checkpoint_path,
        train_config=train_config,
        model_config=model_config,
    )

    if conditional_strategy == 'ContraGAN' and train_config['train']:
        trainer.run_ours(current_step=start_step, total_step=total_step)
    elif train_config['train']:
        trainer.run(current_step=start_step, total_step=total_step)
    elif train_config['eval']:
        is_save = trainer.evaluation(step=start_step)

    if train_config['k_nearest_neighbor'] > 0:
        trainer.K_Nearest_Neighbor(
            train_config['criterion_4_k_nearest_neighbor'],
            train_config['number_of_nearest_samples'],
            random.randrange(num_classes))
コード例 #6
0
ファイル: train.py プロジェクト: rikudoayush/BERT-in-pytorch-
def pretrain(data_dir, train_path, val_path, dictionary_path,
             dataset_limit, vocabulary_size, batch_size, max_len, epochs, clip_grads, device,
             layers_count, hidden_size, heads_count, d_ff, dropout_prob,
             log_output, checkpoint_dir, print_every, save_every, config, run_name=None, **_):

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    train_path = train_path if data_dir is None else join(data_dir, train_path)
    val_path = val_path if data_dir is None else join(data_dir, val_path)
    dictionary_path = dictionary_path if data_dir is None else join(data_dir, dictionary_path)

    run_name = run_name if run_name is not None else make_run_name(RUN_NAME_FORMAT, phase='pretrain', config=config)
    logger = make_logger(run_name, log_output)
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(config)

    logger.info('Constructing dictionaries...')
    dictionary = IndexDictionary.load(dictionary_path=dictionary_path,
                                      vocabulary_size=vocabulary_size)
    vocabulary_size = len(dictionary)
    logger.info(f'dictionary vocabulary : {vocabulary_size} tokens')

    logger.info('Loading datasets...')
    train_dataset = PairedDataset(data_path=train_path, dictionary=dictionary, dataset_limit=dataset_limit)
    val_dataset = PairedDataset(data_path=val_path, dictionary=dictionary, dataset_limit=dataset_limit)
    logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset)))

    logger.info('Building model...')
    model = build_model(layers_count, hidden_size, heads_count, d_ff, dropout_prob, max_len, vocabulary_size)

    logger.info(model)
    logger.info('{parameters_count} parameters'.format(
        parameters_count=sum([p.nelement() for p in model.parameters()])))

    loss_model = MLMNSPLossModel(model)
    if torch.cuda.device_count() > 1:
        loss_model = DataParallel(loss_model, output_device=1)

    metric_functions = [mlm_accuracy, nsp_accuracy]

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        collate_fn=pretraining_collate_function)

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        collate_fn=pretraining_collate_function)

    optimizer = NoamOptimizer(model.parameters(),
                              d_model=hidden_size, factor=2, warmup_steps=10000, betas=(0.9, 0.999), weight_decay=0.01)

    checkpoint_dir = make_checkpoint_dir(checkpoint_dir, run_name, config)

    logger.info('Start training...')
    trainer = Trainer(
        loss_model=loss_model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        metric_functions=metric_functions,
        optimizer=optimizer,
        clip_grads=clip_grads,
        logger=logger,
        checkpoint_dir=checkpoint_dir,
        print_every=print_every,
        save_every=save_every,
        device=device
    )

    trainer.run(epochs=epochs)
    return trainer
コード例 #7
0
ファイル: train.py プロジェクト: rikudoayush/BERT-in-pytorch-
def finetune(pretrained_checkpoint,
             data_dir, train_path, val_path, dictionary_path,
             vocabulary_size, batch_size, max_len, epochs, lr, clip_grads, device,
             layers_count, hidden_size, heads_count, d_ff, dropout_prob,
             log_output, checkpoint_dir, print_every, save_every, config, run_name=None, **_):

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    train_path = train_path if data_dir is None else join(data_dir, train_path)
    val_path = val_path if data_dir is None else join(data_dir, val_path)
    dictionary_path = dictionary_path if data_dir is None else join(data_dir, dictionary_path)

    run_name = run_name if run_name is not None else make_run_name(RUN_NAME_FORMAT, phase='finetune', config=config)
    logger = make_logger(run_name, log_output)
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(config)

    logger.info('Constructing dictionaries...')
    dictionary = IndexDictionary.load(dictionary_path=dictionary_path,
                                      vocabulary_size=vocabulary_size)
    vocabulary_size = len(dictionary)
    logger.info(f'dictionary vocabulary : {vocabulary_size} tokens')

    logger.info('Loading datasets...')
    train_dataset = SST2IndexedDataset(data_path=train_path, dictionary=dictionary)
    val_dataset = SST2IndexedDataset(data_path=val_path, dictionary=dictionary)
    logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset)))

    logger.info('Building model...')
    pretrained_model = build_model(layers_count, hidden_size, heads_count, d_ff, dropout_prob, max_len, vocabulary_size)
    pretrained_model.load_state_dict(torch.load(pretrained_checkpoint, map_location='cpu')['state_dict'])

    model = FineTuneModel(pretrained_model, hidden_size, num_classes=2)

    logger.info(model)
    logger.info('{parameters_count} parameters'.format(
        parameters_count=sum([p.nelement() for p in model.parameters()])))

    loss_model = ClassificationLossModel(model)
    metric_functions = [classification_accuracy]

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        collate_fn=classification_collate_function)

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        collate_fn=classification_collate_function)

    optimizer = Adam(model.parameters(), lr=lr)

    checkpoint_dir = make_checkpoint_dir(checkpoint_dir, run_name, config)

    logger.info('Start training...')
    trainer = Trainer(
        loss_model=loss_model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        metric_functions=metric_functions,
        optimizer=optimizer,
        clip_grads=clip_grads,
        logger=logger,
        checkpoint_dir=checkpoint_dir,
        print_every=print_every,
        save_every=save_every,
        device=device
    )

    trainer.run(epochs=epochs)
    return trainer
コード例 #8
0
def load_worker(local_rank, cfgs, gpus_per_node, run_name, hdf5_path):
    # -----------------------------------------------------------------------------
    # define default variables for loading ckpt or evaluating the trained GAN model.
    # -----------------------------------------------------------------------------
    step, epoch, topk, best_step, best_fid, best_ckpt_path, is_best = \
        0, 0, cfgs.OPTIMIZATION.batch_size, 0, None, None, False
    aa_p = cfgs.AUG.ada_initial_augment_p if cfgs.AUG.ada_initial_augment_p != "N/A" else cfgs.AUG.apa_initial_augment_p
    mu, sigma, eval_model, num_rows, num_cols = None, None, None, 10, 8
    loss_list_dict = {"gen_loss": [], "dis_loss": [], "cls_loss": []}
    metric_dict_during_train = {}
    if "none" in cfgs.RUN.eval_metrics:
        cfgs.RUN.eval_metrics = []
    if "is" in cfgs.RUN.eval_metrics:
        metric_dict_during_train.update({
            "IS": [],
            "Top1_acc": [],
            "Top5_acc": []
        })
    if "fid" in cfgs.RUN.eval_metrics:
        metric_dict_during_train.update({"FID": []})
    if "prdc" in cfgs.RUN.eval_metrics:
        metric_dict_during_train.update({
            "Improved_Precision": [],
            "Improved_Recall": [],
            "Density": [],
            "Coverage": []
        })

    # -----------------------------------------------------------------------------
    # determine cuda, cudnn, and backends settings.
    # -----------------------------------------------------------------------------
    if cfgs.RUN.fix_seed:
        cudnn.benchmark, cudnn.deterministic = False, True
    else:
        cudnn.benchmark, cudnn.deterministic = True, False

    if cfgs.MODEL.backbone in ["stylegan2", "stylegan3"]:
        # Improves training speed
        conv2d_gradfix.enabled = True
        # Avoids errors with the augmentation pipe
        grid_sample_gradfix.enabled = True
        if cfgs.RUN.mixed_precision:
            # Allow PyTorch to internally use tf32 for matmul
            torch.backends.cuda.matmul.allow_tf32 = False
            # Allow PyTorch to internally use tf32 for convolutions
            torch.backends.cudnn.allow_tf32 = False

    # -----------------------------------------------------------------------------
    # initialize all processes and fix seed of each process
    # -----------------------------------------------------------------------------
    if cfgs.RUN.distributed_data_parallel:
        global_rank = cfgs.RUN.current_node * (gpus_per_node) + local_rank
        print("Use GPU: {global_rank} for training.".format(
            global_rank=global_rank))
        misc.setup(global_rank, cfgs.OPTIMIZATION.world_size, cfgs.RUN.backend)
        torch.cuda.set_device(local_rank)
    else:
        global_rank = local_rank

    misc.fix_seed(cfgs.RUN.seed + global_rank)

    # -----------------------------------------------------------------------------
    # Intialize python logger.
    # -----------------------------------------------------------------------------
    if local_rank == 0:
        logger = log.make_logger(cfgs.RUN.save_dir, run_name, None)
        if cfgs.RUN.ckpt_dir is not None and cfgs.RUN.freezeD == -1:
            folder_hier = cfgs.RUN.ckpt_dir.split("/")
            if folder_hier[-1] == "":
                folder_hier.pop()
            logger.info(
                "Run name : {run_name}".format(run_name=folder_hier.pop()))
        else:
            logger.info("Run name : {run_name}".format(run_name=run_name))
        for k, v in cfgs.super_cfgs.items():
            logger.info("cfgs." + k + " =")
            logger.info(json.dumps(vars(v), indent=2))
    else:
        logger = None

    # -----------------------------------------------------------------------------
    # load train and evaluation datasets.
    # -----------------------------------------------------------------------------
    if cfgs.RUN.train or cfgs.RUN.intra_class_fid or cfgs.RUN.GAN_train or cfgs.RUN.GAN_test:
        if local_rank == 0:
            logger.info(
                "Load {name} train dataset.".format(name=cfgs.DATA.name))
        train_dataset = Dataset_(
            data_name=cfgs.DATA.name,
            data_dir=cfgs.RUN.data_dir,
            train=True,
            crop_long_edge=cfgs.PRE.crop_long_edge,
            resize_size=cfgs.PRE.resize_size,
            random_flip=cfgs.PRE.apply_rflip,
            normalize=True,
            hdf5_path=hdf5_path,
            load_data_in_memory=cfgs.RUN.load_data_in_memory)
        if local_rank == 0:
            logger.info("Train dataset size: {dataset_size}".format(
                dataset_size=len(train_dataset)))
    else:
        train_dataset = None

    if len(cfgs.RUN.eval_metrics) + +cfgs.RUN.save_real_images + cfgs.RUN.k_nearest_neighbor + \
            cfgs.RUN.frequency_analysis + cfgs.RUN.tsne_analysis:
        if local_rank == 0:
            logger.info("Load {name} {ref} dataset.".format(
                name=cfgs.DATA.name, ref=cfgs.RUN.ref_dataset))
        eval_dataset = Dataset_(
            data_name=cfgs.DATA.name,
            data_dir=cfgs.RUN.data_dir,
            train=True if cfgs.RUN.ref_dataset == "train" else False,
            crop_long_edge=False
            if cfgs.DATA.name in cfgs.MISC.no_proc_data else True,
            resize_size=None if cfgs.DATA.name in cfgs.MISC.no_proc_data else
            cfgs.DATA.img_size,
            random_flip=False,
            hdf5_path=None,
            normalize=True,
            load_data_in_memory=False)
        if local_rank == 0:
            logger.info("Eval dataset size: {dataset_size}".format(
                dataset_size=len(eval_dataset)))
    else:
        eval_dataset = None

    # -----------------------------------------------------------------------------
    # define a distributed sampler for DDP train and evaluation.
    # define dataloaders for train and evaluation.
    # -----------------------------------------------------------------------------
    if cfgs.RUN.distributed_data_parallel:
        cfgs.OPTIMIZATION.batch_size = cfgs.OPTIMIZATION.batch_size // cfgs.OPTIMIZATION.world_size

    if cfgs.RUN.train and cfgs.RUN.distributed_data_parallel:
        train_sampler = DistributedSampler(
            train_dataset,
            num_replicas=cfgs.OPTIMIZATION.world_size,
            rank=local_rank,
            shuffle=True,
            drop_last=True)
        topk = cfgs.OPTIMIZATION.batch_size
    else:
        train_sampler = None
    cfgs.OPTIMIZATION.basket_size = cfgs.OPTIMIZATION.batch_size * cfgs.OPTIMIZATION.acml_steps * cfgs.OPTIMIZATION.d_updates_per_step

    if cfgs.RUN.train or cfgs.RUN.intra_class_fid or cfgs.RUN.GAN_train or cfgs.RUN.GAN_test:
        train_dataloader = DataLoader(dataset=train_dataset,
                                      batch_size=cfgs.OPTIMIZATION.basket_size,
                                      shuffle=(train_sampler is None),
                                      pin_memory=True,
                                      num_workers=cfgs.RUN.num_workers,
                                      sampler=train_sampler,
                                      drop_last=True,
                                      persistent_workers=True)
    else:
        train_dataloader = None

    if len(cfgs.RUN.eval_metrics) + +cfgs.RUN.save_real_images + cfgs.RUN.k_nearest_neighbor + \
            cfgs.RUN.frequency_analysis + cfgs.RUN.tsne_analysis:
        if cfgs.RUN.distributed_data_parallel:
            eval_sampler = DistributedSampler(
                eval_dataset,
                num_replicas=cfgs.OPTIMIZATION.world_size,
                rank=local_rank,
                shuffle=False,
                drop_last=False)
        else:
            eval_sampler = None

        eval_dataloader = DataLoader(dataset=eval_dataset,
                                     batch_size=cfgs.OPTIMIZATION.batch_size,
                                     shuffle=False,
                                     pin_memory=True,
                                     num_workers=cfgs.RUN.num_workers,
                                     sampler=eval_sampler,
                                     drop_last=False)
    else:
        eval_dataloader = None

    # -----------------------------------------------------------------------------
    # load a generator and a discriminator
    # if cfgs.MODEL.apply_g_ema is True, load an exponential moving average generator (Gen_copy).
    # -----------------------------------------------------------------------------
    Gen, Gen_mapping, Gen_synthesis, Dis, Gen_ema, Gen_ema_mapping, Gen_ema_synthesis, ema =\
        model.load_generator_discriminator(DATA=cfgs.DATA,
                                           OPTIMIZATION=cfgs.OPTIMIZATION,
                                           MODEL=cfgs.MODEL,
                                           STYLEGAN=cfgs.STYLEGAN,
                                           MODULES=cfgs.MODULES,
                                           RUN=cfgs.RUN,
                                           device=local_rank,
                                           logger=logger)

    if local_rank != 0:
        custom_ops.verbosity = "none"

    # -----------------------------------------------------------------------------
    # define optimizers for adversarial training
    # -----------------------------------------------------------------------------
    cfgs.define_optimizer(Gen, Dis)

    # -----------------------------------------------------------------------------
    # load the generator and the discriminator from a checkpoint if possible
    # -----------------------------------------------------------------------------
    if cfgs.RUN.ckpt_dir is not None:
        if local_rank == 0:
            os.remove(join(cfgs.RUN.save_dir, "logs", run_name + ".log"))
        run_name, step, epoch, topk, aa_p, best_step, best_fid, best_ckpt_path, logger =\
            ckpt.load_StudioGAN_ckpts(ckpt_dir=cfgs.RUN.ckpt_dir,
                                      load_best=cfgs.RUN.load_best,
                                      Gen=Gen,
                                      Dis=Dis,
                                      g_optimizer=cfgs.OPTIMIZATION.g_optimizer,
                                      d_optimizer=cfgs.OPTIMIZATION.d_optimizer,
                                      run_name=run_name,
                                      apply_g_ema=cfgs.MODEL.apply_g_ema,
                                      Gen_ema=Gen_ema,
                                      ema=ema,
                                      is_train=cfgs.RUN.train,
                                      RUN=cfgs.RUN,
                                      logger=logger,
                                      global_rank=global_rank,
                                      device=local_rank,
                                      cfg_file=cfgs.RUN.cfg_file)

        if topk == "initialize":
            topk == cfgs.OPTIMIZATION.batch_size
        if cfgs.MODEL.backbone in ["stylegan2", "stylegan3"]:
            ema.ema_rampup = "N/A"  # disable EMA rampup
            if cfgs.MODEL.backbone == "stylegan3" and cfgs.STYLEGAN.stylegan3_cfg == "stylegan3-r":
                cfgs.STYLEGAN.blur_init_sigma = "N/A"  # disable blur rampup
        if cfgs.AUG.apply_ada:
            cfgs.AUG.ada_kimg = 100  # make ADA react faster at the beginning

    if cfgs.RUN.ckpt_dir is None or cfgs.RUN.freezeD != -1:
        if local_rank == 0:
            cfgs.RUN.ckpt_dir = ckpt.make_ckpt_dir(
                join(cfgs.RUN.save_dir, "checkpoints", run_name))
        dict_dir = join(cfgs.RUN.save_dir, "statistics", run_name)
        loss_list_dict = misc.load_log_dicts(directory=dict_dir,
                                             file_name="losses.npy",
                                             ph=loss_list_dict)
        metric_dict_during_train = misc.load_log_dicts(
            directory=dict_dir,
            file_name="metrics.npy",
            ph=metric_dict_during_train)

    # -----------------------------------------------------------------------------
    # prepare parallel training
    # -----------------------------------------------------------------------------
    if cfgs.OPTIMIZATION.world_size > 1:
        Gen, Gen_mapping, Gen_synthesis, Dis, Gen_ema, Gen_ema_mapping, Gen_ema_synthesis =\
        model.prepare_parallel_training(Gen=Gen,
                                        Gen_mapping=Gen_mapping,
                                        Gen_synthesis=Gen_synthesis,
                                        Dis=Dis,
                                        Gen_ema=Gen_ema,
                                        Gen_ema_mapping=Gen_ema_mapping,
                                        Gen_ema_synthesis=Gen_ema_synthesis,
                                        MODEL=cfgs.MODEL,
                                        world_size=cfgs.OPTIMIZATION.world_size,
                                        distributed_data_parallel=cfgs.RUN.distributed_data_parallel,
                                        synchronized_bn=cfgs.RUN.synchronized_bn,
                                        apply_g_ema=cfgs.MODEL.apply_g_ema,
                                        device=local_rank)

    # -----------------------------------------------------------------------------
    # load a pre-trained network (InceptionV3 or ResNet50 trained using SwAV)
    # -----------------------------------------------------------------------------
    if len(cfgs.RUN.eval_metrics) or cfgs.RUN.intra_class_fid:
        eval_model = pp.LoadEvalModel(
            eval_backbone=cfgs.RUN.eval_backbone,
            resize_fn=cfgs.RUN.resize_fn,
            world_size=cfgs.OPTIMIZATION.world_size,
            distributed_data_parallel=cfgs.RUN.distributed_data_parallel,
            device=local_rank)

    if "fid" in cfgs.RUN.eval_metrics:
        mu, sigma = pp.prepare_moments(data_loader=eval_dataloader,
                                       eval_model=eval_model,
                                       quantize=True,
                                       cfgs=cfgs,
                                       logger=logger,
                                       device=local_rank)

    if cfgs.RUN.is_ref_dataset:
        pp.calculate_ins(data_loader=eval_dataloader,
                         eval_model=eval_model,
                         quantize=True,
                         splits=1,
                         cfgs=cfgs,
                         logger=logger,
                         device=local_rank)

    # -----------------------------------------------------------------------------
    # initialize WORKER for training and evaluating GAN
    # -----------------------------------------------------------------------------
    worker = WORKER(
        cfgs=cfgs,
        run_name=run_name,
        Gen=Gen,
        Gen_mapping=Gen_mapping,
        Gen_synthesis=Gen_synthesis,
        Dis=Dis,
        Gen_ema=Gen_ema,
        Gen_ema_mapping=Gen_ema_mapping,
        Gen_ema_synthesis=Gen_ema_synthesis,
        ema=ema,
        eval_model=eval_model,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        global_rank=global_rank,
        local_rank=local_rank,
        mu=mu,
        sigma=sigma,
        logger=logger,
        aa_p=aa_p,
        best_step=best_step,
        best_fid=best_fid,
        best_ckpt_path=best_ckpt_path,
        loss_list_dict=loss_list_dict,
        metric_dict_during_train=metric_dict_during_train,
    )

    # -----------------------------------------------------------------------------
    # train GAN until "total_steps" generator updates
    # -----------------------------------------------------------------------------
    if cfgs.RUN.train:
        if global_rank == 0:
            logger.info("Start training!")

        worker.training, worker.topk = True, topk
        worker.prepare_train_iter(epoch_counter=epoch)
        while step <= cfgs.OPTIMIZATION.total_steps:
            if cfgs.OPTIMIZATION.d_first:
                real_cond_loss, dis_acml_loss = worker.train_discriminator(
                    current_step=step)
                gen_acml_loss = worker.train_generator(current_step=step)
            else:
                gen_acml_loss = worker.train_generator(current_step=step)
                real_cond_loss, dis_acml_loss = worker.train_discriminator(
                    current_step=step)

            if global_rank == 0 and (step + 1) % cfgs.RUN.print_every == 0:

                worker.log_train_statistics(current_step=step,
                                            real_cond_loss=real_cond_loss,
                                            gen_acml_loss=gen_acml_loss,
                                            dis_acml_loss=dis_acml_loss)
            step += 1

            if cfgs.LOSS.apply_topk:
                if (epoch + 1) == worker.epoch_counter:
                    epoch += 1
                    worker.topk = losses.adjust_k(
                        current_k=worker.topk,
                        topk_gamma=cfgs.LOSS.topk_gamma,
                        sup_k=int(cfgs.OPTIMIZATION.batch_size *
                                  cfgs.LOSS.topk_nu))

            if step % cfgs.RUN.save_every == 0:
                # visuailize fake images
                if global_rank == 0:
                    worker.visualize_fake_images(num_cols=num_cols,
                                                 current_step=step)

                # evaluate GAN for monitoring purpose
                if len(cfgs.RUN.eval_metrics):
                    is_best = worker.evaluate(step=step,
                                              metrics=cfgs.RUN.eval_metrics,
                                              writing=True,
                                              training=True)

                # save GAN in "./checkpoints/RUN_NAME/*"
                if global_rank == 0:
                    worker.save(step=step, is_best=is_best)

                # stop processes until all processes arrive
                if cfgs.RUN.distributed_data_parallel:
                    dist.barrier(worker.group)

        if global_rank == 0:
            logger.info("End of training!")

    # -----------------------------------------------------------------------------
    # re-evaluate the best GAN and conduct ordered analyses
    # -----------------------------------------------------------------------------
    print("")
    worker.training, worker.epoch_counter = False, epoch
    worker.gen_ctlr.standing_statistics = cfgs.RUN.standing_statistics
    worker.gen_ctlr.standing_max_batch = cfgs.RUN.standing_max_batch
    worker.gen_ctlr.standing_step = cfgs.RUN.standing_step

    if global_rank == 0:
        best_step = ckpt.load_best_model(ckpt_dir=cfgs.RUN.ckpt_dir,
                                         Gen=Gen,
                                         Dis=Dis,
                                         apply_g_ema=cfgs.MODEL.apply_g_ema,
                                         Gen_ema=Gen_ema,
                                         ema=ema)

    if len(cfgs.RUN.eval_metrics):
        for e in range(cfgs.RUN.num_eval):
            if global_rank == 0:
                print(""), logger.info("-" * 80)
            _ = worker.evaluate(step=best_step,
                                metrics=cfgs.RUN.eval_metrics,
                                writing=False,
                                training=False)

    if cfgs.RUN.save_real_images:
        if global_rank == 0:
            print(""), logger.info("-" * 80)
        worker.save_real_images()

    if cfgs.RUN.save_fake_images:
        if global_rank == 0:
            print(""), logger.info("-" * 80)
        worker.save_fake_images(num_images=cfgs.RUN.save_fake_images_num)

    if cfgs.RUN.vis_fake_images:
        if global_rank == 0:
            print(""), logger.info("-" * 80)
        worker.visualize_fake_images(num_cols=num_cols, current_step=best_step)

    if cfgs.RUN.k_nearest_neighbor:
        if global_rank == 0:
            print(""), logger.info("-" * 80)
        worker.run_k_nearest_neighbor(dataset=eval_dataset,
                                      num_rows=num_rows,
                                      num_cols=num_cols)

    if cfgs.RUN.interpolation:
        if global_rank == 0:
            print(""), logger.info("-" * 80)
        worker.run_linear_interpolation(num_rows=num_rows,
                                        num_cols=num_cols,
                                        fix_z=True,
                                        fix_y=False)
        worker.run_linear_interpolation(num_rows=num_rows,
                                        num_cols=num_cols,
                                        fix_z=False,
                                        fix_y=True)

    if cfgs.RUN.frequency_analysis:
        if global_rank == 0:
            print(""), logger.info("-" * 80)
        worker.run_frequency_analysis(dataloader=eval_dataloader)

    if cfgs.RUN.tsne_analysis:
        if global_rank == 0:
            print(""), logger.info("-" * 80)
        worker.run_tsne(dataloader=eval_dataloader)

    if cfgs.RUN.intra_class_fid:
        if global_rank == 0:
            print(""), logger.info("-" * 80)
        worker.calculate_intra_class_fid(dataset=train_dataset)

    if cfgs.RUN.semantic_factorization:
        if global_rank == 0:
            print(""), logger.info("-" * 80)
        worker.run_semantic_factorization(
            num_rows=cfgs.RUN.num_semantic_axis,
            num_cols=num_cols,
            maximum_variations=cfgs.RUN.maximum_variations)
    if cfgs.RUN.GAN_train:
        if global_rank == 0:
            print(""), logger.info("-" * 80)
        worker.compute_GAN_train_or_test_classifier_accuracy_score(
            GAN_train=True, GAN_test=False)

    if cfgs.RUN.GAN_test:
        if global_rank == 0:
            print(""), logger.info("-" * 80)
        worker.compute_GAN_train_or_test_classifier_accuracy_score(
            GAN_train=False, GAN_test=True)

    if global_rank == 0:
        wandb.finish()
コード例 #9
0
def train_framework(dataset_name, architecture, num_classes, img_size, data_path, eval_dataset, hdf5_path_train, hdf5_path_valid, train_rate, auxiliary_classifier,
                    projection_discriminator, contrastive_training, hyper_dim, nonlinear_embed, normalize_embed, g_spectral_norm, d_spectral_norm, attention, reduce_class,
                    at_after_th_gen_block, at_after_th_dis_block, leaky_relu, g_init, d_init, latent_op, consistency_reg, make_positive_aug, synchronized_bn, ema,
                    ema_decay, ema_start, adv_loss, z_dim, shared_dim, g_conv_dim, d_conv_dim, batch_size, total_step, truncated_factor, prior, d_lr, g_lr,
                    beta1, beta2, batch4metrics, config, **_):

    fix_all_seed(config['seed'])
    cudnn.benchmark = True # Not good Generator for undetermined input size
    cudnn.deterministic = False
    n_gpus = torch.cuda.device_count()
    default_device = torch.cuda.current_device()
    second_device = default_device if n_gpus == 1 else default_device+1
    assert batch_size % n_gpus == 0, "batch_size should be divided by the number of gpus "

    if n_gpus == 1:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    start_step = 0
    best_val_fid, best_checkpoint_fid_path, best_val_is, best_checkpoint_is_path = None, None, None, None
    run_name = make_run_name(RUN_NAME_FORMAT,
                             framework=config['config_path'].split('/')[3][:-5],
                             phase='train',
                             config=config)

    logger = make_logger(run_name, None)
    writer = SummaryWriter(log_dir=join('./logs', run_name))
    logger.info('Run name : {run_name}'.format(run_name=run_name))
    logger.info(config)

    logger.info('Loading train datasets...')
    train_dataset = LoadDataset(dataset_name, data_path, train=True, download=True, resize_size=img_size, hdf5_path=hdf5_path_train,
                                consistency_reg=consistency_reg, make_positive_aug=make_positive_aug)
    if train_rate < 1.0:
        num_train = int(train_rate*len(train_dataset))
        train_dataset, _ = torch.utils.data.random_split(train_dataset, [num_train, len(train_dataset) - num_train])

    logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset)))

    logger.info('Loading valid datasets...')
    valid_dataset = LoadDataset(dataset_name, data_path, train=False, download=True, resize_size=img_size, hdf5_path=hdf5_path_valid)
    logger.info('Valid dataset size : {dataset_size}'.format(dataset_size=len(valid_dataset)))

    logger.info('Building model...')
    module = __import__('models.{architecture}'.format(architecture=architecture),fromlist=['something'])
    logger.info('Modules are located on models.{architecture}'.format(architecture=architecture))
    num_classes = int(reduce_class*num_classes)
    Gen = module.Generator(z_dim, shared_dim, g_conv_dim, g_spectral_norm, attention, at_after_th_gen_block, leaky_relu, auxiliary_classifier,
                           projection_discriminator, num_classes, contrastive_training, synchronized_bn, g_init).to(default_device)

    Dis = module.Discriminator(d_conv_dim, d_spectral_norm, attention, at_after_th_dis_block, leaky_relu, auxiliary_classifier, 
                               projection_discriminator, hyper_dim, num_classes, contrastive_training, nonlinear_embed, normalize_embed,
                               synchronized_bn, d_init).to(default_device)

    if ema:
        print('Preparing EMA for G with decay of {}'.format(ema_decay))
        Gen_copy = module.Generator(z_dim, shared_dim, g_conv_dim, g_spectral_norm, attention, at_after_th_gen_block, leaky_relu, auxiliary_classifier,
                                    projection_discriminator, num_classes, contrastive_training, synchronized_bn=False, initialize=False).to(default_device)
        Gen_ema = ema_(Gen, Gen_copy, ema_decay, ema_start)
    else:
        Gen_copy, Gen_ema = None, None

    if n_gpus > 1:
        Gen = DataParallel(Gen, output_device=second_device)
        Dis = DataParallel(Dis, output_device=second_device)
        if ema:
            Gen_copy = DataParallel(Gen_copy, output_device=second_device)
        if config['synchronized_bn']:
            patch_replication_callback(Gen)
            patch_replication_callback(Dis)

    logger.info(count_parameters(Gen))
    logger.info(Gen)

    logger.info(count_parameters(Dis))
    logger.info(Dis)
    if reduce_class != 1.0:
        assert dataset_name == "TINY_ILSVRC2012" or "ILSVRC2012", "reduce_class mode can not be applied on the CIFAR10 dataset"
        n_train = int(reduce_class*len(train_dataset))
        n_valid = int(reduce_class*len(valid_dataset))
        train_weights = [1.0]*n_train + [0.0]*(len(train_dataset) - n_train)
        valid_weights = [1.0]*n_valid + [0.0]*(len(valid_dataset) - n_valid)
        train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_weights, len(train_weights))
        valid_sampler = torch.utils.data.sampler.WeightedRandomSampler(valid_weights, len(valid_weights))
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=batch_size, sampler=train_sampler, shuffle=False,
                                      pin_memory=True, num_workers=config['num_workers'], drop_last=True)

        evaluation_dataloader = DataLoader(valid_dataset,
                                           sampler=valid_sampler, batch_size=batch4metrics, shuffle=False,
                                           pin_memory=True, num_workers=config['num_workers'], drop_last=False)
    else:       
        train_dataloader = DataLoader(train_dataset,
                                    batch_size=batch_size, shuffle=True, pin_memory=True,
                                    num_workers=config['num_workers'], drop_last=True)

        evaluation_dataloader = DataLoader(valid_dataset,
                                        batch_size=batch4metrics, shuffle=True, pin_memory=True,
                                        num_workers=config['num_workers'], drop_last=False)

    G_loss = {'vanilla': loss_dcgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen}
    D_loss = {'vanilla': loss_dcgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis}

    G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, [beta1, beta2])
    D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, [beta1, beta2])

    checkpoint_dir = make_checkpoint_dir(config['checkpoint_folder'], run_name, config)

    if config['checkpoint_folder'] is not None:
        logger = make_logger(run_name, config['log_output_path'])
        g_checkpoint_dir = glob.glob(os.path.join(checkpoint_dir,"model=G-step=" + str(config['step']) + "*.pth"))[0]
        d_checkpoint_dir = glob.glob(os.path.join(checkpoint_dir,"model=D-step=" + str(config['step']) + "*.pth"))[0]
        Gen, G_optimizer, seed, run_name, start_step = load_checkpoint(Gen, G_optimizer, g_checkpoint_dir)
        Dis, D_optimizer, seed, run_name, start_step, best_val_fid, best_checkpoint_fid_path,\
        best_val_is, best_checkpoint_is_path = load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True)
        if ema:
            g_ema_checkpoint_dir = glob.glob(os.path.join(checkpoint_dir, "model=G_ema-step=" + str(config['step']) + "*.pth"))[0]
            Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=ema)
            Gen_ema.source, Gen_ema.target = Gen, Gen_copy

        writer = SummaryWriter(log_dir=join('./logs', run_name))
        assert config['seed'] == seed, "seed for sampling random numbers should be same!"
        logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
        logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir))

    if config['eval']:
        inception_model = InceptionV3().to(default_device)
        inception_model = DataParallel(inception_model, output_device=second_device)
        mu, sigma, is_score, is_std = prepare_inception_moments_eval_dataset(dataloader=evaluation_dataloader,
                                                                            inception_model=inception_model,
                                                                            reduce_class=reduce_class,
                                                                            splits=10,
                                                                            logger=logger,
                                                                            device=second_device,
                                                                            eval_dataset=eval_dataset)
    else:
        mu, sigma, inception_model = None, None, None

    logger.info('Start training...')
    trainer = Trainer(
        run_name=run_name,
        logger=logger,
        writer=writer,
        n_gpus=n_gpus,
        gen_model=Gen,
        dis_model=Dis,
        inception_model=inception_model,
        Gen_copy=Gen_copy,
        Gen_ema=Gen_ema,
        train_dataloader=train_dataloader,
        evaluation_dataloader=evaluation_dataloader,
        G_loss=G_loss[adv_loss],
        D_loss=D_loss[adv_loss],
        auxiliary_classifier=auxiliary_classifier,
        contrastive_training=contrastive_training,
        contrastive_lambda=config['contrastive_lambda'],
        softmax_posterior=config['softmax_posterior'],
        contrastive_softmax=config['contrastive_softmax'],
        hyper_dim=config['hyper_dim'],
        tempering=config['tempering'],
        discrete_tempering=config['discrete_tempering'],
        tempering_times=config['tempering_times'],
        start_temperature=config['start_temperature'],
        end_temperature=config['end_temperature'],
        gradient_penalty_for_dis=config['gradient_penalty_for_dis'],
        lambda4lp=config['lambda4lp'],
        lambda4gp=config['lambda4gp'],
        weight_clipping_for_dis=config['weight_clipping_for_dis'],
        weight_clipping_bound=config['weight_clipping_bound'],
        latent_op=latent_op,
        latent_op_rate=config['latent_op_rate'],
        latent_op_step=config['latent_op_step'],
        latent_op_step4eval=config['latent_op_step4eval'],
        latent_op_alpha=config['latent_op_alpha'],
        latent_op_beta=config['latent_op_beta'],
        latent_norm_reg_weight=config['latent_norm_reg_weight'],
        consistency_reg=consistency_reg,
        consistency_lambda=config['consistency_lambda'],
        make_positive_aug=make_positive_aug,
        G_optimizer=G_optimizer,
        D_optimizer=D_optimizer,
        default_device=default_device,
        second_device=second_device,
        batch_size=batch_size,
        z_dim=z_dim,
        num_classes=num_classes,
        truncated_factor=truncated_factor,
        prior=prior,
        g_steps_per_iter=config['g_steps_per_iter'],
        d_steps_per_iter=config['d_steps_per_iter'],
        accumulation_steps=config['accumulation_steps'],
        lambda4ortho=config['lambda4ortho'],
        print_every=config['print_every'],
        save_every=config['save_every'],
        checkpoint_dir=checkpoint_dir,
        evaluate=config['eval'],
        mu=mu,
        sigma=sigma,
        best_val_fid=best_val_fid,
        best_checkpoint_fid_path=best_checkpoint_fid_path,
        best_val_is=best_val_is,
        best_checkpoint_is_path=best_checkpoint_is_path,
        config=config,
    )

    if contrastive_training:
        trainer.run_ours(current_step=start_step, total_step=total_step)
    else:
        trainer.run(current_step=start_step, total_step=total_step)