예제 #1
0
                        help='use mixing regularization')
    parser.add_argument(
        '--loss',
        type=str,
        default='wgan-gp',
        choices=['wgan-gp', 'r1'],
        help='class of gan loss',
    )

    args = parser.parse_args()

    generator = nn.DataParallel(StyledGenerator(code_size)).cuda()
    discriminator = nn.DataParallel(
        Discriminator(from_rgb_activate=not args.no_from_rgb_activate)).cuda()
    g_running = StyledGenerator(code_size).cuda()
    g_running.train(False)

    g_optimizer = optim.Adam(generator.module.generator.parameters(),
                             lr=args.lr,
                             betas=(0.0, 0.99))
    g_optimizer.add_param_group({
        'params': generator.module.style.parameters(),
        'lr': args.lr * 0.01,
        'mult': 0.01,
    })
    d_optimizer = optim.Adam(discriminator.parameters(),
                             lr=args.lr,
                             betas=(0.0, 0.99))

    accumulate(g_running, generator.module, 0)
예제 #2
0
def get_model(model_name, config, iteration=None, restart=False, from_step=False, load_discriminator=True,
              alpha=1, step=6, resolution=256, used_samples=0):
    """
    Function that creates a model.
    Arguments:
        model_name -- name to use for save and load the model.
        config -- dict of model parameters.
        iteration -- iteration to load; last if None
        restart -- if true, than creates new model even there is a saved model with `model_name`.
    """
    LOGGER.info(f'Getting model "{model_name}"')
    code_size = config.get('code_size', constants.DEFAULT_CODE_SIZE)
    init_size = config.get('init_size', constants.INIT_SIZE)
    n_frames_params = config.get('n_frames_params', dict())
    n_frames = n_frames_params.get('n', 1)
    from_rgb_activate = config['from_rgb_activate']
    two_noises = n_frames_params.get('two_noises', False)
    lr = config.get('lr', constants.LR)
    dyn_style_coordinates = n_frames_params.get('dyn_style_coordinates', 0)

    generator = nn.DataParallel(StyledGenerator(code_size,
                                                two_noises=two_noises,
                                                dyn_style_coordinates=dyn_style_coordinates,
                                                )).cuda()
    g_running = StyledGenerator(code_size,
                                two_noises=two_noises,
                                dyn_style_coordinates=dyn_style_coordinates,
                                ).cuda()
    g_running.train(False)
    discriminator = nn.DataParallel(Discriminator(from_rgb_activate=from_rgb_activate)).cuda()
    n_frames_discriminator = nn.DataParallel(
        NFramesDiscriminator(from_rgb_activate=from_rgb_activate, n_frames=n_frames)
    ).cuda()

    if not restart:
        if iteration is None:
            model = get_last_model(model_name, from_step)
        else:
            iteration = str(iteration).zfill(6)
            checkpoint_path = os.path.join(constants.CHECKPOINT_DIR, model_name, f'{iteration}.model')
            LOGGER.info(f'Loading {checkpoint_path}')
            model = torch.load(checkpoint_path)
        generator.module.load_state_dict(model['generator'])
        g_running.load_state_dict(model['g_running'])
        if load_discriminator:
            discriminator.module.load_state_dict(model['discriminator'])
        if 'n_frames_params' in config:
            n_frames_discriminator.module.load_state_dict(model['n_frames_discriminator'])
        alpha = model['alpha']
        step = model['step']
        LOGGER.debug(f'Step: {step}')
        resolution = model['resolution']
        used_samples = model['used_samples']
        LOGGER.debug(f'Used samples: {used_samples}.')
        iteration = model['iteration']
    else:
        alpha = 0
        step = int(math.log2(init_size)) - 2
        resolution = 4 * 2 ** step
        used_samples = 0
        iteration = 0
        accumulate(to_model=g_running, from_model=generator.module, decay=0)

    g_optimizer = optim.Adam(
        generator.module.generator.parameters(),
        lr=lr[resolution], betas=(0.0, 0.99)
    )

    style_module = generator.module
    style_params = list(style_module.style.parameters())
    g_optimizer.add_param_group(
        {
            'params': style_params,
            'lr': lr[resolution] * 0.01,
            'mult': 0.01,
        }
    )

    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr[resolution], betas=(0.0, 0.99))
    nfd_optimizer = optim.Adam(n_frames_discriminator.parameters(), lr=lr[resolution], betas=(0.0, 0.99))

    if not restart:
        g_optimizer.load_state_dict(model['g_optimizer'])
        d_optimizer.load_state_dict(model['d_optimizer'])
        nfd_optimizer.load_state_dict(model['nfd_optimizer'])

    return EasyDict(
           generator=generator,
           discriminator=discriminator,
           n_frames_discriminator=n_frames_discriminator,
           g_running=g_running,
           g_optimizer=g_optimizer,
           d_optimizer=d_optimizer,
           nfd_optimizer=nfd_optimizer,
           alpha=alpha,
           step=step,
           resolution=resolution,
           used_samples=used_samples,
           iteration=iteration,
       )
예제 #3
0
    dataset = MultiResolutionDataset(f'./dataset/{args.dataset}_lmdb',
                                     transform,
                                     resolution=args.image_size)

    ### load G and D ###

    if args.supervised:
        G_target = nn.DataParallel(
            StyledGenerator(code_size,
                            dataset_size=len(dataset),
                            embed_dim=code_size)).cuda()
        G_running_target = StyledGenerator(code_size,
                                           dataset_size=len(dataset),
                                           embed_dim=code_size).cuda()
        G_running_target.train(False)
        accumulate(G_running_target, G_target.module, 0)
    else:
        G_target = nn.DataParallel(StyledGenerator(code_size)).cuda()
        D_target = nn.DataParallel(
            Discriminator(from_rgb_activate=True)).cuda()
        G_running_target = StyledGenerator(code_size).cuda()
        G_running_target.train(False)
        accumulate(G_running_target, G_target.module, 0)

        G_source = nn.DataParallel(StyledGenerator(code_size)).cuda()
        D_source = nn.DataParallel(
            Discriminator(from_rgb_activate=True)).cuda()
        requires_grad(G_source, False)
        requires_grad(D_source, False)
예제 #4
0
def main(args, myargs):
    code_size = 512
    batch_size = 16
    n_critic = 1

    generator = nn.DataParallel(StyledGenerator(code_size)).cuda()
    discriminator = nn.DataParallel(
        Discriminator(from_rgb_activate=not args.no_from_rgb_activate)).cuda()
    g_running = StyledGenerator(code_size).cuda()
    g_running.train(False)

    g_optimizer = optim.Adam(generator.module.generator.parameters(),
                             lr=args.lr,
                             betas=(0.0, 0.99))
    g_optimizer.add_param_group({
        'params': generator.module.style.parameters(),
        'lr': args.lr * 0.01,
        'mult': 0.01,
    })
    d_optimizer = optim.Adam(discriminator.parameters(),
                             lr=args.lr,
                             betas=(0.0, 0.99))

    accumulate(g_running, generator.module, 0)

    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt)

        generator.module.load_state_dict(ckpt['generator'])
        discriminator.module.load_state_dict(ckpt['discriminator'])
        g_running.load_state_dict(ckpt['g_running'])
        g_optimizer.load_state_dict(ckpt['g_optimizer'])
        d_optimizer.load_state_dict(ckpt['d_optimizer'])

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    dataset = MultiResolutionDataset(args.path, transform)

    if args.sched:
        args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        args.batch = {
            4: 512,
            8: 256,
            16: 128,
            32: 64,
            64: 32,
            128: 32,
            256: 32
        }

    else:
        args.lr = {}
        args.batch = {}

    args.gen_sample = {512: (8, 4), 1024: (4, 2)}

    args.batch_default = 32

    train(args,
          dataset,
          generator,
          discriminator,
          g_optimizer=g_optimizer,
          d_optimizer=d_optimizer,
          g_running=g_running,
          code_size=code_size,
          n_critic=n_critic,
          myargs=myargs)