コード例 #1
0
ファイル: train.py プロジェクト: ErrorInever/WGAN-GP
def train_one_epoch(epoch,
                    dataloader,
                    gen,
                    critic,
                    opt_gen,
                    opt_critic,
                    fixed_noise,
                    device,
                    metric_logger,
                    num_samples,
                    freq=100):
    """
    Train one epoch
    :param epoch: ``int`` current epoch
    :param dataloader: object of dataloader
    :param gen: Generator model
    :param critic: Discriminator model
    :param opt_gen: Optimizer for generator
    :param opt_critic: Optimizer for discriminator
    :param fixed_noise: ``tensor[[cfg.BATCH_SIZE, latent_space_dimension, 1, 1]]`` fixed noise (latent space) for image metrics
    :param device: cuda device or cpu
    :param metric_logger: object of MetricLogger
    :param num_samples: ``int`` well retrievable sqrt() (for example: 4, 16, 64) for good result,
    number of samples for grid image metric
    :param freq: ``int``, freq < len(dataloader)`` freq for display results
    """
    for batch_idx, img in enumerate(dataloader):
        real = img.to(device)
        # Train critic
        acc_real = 0
        acc_fake = 0
        for _ in range(cfg.CRITIC_ITERATIONS):
            noise = get_random_noise(cfg.BATCH_SIZE, cfg.Z_DIMENSION, device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            acc_real += critic_real
            critic_fake = critic(fake).reshape(-1)
            acc_fake += critic_fake
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = -1 * (torch.mean(critic_real) -
                                torch.mean(critic_fake)) + cfg.LAMBDA_GP * gp
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        acc_real = acc_real / cfg.CRITIC_ITERATIONS
        acc_fake = acc_fake / cfg.CRITIC_ITERATIONS

        # Train generator: minimize -E[critic(gen_fake)]
        output = critic(fake).reshape(-1)
        loss_gen = -1 * torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # logs metrics
        if batch_idx % freq == 0:
            with torch.no_grad():
                metric_logger.log(loss_critic, loss_gen, acc_real, acc_fake)
                fake = gen(fixed_noise)
                metric_logger.log_image(fake, num_samples, epoch, batch_idx,
                                        len(dataloader))
                metric_logger.display_status(epoch, cfg.NUM_EPOCHS, batch_idx,
                                             len(dataloader), loss_critic,
                                             loss_gen, acc_real, acc_fake)
コード例 #2
0
ファイル: train.py プロジェクト: ErrorInever/WGAN-GP
        cfg.NUM_EPOCHS = end_epoch
    else:
        print("=> Init default weights of models and fixed noise")
        # FIXME sometime (usually) when the weights is initialized from normal distribution can cause mode collapse
        # init_weights(gen)
        # init_weights(disc)
        # defining optimizers after init weights
        opt_gen = optim.Adam(gen.parameters(),
                             lr=cfg.LEARNING_RATE,
                             betas=(0.5, 0.999))
        opt_critic = optim.Adam(critic.parameters(),
                                lr=cfg.LEARNING_RATE,
                                betas=(0.5, 0.999))
        start_epoch = 1
        end_epoch = cfg.NUM_EPOCHS
        fixed_noise = get_random_noise(cfg.BATCH_SIZE, cfg.Z_DIMENSION, device)

    if args.resume_id:
        metric_logger = MetricLogger(cfg.PROJECT_VERSION_NAME,
                                     resume_id=args.resume_id)
    else:
        metric_logger = MetricLogger(cfg.PROJECT_VERSION_NAME)

    # gradients metric
    wandb.watch(gen)
    wandb.watch(critic)
    # model mode
    gen.train()
    critic.train()

    start_time = time.time()
コード例 #3
0
def train_one_epoch(epoch,
                    dataloader,
                    gen,
                    disc,
                    criterion,
                    opt_gen,
                    opt_disc,
                    fixed_noise,
                    device,
                    metric_logger,
                    num_samples,
                    freq=100):
    """
    Train one epoch
    :param epoch: ``int`` current epoch
    :param dataloader: object of dataloader
    :param gen: Generator model
    :param disc: Discriminator model
    :param criterion: Loss function (for this case: binary cross entropy)
    :param opt_gen: Optimizer for generator
    :param opt_disc: Optimizer for discriminator
    :param fixed_noise: ``tensor[[cfg.BATCH_SIZE, latent_space_dimension, 1, 1]]``
    fixed noise (latent space) for image metrics
    :param device: cuda device or cpu
    :param metric_logger: object of MetricLogger
    :param num_samples: ``int`` well retrievable sqrt() (for example: 4, 16, 64) for good result,
    number of samples for grid image metric
    :param freq: ``int``, must be < len(dataloader)`` freq for display results
    """
    for batch_idx, img in enumerate(dataloader):
        real = img.to(device)
        noise = get_random_noise(cfg.BATCH_SIZE, cfg.Z_DIMENSION, device)
        fake = gen(noise)

        # Train discriminator: We maximize log(D(x)) + log(1 - D(G(z))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # Train generator: We minimize log(1 - D(G(z))). This is the same as maximize log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # logs metrics
        if batch_idx % freq == 0:
            with torch.no_grad():
                metric_logger.log(loss_disc, loss_gen, disc_real, disc_fake)
                fake = gen(fixed_noise)
                metric_logger.log_image(fake, num_samples, epoch, batch_idx,
                                        len(dataloader))
                metric_logger.display_status(epoch, cfg.NUM_EPOCHS, batch_idx,
                                             len(dataloader), loss_disc,
                                             loss_gen, disc_real, disc_fake)
コード例 #4
0
        out_path = 'DCGAN-Anime-Faces'
    else:
        out_path = args.out_path

    if args.device == 'cuda':
        if torch.cuda.is_available():
            device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    gen = Generator(128, 3, 64)
    load_gen(gen, args.path_ckpt, device)
    gen.eval()

    if args.grid:
        noise = get_random_noise(args.num_samples, args.z_size, device)
        print("==> Generate IMAGE GRID...")
        output = gen(noise)
        show_batch(output, out_path, num_samples=args.num_samples, figsize=(args.img_size, args.img_size))
    elif args.gif:
        noise = get_random_noise(args.num_samples, args.z_size, device)
        print("==> Generate GIF...")
        images = latent_space_interpolation_sequence(noise, step_interpolation=args.steps)
        output = gen(images)
        if args.resize and isinstance(args.resize, int):
            print(f"==> Resize images to {args.resize}px")
            output = F.interpolate(output, size=args.resize)

        images = []
        for img in output:
            img = img.detach().permute(1, 2, 0)
コード例 #5
0
                         betas=(0.0, 0.99))
    opt_dis = optim.Adam(params=dis.parameters(),
                         lr=cfg.LEARNING_RATE,
                         betas=(0.0, 0.99))
    # defining gradient scalers for automatic mixed precision
    scaler_gen = torch.cuda.amp.GradScaler()
    scaler_dis = torch.cuda.amp.GradScaler()

    if args.checkpoint:
        fixed_noise, cfg.START_EPOCH = load_checkpoint(args.checkpoint, gen,
                                                       opt_gen, scaler_gen,
                                                       dis, opt_dis,
                                                       scaler_dis,
                                                       cfg.LEARNING_RATE)
    else:
        fixed_noise = get_random_noise(cfg.FIXED_NOISE_SAMPLES,
                                       cfg.Z_DIMENSION, device)
        # logger.info("load weights from normal distribution")
        # init_weights(gen)
        # init_weights(dis)

    gen.train()
    dis.train()

    metric_logger = MetricLogger(cfg.PROJECT_VERSION_NAME)

    for epoch in range(cfg.START_EPOCH, cfg.END_EPOCH):
        if args.wgp:
            train_one_epoch_with_gp(gen,
                                    opt_gen,
                                    dis,
                                    opt_dis,