Example #1
0
class FlowExperiment(BaseExperiment):

    log_base = os.path.join(get_survae_path(), 'experiments/manifold/log')
    no_log_keys = [
        'project', 'name', 'log_tb', 'log_wandb', 'check_every', 'eval_every',
        'device', 'parallel', 'pin_memory', 'num_workers'
    ]

    def __init__(self, args, data_id, model_id, optim_id, train_loader,
                 eval_loader, model, optimizer, scheduler_iter,
                 scheduler_epoch):

        # Edit args
        if args.eval_every is None or args.eval_every == 0:
            args.eval_every = args.epochs
        if args.check_every is None or args.check_every == 0:
            args.check_every = args.epochs
        if args.name is None:
            args.name = time.strftime("%Y-%m-%d_%H-%M-%S")
        if args.project is None:
            args.project = '_'.join([data_id, model_id])

        if args.name == "debug":
            log_path = os.path.join(self.log_base, "debug", data_id, model_id,
                                    optim_id, f"seed{args.seed}",
                                    time.strftime("%Y-%m-%d_%H-%M-%S"))
        else:
            log_path = os.path.join(self.log_base, data_id, model_id, optim_id,
                                    f"seed{args.seed}", args.name)

        # Move model
        model = model.to(args.device)
        if args.parallel == 'dp':
            model = DataParallelDistribution(model)

        # Init parent
        super(FlowExperiment, self).__init__(model=model,
                                             optimizer=optimizer,
                                             scheduler_iter=scheduler_iter,
                                             scheduler_epoch=scheduler_epoch,
                                             log_path=log_path,
                                             eval_every=args.eval_every,
                                             check_every=args.check_every,
                                             save_samples=args.save_samples)

        # Store args
        self.create_folders()
        self.save_args(args)
        self.args = args

        # Store IDs
        self.data_id = data_id
        self.model_id = model_id
        self.optim_id = optim_id

        # Store data loaders
        self.train_loader = train_loader
        self.eval_loader = eval_loader

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args",
                                 get_args_table(args_dict).get_html_string(),
                                 global_step=0)

        if args.log_wandb:
            wandb.init(config=args_dict,
                       project=args.project,
                       id=args.name,
                       dir=self.log_path)

        # training params
        self.max_grad_norm = args.max_grad_norm

        # automatic mixed precision
        # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward)
        pytorch_170 = int(str(torch.__version__)[2]) >= 7
        self.amp = args.amp and args.parallel != 'dp' and pytorch_170
        if self.amp:
            # only available in pytorch 1.7.0+
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.scaler = None

        # save model architecture for reference
        self.save_architecture()

    def log_fn(self, epoch, train_dict, eval_dict):
        # Tensorboard
        if self.args.log_tb:
            for metric_name, metric_value in train_dict.items():
                self.writer.add_scalar('base/{}'.format(metric_name),
                                       metric_value,
                                       global_step=epoch + 1)
            if eval_dict:
                for metric_name, metric_value in eval_dict.items():
                    self.writer.add_scalar('eval/{}'.format(metric_name),
                                           metric_value,
                                           global_step=epoch + 1)

        # Weights & Biases
        if self.args.log_wandb:
            for metric_name, metric_value in train_dict.items():
                wandb.log({'base/{}'.format(metric_name): metric_value},
                          step=epoch + 1)
            if eval_dict:
                for metric_name, metric_value in eval_dict.items():
                    wandb.log({'eval/{}'.format(metric_name): metric_value},
                              step=epoch + 1)

    def resume(self):
        resume_path = os.path.join(self.log_base, self.data_id, self.model_id,
                                   self.optim_id, self.args.resume, 'check')
        self.checkpoint_load(resume_path)
        for epoch in range(self.current_epoch):

            train_dict = {}
            for metric_name, metric_values in self.train_metrics.items():
                train_dict[metric_name] = metric_values[epoch]

            if epoch in self.eval_epochs:
                eval_dict = {}
                for metric_name, metric_values in self.eval_metrics.items():
                    eval_dict[metric_name] = metric_values[
                        self.eval_epochs.index(epoch)]
            else:
                eval_dict = None

            self.log_fn(epoch, train_dict=train_dict, eval_dict=eval_dict)

    def beta_annealing(self, epoch, min_beta=0.0, max_beta=1.0):
        """
        Beta annealing term to control the weight of the NDP portion of the flow.
        Often it can help to let the bijective portion of the flow begin to map to
        a the first base distribution before expecting the NDP portion of the flow find a 
        low dimensional representation.
        Only applicable when using 2 (or more) base distributions
        """
        if self.args.annealing_schedule > 0 and self.args.compression == "vae" and len(
                self.args.base_distributions) > 1:
            return max(
                min([(epoch * 1.0) / max([self.args.annealing_schedule, 1.0]),
                     max_beta]), min_beta)
        else:
            return 1.0

    def run(self):
        if self.args.resume:
            self.resume()

        super(FlowExperiment, self).run(epochs=self.args.epochs)

    def train_fn(self, epoch):
        if self.amp:
            # use automatic mixed precision
            return self._train_amp(epoch)
        else:
            return self._train(epoch)

    def _train_amp(self, epoch):
        """
        Same training procedure, but uses half precision to speed up training on GPUs

        NOTE: Not currently implemented, this only runs on the latest pytorch versions
              so I'm leaving this out for now.
        """
        self.model.train()
        beta = self.beta_annealing(epoch)
        loss_sum = 0.0
        loss_count = 0
        for x in self.train_loader:

            # Cast operations to mixed precision
            with torch.cuda.amp.autocast():
                loss = -self.model.log_prob(x.to(
                    self.args.device), beta=beta).sum() / (math.log(2) *
                                                           x.shape.numel())
                #loss = elbo_bpd(self.model, x.to(self.args.device), beta=beta)

            # Scale loss and call backward() to create scaled gradients
            self.scaler.scale(loss).backward()

            if self.max_grad_norm > 0:
                # Unscales the gradients of optimizer's assigned params in-place
                self.scaler.unscale_(self.optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.max_grad_norm)

            # Unscale gradients and call (or skip) optimizer.step()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            if self.scheduler_iter:
                self.scheduler_iter.step()

            self.optimizer.zero_grad(set_to_none=True)

            # accumulate loss and report
            loss_sum += loss.detach().cpu().item() * len(x)
            loss_count += len(x)
            print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.
                  format(epoch + 1, self.args.epochs, loss_count,
                         len(self.train_loader.dataset),
                         loss_sum / loss_count),
                  end='\r')

        print('')
        if self.scheduler_epoch:
            self.scheduler_epoch.step()

        return {'bpd': loss_sum / loss_count}

    def _train(self, epoch):
        self.model.train()
        beta = self.beta_annealing(epoch)
        loss_sum = 0.0
        loss_count = 0
        for x in self.train_loader:
            self.optimizer.zero_grad()

            loss = -self.model.log_prob(x.to(self.args.device),
                                        beta=beta).sum() / (math.log(2) *
                                                            x.shape.numel())
            #loss = elbo_bpd(self.model, x.to(self.args.device), beta=beta)
            loss.backward()

            if self.max_grad_norm > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.max_grad_norm)

            self.optimizer.step()
            if self.scheduler_iter:
                self.scheduler_iter.step()

            loss_sum += loss.detach().cpu().item() * len(x)
            loss_count += len(x)
            print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.
                  format(epoch + 1, self.args.epochs, loss_count,
                         len(self.train_loader.dataset),
                         loss_sum / loss_count),
                  end='\r')

        print('')
        if self.scheduler_epoch:
            self.scheduler_epoch.step()

        return {'bpd': loss_sum / loss_count}

    def eval_fn(self, epoch):
        self.model.eval()
        with torch.no_grad():
            loss_sum = 0.0
            loss_count = 0
            for x in self.eval_loader:
                loss = elbo_bpd(self.model, x.to(self.args.device))
                loss_sum += loss.detach().cpu().item() * len(x)
                loss_count += len(x)
                print(
                    'Evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'
                    .format(epoch + 1, self.args.epochs, loss_count,
                            len(self.eval_loader.dataset),
                            loss_sum / loss_count),
                    end='\r')
            print('')
        return {'bpd': loss_sum / loss_count}

    def sample_fn(self):
        self.model.eval()

        path_check = '{}/check/checkpoint.pt'.format(self.log_path)
        checkpoint = torch.load(path_check)

        path_samples = '{}/samples/sample_ep{}_s{}.png'.format(
            self.log_path, checkpoint['current_epoch'], self.args.seed)
        if not os.path.exists(os.path.dirname(path_samples)):
            os.mkdir(os.path.dirname(path_samples))

        # save model samples
        samples = self.model.sample(
            self.args.samples).cpu().float() / (2**self.args.num_bits - 1)
        vutils.save_image(samples, path_samples, nrow=self.args.nrow)

        # save real samples too
        path_true_samples = '{}/samples/true_ep{}_s{}.png'.format(
            self.log_path, checkpoint['current_epoch'], self.args.seed)
        imgs = next(iter(self.eval_loader))[:self.args.samples].cpu().float()
        if imgs.max().item() > 2:
            imgs /= (2**self.args.num_bits - 1)
        vutils.save_image(imgs, path_true_samples, nrow=self.args.nrow)

    def stop_early(self, loss_dict, epoch):
        if self.args.early_stop == 0 or epoch < self.args.annealing_schedule:
            return False, True

        # else check if we've passed the early stopping threshold
        current_loss = loss_dict['bpd']
        model_improved = current_loss < self.best_loss
        if model_improved:
            early_stop_flag = False
            self.best_loss = current_loss
            self.best_loss_epoch = epoch
        else:
            # model didn't improve, do we consider it converged yet?
            early_stop_count = (epoch - self.best_loss_epoch)
            early_stop_flag = early_stop_count >= self.args.early_stop

        if early_stop_flag:
            print(
                f'Stopping training early: no improvement after {self.args.early_stop} epochs (last improvement at epoch {self.best_loss_epoch})'
            )

        return early_stop_flag, model_improved
Example #2
0
class GaussianProcessExperiment(BaseExperiment):

    log_base = os.path.join(get_survae_path(), 'experiments/student_teacher/log')
    no_log_keys = ['project', 'name',
                   'log_tb', 'log_wandb',
                   'device',
                   'parallel',
                   'eval_every',
                   'num_flows', 'actnorm', 'scale_fn', 'hidden_units', 'range_flow', 'base_dist', 'affine', 'augment_size',
                   'pin_memory', 'num_workers']

    def __init__(self, args, data_id, model_id, model, teacher):

        # Edit args
        args.epoch = 1
        if args.name is None:
            args.name = time.strftime("%Y-%m-%d_%H-%M-%S")
        if args.project is None:
            args.project = '_'.join([data_id, model_id])

        # Move models
        self.teacher = teacher.to(args.device)
        self.teacher.eval()
        cond_id = args.cond_trans.lower()
        self.cond_size = 1 if cond_id.startswith('split') or cond_id.startswith('multiply') else 2

        seed_id = f"seed{args.seed}"
        arch_id = f"Gaussian_Process_{args.kernel}_kernel_a{int(args.gp_alpha)}_s{int(100*args.gp_length_scale)}"
        if args.name == "debug":
            log_path = os.path.join(
                self.log_base, "debug", model_id, data_id, cond_id, arch_id, seed_id, time.strftime("%Y-%m-%d_%H-%M-%S"))
        else:
            log_path = os.path.join(
                self.log_base, model_id, data_id, cond_id, arch_id, seed_id, args.name)
        
        # Init parent
        super(GaussianProcessExperiment, self).__init__(model=model,
                                                        optimizer=None,
                                                        scheduler_iter=None,
                                                        scheduler_epoch=None,
                                                        log_path=log_path,
                                                        eval_every=args.eval_every)
        # Store args
        self.create_folders()
        self.save_args(args)
        self.args = args

        # Store IDs
        self.model_id = model_id
        self.data_id = data_id
        self.cond_id = cond_id
        self.arch_id = arch_id
        self.seed_id = seed_id

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args", get_args_table(
                args_dict).get_html_string(), global_step=0)

    def run(self):
        # Train
        train_dict = self.train_fn()
        self.log_train_metrics(train_dict)

        # Eval
        eval_dict = self.eval_fn()
        self.log_eval_metrics(eval_dict)

        # Log
        self.save_metrics()

        # Plotting
        self.plot_fn()

    def train_fn(self):
        self.teacher.eval()
        with torch.no_grad():
            y = self.teacher.sample(num_samples=self.args.train_samples)
            x = self.cond_fn(y)
            x = x.cpu().numpy()
            y = y.cpu().numpy()

        self.model.fit(x, y)
        nll = self.model.log_marginal_likelihood()
        r2 = self.model.score(x, y)
        print(f"Baseline: Log-marginal Likelihood={nll}, R-squared={r2}")
        return {'nll': nll, 'rsquared': r2}

    def eval_fn(self):
        K_test = 3 # number of MC samples
        
        self.teacher.eval()
        with torch.no_grad():
            y = self.teacher.sample(num_samples=self.args.test_samples)
            x = self.cond_fn(y)
            x = x.cpu().numpy()
            y = y.cpu().numpy()
            MC_samples = [self.model.predict(x, return_std=True) for _ in range(K_test)]

        means = np.stack([tup[0] for tup in MC_samples])
        logvar = np.stack([np.tile(tup[1], (2,1)).T for tup in MC_samples])

        test_ll = -0.5 * np.exp(-logvar) * (means - y.squeeze())**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi)
        test_ll = np.sum(np.sum(test_ll, -1), -1)
        test_ll = logsumexp(test_ll) - np.log(K_test)
        #pppp = test_ll / self.args.test_samples  # per point predictive probability
        rmse = np.mean((np.mean(means, 0) - y.squeeze())**2.0)**0.5

        r2 = self.model.score(x, y)
        print(f"Baseline: R-squared={r2}, rmse={rmse}, likelihood={test_ll}")
        return {'rsquared': r2, 'rmse': rmse, 'lhood': test_ll}

    def plot_fn(self):
        plot_path = os.path.join(self.log_path, "samples/")
        if not os.path.exists(plot_path):
            os.mkdir(plot_path)

        if self.args.dataset == 'face_einstein':
            bounds = [[0, 1], [0, 1]]
        else:
            bounds = [[-4, 4], [-4, 4]]

        # plot true data
        test_data = self.teacher.sample(num_samples=self.args.test_samples).data.numpy()
        plt.figure(figsize=(self.args.pixels/self.args.dpi, self.args.pixels/self.args.dpi), dpi=self.args.dpi)
        plt.hist2d(test_data[...,0], test_data[...,1], bins=256, range=bounds)
        plt.xlim(bounds[0])
        plt.ylim(bounds[1])
        plt.axis('off')
        plt.savefig(os.path.join(plot_path, f'teacher.png'), bbox_inches='tight', pad_inches=0)

        # Plot samples while varying the context
        with torch.no_grad():
            y = self.teacher.sample(num_samples=self.args.num_samples)
            x = self.cond_fn(y).cpu().numpy()
        y_mean, y_std = self.model.predict(x, return_std=True)
        temperature = 0.4
        samples = [np.random.normal(y_mean[:, i], y_std * temperature, self.args.num_samples).T[:, np.newaxis] \
                   for i in range(y_mean.shape[1])]
        samples = np.hstack(samples)
        plt.figure(figsize=(self.args.pixels/self.args.dpi, self.args.pixels/self.args.dpi), dpi=self.args.dpi)
        plt.hist2d(samples[...,0], samples[...,1], bins=256, range=bounds)
        plt.xlim(bounds[0])
        plt.ylim(bounds[1])
        plt.axis('off')
        plt.savefig(os.path.join(plot_path, f'varying_context_flow_samples.png'), bbox_inches='tight', pad_inches=0)
        
        # Plot density
        xv, yv = torch.meshgrid([torch.linspace(bounds[0][0], bounds[0][1], self.args.grid_size), torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size)])
        y = torch.cat([xv.reshape(-1,1), yv.reshape(-1,1)], dim=-1)
        x = self.cond_fn(y).numpy()
        means, logvar = self.model.predict(x, return_std=True)
        logvar = np.tile(logvar, (2,1)).T
        logprobs = -0.5 * np.exp(-logvar) * (means - y.numpy())**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi)
        logprobs = np.sum(logprobs, -1)
        logprobs = logprobs - logprobs.max()
        probs = np.exp(logprobs)
        plt.figure(figsize=(self.args.pixels/self.args.dpi, self.args.pixels/self.args.dpi), dpi=self.args.dpi)
        plt.pcolormesh(xv, yv, probs.reshape(xv.shape), shading='auto')
        plt.xlim(bounds[0])
        plt.ylim(bounds[1])
        plt.axis('off')
        plt.clim(0,self.args.clim)
        plt.savefig(os.path.join(plot_path, f'varying_context_flow_density.png'), bbox_inches='tight', pad_inches=0)

        # plot samples with fixed context
        for i, x in enumerate(self.context_permutations()):
            x = x.view((1, self.cond_size)).cpu().numpy()
            y_mean, y_std = self.model.predict(x, return_std=True)
            temperature = 1.5
            samples = np.hstack([np.random.normal(y_mean[:, i], y_std * temperature, self.args.num_samples).T[:, np.newaxis] \
                                 for i in range(y_mean.shape[1])])
            #samples = self.model.sample_y(x, n_samples=self.args.num_samples).T[..., 0]
            plt.figure(figsize=(self.args.pixels/self.args.dpi, self.args.pixels/self.args.dpi), dpi=self.args.dpi)
            plt.hist2d(samples[...,0], samples[...,1], bins=256, range=bounds)
            plt.xlim(bounds[0])
            plt.ylim(bounds[1])
            plt.axis('off')
            plt.savefig(os.path.join(plot_path, f'fixed_context_flow_samples_{i}.png'), bbox_inches='tight', pad_inches=0)

            # Plot density
            xv, yv = torch.meshgrid([torch.linspace(bounds[0][0], bounds[0][1], self.args.grid_size), torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size)])
            y = torch.cat([xv.reshape(-1,1), yv.reshape(-1,1)], dim=-1).cpu().numpy()
            means, logvar = self.model.predict(x, return_std=True)
            logvar = np.tile(logvar, (2,1)).T
            logprobs = -0.5 * np.exp(-logvar) * (means - y)**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi)
            logprobs = np.sum(logprobs, -1)
            logprobs = logprobs - logprobs.max()
            probs = np.exp(logprobs)
            plt.figure(figsize=(self.args.pixels/self.args.dpi, self.args.pixels/self.args.dpi), dpi=self.args.dpi)
            plt.pcolormesh(xv, yv, probs.reshape(xv.shape), shading='auto')
            plt.xlim(bounds[0])
            plt.ylim(bounds[1])
            plt.axis('off')
            plt.clim(0,self.args.clim)
            plt.savefig(os.path.join(plot_path, f'fixed_context_flow_density_{i}.png'), bbox_inches='tight', pad_inches=0)
Example #3
0
class FlowExperiment(BaseExperiment):

    log_base = os.path.join(get_survae_path(), 'experiments/image/log')
    no_log_keys = [
        'project', 'name', 'log_tb', 'log_wandb', 'check_every', 'eval_every',
        'device', 'parallel', 'pin_memory', 'num_workers'
    ]

    def __init__(self, args, data_id, model_id, optim_id, train_loader,
                 eval_loader, model, optimizer, scheduler_iter,
                 scheduler_epoch):

        # Edit args
        if args.eval_every is None:
            args.eval_every = args.epochs
        if args.check_every is None:
            args.check_every = args.epochs
        if args.name is None:
            args.name = time.strftime("%Y-%m-%d_%H-%M-%S")
        if args.project is None:
            args.project = '_'.join([data_id, model_id])

        # Move model
        model = model.to(args.device)
        if args.parallel == 'dp':
            model = DataParallelDistribution(model)

        # Init parent
        log_path = os.path.join(self.log_base, data_id, model_id, optim_id,
                                args.name)
        super(FlowExperiment, self).__init__(model=model,
                                             optimizer=optimizer,
                                             scheduler_iter=scheduler_iter,
                                             scheduler_epoch=scheduler_epoch,
                                             log_path=log_path,
                                             eval_every=args.eval_every,
                                             check_every=args.check_every)

        # Store args
        self.create_folders()
        self.save_args(args)
        self.args = args

        # Store IDs
        self.data_id = data_id
        self.model_id = model_id
        self.optim_id = optim_id

        # Store data loaders
        self.train_loader = train_loader
        self.eval_loader = eval_loader

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args",
                                 get_args_table(args_dict).get_html_string(),
                                 global_step=0)
        if args.log_wandb:
            wandb.init(config=args_dict,
                       project=args.project,
                       id=args.name,
                       dir=self.log_path)

    def log_fn(self, epoch, train_dict, eval_dict):

        # Tensorboard
        if self.args.log_tb:
            for metric_name, metric_value in train_dict.items():
                self.writer.add_scalar('base/{}'.format(metric_name),
                                       metric_value,
                                       global_step=epoch + 1)
            if eval_dict:
                for metric_name, metric_value in eval_dict.items():
                    self.writer.add_scalar('eval/{}'.format(metric_name),
                                           metric_value,
                                           global_step=epoch + 1)

        # Weights & Biases
        if self.args.log_wandb:
            for metric_name, metric_value in train_dict.items():
                wandb.log({'base/{}'.format(metric_name): metric_value},
                          step=epoch + 1)
            if eval_dict:
                for metric_name, metric_value in eval_dict.items():
                    wandb.log({'eval/{}'.format(metric_name): metric_value},
                              step=epoch + 1)

    def resume(self):
        resume_path = os.path.join(self.log_base, self.data_id, self.model_id,
                                   self.optim_id, self.args.resume, 'check')
        self.checkpoint_load(resume_path)
        for epoch in range(self.current_epoch):
            train_dict = {}
            for metric_name, metric_values in self.train_metrics.items():
                train_dict[metric_name] = metric_values[epoch]
            if epoch in self.eval_epochs:
                eval_dict = {}
                for metric_name, metric_values in self.eval_metrics.items():
                    eval_dict[metric_name] = metric_values[
                        self.eval_epochs.index(epoch)]
            else:
                eval_dict = None
            self.log_fn(epoch, train_dict=train_dict, eval_dict=eval_dict)

    def run(self):
        if self.args.resume: self.resume()
        super(FlowExperiment, self).run(epochs=self.args.epochs)

    def train_fn(self, epoch):
        self.model.train()
        loss_sum = 0.0
        loss_count = 0
        for x in self.train_loader:
            self.optimizer.zero_grad()
            loss = elbo_bpd(self.model, x.to(self.args.device))
            loss.backward()
            self.optimizer.step()
            if self.scheduler_iter: self.scheduler_iter.step()
            loss_sum += loss.detach().cpu().item() * len(x)
            loss_count += len(x)
            print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.
                  format(epoch + 1, self.args.epochs, loss_count,
                         len(self.train_loader.dataset),
                         loss_sum / loss_count),
                  end='\r')
        print('')
        if self.scheduler_epoch: self.scheduler_epoch.step()
        return {'bpd': loss_sum / loss_count}

    def eval_fn(self, epoch):
        self.model.eval()
        with torch.no_grad():
            loss_sum = 0.0
            loss_count = 0
            for x in self.eval_loader:
                loss = elbo_bpd(self.model, x.to(self.args.device))
                loss_sum += loss.detach().cpu().item() * len(x)
                loss_count += len(x)
                print(
                    'Evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'
                    .format(epoch + 1, self.args.epochs, loss_count,
                            len(self.eval_loader.dataset),
                            loss_sum / loss_count),
                    end='\r')
            print('')
        return {'bpd': loss_sum / loss_count}
Example #4
0
class DropoutExperiment(BaseExperiment):

    log_base = os.path.join(get_survae_path(),
                            'experiments/student_teacher/log')
    no_log_keys = [
        'project', 'name', 'log_tb', 'log_wandb', 'eval_every', 'device',
        'parallel', 'pin_memory', 'num_workers'
    ]

    def __init__(self, args, data_id, model_id, optim_id, model, teacher,
                 optimizer, scheduler_iter, scheduler_epoch):

        # Edit args
        if args.eval_every is None or args.eval_every == 0:
            args.eval_every = args.epochs
        if args.name is None:
            args.name = time.strftime("%Y-%m-%d_%H-%M-%S")
        if args.project is None:
            args.project = '_'.join([data_id, model_id])

        cond_id = args.cond_trans.lower()
        arch_id = f"Dropout_h{args.hidden_units}"
        seed_id = f"seed{args.seed}"
        if args.name == "debug":
            log_path = os.path.join(self.log_base, "debug", model_id, data_id,
                                    cond_id, arch_id, seed_id,
                                    time.strftime("%Y-%m-%d_%H-%M-%S"))
        else:
            log_path = os.path.join(self.log_base, model_id, data_id, cond_id,
                                    arch_id, seed_id, args.name)

        # Move models
        model = model.to(args.device)
        if args.parallel == 'dp':
            model = DataParallelDistribution(model)

        # Init parent
        super(DropoutExperiment,
              self).__init__(model=model,
                             optimizer=optimizer,
                             scheduler_iter=scheduler_iter,
                             scheduler_epoch=scheduler_epoch,
                             log_path=log_path,
                             eval_every=args.eval_every)
        # student teacher args
        teacher = teacher.to(args.device)
        self.teacher = teacher
        self.teacher.eval()
        self.cond_size = 1 if cond_id.startswith(
            'split') or cond_id.startswith('multiply') else 2

        # Store args
        self.create_folders()
        self.save_args(args)
        self.args = args

        # Store IDs
        self.model_id = model_id
        self.data_id = data_id
        self.cond_id = cond_id
        self.optim_id = optim_id
        self.arch_id = arch_id
        self.seed_id = seed_id

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args",
                                 get_args_table(args_dict).get_html_string(),
                                 global_step=0)

        if args.log_wandb:
            wandb.init(config=args_dict,
                       project=args.project,
                       id=args.name,
                       dir=self.log_path)

        # training params
        self.max_grad_norm = args.max_grad_norm

        # automatic mixed precision
        # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward)
        pytorch_170 = int(str(torch.__version__)[2]) >= 7
        self.amp = args.amp and args.parallel != 'dp' and pytorch_170
        if self.amp:
            # only available in pytorch 1.7.0+
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.scaler = None

        # save model architecture for reference
        self.save_architecture()

    def log_fn(self, epoch, train_dict, eval_dict):
        # Tensorboard
        if self.args.log_tb:
            for metric_name, metric_value in train_dict.items():
                self.writer.add_scalar('base/{}'.format(metric_name),
                                       metric_value,
                                       global_step=epoch + 1)
            if eval_dict:
                for metric_name, metric_value in eval_dict.items():
                    self.writer.add_scalar('eval/{}'.format(metric_name),
                                           metric_value,
                                           global_step=epoch + 1)

        # Weights & Biases
        if self.args.log_wandb:
            for metric_name, metric_value in train_dict.items():
                wandb.log({'base/{}'.format(metric_name): metric_value},
                          step=epoch + 1)
            if eval_dict:
                for metric_name, metric_value in eval_dict.items():
                    wandb.log({'eval/{}'.format(metric_name): metric_value},
                              step=epoch + 1)

    def resume(self):
        resume_path = os.path.join(self.log_base, self.model_id, self.data_id,
                                   self.cond_id, self.arch_id, self.seed_id,
                                   self.args.resume, 'check')
        self.checkpoint_load(resume_path)
        for epoch in range(self.current_epoch):

            train_dict = {}
            for metric_name, metric_values in self.train_metrics.items():
                train_dict[metric_name] = metric_values[epoch]

            if epoch in self.eval_epochs:
                eval_dict = {}
                for metric_name, metric_values in self.eval_metrics.items():
                    eval_dict[metric_name] = metric_values[
                        self.eval_epochs.index(epoch)]
            else:
                eval_dict = None

            self.log_fn(epoch, train_dict=train_dict, eval_dict=eval_dict)

    def run(self):
        if self.args.resume:
            self.resume()

        super(DropoutExperiment, self).run(epochs=self.args.epochs)

    def train_fn(self, epoch):
        if self.amp:
            # use automatic mixed precision
            return self._train_amp(epoch)
        else:
            return self._train(epoch)

    def _train_amp(self, epoch):
        """
        Same training procedure, but uses half precision to speed up training on GPUs.
        Only works on SOME GPUs and the latest version of Pytorch.
        """
        self.model.train()
        self.teacher.eval()
        loss_sum = 0.0
        loss_count = 0
        iters_per_epoch = self.args.train_samples // self.args.batch_size
        for i in range(iters_per_epoch):

            batch_size = self.args.batch_size
            with torch.no_grad():
                y = self.teacher.sample(num_samples=self.args.batch_size)
                x = self.cond_fn(y)

            with torch.cuda.amp.autocast():
                loss = self.model.loss(x, y)

            # Scale loss and call backward() to create scaled gradients
            self.scaler.scale(loss).backward()

            if self.max_grad_norm > 0:
                # Unscales the gradients of optimizer's assigned params in-place
                self.scaler.unscale_(self.optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.max_grad_norm)

            # Unscale gradients and call (or skip) optimizer.step()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            if self.scheduler_iter:
                self.scheduler_iter.step()

            self.optimizer.zero_grad(set_to_none=True)

            # accumulate loss and report
            loss_sum += loss.detach().cpu().item() * batch_size
            loss_count += batch_size
            self.log_epoch("Training", loss_count, self.args.train_samples,
                           loss_sum)

        print('')
        if self.scheduler_epoch:
            self.scheduler_epoch.step()

        return {'nll': loss_sum / loss_count}

    def _train(self, epoch):
        self.model.train()
        self.teacher.eval()
        loss_sum = 0.0
        loss_count = 0
        iters_per_epoch = max(1,
                              self.args.train_samples // self.args.batch_size)
        for i in range(iters_per_epoch):

            with torch.no_grad():
                y = self.teacher.sample(num_samples=self.args.batch_size)
                x = self.cond_fn(y)

            self.optimizer.zero_grad()
            loss = self.model.loss(x, y)
            loss.backward()

            if self.max_grad_norm > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.max_grad_norm)

            self.optimizer.step()
            if self.scheduler_iter:
                self.scheduler_iter.step()

            loss_sum += loss.detach().cpu().item() * self.args.batch_size
            loss_count += self.args.batch_size
            self.log_epoch("Training", loss_count, self.args.train_samples,
                           loss_sum)

        print('')
        if self.scheduler_epoch:
            self.scheduler_epoch.step()

        return {'nll': loss_sum / loss_count}

    def eval_fn(self, epoch):
        K_test = 2  # Number of MC samples

        self.model.eval()
        self.teacher.eval()
        with torch.no_grad():
            y = self.teacher.sample(num_samples=self.args.test_samples)
            x = self.cond_fn(y)
            MC_samples = [self.model(x) for _ in range(K_test)]
            x = x.numpy()
            y = y.numpy()

        means = torch.stack([tup[0] for tup in MC_samples
                             ]).view(K_test, x.shape[0], 2).cpu().data.numpy()
        logvar = torch.stack([tup[1] for tup in MC_samples
                              ]).view(K_test, x.shape[0],
                                      2).cpu().data.numpy()

        test_ll = -0.5 * np.exp(-logvar) * (
            means - y.squeeze())**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi)
        test_ll = np.sum(np.sum(test_ll, -1), -1)
        test_ll = logsumexp(test_ll) - np.log(K_test)
        #pppp = test_ll / self.args.test_samples  # per point predictive probability
        rmse = np.mean((np.mean(means, 0) - y)**2.0)**0.5

        print("Eval:", rmse, "test LL:", test_ll)
        return {'rmse': rmse}

    def plot_fn(self):
        plot_path = os.path.join(self.log_path, "samples/")
        if not os.path.exists(plot_path):
            os.mkdir(plot_path)

        if self.args.dataset == 'face_einstein':
            bounds = [[0, 1], [0, 1]]
        else:
            bounds = [[-4, 4], [-4, 4]]

        # plot true data
        test_data = self.teacher.sample(
            num_samples=self.args.test_samples).data.numpy()
        plt.figure(figsize=(self.args.pixels / self.args.dpi,
                            self.args.pixels / self.args.dpi),
                   dpi=self.args.dpi)
        plt.hist2d(test_data[..., 0],
                   test_data[..., 1],
                   bins=256,
                   range=bounds)
        plt.xlim(bounds[0])
        plt.ylim(bounds[1])
        plt.axis('off')
        plt.savefig(os.path.join(plot_path, f'teacher.png'),
                    bbox_inches='tight',
                    pad_inches=0)

        # Plot samples while varying the context
        with torch.no_grad():
            y = self.teacher.sample(num_samples=self.args.num_samples)
            x = self.cond_fn(y)
            y_mean, y_logvar, _ = self.model(x)

        y_mean = y_mean.cpu().numpy()
        y_var = np.exp(y_logvar.cpu().numpy())
        temperature = 0.4
        samples = [np.random.normal(y_mean[:, i], y_var[:, i] * temperature, self.args.num_samples).T[:, np.newaxis] \
                   for i in range(y_mean.shape[1])]
        samples = np.hstack(samples)
        #samples = y_mean.cpu().numpy()
        plt.figure(figsize=(self.args.pixels / self.args.dpi,
                            self.args.pixels / self.args.dpi),
                   dpi=self.args.dpi)
        plt.hist2d(samples[..., 0], samples[..., 1], bins=256, range=bounds)
        plt.xlim(bounds[0])
        plt.ylim(bounds[1])
        plt.axis('off')
        plt.savefig(os.path.join(plot_path,
                                 f'varying_context_flow_samples.png'),
                    bbox_inches='tight',
                    pad_inches=0)

        # Plot density
        xv, yv = torch.meshgrid([
            torch.linspace(bounds[0][0], bounds[0][1], self.args.grid_size),
            torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size)
        ])
        y = torch.cat([xv.reshape(-1, 1), yv.reshape(-1, 1)],
                      dim=-1).to(self.args.device)
        with torch.no_grad():
            x = self.cond_fn(y)
            means, logvar, _ = self.model(x)
        logprobs = -0.5 * torch.exp(-logvar) * (
            means - y)**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi)
        logprobs = torch.sum(logprobs, dim=1)
        logprobs = logprobs - logprobs.max()
        probs = torch.exp(logprobs)
        plt.figure(figsize=(self.args.pixels / self.args.dpi,
                            self.args.pixels / self.args.dpi),
                   dpi=self.args.dpi)
        plt.pcolormesh(xv, yv, probs.reshape(xv.shape), shading='auto')
        plt.xlim(bounds[0])
        plt.ylim(bounds[1])
        plt.axis('off')
        plt.clim(0, self.args.clim)
        plt.savefig(os.path.join(plot_path,
                                 f'varying_context_flow_density.png'),
                    bbox_inches='tight',
                    pad_inches=0)

        # plot samples with fixed context
        for i, x in enumerate(self.context_permutations()):
            with torch.no_grad():
                y_mean, y_logvar, _ = self.model(
                    x.expand((self.args.num_samples, self.cond_size)))
            y_mean = y_mean.cpu().numpy()
            y_var = np.exp(y_logvar.cpu().numpy())
            temperature = 1.5
            samples = [np.random.normal(y_mean[:, i], y_var[:, i] * temperature, self.args.num_samples).T[:, np.newaxis] \
                       for i in range(y_mean.shape[1])]
            samples = np.hstack(samples)
            #samples = y_mean.cpu().numpy()
            plt.figure(figsize=(self.args.pixels / self.args.dpi,
                                self.args.pixels / self.args.dpi),
                       dpi=self.args.dpi)
            plt.hist2d(samples[..., 0],
                       samples[..., 1],
                       bins=256,
                       range=bounds)
            plt.xlim(bounds[0])
            plt.ylim(bounds[1])
            plt.axis('off')
            plt.savefig(os.path.join(plot_path,
                                     f'fixed_context_flow_samples_{i}.png'),
                        bbox_inches='tight',
                        pad_inches=0)

            # Plot density
            xv, yv = torch.meshgrid([
                torch.linspace(bounds[0][0], bounds[0][1],
                               self.args.grid_size),
                torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size)
            ])
            y = torch.cat(
                [xv.reshape(-1, 1), yv.reshape(-1, 1)],
                dim=-1).to(self.args.device)
            with torch.no_grad():
                means, logvar, _ = self.model(x.view(1, self.cond_size))
                means = means.expand((y.size(0), 2))
                logvar = logvar.expand((y.size(0), 2))
                logprobs = -0.5 * torch.exp(-logvar) * (
                    means - y)**2.0 - 0.5 * logvar - 0.5 * np.log(2 * np.pi)
                logprobs = torch.sum(logprobs, dim=1)
                logprobs = logprobs - logprobs.max()
                probs = torch.exp(logprobs)
            plt.figure(figsize=(self.args.pixels / self.args.dpi,
                                self.args.pixels / self.args.dpi),
                       dpi=self.args.dpi)
            plt.pcolormesh(xv, yv, probs.reshape(xv.shape), shading='auto')
            plt.xlim(bounds[0])
            plt.ylim(bounds[1])
            plt.axis('off')
            plt.clim(0, self.args.clim)
            plt.savefig(os.path.join(plot_path,
                                     f'fixed_context_flow_density_{i}.png'),
                        bbox_inches='tight',
                        pad_inches=0)
Example #5
0
class StudentExperiment(BaseExperiment):

    log_base = os.path.join(get_survae_path(),
                            'experiments/student_teacher/log')
    no_log_keys = [
        'project', 'name', 'log_tb', 'log_wandb', 'eval_every', 'device',
        'parallel', 'pin_memory', 'num_workers'
    ]

    def __init__(self, args, data_id, model_id, optim_id, model, teacher,
                 optimizer, scheduler_iter, scheduler_epoch):

        # Edit args
        if args.eval_every is None or args.eval_every == 0:
            args.eval_every = args.epochs
        if args.name is None:
            args.name = time.strftime("%Y-%m-%d_%H-%M-%S")
        if args.project is None:
            args.project = '_'.join([data_id, model_id])

        aug_or_abs = 'abs' if args.augment_size == 0 else f"aug{args.augment_size}"
        hidden = '_'.join([str(u) for u in args.hidden_units])
        cond_id = args.cond_trans.lower()
        arch_id = f"flow_{aug_or_abs}_k{args.num_flows}_h{hidden}_{'affine' if args.affine else 'additive'}{'_actnorm' if args.actnorm else ''}"
        seed_id = f"seed{args.seed}"
        if args.name == "debug":
            log_path = os.path.join(self.log_base, "debug", model_id, data_id,
                                    cond_id, arch_id, seed_id,
                                    time.strftime("%Y-%m-%d_%H-%M-%S"))
        else:
            log_path = os.path.join(self.log_base, model_id, data_id, cond_id,
                                    arch_id, seed_id, args.name)

        # Move models
        model = model.to(args.device)
        if args.parallel == 'dp':
            model = DataParallelDistribution(model)

        # Init parent
        super(StudentExperiment,
              self).__init__(model=model,
                             optimizer=optimizer,
                             scheduler_iter=scheduler_iter,
                             scheduler_epoch=scheduler_epoch,
                             log_path=log_path,
                             eval_every=args.eval_every)
        # student teacher args
        teacher = teacher.to(args.device)
        self.teacher = teacher
        self.teacher.eval()
        self.cond_size = 1 if cond_id.startswith(
            'split') or cond_id.startswith('multiply') else 2

        # Store args
        self.create_folders()
        self.save_args(args)
        self.args = args

        # Store IDs
        self.model_id = model_id
        self.data_id = data_id
        self.cond_id = cond_id
        self.optim_id = optim_id
        self.arch_id = arch_id
        self.seed_id = seed_id

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args",
                                 get_args_table(args_dict).get_html_string(),
                                 global_step=0)

        if args.log_wandb:
            wandb.init(config=args_dict,
                       project=args.project,
                       id=args.name,
                       dir=self.log_path)

        # training params
        self.max_grad_norm = args.max_grad_norm

        # automatic mixed precision
        # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward)
        pytorch_170 = int(str(torch.__version__)[2]) >= 7
        self.amp = args.amp and args.parallel != 'dp' and pytorch_170
        if self.amp:
            # only available in pytorch 1.7.0+
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.scaler = None

        # save model architecture for reference
        self.save_architecture()

    def log_fn(self, epoch, train_dict, eval_dict):
        # Tensorboard
        if self.args.log_tb:
            for metric_name, metric_value in train_dict.items():
                self.writer.add_scalar('base/{}'.format(metric_name),
                                       metric_value,
                                       global_step=epoch + 1)
            if eval_dict:
                for metric_name, metric_value in eval_dict.items():
                    self.writer.add_scalar('eval/{}'.format(metric_name),
                                           metric_value,
                                           global_step=epoch + 1)

        # Weights & Biases
        if self.args.log_wandb:
            for metric_name, metric_value in train_dict.items():
                wandb.log({'base/{}'.format(metric_name): metric_value},
                          step=epoch + 1)
            if eval_dict:
                for metric_name, metric_value in eval_dict.items():
                    wandb.log({'eval/{}'.format(metric_name): metric_value},
                              step=epoch + 1)

    def resume(self):
        resume_path = os.path.join(self.log_base, self.model_id, self.data_id,
                                   self.cond_id, self.arch_id, self.seed_id,
                                   self.args.resume, 'check')
        self.checkpoint_load(resume_path)
        for epoch in range(self.current_epoch):

            train_dict = {}
            for metric_name, metric_values in self.train_metrics.items():
                train_dict[metric_name] = metric_values[epoch]

            if epoch in self.eval_epochs:
                eval_dict = {}
                for metric_name, metric_values in self.eval_metrics.items():
                    eval_dict[metric_name] = metric_values[
                        self.eval_epochs.index(epoch)]
            else:
                eval_dict = None

            self.log_fn(epoch, train_dict=train_dict, eval_dict=eval_dict)

    def run(self):
        if self.args.resume:
            self.resume()

        super(StudentExperiment, self).run(epochs=self.args.epochs)

    def train_fn(self, epoch):
        if self.amp:
            # use automatic mixed precision
            return self._train_amp(epoch)
        else:
            return self._train(epoch)

    def _train_amp(self, epoch):
        """
        Same training procedure, but uses half precision to speed up training on GPUs.
        Only works on SOME GPUs and the latest version of Pytorch.
        """
        self.model.train()
        self.teacher.eval()
        loss_sum = 0.0
        loss_count = 0
        iters_per_epoch = self.args.train_samples // self.args.batch_size
        for i in range(iters_per_epoch):

            batch_size = self.args.batch_size
            with torch.no_grad():
                y = self.teacher.sample(num_samples=self.args.batch_size)
                x = self.cond_fn(y)

            with torch.cuda.amp.autocast():
                loss = -self.model.log_prob(y, context=x).mean()

            # Scale loss and call backward() to create scaled gradients
            self.scaler.scale(loss).backward()

            if self.max_grad_norm > 0:
                # Unscales the gradients of optimizer's assigned params in-place
                self.scaler.unscale_(self.optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.max_grad_norm)

            # Unscale gradients and call (or skip) optimizer.step()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            if self.scheduler_iter:
                self.scheduler_iter.step()

            self.optimizer.zero_grad(set_to_none=True)

            # accumulate loss and report
            loss_sum += loss.detach().cpu().item() * batch_size
            loss_count += batch_size
            self.log_epoch("Training", loss_count, self.args.train_samples,
                           loss_sum)

        print('')
        if self.scheduler_epoch:
            self.scheduler_epoch.step()

        return {'nll': loss_sum / loss_count}

    def _train(self, epoch):
        self.model.train()
        self.teacher.eval()
        loss_sum = 0.0
        loss_count = 0
        iters_per_epoch = self.args.train_samples // self.args.batch_size
        for i in range(iters_per_epoch):

            with torch.no_grad():
                y = self.teacher.sample(num_samples=self.args.batch_size)
                x = self.cond_fn(y)

            self.optimizer.zero_grad()
            loss = -self.model.log_prob(y, context=x).mean()
            loss.backward()

            if self.max_grad_norm > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.max_grad_norm)

            self.optimizer.step()
            if self.scheduler_iter:
                self.scheduler_iter.step()

            loss_sum += loss.detach().cpu().item() * self.args.batch_size
            loss_count += self.args.batch_size
            self.log_epoch("Training", loss_count, self.args.train_samples,
                           loss_sum)

        print('')
        if self.scheduler_epoch:
            self.scheduler_epoch.step()

        return {'nll': loss_sum / loss_count}

    def eval_fn(self):
        self.model.eval()
        self.teacher.eval()
        with torch.no_grad():
            loss_sum = 0.0
            loss_count = 0
            iters_per_epoch = self.args.test_samples // self.args.batch_size
            for i in range(iters_per_epoch):
                y = self.teacher.sample(num_samples=self.args.batch_size)
                x = self.cond_fn(y)
                loss = -self.model.log_prob(y, context=x).mean()
                loss_sum += loss.detach().cpu().item() * self.args.batch_size
                loss_count += self.args.batch_size
                print(
                    'Evaluating. Epoch: {}/{}, Datapoint: {}/{}, NLL: {:.3f}'.
                    format(self.current_epoch + 1, self.args.epochs,
                           loss_count, self.args.test_samples,
                           loss_sum / loss_count),
                    end='\r')
            print('')
        return {'nll': loss_sum / loss_count}

    def plot_fn(self):
        plot_path = os.path.join(self.log_path, "samples/")
        if not os.path.exists(plot_path):
            os.mkdir(plot_path)

        if self.args.dataset == 'face_einstein':
            bounds = [[0, 1], [0, 1]]
        else:
            bounds = [[-4, 4], [-4, 4]]

        # plot true data
        test_data = self.teacher.sample(
            num_samples=self.args.test_samples).data.numpy()
        plt.figure(figsize=(self.args.pixels / self.args.dpi,
                            self.args.pixels / self.args.dpi),
                   dpi=self.args.dpi)
        plt.hist2d(test_data[..., 0],
                   test_data[..., 1],
                   bins=256,
                   range=bounds)
        plt.xlim(bounds[0])
        plt.ylim(bounds[1])
        plt.axis('off')
        plt.savefig(os.path.join(plot_path, f'teacher.png'),
                    bbox_inches='tight',
                    pad_inches=0)

        # Plot samples while varying the context
        with torch.no_grad():
            y = self.teacher.sample(num_samples=self.args.num_samples)
            x = self.cond_fn(y)
            samples = self.model.sample(context=x, temperature=0.4)
            samples = samples.cpu().numpy()

        plt.figure(figsize=(self.args.pixels / self.args.dpi,
                            self.args.pixels / self.args.dpi),
                   dpi=self.args.dpi)
        plt.hist2d(samples[..., 0], samples[..., 1], bins=256, range=bounds)
        plt.xlim(bounds[0])
        plt.ylim(bounds[1])
        plt.axis('off')
        plt.savefig(os.path.join(plot_path,
                                 f'varying_context_flow_samples.png'),
                    bbox_inches='tight',
                    pad_inches=0)

        # Plot density
        xv, yv = torch.meshgrid([
            torch.linspace(bounds[0][0], bounds[0][1], self.args.grid_size),
            torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size)
        ])
        y = torch.cat([xv.reshape(-1, 1), yv.reshape(-1, 1)],
                      dim=-1).to(self.args.device)
        x = self.cond_fn(y)
        with torch.no_grad():
            logprobs = self.model.log_prob(y, context=x)
            logprobs = logprobs - logprobs.max().item()

        plt.figure(figsize=(self.args.pixels / self.args.dpi,
                            self.args.pixels / self.args.dpi),
                   dpi=self.args.dpi)
        plt.pcolormesh(xv, yv, logprobs.exp().reshape(xv.shape))
        plt.xlim(bounds[0])
        plt.ylim(bounds[1])
        plt.axis('off')
        print('Range:',
              logprobs.exp().min().item(),
              logprobs.exp().max().item())
        print('Limits:', 0.0, self.args.clim)
        plt.clim(0, self.args.clim)
        plt.savefig(os.path.join(plot_path,
                                 f'varying_context_flow_density.png'),
                    bbox_inches='tight',
                    pad_inches=0)

        # plot samples with fixed context
        for i, x in enumerate(self.context_permutations()):
            with torch.no_grad():
                #y = self.teacher.sample(num_samples=1).expand((self.args.num_samples, 2))
                samples = self.model.sample(context=x.expand(
                    (self.args.num_samples, self.cond_size)),
                                            temperature=1.5).cpu().numpy()

            plt.figure(figsize=(self.args.pixels / self.args.dpi,
                                self.args.pixels / self.args.dpi),
                       dpi=self.args.dpi)
            plt.hist2d(samples[..., 0],
                       samples[..., 1],
                       bins=256,
                       range=bounds)
            plt.xlim(bounds[0])
            plt.ylim(bounds[1])
            plt.axis('off')
            plt.savefig(os.path.join(plot_path,
                                     f'fixed_context_flow_samples_{i}.png'),
                        bbox_inches='tight',
                        pad_inches=0)

            # Plot density
            xv, yv = torch.meshgrid([
                torch.linspace(bounds[0][0], bounds[0][1],
                               self.args.grid_size),
                torch.linspace(bounds[1][0], bounds[1][1], self.args.grid_size)
            ])
            y = torch.cat(
                [xv.reshape(-1, 1), yv.reshape(-1, 1)],
                dim=-1).to(self.args.device)
            with torch.no_grad():
                logprobs = self.model.log_prob(y,
                                               context=x.expand(
                                                   (y.size(0),
                                                    self.cond_size)))
                logprobs = logprobs - logprobs.max().item()

            plt.figure(figsize=(self.args.pixels / self.args.dpi,
                                self.args.pixels / self.args.dpi),
                       dpi=self.args.dpi)
            plt.pcolormesh(xv, yv, logprobs.exp().reshape(xv.shape))
            plt.xlim(bounds[0])
            plt.ylim(bounds[1])
            plt.axis('off')
            print('Range:',
                  logprobs.exp().min().item(),
                  logprobs.exp().max().item())
            print('Limits:', 0.0, self.args.clim)
            plt.clim(0, self.args.clim)
            plt.savefig(os.path.join(plot_path,
                                     f'fixed_context_flow_density_{i}.png'),
                        bbox_inches='tight',
                        pad_inches=0)
Example #6
0
class FlowExperiment(BaseExperiment):

    log_base = os.path.join(get_survae_path(), 'experiments/gbnf/log')
    no_log_keys = [
        'project', 'name', 'log_tb', 'log_wandb', 'check_every', 'eval_every',
        'device', 'parallel', 'pin_memory', 'num_workers'
    ]

    def __init__(self, args, data_id, model_id, optim_id, train_loader,
                 eval_loader, model, optimizer, scheduler_iter,
                 scheduler_epoch):

        # Edit args
        if args.eval_every is None or args.eval_every == 0:
            args.eval_every = args.epochs
        if args.check_every is None or args.check_every == 0:
            args.check_every = args.epochs
        if args.name is None:
            args.name = time.strftime("%Y-%m-%d_%H-%M-%S")
        if args.project is None:
            args.project = '_'.join([data_id, model_id])

        arch_id = f"{args.coupling_network}_coupling_scales{args.num_scales}_steps{args.num_steps}"
        seed_id = f"seed{args.seed}"
        if args.name == "debug":
            log_path = os.path.join(self.log_base, "debug", data_id, model_id,
                                    arch_id, optim_id, seed_id,
                                    time.strftime("%Y-%m-%d_%H-%M-%S"))
        else:
            log_path = os.path.join(self.log_base, data_id, model_id, arch_id,
                                    optim_id, seed_id, args.name)

        # Move model
        model = model.to(args.device)
        if args.parallel == 'dp':
            model = DataParallelDistribution(model)

        # Init parent
        super(FlowExperiment, self).__init__(model=model,
                                             optimizer=optimizer,
                                             scheduler_iter=scheduler_iter,
                                             scheduler_epoch=scheduler_epoch,
                                             log_path=log_path,
                                             eval_every=args.eval_every,
                                             check_every=args.check_every)
        # Store args
        self.create_folders()
        self.save_args(args)
        self.args = args

        # Store IDs
        self.data_id = data_id
        self.model_id = model_id
        self.optim_id = optim_id
        self.arch_id = arch_id
        self.seed_id = seed_id

        # Store data loaders
        self.train_loader = train_loader
        self.eval_loader = eval_loader

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args",
                                 get_args_table(args_dict).get_html_string(),
                                 global_step=0)

        if args.log_wandb:
            wandb.init(config=args_dict,
                       project=args.project,
                       id=args.name,
                       dir=self.log_path)

        # training params
        self.max_grad_norm = args.max_grad_norm

        # automatic mixed precision
        # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward)
        pytorch_170 = int(str(torch.__version__)[2]) >= 7
        self.amp = args.amp and args.parallel != 'dp' and pytorch_170
        if self.amp:
            # only available in pytorch 1.7.0+
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.scaler = None

        # save model architecture for reference
        self.save_architecture()

    def log_fn(self, epoch, train_dict, eval_dict):
        # Tensorboard
        if self.args.log_tb:
            for metric_name, metric_value in train_dict.items():
                self.writer.add_scalar('base/{}'.format(metric_name),
                                       metric_value,
                                       global_step=epoch + 1)
            if eval_dict:
                for metric_name, metric_value in eval_dict.items():
                    self.writer.add_scalar('eval/{}'.format(metric_name),
                                           metric_value,
                                           global_step=epoch + 1)

        # Weights & Biases
        if self.args.log_wandb:
            for metric_name, metric_value in train_dict.items():
                wandb.log({'base/{}'.format(metric_name): metric_value},
                          step=epoch + 1)
            if eval_dict:
                for metric_name, metric_value in eval_dict.items():
                    wandb.log({'eval/{}'.format(metric_name): metric_value},
                              step=epoch + 1)

    def resume(self):
        resume_path = os.path.join(self.log_base, self.data_id, self.model_id,
                                   self.arch_id, self.optim_id, self.seed_id,
                                   self.args.resume, 'check')
        #self.checkpoint_load(resume_path, device=self.args.device)
        self.checkpoint_load(resume_path)
        for epoch in range(self.current_epoch):

            train_dict = {}
            for metric_name, metric_values in self.train_metrics.items():
                train_dict[metric_name] = metric_values[epoch]

            if epoch in self.eval_epochs:
                eval_dict = {}
                for metric_name, metric_values in self.eval_metrics.items():
                    eval_dict[metric_name] = metric_values[
                        self.eval_epochs.index(epoch)]
            else:
                eval_dict = None

            self.log_fn(epoch, train_dict=train_dict, eval_dict=eval_dict)

    def run(self):
        if self.args.resume:
            self.resume()

        super(FlowExperiment, self).run(epochs=self.args.epochs)

    def train_fn(self, epoch):
        if self.amp:
            # use automatic mixed precision
            return self._train_amp(epoch)
        else:
            return self._train(epoch)

    def _train_amp(self, epoch):
        """
        Same training procedure, but uses half precision to speed up training on GPUs.
        Only works on SOME GPUs and the latest version of Pytorch.
        """
        self.model.train()
        loss_sum = 0.0
        loss_count = 0
        print_every = max(
            1, (len(self.train_loader.dataset) // self.args.batch_size) // 20)
        for i, x in enumerate(self.train_loader):

            # Cast operations to mixed precision
            if self.args.super_resolution or self.args.conditional:
                batch_size = len(x[0])
                with torch.cuda.amp.autocast():
                    loss = cond_elbo_bpd(self.model,
                                         x[0].to(self.args.device),
                                         context=x[1].to(self.args.device))
            else:
                batch_size = len(x)
                with torch.cuda.amp.autocast():
                    loss = elbo_bpd(self.model, x.to(self.args.device))

            # Scale loss and call backward() to create scaled gradients
            self.scaler.scale(loss).backward()

            if self.max_grad_norm > 0:
                # Unscales the gradients of optimizer's assigned params in-place
                self.scaler.unscale_(self.optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.max_grad_norm)

            # Unscale gradients and call (or skip) optimizer.step()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            if self.scheduler_iter:
                self.scheduler_iter.step()

            self.optimizer.zero_grad(set_to_none=True)

            # accumulate loss and report
            loss_sum += loss.detach().cpu().item() * batch_size
            loss_count += batch_size
            if i % print_every == 0 or i == (len(self.train_loader) - 1):
                self.log_epoch("Training", loss_count,
                               len(self.train_loader.dataset), loss_sum)

        print('')
        if self.scheduler_epoch:
            self.scheduler_epoch.step()

        return {'bpd': loss_sum / loss_count}

    def _train(self, epoch):
        self.model.train()
        loss_sum = 0.0
        loss_count = 0
        print_every = max(
            1, (len(self.train_loader.dataset) // self.args.batch_size) // 20)
        for i, x in enumerate(self.train_loader):
            self.optimizer.zero_grad()

            if self.args.super_resolution or self.args.conditional:
                batch_size = len(x[0])
                loss = cond_elbo_bpd(self.model,
                                     x[0].to(self.args.device),
                                     context=x[1].to(self.args.device))
            else:
                batch_size = len(x)
                loss = elbo_bpd(self.model, x.to(self.args.device))

            loss.backward()

            if self.max_grad_norm > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.max_grad_norm)

            self.optimizer.step()
            if self.scheduler_iter:
                self.scheduler_iter.step()

            loss_sum += loss.detach().cpu().item() * batch_size
            loss_count += batch_size
            if i % print_every == 0 or i == (len(self.train_loader) - 1):
                self.log_epoch("Training", loss_count,
                               len(self.train_loader.dataset), loss_sum)

        print('')
        if self.scheduler_epoch:
            self.scheduler_epoch.step()

        return {'bpd': loss_sum / loss_count}

    def eval_fn(self, epoch):
        self.model.eval()
        with torch.no_grad():
            loss_sum = 0.0
            loss_count = 0
            for x in self.eval_loader:
                if self.args.super_resolution or self.args.conditional:
                    batch_size = len(x[0])
                    loss = cond_elbo_bpd(self.model,
                                         x[0].to(self.args.device),
                                         context=x[1].to(self.args.device))
                else:
                    batch_size = len(x)
                    loss = elbo_bpd(self.model, x.to(self.args.device))

                loss_sum += loss.detach().cpu().item() * batch_size
                loss_count += batch_size

            self.log_epoch("Evaluating", loss_count,
                           len(self.eval_loader.dataset), loss_sum)
            print('')
        return {'bpd': loss_sum / loss_count}

    def sample_fn(self, temperature=None, sample_new_batch=False):
        if self.args.samples < 1:
            return

        self.model.eval()
        get_new_batch = self.sample_batch is None or sample_new_batch
        if get_new_batch:
            self.sample_batch = next(iter(self.eval_loader))

        if self.args.super_resolution or self.args.conditional:
            imgs = self.sample_batch[0][:self.args.samples]
            context = self.sample_batch[1][:self.args.samples]
            self._cond_sample_fn(context,
                                 temperature=temperature,
                                 save_context=get_new_batch)
        else:
            imgs = self.sample_batch[:self.args.samples]
            self._sample_fn(temperature=temperature)

        if get_new_batch:
            # save real samples
            path_true_samples = '{}/samples/true_e{}_s{}.png'.format(
                self.log_path, self.current_epoch + 1, self.args.seed)
            self.save_images(imgs, path_true_samples)

    def _sample_fn(self, temperature=None):
        path_samples = '{}/samples/sample_e{}_s{}.png'.format(
            self.log_path, self.current_epoch + 1, self.args.seed)
        samples = self.model.sample(self.args.samples, temperature=temperature)
        self.save_images(samples, path_samples)

    def _cond_sample_fn(self, context, temperature=None, save_context=True):
        if save_context:
            # save low-resolution samples
            path_context = '{}/samples/context_e{}_s{}.png'.format(
                self.log_path, self.current_epoch + 1, self.args.seed)
            self.save_images(context, path_context)

        # save samples from model conditioned on context
        path_samples = '{}/samples/sample_e{}_s{}.png'.format(
            self.log_path, self.current_epoch + 1, self.args.seed)
        samples = self.model.sample(context.to(self.args.device),
                                    temperature=temperature)
        self.save_images(samples, path_samples)

    def save_images(self, imgs, file_path):
        if not os.path.exists(os.path.dirname(file_path)):
            os.mkdir(os.path.dirname(file_path))

        out = imgs.cpu().float()
        if out.max().item() > 2:
            out /= (2**self.args.num_bits - 1)

        vutils.save_image(out, file_path, nrow=self.args.nrow)

    def stop_early(self, loss_dict, epoch):
        if self.args.early_stop == 0:
            return False, True

        # else check if we've passed the early stopping threshold
        current_loss = loss_dict['bpd']
        model_improved = current_loss < self.best_loss
        if model_improved:
            early_stop_flag = False
            self.best_loss = current_loss
            self.best_loss_epoch = epoch
        else:
            # model didn't improve, do we consider it converged yet?
            early_stop_count = (epoch - self.best_loss_epoch)
            early_stop_flag = early_stop_count >= self.args.early_stop

        if early_stop_flag:
            print(
                f'Stopping training early: no improvement after {self.args.early_stop} epochs (last improvement at epoch {self.best_loss_epoch})'
            )

        return early_stop_flag, model_improved