Ejemplo n.º 1
0
    def forward(self, batch, num_samples=None, reduce_ll=True):
        outs = AttrDict()
        if self.training:
            pz = self.lenc(batch.xc, batch.yc)
            qz = self.lenc(batch.x, batch.y)
            z = qz.rsample() if num_samples is None else \
                    qz.rsample([num_samples])
            py = self.predict(batch.xc,
                              batch.yc,
                              batch.x,
                              z=z,
                              num_samples=num_samples)

            if num_samples > 1:
                # K * B * N
                recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
                # K * B
                log_qz = qz.log_prob(z).sum(-1)
                log_pz = pz.log_prob(z).sum(-1)

                # K * B
                log_w = recon.sum(-1) + log_pz - log_qz

                outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
            else:
                outs.recon = py.log_prob(batch.y).sum(-1).mean()
                outs.kld = kl_divergence(qz, pz).sum(-1).mean()
                outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]

        else:
            py = self.predict(batch.xc,
                              batch.yc,
                              batch.x,
                              num_samples=num_samples)
            if num_samples is None:
                ll = py.log_prob(batch.y).sum(-1)
            else:
                y = torch.stack([batch.y] * num_samples)
                if reduce_ll:
                    ll = logmeanexp(py.log_prob(y).sum(-1))
                else:
                    ll = py.log_prob(y).sum(-1)
            num_ctx = batch.xc.shape[-2]

            if reduce_ll:
                outs.ctx_ll = ll[..., :num_ctx].mean()
                outs.tar_ll = ll[..., num_ctx:].mean()
            else:
                outs.ctx_ll = ll[..., :num_ctx]
                outs.tar_ll = ll[..., num_ctx:]

        return outs
Ejemplo n.º 2
0
 def compute_ll(py, y):
     ll = py.log_prob(y).sum(-1)
     if ll.dim() == 3 and reduce_ll:
         ll = logmeanexp(ll)
     return ll
Ejemplo n.º 3
0
def ensemble(args, model):
    num_runs = 5
    models = []
    for i in range(num_runs):
        model_ = deepcopy(model)
        ckpt = torch.load(
            osp.join(results_path, 'celeba', args.model, f'run{i+1}',
                     'ckpt.tar'))
        model_.load_state_dict(ckpt['model'])
        model_.cuda()
        model_.eval()
        models.append(model_)

    path = osp.join(evalsets_path, 'celeba')
    if not osp.isdir(path):
        os.makedirs(path)
    filename = 'no_noise.tar' if args.t_noise is None else \
            f'{args.t_noise}.tar'
    if not osp.isfile(osp.join(path, filename)):
        print('generating evaluation sets...')
        gen_evalset(args)

    eval_batches = torch.load(osp.join(path, filename))

    ravg = RunningAverage()
    with torch.no_grad():
        for batch in tqdm(eval_batches):
            for key, val in batch.items():
                batch[key] = val.cuda()

            ctx_ll = []
            tar_ll = []
            for model in models:
                outs = model(batch,
                             num_samples=args.eval_num_samples,
                             reduce_ll=False)
                ctx_ll.append(outs.ctx_ll)
                tar_ll.append(outs.tar_ll)

            if ctx_ll[0].dim() == 2:
                ctx_ll = torch.stack(ctx_ll)
                tar_ll = torch.stack(tar_ll)
            else:
                ctx_ll = torch.cat(ctx_ll)
                tar_ll = torch.cat(tar_ll)

            ctx_ll = logmeanexp(ctx_ll).mean()
            tar_ll = logmeanexp(tar_ll).mean()

            ravg.update('ctx_ll', ctx_ll)
            ravg.update('tar_ll', tar_ll)

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    filename = f'ensemble'
    if args.t_noise is not None:
        filename += f'_{args.t_noise}'
    filename += '.log'
    logger = get_logger(osp.join(results_path, 'celeba', args.model, filename),
                        mode='w')
    logger.info(ravg.info())
Ejemplo n.º 4
0
def ensemble(args, model):
    num_runs = 5
    models = []
    for i in range(num_runs):
        model_ = deepcopy(model)
        ckpt = torch.load(
            osp.join(results_path, 'lotka_volterra', args.model, f'run{i+1}',
                     'ckpt.tar'))
        model_.load_state_dict(ckpt['model'])
        model_.cuda()
        model_.eval()
        models.append(model_)

    torch.manual_seed(args.eval_seed)
    torch.cuda.manual_seed(args.eval_seed)

    if args.hare_lynx:
        eval_data = load_hare_lynx(1000, 16)
    else:
        eval_data = torch.load(
            osp.join(datasets_path, 'lotka_volterra', 'eval.tar'))

    ravg = RunningAverage()
    with torch.no_grad():
        for batch in tqdm(eval_data):
            batch = standardize(batch)
            for key, val in batch.items():
                batch[key] = val.cuda()

            ctx_ll = []
            tar_ll = []
            for model_ in models:
                outs = model_(batch,
                              num_samples=args.eval_num_samples,
                              reduce_ll=False)
                ctx_ll.append(outs.ctx_ll)
                tar_ll.append(outs.tar_ll)

            if ctx_ll[0].dim() == 2:
                ctx_ll = torch.stack(ctx_ll)
                tar_ll = torch.stack(tar_ll)
            else:
                ctx_ll = torch.cat(ctx_ll)
                tar_ll = torch.cat(tar_ll)

            ctx_ll = logmeanexp(ctx_ll).mean()
            tar_ll = logmeanexp(tar_ll).mean()

            ravg.update('ctx_ll', ctx_ll)
            ravg.update('tar_ll', tar_ll)

    torch.manual_seed(time.time())
    torch.cuda.manual_seed(time.time())

    filename = 'ensemble'
    if args.hare_lynx:
        filename += '_hare_lynx'
    filename += '.log'
    logger = get_logger(osp.join(results_path, 'lotka_volterra', args.model,
                                 filename),
                        mode='w')
    logger.info(ravg.info())