コード例 #1
0
ファイル: train.py プロジェクト: momciloknezevic7/RI-GAN
                            logging.StreamHandler()
                        ])

    # fetching device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.debug(f"{device}")

    # training configuration
    with open("train_config.json", "r") as f:
        train_config = json.load(f)
        args = Namespace(**train_config)

    # initializing networks and optimizers
    if args.type == "DCGAN":
        G, D = utils.get_gan(GANType.DCGAN, device)
        G_optim, D_optim = utils.get_optimizers(G, D)
    elif args.type == "SN_DCGAN":
        G, D = utils.get_gan(GANType.SN_DCGAN, device, args.n_power_iterations)
        G_optim, D_optim = utils.get_optimizers(G, D)

    # initializing loader for data
    data_loader = utils.get_data_loader(args.batch_size, args.img_size)

    # setting up loss and GT
    adversarial_loss = nn.BCELoss()
    real_gt, fake_gt = utils.get_gt(args.batch_size, device)

    # for logging
    log_batch_size = 25
    log_noise = utils.get_latent_batch(log_batch_size, device)
    D_loss_values, G_loss_values = [], []
コード例 #2
0
def train_gan(training_config):
    writer = SummaryWriter()
    device = torch.device("cpu")

    # Download MNIST dataset in the directory data
    mnist_data_loader = utils.get_mnist_data_loader(
        training_config['batch_size'])

    discriminator_net, generator_net = utils.get_gan(device,
                                                     GANType.CLASSIC.name)
    discriminator_opt, generator_opt = utils.get_optimizers(
        discriminator_net, generator_net)

    adversarial_loss = nn.BCELoss()
    real_image_gt = torch.ones((training_config['batch_size'], 1),
                               device=device)
    fake_image_gt = torch.zeros((training_config['batch_size'], 1),
                                device=device)

    ref_batch_size = 16
    ref_noise_batch = utils.get_gaussian_latent_batch(ref_batch_size, device)
    discriminator_loss_values = []
    generator_loss_values = []
    img_cnt = 0

    ts = time.time()

    utils.print_training_info_to_console(training_config)
    for epoch in range(training_config['num_epochs']):
        for batch_idx, (real_images, _) in enumerate(mnist_data_loader):
            real_images = real_images.to(device)

            # Train discriminator
            discriminator_opt.zero_grad()

            real_discriminator_loss = adversarial_loss(
                discriminator_net(real_images), real_image_gt)

            fake_images = generator_net(
                utils.get_gaussian_latent_batch(training_config['batch_size'],
                                                device))
            fake_images_predictions = discriminator_net(fake_images.detach())
            fake_discriminator_loss = adversarial_loss(fake_images_predictions,
                                                       fake_image_gt)

            discriminator_loss = real_discriminator_loss + fake_discriminator_loss
            discriminator_loss.backward()
            discriminator_opt.step()

            # Train generator
            generator_opt.zero_grad()

            generated_images_prediction = discriminator_net(
                generator_net(
                    utils.get_gaussian_latent_batch(
                        training_config['batch_size'], device)))

            generator_loss = adversarial_loss(generated_images_prediction,
                                              real_image_gt)

            generator_loss.backward()
            generator_opt.step()

            # Logging and checkpoint creation
            generator_loss_values.append(generator_loss.item())
            discriminator_loss_values.append(discriminator_loss.item())

            if training_config['enable_tensorboard']:
                writer.add_scalars(
                    'Losses/g-and-d', {
                        'g': generator_loss.item(),
                        'd': discriminator_loss.item()
                    },
                    len(mnist_data_loader) * epoch + batch_idx + 1)

                if training_config[
                        'debug_imagery_log_freq'] is not None and batch_idx % training_config[
                            'debug_imagery_log_freq'] == 0:
                    with torch.no_grad():
                        log_generated_images = generator_net(ref_noise_batch)
                        log_generated_images_resized = nn.Upsample(
                            scale_factor=2,
                            mode='nearest')(log_generated_images)
                        intermediate_imagery_grid = make_grid(
                            log_generated_images_resized,
                            nrow=int(np.sqrt(ref_batch_size)),
                            normalize=True)
                        writer.add_image(
                            'intermediate generated imagery',
                            intermediate_imagery_grid,
                            len(mnist_data_loader) * epoch + batch_idx + 1)

            if training_config[
                    'console_log_freq'] is not None and batch_idx % training_config[
                        'console_log_freq'] == 0:
                print(
                    f'GAN training: time elapsed = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(mnist_data_loader)}]'
                )

            # Save intermediate generator images
            if training_config[
                    'debug_imagery_log_freq'] is not None and batch_idx % training_config[
                        'debug_imagery_log_freq'] == 0:
                with torch.no_grad():
                    log_generated_images = generator_net(ref_noise_batch)
                    log_generated_images_resized = nn.Upsample(
                        scale_factor=2, mode='nearest')(log_generated_images)
                    save_image(log_generated_images_resized,
                               os.path.join(training_config['debug_path'],
                                            f'{str(img_cnt).zfill(6)}.jpg'),
                               nrow=int(np.sqrt(ref_batch_size)),
                               normalize=True)
                    img_cnt += 1

            # Save generator checkpoint
            if training_config['checkpoint_freq'] is not None and (
                    epoch + 1
            ) % training_config['checkpoint_freq'] == 0 and batch_idx == 0:
                ckpt_model_name = f"Classic_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth"
                torch.save(
                    utils.get_training_state(generator_net,
                                             GANType.CLASSIC.name),
                    os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

    torch.save(utils.get_training_state(generator_net, GANType.CLASSIC.name),
               os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
コード例 #3
0
def train(cfg):
    '''
    This is the main loop for training
    Loads the dataset, model, and other things
    '''
    print json.dumps(cfg, sort_keys=True, indent=4)

    use_cuda = cfg['use-cuda']

    _, _, train_dl, val_dl = utils.get_data_loaders(cfg)

    model = utils.get_model(cfg)
    if use_cuda:
        model = model.cuda()
    model = utils.init_weights(model, cfg)

    # Get pretrained models, optimizers and loss functions
    optim = utils.get_optimizers(model, cfg)
    model, optim, metadata = utils.load_ckpt(model, optim, cfg)
    loss_fn = utils.get_losses(cfg)

    # Set up random seeds
    seed = np.random.randint(2**32)
    ckpt = 0
    if metadata is not None:
        seed = metadata['seed']
        ckpt = metadata['ckpt']

    # Get schedulers after getting checkpoints
    scheduler = utils.get_schedulers(optim, cfg, ckpt)
    # Print optimizer state
    print optim

    # Get loss file handle to dump logs to
    if not os.path.exists(cfg['save-path']):
        os.makedirs(cfg['save-path'])
    lossesfile = open(os.path.join(cfg['save-path'], 'losses.txt'), 'a+')

    # Random seed according to what the saved model is
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Run training loop
    num_epochs = cfg['train']['num-epochs']
    for epoch in range(num_epochs):
        # Run the main training loop
        model.train()
        for data in train_dl:
            # zero out the grads
            optim.zero_grad()

            # Change to required device
            for key, value in data.items():
                data[key] = Variable(value)
                if use_cuda:
                    data[key] = data[key].cuda()

            # Get all outputs
            outputs = model(data)
            loss_val = loss_fn(outputs, data, cfg)

            # print it
            print('Epoch: {}, step: {}, loss: {}'.format(
                epoch, ckpt,
                loss_val.data.cpu().numpy()))

            # Log into the file after some epochs
            if ckpt % cfg['train']['step-log'] == 0:
                lossesfile.write('Epoch: {}, step: {}, loss: {}\n'.format(
                    epoch, ckpt,
                    loss_val.data.cpu().numpy()))

            # Backward
            loss_val.backward()
            optim.step()

            # Update schedulers
            scheduler.step()

            # Peek into the validation set
            ckpt += 1
            if ckpt % cfg['peek-validation'] == 0:
                model.eval()
                with torch.no_grad():
                    for val_data in val_dl:
                        # Change to required device
                        for key, value in val_data.items():
                            val_data[key] = Variable(value)
                            if use_cuda:
                                val_data[key] = val_data[key].cuda()

                        # Get all outputs
                        outputs = model(val_data)
                        loss_val = loss_fn(outputs, val_data, cfg)

                        print 'Validation loss: {}'.format(
                            loss_val.data.cpu().numpy())

                        lossesfile.write('Validation loss: {}\n'.format(\
                            loss_val.data.cpu().numpy()))
                        utils.save_images(val_data, outputs, cfg, ckpt)
                        break
                model.train()
            # Save checkpoint
            utils.save_ckpt((model, optim), cfg, ckpt, seed)

    lossesfile.close()
コード例 #4
0
def train_vanilla_gan(training_config):
    writer = SummaryWriter()  # (tensorboard) writer will output to ./runs/ directory by default
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # checking whether you have a GPU

    # Prepare MNIST data loader (it will download MNIST the first time you run it)
    mnist_data_loader = utils.get_mnist_data_loader(training_config['batch_size'])

    # Fetch feed-forward nets (place them on GPU if present) and optimizers which will tweak their weights
    discriminator_net, generator_net = utils.get_gan(device, GANType.VANILLA.name)
    discriminator_opt, generator_opt = utils.get_optimizers(discriminator_net, generator_net)

    # 1s will configure BCELoss into -log(x) whereas 0s will configure it to -log(1-x)
    # So that means we can effectively use binary cross-entropy loss to achieve adversarial loss!
    adversarial_loss = nn.BCELoss()
    real_images_gt = torch.ones((training_config['batch_size'], 1), device=device)
    fake_images_gt = torch.zeros((training_config['batch_size'], 1), device=device)

    # For logging purposes
    ref_batch_size = 16
    ref_noise_batch = utils.get_gaussian_latent_batch(ref_batch_size, device)  # Track G's quality during training
    discriminator_loss_values = []
    generator_loss_values = []
    img_cnt = 0

    ts = time.time()  # start measuring time

    # GAN training loop, it's always smart to first train the discriminator so as to avoid mode collapse!
    utils.print_training_info_to_console(training_config)
    for epoch in range(training_config['num_epochs']):
        for batch_idx, (real_images, _) in enumerate(mnist_data_loader):

            real_images = real_images.to(device)  # Place imagery on GPU (if present)

            #
            # Train discriminator: maximize V = log(D(x)) + log(1-D(G(z))) or equivalently minimize -V
            # Note: D = discriminator, x = real images, G = generator, z = latent Gaussian vectors, G(z) = fake images
            #

            # Zero out .grad variables in discriminator network (otherwise we would have corrupt results)
            discriminator_opt.zero_grad()

            # -log(D(x)) <- we minimize this by making D(x)/discriminator_net(real_images) as close to 1 as possible
            real_discriminator_loss = adversarial_loss(discriminator_net(real_images), real_images_gt)

            # G(z) | G == generator_net and z == utils.get_gaussian_latent_batch(batch_size, device)
            fake_images = generator_net(utils.get_gaussian_latent_batch(training_config['batch_size'], device))
            # D(G(z)), we call detach() so that we don't calculate gradients for the generator during backward()
            fake_images_predictions = discriminator_net(fake_images.detach())
            # -log(1 - D(G(z))) <- we minimize this by making D(G(z)) as close to 0 as possible
            fake_discriminator_loss = adversarial_loss(fake_images_predictions, fake_images_gt)

            discriminator_loss = real_discriminator_loss + fake_discriminator_loss
            discriminator_loss.backward()  # this will populate .grad vars in the discriminator net
            discriminator_opt.step()  # perform D weights update according to optimizer's strategy

            #
            # Train generator: minimize V1 = log(1-D(G(z))) or equivalently maximize V2 = log(D(G(z))) (or min of -V2)
            # The original expression (V1) had problems with diminishing gradients for G when D is too good.
            #

            # if you want to cause mode collapse probably the easiest way to do that would be to add "for i in range(n)"
            # here (simply train G more frequent than D), n = 10 worked for me other values will also work - experiment.

            # Zero out .grad variables in discriminator network (otherwise we would have corrupt results)
            generator_opt.zero_grad()

            # D(G(z)) (see above for explanations)
            generated_images_predictions = discriminator_net(generator_net(utils.get_gaussian_latent_batch(training_config['batch_size'], device)))
            # By placing real_images_gt here we minimize -log(D(G(z))) which happens when D approaches 1
            # i.e. we're tricking D into thinking that these generated images are real!
            generator_loss = adversarial_loss(generated_images_predictions, real_images_gt)

            generator_loss.backward()  # this will populate .grad vars in the G net (also in D but we won't use those)
            generator_opt.step()  # perform G weights update according to optimizer's strategy

            #
            # Logging and checkpoint creation
            #

            generator_loss_values.append(generator_loss.item())
            discriminator_loss_values.append(discriminator_loss.item())

            if training_config['enable_tensorboard']:
                writer.add_scalars('losses/g-and-d', {'g': generator_loss.item(), 'd': discriminator_loss.item()}, len(mnist_data_loader) * epoch + batch_idx + 1)
                # Save debug imagery to tensorboard also (some redundancy but it may be more beginner-friendly)
                if training_config['debug_imagery_log_freq'] is not None and batch_idx % training_config['debug_imagery_log_freq'] == 0:
                    with torch.no_grad():
                        log_generated_images = generator_net(ref_noise_batch)
                        log_generated_images_resized = nn.Upsample(scale_factor=2, mode='nearest')(log_generated_images)
                        intermediate_imagery_grid = make_grid(log_generated_images_resized, nrow=int(np.sqrt(ref_batch_size)), normalize=True)
                        writer.add_image('intermediate generated imagery', intermediate_imagery_grid, len(mnist_data_loader) * epoch + batch_idx + 1)

            if training_config['console_log_freq'] is not None and batch_idx % training_config['console_log_freq'] == 0:
                print(f'GAN training: time elapsed = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(mnist_data_loader)}]')

            # Save intermediate generator images (more convenient like this than through tensorboard)
            if training_config['debug_imagery_log_freq'] is not None and batch_idx % training_config['debug_imagery_log_freq'] == 0:
                with torch.no_grad():
                    log_generated_images = generator_net(ref_noise_batch)
                    log_generated_images_resized = nn.Upsample(scale_factor=2.5, mode='nearest')(log_generated_images)
                    save_image(log_generated_images_resized, os.path.join(training_config['debug_path'], f'{str(img_cnt).zfill(6)}.jpg'), nrow=int(np.sqrt(ref_batch_size)), normalize=True)
                    img_cnt += 1

            # Save generator checkpoint
            if training_config['checkpoint_freq'] is not None and (epoch + 1) % training_config['checkpoint_freq'] == 0 and batch_idx == 0:
                ckpt_model_name = f"vanilla_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth"
                torch.save(utils.get_training_state(generator_net, GANType.VANILLA.name), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

    # Save the latest generator in the binaries directory
    torch.save(utils.get_training_state(generator_net, GANType.VANILLA.name), os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
コード例 #5
0
def validate(cfg):
    '''
    Main loop for validation, load the dataset, model, and
    other things. Run validation on the validation set
    '''
    print json.dumps(cfg, sort_keys=True, indent=4)

    use_cuda = cfg['use-cuda']
    _, _, _, val_dl = utils.get_data_loaders(cfg)

    model = utils.get_model(cfg)
    if use_cuda:
        model = model.cuda()
    model = utils.init_weights(model, cfg)
    model.eval()

    # Get pretrained models, optimizers and loss functions
    optim = utils.get_optimizers(model, cfg)
    model, _, metadata = utils.load_ckpt(model, optim, cfg)
    loss_fn = utils.get_losses(cfg)

    # Set up random seeds
    if metadata is not None:
        seed = metadata['seed']
    # Validation code, reproducibility is required
    seed = 42

    # Random seed according to what the saved model is
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Run test loop
    losses_list = []
    # Run the main training loop
    for idx, data in enumerate(val_dl):
        # Change to required device
        for key, value in data.items():
            data[key] = Variable(value)
            if use_cuda:
                data[key] = data[key].cuda()

        data = utils.repeat_data(data, cfg)

        # Get all outputs
        outputs = model(data)
        loss_val = loss_fn(outputs, data, cfg, val=True)
        losses_list.append(float(loss_val))
        # print it
        print('Step: {}, val_loss: {}'.format(idx,
                                              loss_val.data.cpu().numpy()))

        if cfg['val']['save-img']:
            print outputs['out'].shape
            utils.save_val_images(data, outputs, cfg, idx)

    print(
        """
        Summary:
        Mean:   {},
        Std:    {},
        25per:  {},
        50per:  {},
        75per:  {},
    """.format(
            np.mean(losses_list),
            np.std(losses_list),
            np.percentile(losses_list, 25),
            np.percentile(losses_list, 50),
            np.percentile(losses_list, 75),
        ))