示例#1
0
    def __init__(self, args):
        super().__init__()

        self.device = args.device
        self.args = args

        self.initial_lr = self.args.lr
        self.lr = self.args.lr
        self.lr_rampdown_length = 0.4

        # Init global step
        self.global_step = 0
        self.step_start = 0

        # Init Generator
        self.g = style_gan_2.PretrainedGenerator1024().eval().to(self.device)
        for param in self.g.parameters():
            param.requires_grad = False

        # Define audio encoder
        self.audio_encoder = models.AudioExpressionNet3(args.T).to(self.device).train()

        # Print # parameters
        print("# params {} (trainable {})".format(
            utils.count_params(self.audio_encoder),
            utils.count_trainable_params(self.audio_encoder)
        ))

        # Select optimizer and loss criterion
        self.optim = torch.optim.Adam(self.audio_encoder.parameters(), lr=self.lr)
        self.lpips = PerceptualLoss(model='net-lin', net='vgg', gpu_id=args.gpu)

        if self.args.cont or self.args.test:
            path = self.args.model_path
            self.load(path)
            self.step_start = self.global_step

        # Mouth mask for image
        mouth_mask = torch.load('saves/pre-trained/tagesschau_mouth_mask_5std.pt').to(self.device)
        # eyes_mask = torch.load('saves/pre-trained/tagesschau_eyes_mask_3std.pt').to(self.device)
        self.image_mask = mouth_mask.clamp(0., 1.)
        # self.image_mask = (mouth_mask + eyes_mask).clamp(0., 1.)

        # MSE mask
        self.mse_mask = torch.load('saves/pre-trained/mse_mask_var+1.pt')[4:8].unsqueeze(0).to(self.device)

        # Set up tensorboard
        if not self.args.debug and not self.args.test:
            tb_dir = self.args.save_dir
            # self.writer = SummaryWriter(tb_dir)
            self.train_writer = utils.HparamWriter(tb_dir + 'train/')
            self.val_writer = utils.HparamWriter(tb_dir + 'val/')
            self.train_writer.log_hyperparams(self.args)
            print(f"Logging run to {tb_dir}")

            # Create save dir
            os.makedirs(self.args.save_dir + 'models', exist_ok=True)
            os.makedirs(self.args.save_dir + 'sample', exist_ok=True)
示例#2
0
    def __init__(self,
                 g,
                 num_steps=1000,
                 initial_learning_rate=0.1,
                 initial_noise_factor=0.05,
                 verbose=True):

        self.num_steps = num_steps
        self.n_mean_latent = 10000
        self.initial_lr = initial_learning_rate
        self.initial_noise_factor = initial_noise_factor
        self.lr_rampdown_length = 0.25
        self.lr_rampup_length = 0.05
        self.noise_ramp_length = 0.75
        self.regularize_noise_weight = 1e5
        self.verbose = verbose

        self.latent_expr = None
        self.lpips = None
        self.target_images = None
        self.imag_gen = None
        self.loss = None
        self.lr = None
        self.cur_step = None

        self.g_ema = g
        self.device = next(g.parameters()).device

        # Find latent stats
        self._info(('Finding W midpoint and stddev using %d samples...' %
                    self.n_mean_latent))
        torch.manual_seed(123)
        with torch.no_grad():
            noise_sample = torch.randn(self.n_mean_latent,
                                       512,
                                       device=self.device)
            latent_out = self.g_ema.style(noise_sample)

        self.latent_mean = latent_out.mean(0)
        self.latent_std = ((latent_out - self.latent_mean).pow(2).sum() /
                           self.n_mean_latent)**0.5
        self._info('std = {}'.format(self.latent_std))

        self.latent_in = self.latent_mean.detach().clone().unsqueeze(0)
        self.latent_in = self.latent_in.repeat(self.g_ema.n_latent, 1)
        self.latent_in.requires_grad = True

        # Find noise inputs.
        self.noises = [noise.to(self.device) for noise in g.noises]

        # Init optimizer
        self.opt = torch.optim.Adam([self.latent_in] + self.noises,
                                    lr=self.initial_lr)

        # Init loss function
        self.lpips = PerceptualLoss(model='net-lin', net='vgg').to(self.device)
    def __init__(self, args):
        super().__init__()

        self.device = args.device
        self.args = args

        self.initial_lr = self.args.lr
        self.lr = self.args.lr
        self.lr_rampdown_length = 0.3
        self.lr_rampup_length = 0.1

        # Load generator
        self.g = style_gan_2.PretrainedGenerator1024().eval().to(self.device)
        for param in self.g.parameters():
            param.requires_grad = False
        self.latent_avg = self.g.latent_avg.repeat(
            18, 1).unsqueeze(0).to(self.device)

        # Init global step
        self.global_step = 0

        # Define encoder model
        self.e = resnetEncoder().train().to(self.device)

        # Print # parameters
        print("# params {} (trainable {})".format(
            utils.count_params(self.e),
            utils.count_trainable_params(self.e)
        ))

        # Select optimizer and loss criterion
        self.optim = torch.optim.Adam(self.e.parameters(), lr=self.initial_lr)
        self.criterion = PerceptualLoss(
            model='net-lin', net='vgg', gpu_id=args.gpu)

        # Load model and optimizer checkpoint
        if self.args.cont or self.args.test or self.args.run:
            path = self.args.model_path
            self.load(path)

        # Set up tensorboard
        if not self.args.debug and not self.args.test and not self.args.run:
            tb_dir = 'tensorboard_runs/encode_stylegan/' + \
                self.args.save_dir.split('/')[-2]
            self.writer = SummaryWriter(tb_dir)
            print(f"Logging run to {tb_dir}")

            # Create save dir
            os.makedirs(self.args.save_dir + 'models', exist_ok=True)
def main(hparams):
    # set up perceptual loss
    device = 'cuda:0'
    percept = PerceptualLoss(
            model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
    )

    utils.print_hparams(hparams)

    # get inputs
    xs_dict = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses, lpips_scores, z_hats = utils.load_checkpoints(hparams)

    x_hats_dict = {model_type : {} for model_type in hparams.model_types}
    x_batch_dict = {}

    A = utils.get_A(hparams)
    noise_batch = hparams.noise_std * np.random.standard_t(2, size=(hparams.batch_size, hparams.num_measurements))



    for key, x in xs_dict.items():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()])
            if is_saved:
                continue

        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue

        # Reshape input
        x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.items()]
        x_batch = np.concatenate(x_batch_list)

        # Construct noise and measurements


        y_batch = utils.get_measurements(x_batch, A, noise_batch, hparams)

        # Construct estimates using each estimator
        for model_type in hparams.model_types:
            estimator = estimators[model_type]
            x_hat_batch, z_hat_batch, m_loss_batch = estimator(A, y_batch, hparams)

            for i, key in enumerate(x_batch_dict.keys()):
                x = xs_dict[key]
                y_train = y_batch[i]
                x_hat = x_hat_batch[i]

                # Save the estimate
                x_hats_dict[model_type][key] = x_hat

                # Compute and store measurement and l2 loss
                measurement_losses[model_type][key] = m_loss_batch[key]
                l2_losses[model_type][key] = utils.get_l2_loss(x_hat, x)
                lpips_scores[model_type][key] = utils.get_lpips_score(percept, x_hat, x, hparams.image_shape)
                z_hats[model_type][key] = z_hat_batch[i]

        print('Processed upto image {0} / {1}'.format(key+1, len(xs_dict)))

        # Checkpointing
        if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams)
            x_hats_dict = {model_type : {} for model_type in hparams.model_types}
            print('\nProcessed and saved first ', key+1, 'images\n')

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams)
        print('\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict)))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print(model_type)
            measurement_loss_list = list(measurement_losses[model_type].values())
            l2_loss_list = list(l2_losses[model_type].values())
            mean_m_loss = np.mean(measurement_loss_list)
            mean_l2_loss = np.mean(l2_loss_list)
            print('mean measurement loss = {0}'.format(mean_m_loss))
            print('mean l2 loss = {0}'.format(mean_l2_loss))

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print('\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict)))
        print('Consider rerunning lazily with a smaller batch size.')