Ejemplo n.º 1
0
def train(generator, optimizer, scheduler, eng, params, pca=None):

    generator.train()

    # initialization
    if params.restore_from is None:
        effs_mean_history = []
        binarization_history = []
        diversity_history = []
        iter0 = 0
    else:
        effs_mean_history = params.checkpoint['effs_mean_history']
        binarization_history = params.checkpoint['binarization_history']
        diversity_history = params.checkpoint['diversity_history']
        iter0 = params.checkpoint['iter']

    # training loop
    with tqdm(total=params.numIter) as t:
        it = 0
        while True:
            it += 1
            params.iter = it + iter0

            # normalized iteration number
            normIter = params.iter / params.numIter

            # specify current batch size
            params.batch_size = int(
                params.batch_size_start +
                (params.batch_size_end - params.batch_size_start) *
                (1 - (1 - normIter)**params.batch_size_power))

            # sigma decay
            params.sigma = params.sigma_start + (params.sigma_end -
                                                 params.sigma_start) * normIter

            # learning rate decay
            scheduler.step()

            # binarization amplitude in the tanh function
            if params.iter < 1000:
                params.binary_amp = int(params.iter / 100) + 1
            else:
                params.binary_amp = 10

            # save model
            if it % 5000 == 0 or it > params.numIter:
                model_dir = os.path.join(params.output_dir, 'model',
                                         'iter{}'.format(it + iter0))
                os.makedirs(model_dir, exist_ok=True)
                utils.save_checkpoint(
                    {
                        'iter': it + iter0 - 1,
                        'gen_state_dict': generator.state_dict(),
                        'optim_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'effs_mean_history': effs_mean_history,
                        'binarization_history': binarization_history,
                        'diversity_history': diversity_history
                    },
                    checkpoint=model_dir)

            # terminate the loop
            if it > params.numIter:
                return

            # sample  z
            z = sample_z(params.batch_size, params)

            # generate a batch of iamges
            gen_imgs = generator(z, params)

            # calculate efficiencies and gradients using EM solver
            effs, gradients = compute_effs_and_gradients(gen_imgs, eng, params)

            # free optimizer buffer
            optimizer.zero_grad()

            # construct the loss function
            binary_penalty = params.binary_penalty_start if params.iter < params.binary_step_iter else params.binary_penalty_end
            g_loss = global_loss_function(gen_imgs, effs, gradients,
                                          params.sigma, binary_penalty)

            # train the generator
            g_loss.backward()
            optimizer.step()

            # evaluate
            if it % params.plot_iter == 0:
                generator.eval()

                # vilualize generated images at various conditions
                visualize_generated_images(generator, params)

                # evaluate the performance of current generator
                effs_mean, binarization, diversity = evaluate_training_generator(
                    generator, eng, params)

                # add to history
                effs_mean_history.append(effs_mean)
                binarization_history.append(binarization)
                diversity_history.append(diversity)

                # plot current history
                utils.plot_loss_history((effs_mean_history, diversity_history,
                                         binarization_history), params)
                generator.train()

            t.update()
Ejemplo n.º 2
0
    generator = Generator(params)
    discriminator = Discriminator(params)
    if params.cuda:
        generator.cuda()
        discriminator.cuda()

    # Define the optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=params.lr_gen,
                                   betas=(params.beta1_gen, params.beta2_gen))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=params.lr_dis,
                                   betas=(params.beta1_dis, params.beta2_dis))

    # train the model and save
    logging.info('Start training')
    loss_history = train((generator, discriminator),
                         (optimizer_G, optimizer_D), dataloader, params)

    # plot loss history and save
    utils.plot_loss_history(loss_history, output_dir)

    # Generate images and save
    wavelengths = [w for w in range(500, 1301, 50)]
    angles = [a for a in range(35, 86, 5)]

    logging.info(
        'Start generating devices for wavelength range {} to {} and angle range from {} to {} \n'
        .format(min(wavelengths), max(wavelengths), min(angles), max(angles)))
    evaluate(generator, wavelengths, angles, num_imgs=500, params=params)
Ejemplo n.º 3
0
def train(models, optimizers, schedulers, eng, params):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    generator = models
    optimizer_G = optimizers
    scheduler_G = schedulers

    generator.train()

    pca = PCA_model("PCA.mat")

    make_figure_dir(params.output_dir)

    # lamda_list = [600, 700, 800, 900, 1000, 1100, 1200]
    # theta_list = [40, 50, 60, 70, 80]

    if params.restore_from is None:
        Eff_mean_history = []
        Binarization_history = []
        pattern_variance = []
        iter0 = 0
        imgs_2 = []
        Effs_2 = []
    else:
        Eff_mean_history = params.checkpoint['Eff_mean_history']
        iter0 = params.checkpoint['iter']
        Binarization_history = params.checkpoint['Binarization_history']
        pattern_variance = params.checkpoint['pattern_variance']
        imgs_2 = params.checkpoint['imgs_2']
        Effs_2 = params.checkpoint['Effs_2']

    if params.tensorboard:
        loss_logger = logger.set_logger(params.output_dir)

    with tqdm(total=params.numIter, leave=False, ncols=70) as t:

        for i in range(params.numIter):
            it = i + 1
            normIter = it / params.numIter
            params.iter = it + iter0

            scheduler_G.step()

            # binarization amplitude in the tanh function
            if params.iter < 1000:
                params.binary_amp = int(params.iter / 100) + 1

            # use solver and phyiscal gradient to update the Generator
            params.solver_batch_size = int(
                params.solver_batch_size_start +
                (params.solver_batch_size_end -
                 params.solver_batch_size_start) *
                (1 - (1 - normIter)**params.solver_batch_size_power))
            if params.noise_constant == 1:
                noise = (torch.ones(params.solver_batch_size,
                                    params.noise_dims).type(Tensor) *
                         randconst) * params.noise_amplitude
            else:
                if params.noise_distribution == 'uniform':
                    noise = ((torch.rand(params.solver_batch_size,
                                         params.noise_dims).type(Tensor) * 2. -
                              1.) * params.noise_amplitude)
                else:
                    noise = (torch.randn(params.solver_batch_size,
                                         params.noise_dims).type(Tensor)
                             ) * params.noise_amplitude
            """
            batch equivalent
            """
            # lamdaconst = torch.rand(1).type(Tensor) * 600 + 600
            # thetaconst = torch.rand(1).type(Tensor) * 40 + 40
            # lamda = torch.ones(params.solver_batch_size,
            #                    1).type(Tensor) * lamdaconst
            # theta = torch.ones(params.solver_batch_size,
            #                    1).type(Tensor) * thetaconst
            """
            batch randomized
            """
            lamda = torch.rand(params.solver_batch_size,
                               1).type(Tensor) * 600 + 600
            theta = torch.rand(params.solver_batch_size,
                               1).type(Tensor) * 40 + 40

            z = torch.cat((lamda, theta, noise), 1)
            z = z.to(device)
            generator.to(device)
            gen_imgs = generator(z, params.binary_amp)

            img = torch.squeeze(gen_imgs[:, 0, :]).data.cpu().numpy()
            img = matlab.double(img.tolist())

            wavelength = matlab.double(lamda.cpu().numpy().tolist())
            desired_angle = matlab.double(theta.cpu().numpy().tolist())

            Grads_and_Effs = eng.GradientFromSolver_1D_parallel(
                img, wavelength, desired_angle)
            Grads_and_Effs = Tensor(Grads_and_Effs)
            grads = Grads_and_Effs[:, 1:]
            Efficiency_real = Grads_and_Effs[:, 0]

            Eff_max = torch.max(Efficiency_real.view(-1))
            Eff_reshape = Efficiency_real.view(-1, 1).unsqueeze(2)

            Gradients = Tensor(grads).unsqueeze(1) * gen_imgs * (
                1. / params.sigma * torch.exp(
                    (Eff_reshape - Eff_max) / params.sigma))

            # Train generator
            optimizer_G.zero_grad()

            binary_penalty = params.binary_penalty_start if params.iter < params.binary_step_iter else params.binary_penalty_end
            if params.binary == 1:
                g_loss_solver = -torch.mean(
                    torch.mean(Gradients, dim=0).view(-1)) - torch.mean(
                        torch.abs(gen_imgs.view(-1)) *
                        (2.0 - torch.abs(gen_imgs.view(-1)))) * binary_penalty
            else:
                g_loss_solver = -torch.mean(
                    torch.mean(Gradients, dim=0).view(-1))

            g_loss_solver.backward()
            optimizer_G.step()

            if params.tensorboard:
                loss_logger.scalar_summary(
                    'loss',
                    g_loss_solver.cpu().detach().numpy(), it)

            if it == 1 or it % params.save_iter == 0:

                # visualization

                generator.eval()
                outputs_imgs = sample_images(generator, 100, params)

                Binarization = torch.mean(torch.abs(outputs_imgs.view(-1)))
                Binarization_history.append(Binarization)

                diversity = torch.mean(torch.std(outputs_imgs, dim=0))
                pattern_variance.append(diversity.data)

                numImgs = 1 if params.noise_constant == 1 else 100

                img_2_tmp, Eff_2_tmp = PCA_analysis(generator, pca, eng,
                                                    params, numImgs)
                imgs_2.append(img_2_tmp)
                Effs_2.append(Eff_2_tmp)

                Eff_mean_history.append(np.mean(Eff_2_tmp))
                utils.plot_loss_history(
                    ([], [], Eff_mean_history, pattern_variance,
                     Binarization_history), params.output_dir)

                generator.train()

                # save model

                model_dir = os.path.join(params.output_dir, 'model',
                                         'iter{}'.format(it + iter0))
                os.makedirs(model_dir, exist_ok=True)
                utils.save_checkpoint(
                    {
                        'iter': it + iter0,
                        'gen_state_dict': generator.state_dict(),
                        'optim_G_state_dict': optimizer_G.state_dict(),
                        'scheduler_G_state_dict': scheduler_G.state_dict(),
                        'Eff_mean_history': Eff_mean_history,
                        'Binarization_history': Binarization_history,
                        'pattern_variance': pattern_variance,
                        'Effs_2': Effs_2,
                        'imgs_2': imgs_2
                    },
                    checkpoint=model_dir)

            if it == params.numIter:
                model_dir = os.path.join(params.output_dir, 'model')
                utils.save_checkpoint(
                    {
                        'iter': it + iter0,
                        'gen_state_dict': generator.state_dict(),
                        'optim_G_state_dict': optimizer_G.state_dict(),
                        'scheduler_G_state_dict': scheduler_G.state_dict(),
                        'Eff_mean_history': Eff_mean_history,
                        'Binarization_history': Binarization_history,
                        'pattern_variance': pattern_variance,
                        'Effs_2': Effs_2,
                        'imgs_2': imgs_2
                    },
                    checkpoint=model_dir)

                io.savemat(params.output_dir + '/scatter.mat',
                           mdict={
                               'imgs_2': np.asarray(imgs_2),
                               'Effs_2': np.asarray(Effs_2)
                           })
                return

            t.update()
Ejemplo n.º 4
0
def train_model(model,
                metrics,
                train_batch_generator,
                val_batch_generator,
                opt,
                lr_scheduler=None,
                ckpt_name=None,
                n_epochs=30,
                plot_path=None):
    """
    A function to train a model. While being executed, plots
    the dependency of loss value and metrics from epoch number.
    Saves the parameters of encoder with best loss value in
    checkpoint file  
    :params: 
        model   - a model to be trained. Should be inherited from
                  torch.nn.Module and should implement 'forward()' method

        metrics - callabe to evaluate target metrics on model
                  predictions and correct labels

        train_batch_generator - torch.Dataloader for dataset of
                  CigButtDataset class.

        val_batch_generator   - torch.Dataloader for dataset of 
                  CigButtDataset class.

        opt          - optimizer from torch.optim
        lr_scheduler - scheduler form torch.optim.lr_scheduler 
        cktp_name    - full path to checkpoint file
        n_epochs     - number of epochs, default=30
        save_plots   - a path to save plots of loss and metrics, default=None
    """

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)
    n_train = len(train_batch_generator.dataset)
    n_val = len(val_batch_generator.dataset)

    batch_size = train_batch_generator.batch_size

    loss_history_train, metrics_history_train = [], []
    loss_history_val, metrics_history_val = [], []

    top_val_metrics_value = 0.0

    total_time = time.time()

    for epoch in range(n_epochs):

        train_loss, train_metrics = [], []  # history over the batch
        val_loss, val_metrics = [], []  # history over the batch

        start_time = time.time()

        # Training phase

        model.train(True)
        for batch in tqdm(train_batch_generator, desc='Training'):

            batch_img = batch['image'].to(device)
            batch_msk = batch['mask'].to(device)
            weights = batch['weights'].to(device)

            opt.zero_grad()
            batch_preds = model.forward(batch_img)
            loss_train = loss(batch_preds, batch_msk, weights)
            with autograd.detect_anomaly():
                try:
                    loss_train.backward()
                except RuntimeError as err:
                    continue

            opt.step()

            train_loss.append(loss_train.cpu().data.numpy())
            train_metrics.append(
                metrics(batch_preds.cpu().data.numpy(),
                        batch_msk.cpu().data.numpy()))

            torch.cuda.empty_cache()

        # Evaluation phase

        model.train(False)
        for batch in tqdm(val_batch_generator, desc='Validation'):

            batch_img = batch['image'].to(device)
            batch_msk = batch['mask'].to(device)
            weights = batch['weights'].to(device)

            batch_preds = model.forward(batch_img)
            loss_val = loss(batch_preds, batch_msk, weights)

            val_loss.append(loss_val.cpu().data.numpy())
            val_metrics.append(
                metrics(batch_preds.cpu().data.numpy(),
                        batch_msk.cpu().data.numpy()))

            torch.cuda.empty_cache()

        train_loss_value = np.mean(train_loss[-n_train // batch_size:])
        train_metrics_value = np.mean(train_metrics[-n_train // batch_size:])
        loss_history_train.append(train_loss_value)
        metrics_history_train.append(train_metrics_value)

        val_loss_value = np.mean(val_loss[-n_val // batch_size:])
        val_metrics_value = np.mean(val_metrics[-n_val // batch_size:])
        loss_history_val.append(val_loss_value)
        metrics_history_val.append(val_metrics_value)

        if lr_scheduler: lr_scheduler.step(val_loss_value)

        if val_metrics_value > top_val_metrics_value and ckpt_name is not None:
            top_val_metrics_value = val_metrics_value
            with open(ckpt_name, 'wb') as f:
                torch.save(model, f)

        clear_output(True)

        f, axarr = plt.subplots(1, 2, figsize=(16, 8))
        metrics_log = {
            'train': metrics_history_train,
            'val': metrics_history_val
        }
        if epoch: plot_loss_history(metrics_log, axarr[1], 'Metrics')
        loss_log = {'train': loss_history_train, 'val': loss_history_val}
        if epoch:
            plot_loss_history(loss_log, axarr[0], 'Loss')
            plt.legend()
            if plot_path:
                plt.savefig(os.path.join(plot_path, f'epoch{epoch}.jpg'))
            plt.show()

        # display the results for currrent epoch:
        print("Epoch {} of {} took {:.3f}s".format(epoch + 1, n_epochs,
                                                   time.time() - start_time))
        print("  Training metrics: \t{:.6f}".format(train_metrics_value))
        print("  Validation metrics: \t\t\t{:.6f} ".format(val_metrics_value))

    print(f"Trainig took {time.time() - total_time}s in total")

    return model, opt, loss_history_val, metrics_history_val
Ejemplo n.º 5
0
    def train(self, plot_period: int=5):
        """ define loss-, optimzer- and scheduler-functions """
        criterion_disc = nn.BCELoss()
        criterion_gen = nn.BCELoss()

        optimizer_disc = torch.optim.Adam(self.discriminator.parameters(), lr=self.disc_lr, betas=(0.5, 0.999))
        optimizer_gen = torch.optim.Adam(self.generator.parameters(), lr=self.gen_lr, betas=(0.5, 0.999))

        """ create benchmark """
        self.benchmark_logger.create_entry(self.benchmark_id, optimizer_disc, criterion_disc, self.epochs, self.batch_size, self.disc_lr, self.gen_lr, self.disc_lr_decay, self.gen_lr_decay, self.lr_decay_period, self.gaussian_noise_range)

        # initial noise rate
        noise_rate = self.gaussian_noise_range[0]

        # total loss log
        loss_disc_history, loss_disc_real_history, loss_disc_fake_history = [], [], []
        loss_gen_history = []

        for epoch in range(self.epochs):
            # epoch loss log
            epoch_loss_disc, epoch_loss_disc_real, epoch_loss_disc_fake = [], [], []
            epoch_loss_gen = []

            for iteration in tqdm(range(self.iterations), ncols=120, desc="batch-iterations"):
                images_real, targets_real, images_fake, targets_fake = self._create_batch(iteration)

                """ train discriminator """
                # update every third iteration to make the generator stronger
                self.discriminator.zero_grad()

                # train with real images
                predictions_real = self.discriminator.train()(images_real, gaussian_noise_rate=noise_rate)
                loss_real = criterion_disc(predictions_real, targets_real)

                loss_real.backward()

                # train with fake images
                predictions_fake = self.discriminator.train()(images_fake, gaussian_noise_rate=noise_rate)
                loss_fake = criterion_disc(predictions_fake, targets_fake)

                loss_fake.backward(retain_graph=True)

                if iteration % 1 == 0:
                    optimizer_disc.step()

                # save losses
                epoch_loss_disc.append(loss_real.item() + loss_fake.item())
                epoch_loss_disc_real.append(loss_real.item())
                epoch_loss_disc_fake.append(loss_fake.item())

                """ train generator """
                self.generator.zero_grad()

                # train discriminator on fake images with target "real image" ([1, 0])
                predictions_fake = self.discriminator.train()(images_fake)
                loss_gen = criterion_gen(predictions_fake, targets_real)

                loss_gen.backward()
                optimizer_gen.step()

                epoch_loss_gen.append(loss_gen.item())
    

            """ linear gaussian noise decay for disc. inputs """
            noise_rate = np.linspace(self.gaussian_noise_range[0], self.gaussian_noise_range[1], self.epochs)[epoch]


            """ save models """
            save_models(self.generator, self.discriminator, save_to=(self.models_path), current_epoch=epoch, period=5)
            

            """ calculate average losses of the epoch """
            current_loss_disc, current_loss_disc_real, current_loss_disc_fake = round(np.mean(epoch_loss_disc), 4), round(np.mean(epoch_loss_disc_real), 4), round(np.mean(epoch_loss_disc_fake), 4)
            current_loss_gen = round(np.mean(epoch_loss_gen), 4)


            """ get learning-rate """
            current_disc_lr = round(optimizer_disc.param_groups[0]["lr"], 7)
            current_gen_lr = round(optimizer_gen.param_groups[0]["lr"], 7)


            """ learning-rate decay (set 'p' to 'False' for not doing lr-decay) """
            do = False
            if do:
                optimizer_disc.param_groups[0]["lr"] = lr_decay(lr=optimizer_disc.param_groups[0]["lr"], epoch=epoch, decay_rate=self.disc_lr_decay, period=self.lr_decay_period)
                optimizer_gen.param_groups[0]["lr"] = lr_decay(lr=optimizer_gen.param_groups[0]["lr"], epoch=epoch, decay_rate=self.gen_lr_decay, period=self.lr_decay_period)


            """ save losses for plotting """
            loss_disc_history.append(current_loss_disc); loss_disc_real_history.append(current_loss_disc_real); loss_disc_fake_history.append(current_loss_disc_fake)
            loss_gen_history.append(current_loss_gen)


            """ print trainings progress """
            print_progress(epoch, self.epochs, current_loss_disc, current_loss_disc_real, current_loss_disc_fake, current_loss_gen, current_disc_lr, current_gen_lr)


            """ plot generated images """
            if plot_period is not None:
                show_generated(self.generator, view_seconds=1, current_epoch=epoch, period=plot_period, save_to=(self.generated_images_path + "/" + str(epoch + 1) + ".png"))


        """ plot loss history """
        plot_loss_history(loss_disc_history, loss_disc_real_history, loss_disc_fake_history, loss_gen_history, save_to=(self.plots_path + "/" + self.benchmark_id + ".png"))