def forward(self, batch, num_samples=None, reduce_ll=True): outs = AttrDict() if self.training: pz = self.lenc(batch.xc, batch.yc) qz = self.lenc(batch.x, batch.y) z = qz.rsample() if num_samples is None else \ qz.rsample([num_samples]) py = self.predict(batch.xc, batch.yc, batch.x, z=z, num_samples=num_samples) if num_samples > 1: # K * B * N recon = py.log_prob(stack(batch.y, num_samples)).sum(-1) # K * B log_qz = qz.log_prob(z).sum(-1) log_pz = pz.log_prob(z).sum(-1) # K * B log_w = recon.sum(-1) + log_pz - log_qz outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2] else: outs.recon = py.log_prob(batch.y).sum(-1).mean() outs.kld = kl_divergence(qz, pz).sum(-1).mean() outs.loss = -outs.recon + outs.kld / batch.x.shape[-2] else: py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples) if num_samples is None: ll = py.log_prob(batch.y).sum(-1) else: y = torch.stack([batch.y] * num_samples) if reduce_ll: ll = logmeanexp(py.log_prob(y).sum(-1)) else: ll = py.log_prob(y).sum(-1) num_ctx = batch.xc.shape[-2] if reduce_ll: outs.ctx_ll = ll[..., :num_ctx].mean() outs.tar_ll = ll[..., num_ctx:].mean() else: outs.ctx_ll = ll[..., :num_ctx] outs.tar_ll = ll[..., num_ctx:] return outs
def compute_ll(py, y): ll = py.log_prob(y).sum(-1) if ll.dim() == 3 and reduce_ll: ll = logmeanexp(ll) return ll
def ensemble(args, model): num_runs = 5 models = [] for i in range(num_runs): model_ = deepcopy(model) ckpt = torch.load( osp.join(results_path, 'celeba', args.model, f'run{i+1}', 'ckpt.tar')) model_.load_state_dict(ckpt['model']) model_.cuda() model_.eval() models.append(model_) path = osp.join(evalsets_path, 'celeba') if not osp.isdir(path): os.makedirs(path) filename = 'no_noise.tar' if args.t_noise is None else \ f'{args.t_noise}.tar' if not osp.isfile(osp.join(path, filename)): print('generating evaluation sets...') gen_evalset(args) eval_batches = torch.load(osp.join(path, filename)) ravg = RunningAverage() with torch.no_grad(): for batch in tqdm(eval_batches): for key, val in batch.items(): batch[key] = val.cuda() ctx_ll = [] tar_ll = [] for model in models: outs = model(batch, num_samples=args.eval_num_samples, reduce_ll=False) ctx_ll.append(outs.ctx_ll) tar_ll.append(outs.tar_ll) if ctx_ll[0].dim() == 2: ctx_ll = torch.stack(ctx_ll) tar_ll = torch.stack(tar_ll) else: ctx_ll = torch.cat(ctx_ll) tar_ll = torch.cat(tar_ll) ctx_ll = logmeanexp(ctx_ll).mean() tar_ll = logmeanexp(tar_ll).mean() ravg.update('ctx_ll', ctx_ll) ravg.update('tar_ll', tar_ll) torch.manual_seed(time.time()) torch.cuda.manual_seed(time.time()) filename = f'ensemble' if args.t_noise is not None: filename += f'_{args.t_noise}' filename += '.log' logger = get_logger(osp.join(results_path, 'celeba', args.model, filename), mode='w') logger.info(ravg.info())
def ensemble(args, model): num_runs = 5 models = [] for i in range(num_runs): model_ = deepcopy(model) ckpt = torch.load( osp.join(results_path, 'lotka_volterra', args.model, f'run{i+1}', 'ckpt.tar')) model_.load_state_dict(ckpt['model']) model_.cuda() model_.eval() models.append(model_) torch.manual_seed(args.eval_seed) torch.cuda.manual_seed(args.eval_seed) if args.hare_lynx: eval_data = load_hare_lynx(1000, 16) else: eval_data = torch.load( osp.join(datasets_path, 'lotka_volterra', 'eval.tar')) ravg = RunningAverage() with torch.no_grad(): for batch in tqdm(eval_data): batch = standardize(batch) for key, val in batch.items(): batch[key] = val.cuda() ctx_ll = [] tar_ll = [] for model_ in models: outs = model_(batch, num_samples=args.eval_num_samples, reduce_ll=False) ctx_ll.append(outs.ctx_ll) tar_ll.append(outs.tar_ll) if ctx_ll[0].dim() == 2: ctx_ll = torch.stack(ctx_ll) tar_ll = torch.stack(tar_ll) else: ctx_ll = torch.cat(ctx_ll) tar_ll = torch.cat(tar_ll) ctx_ll = logmeanexp(ctx_ll).mean() tar_ll = logmeanexp(tar_ll).mean() ravg.update('ctx_ll', ctx_ll) ravg.update('tar_ll', tar_ll) torch.manual_seed(time.time()) torch.cuda.manual_seed(time.time()) filename = 'ensemble' if args.hare_lynx: filename += '_hare_lynx' filename += '.log' logger = get_logger(osp.join(results_path, 'lotka_volterra', args.model, filename), mode='w') logger.info(ravg.info())