Exemplo n.º 1
0
def evaluate(data_iter, model, pad_id):
    model.eval()
    data_iter.init_epoch()
    size = len(data_iter.data())
    seq_loss = 0.0
    bow_loss = 0.0
    kld_z = 0.0
    kld_t = 0.0
    seq_words = 0
    bow_words = 0
    for batch in data_iter:
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        results = model(inputs, lengths-1, pad_id)
        batch_seq = seq_recon_loss(
            results.seq_outputs, targets, pad_id
        )
        batch_bow = bow_recon_loss(
            results.bow_outputs, results.bow_targets
        )
        batch_kld_z = total_kld(results.posterior_z)
        batch_kld_t = total_kld(results.posterior_t,
                                results.prior_t).to(inputs.device)
        seq_loss += batch_seq.item() / size
        bow_loss += batch_bow.item() / size        
        kld_z += batch_kld_z.item() / size
        kld_t += batch_kld_t.item() / size
        seq_words += torch.sum(lengths-1).item()
        bow_words += torch.sum(results.bow_targets)
    seq_ppl = math.exp(seq_loss * size / seq_words)
    bow_ppl = math.exp(bow_loss * size / bow_words)
    return (seq_loss, bow_loss, kld_z, kld_t,
            seq_ppl, bow_ppl)
Exemplo n.º 2
0
def train_alt(data_iter, model, pad_id, optimizer, epoch):
    model.train()
    data_iter.init_epoch()
    size = min(len(data_iter.data()), args.epoch_size * args.batch_size)
    seq_loss = 0.0
    bow_loss = 0.0
    kld_z = 0.0
    kld_t = 0.0    
    seq_words = 0
    bow_words = 0
    for i, batch in enumerate(data_iter):
        if i == args.epoch_size:
            break
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        results = model(inputs, lengths-1, pad_id)
        batch_bow = bow_recon_loss(
            results.bow_outputs, results.bow_targets
        )
        batch_kld_t = total_kld(results.posterior_t,
                                results.prior_t).to(inputs.device)
        
        bow_loss += batch_bow.item() / size        
        kld_t += batch_kld_t.item() / size        
        bow_words += torch.sum(results.bow_targets)
        optimizer.zero_grad()
        kld_term = batch_kld_t
        loss = batch_bow + kld_term
        loss.backward()
        optimizer.step()
    data_iter.init_epoch()        
    for i, batch in enumerate(data_iter):
        if i == args.epoch_size:
            break
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        results = model(inputs, lengths-1, pad_id)
        batch_seq = seq_recon_loss(
            results.seq_outputs, targets, pad_id
        )
        batch_kld_z = total_kld(results.posterior_z)
        
        seq_loss += batch_seq.item() / size
        kld_z += batch_kld_z.item() / size
        seq_words += torch.sum(lengths-1).item()
        kld_weight = weight_schedule(args.epoch_size * (epoch - 1) + i) if args.kla else 1.
        optimizer.zero_grad()
        kld_term = batch_kld_z
        loss = batch_seq + kld_weight * kld_term
        loss.backward()
        optimizer.step()
    seq_ppl = math.exp(seq_loss * size / seq_words)
    bow_ppl = math.exp(bow_loss * size / bow_words)
    return (seq_loss, bow_loss, kld_z, kld_t,
            seq_ppl, bow_ppl)
Exemplo n.º 3
0
def baseline(args, data_iter, model, optimizer, epoch, train=True):
    batch_size = args.batch_size
    size = len(data_iter.dataset)
    if train:
        model.train()
    else:
        model.eval()
    # data_iter.init_epoch()
    re_loss = 0
    r_re_loss = 0
    kl_divergence = 0
    r_kl_divergence = 0
    discriminator_loss = 0
    nll = 0
    for i, (data, label) in enumerate(data_iter):
        data = data.to(args.device)
        disloss = torch.zeros(1).to(args.device)

        if train:
            recon, q_z, p_z, z = model(data)
            recon = recon.view(-1, data.size(-2), data.size(-1))
            reloss = recon_loss(recon, data)  # sum over batch
            kld = total_kld(q_z, p_z)  # sum over batch
            optimizer.zero_grad()
            loss = (reloss + kld) / batch_size
            loss.backward()
            optimizer.step()
        else:
            angles = torch.randint(0, 3, (data.size(0), )).to(args.device)
            r_data = batch_rotate(data.clone(), angles)
            r_recon, r_qz, r_pz, r_z = model(r_data)
            r_recon = r_recon.view(-1, 1, data.size(-2), data.size(-1))
            reloss = recon_loss(r_recon, r_data)
            kld = total_kld(r_qz, r_pz)

        re_loss += reloss.item() / size
        kl_divergence += kld.item() / size
        discriminator_loss += disloss.item() / size

    nll = re_loss + kl_divergence
    return nll, re_loss, kl_divergence, discriminator_loss
Exemplo n.º 4
0
def evaluate(data_iter, model, pad_id):
    model.eval()
    data_iter.init_epoch()
    size = len(data_iter.data())
    seq_loss = 0.0
    bow_loss = 0.0
    kld = 0.0
    mi = 0.0
    tc = 0.0
    dwkl = 0.0
    seq_words = 0
    bow_words = 0
    log_c = 0
    mmd_loss = 0
    if args.multi:
        model = model.module
    for batch in data_iter:
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        (posterior, prior, z, seq_outputs, bow_outputs, bow_targets,
         log_copula, log_marginals) = model(inputs, lengths - 1, pad_id)
        batch_seq = seq_recon_loss(seq_outputs, targets, pad_id)
        batch_bow = bow_recon_loss(bow_outputs, bow_targets)
        # kld terms are averaged across the mini-batch
        batch_kld = total_kld(posterior, prior) / batch_size
        batch_mi, batch_tc, batch_dwkl = kld_decomp(posterior, prior, z)

        prior_samples = torch.randn(args.batch_size, z.size(-1)).to(z.device)
        mmd = compute_mmd(prior_samples, z)

        iw_nll = model.iw_nll(posterior, prior, inputs, targets, lengths - 1,
                              pad_id, args.nsamples)

        seq_loss += batch_seq.item() / size
        # bow_loss += batch_bow.item() / size
        bow_loss += iw_nll.item() * batch_size / size
        kld += batch_kld.item() * batch_size / size
        mi += batch_mi.item() * batch_size / size
        tc += batch_tc.item() * batch_size / size
        dwkl += batch_dwkl.item() * batch_size / size
        seq_words += torch.sum(lengths - 1).item()
        bow_words += torch.sum(bow_targets)
        log_c += torch.sum(log_copula).item() / size
        mmd_loss += mmd.item() / size
    # seq_ppl = math.exp(seq_loss * size / seq_words)
    seq_ppl = math.exp((seq_loss + kld - log_c) * size / seq_words)
    bow_ppl = math.exp(bow_loss * size / bow_words)
    return (seq_loss, bow_loss, kld, mi, tc, dwkl, seq_ppl, bow_ppl, log_c,
            mmd_loss)
Exemplo n.º 5
0
def main(args):
    print("Loading data")
    dataset = args.data.rstrip('/').split('/')[-1]
    torch.cuda.set_device(args.cuda)
    device = args.device
    if dataset == 'mnist':
        train_loader, test_loader = get_mnist(args.batch_size, 'data/mnist')
        num = 10
    elif dataset == 'fashion':
        train_loader, test_loader = get_fashion_mnist(args.batch_size,
                                                      'data/fashion')
        num = 10
    elif dataset == 'svhn':
        train_loader, test_loader, _ = get_svhn(args.batch_size, 'data/svhn')
        num = 10
    elif dataset == 'stl':
        train_loader, test_loader, _ = get_stl10(args.batch_size, 'data/stl10')
    elif dataset == 'cifar':
        train_loader, test_loader = get_cifar(args.batch_size, 'data/cifar')
        num = 10
    elif dataset == 'chair':
        train_loader, test_loader = get_chair(args.batch_size,
                                              '~/data/rendered_chairs')
        num = 1393
    elif dataset == 'yale':
        train_loader, test_loader = get_yale(args.batch_size, 'data/yale')
        num = 38
    model = VAE(28 * 28, args.code_dim, args.batch_size, num,
                dataset).to(device)
    phi = nn.Sequential(
        nn.Linear(args.code_dim, args.phi_dim),
        nn.LeakyReLU(0.2, True),
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    optimizer_phi = torch.optim.Adam(phi.parameters(), lr=args.lr)
    criterion = nn.MSELoss(reduction='sum')
    for epoch in range(args.epochs):
        re_loss = 0
        kl_div = 0
        size = len(train_loader.dataset)
        for data, target in train_loader:
            data, target = data.squeeze(1).to(device), target.to(device)
            c = F.one_hot(target.long(), num_classes=num).float()
            output, q_z, p_z, z = model(data, c)
            hsic = HSIC(phi(z), target.long(), num)
            if dataset == 'mnist' or dataset == 'fashion':
                reloss = recon_loss(output, data.view(-1, 28 * 28))
            else:
                reloss = criterion(output, data)
            kld = total_kld(q_z, p_z)
            loss = reloss + kld + args.c * hsic

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            optimizer_phi.zero_grad()
            neg = -HSIC(phi(z.detach()), target.long(), num)
            neg.backward()
            optimizer_phi.step()

            re_loss += reloss.item() / size
            kl_div += kld.item() / size
        print('-' * 50)
        print(
            " Epoch {} |re loss {:5.2f} | kl div {:5.2f} | hs {:5.2f}".format(
                epoch, re_loss, kl_div, hsic))
    for data, target in test_loader:
        data, target = data.squeeze(1).to(device), target.to(device)
        c = F.one_hot(target.long(), num_classes=num).float()
        output, _, _, z = model(data, c)
        break
    if dataset == 'mnist' or dataset == 'fashion':
        img_size = [data.size(0), 1, 28, 28]
    else:
        img_size = [data.size(0), 3, 32, 32]
    images = [data.view(img_size)[:30].cpu()]
    for i in range(10):
        c = F.one_hot(torch.ones(z.size(0)).long() * i,
                      num_classes=num).float().to(device)
        output = model.decoder(torch.cat((z, c), dim=-1))
        images.append(output.view(img_size)[:30].cpu())
    images = torch.cat(images, dim=0)
    save_image(images,
               'imgs/recon_c{}_{}.png'.format(int(args.c), dataset),
               nrow=30)
    torch.save(model.state_dict(),
               'vae_c{}_{}.pt'.format(int(args.c), dataset))
    # z = p_z.sample()
    # for i in range(10):
    #     c = F.one_hot(torch.ones(z.size(0)).long()*i, num_classes=10).float().to(device)
    #     output = model.decoder(torch.cat((z, c), dim=-1))
    #     n = min(z.size(0), 8)
    #     save_image(output.view(z.size(0), 1, 28, 28)[:n].cpu(), 'imgs/recon_{}.png'.format(i), nrow=n)
    if args.tsne:
        datas, targets = [], []
        for i, (data, target) in enumerate(test_loader):
            datas.append(data), targets.append(target)
            if i >= 5:
                break
        data, target = torch.cat(datas, dim=0), torch.cat(targets, dim=0)
        c = F.one_hot(target.long(), num_classes=num).float()
        _, _, _, z = model(data.to(args.device), c.to(args.device))
        z, target = z.detach().cpu().numpy(), target.cpu().numpy()
        tsne = TSNE(n_components=2, init='pca', random_state=0)
        z_2d = tsne.fit_transform(z)
        plt.figure(figsize=(6, 5))
        plot_embedding(z_2d, target)
        plt.savefig('tsnes/tsne_c{}_{}.png'.format(int(args.c), dataset))
Exemplo n.º 6
0
def run(args, data_iter, model, optimizer, epoch, train=True):
    batch_size = args.batch_size
    size = len(data_iter.dataset)
    if train:
        model.train()
    else:
        model.eval()
    # data_iter.init_epoch()
    re_loss = 0
    kl_divergence = 0
    discriminator_loss = 0
    nll = 0
    for i, (data, label) in enumerate(data_iter):
        data = data.to(args.device)
        recon, q_z, p_z, z = model(data)
        recon = recon.view(-1, data.size(-2), data.size(-1))
        reloss = recon_loss(recon, data)  # sum over batch
        kld = total_kld(q_z, p_z)  # sum over batch
        disloss = torch.zeros(1).to(args.device)

        if args.ro:
            disloss, r_reloss, r_kld = [], [], []
            for d in range(1, len(rotations)):
                angles = torch.tensor([d],
                                      dtype=torch.long,
                                      device=args.device).expand(data.size(0))
                r_data = batch_rotate(data.clone(), angles)
                r_recon, r_qz, r_pz, r_z = model(r_data)
                r_recon = r_recon.view(-1, 1, data.size(-2), data.size(-1))
                D_z = D(r_z)
                disloss.append(disc_loss(D_z, angles))  # sum over batch
                r_reloss.append(recon_loss(r_recon, r_data))
                r_kld.append(total_kld(r_qz, r_pz))
            disloss = sum(disloss)  # / (len(rotations)-1)
            r_reloss = sum(r_reloss)  # / (len(rotations)-1)
            r_kld = sum(r_kld)  # / (len(rotations)-1)

            # angles = torch.randint(0, 3, (data.size(0), )).to(args.device)
            # r_data = batch_rotate(data.clone(), angles)
            # r_recon, r_qz, r_pz, r_z = model(r_data)
            # r_recon = r_recon.view(-1, 1, data.size(-2), data.size(-1))
            # D_z = D(r_z)
            # disloss = disc_loss(D_z, angles) # sum over batch
            # r_reloss = recon_loss(r_recon, r_data)
            # r_kld = total_kld(r_qz, r_pz)

        if train:
            if args.ro:
                optimizer_D.zero_grad()
                D_loss = disloss / batch_size
                D_loss.backward(retain_graph=True)
                optimizer_D.step()

                optimizer.zero_grad()
                loss = (reloss + kld + r_reloss + r_kld - disloss) / batch_size
                loss.backward()
                optimizer.step()
            else:
                optimizer.zero_grad()
                loss = (reloss + kld) / batch_size
                loss.backward()
                optimizer.step()

        re_loss += reloss.item() / size
        kl_divergence += kld.item() / size
        discriminator_loss += disloss.item() / size

    nll = re_loss + kl_divergence
    return nll, re_loss, kl_divergence, discriminator_loss
Exemplo n.º 7
0
def run(args, data_iter, model, pad_id, optimizer, epoch, train=True):
    if train is True:
        model.train()
    else:
        model.eval()
    data_iter.init_epoch()
    batch_time = AverageMeter()
    size = min(len(data_iter.data()), args.epoch_size * args.batch_size)
    re_loss = 0
    kl_divergence = 0
    flow_kl_divergence = 0
    mutual_information1, mutual_information2 = 0, 0
    seq_words = 0
    mmd_loss = 0
    negative_ll = 0
    iw_negative_ll = 0
    sum_log_j = 0
    start = time.time()
    end = time.time()
    for i, batch in enumerate(data_iter):
        if i == args.epoch_size:
            break
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        q_z, p_z, z, outputs, sum_log_jacobian, penalty, z0 = model(
            inputs, lengths - 1, pad_id)
        if args.loss_type == 'entropy':
            reloss = recon_loss(outputs, targets, pad_id, id=args.loss_type)
        else:
            reloss = recon_loss(inputs, outputs, pad_id, id=args.loss_type)

        kld = total_kld(q_z, p_z)

        if args.flow:
            f_kld = flow_kld(q_z, p_z, z, z0, sum_log_jacobian)
        else:
            f_kld = torch.zeros(1)

        mi_z = mutual_info(q_z, p_z, z0)
        nll = compute_nll(q_z, p_z, z, z0, sum_log_jacobian, reloss)

        if args.iw:
            iw_nll = model.iw_nll(q_z, p_z, inputs, targets, lengths - 1,
                                  pad_id, args.nsamples)
        else:
            iw_nll = torch.zeros(1)

        if args.flow:
            mi_flow = mutual_info_flow(q_z, p_z, z, z0, sum_log_jacobian)
        else:
            mi_flow = torch.zeros(1).to(z.device)

        mmd = torch.zeros(1).to(z.device)
        kld_weight = weight_schedule(args.epoch_size * (epoch - 1) +
                                     i) if args.kla else 1.
        if args.mmd:
            # prior_samples = torch.randn(z.size(0), z.size(-1)).to(z.device)
            mmd = compute_mmd(p_z, q_z, args.kernel)
        if kld_weight > args.t:
            kld_weight = args.t
        if args.nokld:
            kld_weight = 0

        if train is True:
            optimizer.zero_grad()
            if args.flow:
                # loss = reloss / batch_size + kld_weight * (kld - torch.sum(sum_log_jacobian) + torch.sum(penalty)) / batch_size + (args.mmd_w - kld_weight) * mmd
                loss = reloss / batch_size + kld_weight * (q_z.log_prob(
                    z0).sum() - p_z.log_prob(z).sum()) / batch_size - (
                        torch.sum(sum_log_jacobian) - torch.sum(penalty)
                    ) / batch_size + (args.mmd_w - kld_weight) * mmd
            else:
                loss = (reloss + kld_weight * kld) / batch_size + (
                    args.mmd_w - kld_weight) * mmd

            loss.backward()
            optimizer.step()

        re_loss += reloss.item() / size
        kl_divergence += kld.item() / size
        flow_kl_divergence += f_kld.item() * batch_size / size
        mutual_information1 += mi_z.item() * batch_size / size
        mutual_information2 += mi_flow.item() * batch_size / size
        seq_words += torch.sum(lengths - 1).item()
        mmd_loss += mmd.item() * batch_size / size
        negative_ll += nll.item() * batch_size / size
        iw_negative_ll += iw_nll.item() * batch_size / size
        sum_log_j += torch.sum(sum_log_jacobian).item() / size
        batch_time.update(time.time() - end)

    if kl_divergence > 100:
        kl_divergence = 100
        flow_kl_divergence = 100
    if args.iw:
        nll_ppl = math.exp(iw_negative_ll * size / seq_words)
    else:
        nll_ppl = math.exp(negative_ll * size / seq_words)

    return re_loss, kl_divergence, flow_kl_divergence, mutual_information1, mutual_information2, mmd_loss, nll_ppl, negative_ll, iw_negative_ll, sum_log_j, start, batch_time
Exemplo n.º 8
0
def train(data_iter, model, pad_id, optimizer, epoch):
    model.train()
    data_iter.init_epoch()
    batch_time = AverageMeter()
    size = min(len(data_iter.data()), args.epoch_size * args.batch_size)
    seq_loss = 0.0
    bow_loss = 0.0
    kld = 0.0
    mi = 0.0
    tc = 0.0
    dwkl = 0.0
    seq_words = 0
    bow_words = 0
    log_c = 0
    mmd_loss = 0
    end = time.time()
    # if args.multi:
    #     model = model.module
    for i, batch in enumerate(tqdm(data_iter)):
        if i == args.epoch_size:
            break
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        posterior, prior, z, seq_outputs, bow_outputs, bow_targets, log_copula, log_marginals = model(
            inputs, lengths - 1, pad_id)
        # print('debug')
        # posterior, prior = model.get_dist()
        batch_seq = seq_recon_loss(seq_outputs, targets, pad_id)
        batch_bow = bow_recon_loss(bow_outputs, bow_targets)
        # kld terms are averaged across the mini-batch
        batch_kld = total_kld(posterior, prior) / batch_size
        batch_mi, batch_tc, batch_dwkl = kld_decomp(posterior, prior, z)

        prior_samples = torch.randn(args.batch_size, z.size(-1)).to(z.device)
        mmd = compute_mmd(prior_samples, z)

        kld_weight = weight_schedule(args.epoch_size * (epoch - 1) +
                                     i) if args.kla else 1.
        # kld_weight = min(1./10 * epoch, 1) if args.kla else 1
        if args.copa is True:
            # copula_weight = weight_schedule(args.epoch_size * (epoch- 1 )+ i)
            copula_weight = min(1. / args.epochs * epoch, 1)
            if args.copat is not None:
                copula_weight = copula_weight if copula_weight < args.copat else args.copat
        elif args.recopa is True:
            copula_weight = weight_schedule(args.epoch_size * (epoch - 1) + i)
            if args.copat is not None:
                copula_weight = 1 - copula_weight if copula_weight > args.copat else args.copat
            else:
                copula_weight = 1 - copula_weight
        else:
            copula_weight = args.cw
        optimizer.zero_grad()
        if args.decomp:
            kld_term = args.alpha * batch_mi + args.beta * batch_tc +\
                       args.gamma * batch_dwkl
        else:
            kld_term = batch_kld

        if args.copula is True:
            loss = (batch_seq + args.c * batch_bow
                    ) / batch_size + kld_weight * kld_term - copula_weight * (
                        log_copula.sum() +
                        1 * log_marginals.sum()) / batch_size
        else:
            loss = (batch_seq +
                    args.c * batch_bow) / batch_size + kld_weight * kld_term
        if args.mmd is True:
            loss += (2.5 - kld_weight) * mmd

        loss.backward()
        optimizer.step()

        seq_loss += batch_seq.item() / size
        bow_loss += batch_bow.item() / size
        kld += batch_kld.item() * batch_size / size
        mi += batch_mi.item() * batch_size / size
        tc += batch_tc.item() * batch_size / size
        dwkl += batch_dwkl.item() * batch_size / size
        seq_words += torch.sum(lengths - 1).item()
        bow_words += torch.sum(bow_targets)
        log_c += torch.sum(log_copula).item() / size
        mmd_loss += mmd.item() / size

        batch_time.update(time.time() - end)

    # seq_ppl = math.exp(seq_loss * size / seq_words)
    seq_ppl = math.exp((seq_loss + kld - log_c) * size / seq_words)
    # bow_ppl = math.exp(bow_loss * size / bow_words)
    bow_ppl = math.exp(bow_loss * size / seq_words)
    return (seq_loss, bow_loss, kld, mi, tc, dwkl, seq_ppl, bow_ppl, log_c,
            mmd_loss, batch_time)