示例#1
0
def evaluate(data_iter, model, pad_id):
    model.eval()
    data_iter.init_epoch()
    size = len(data_iter.data())
    seq_loss = 0.0
    bow_loss = 0.0
    kld_z = 0.0
    kld_t = 0.0
    seq_words = 0
    bow_words = 0
    for batch in data_iter:
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        results = model(inputs, lengths-1, pad_id)
        batch_seq = seq_recon_loss(
            results.seq_outputs, targets, pad_id
        )
        batch_bow = bow_recon_loss(
            results.bow_outputs, results.bow_targets
        )
        batch_kld_z = total_kld(results.posterior_z)
        batch_kld_t = total_kld(results.posterior_t,
                                results.prior_t).to(inputs.device)
        seq_loss += batch_seq.item() / size
        bow_loss += batch_bow.item() / size        
        kld_z += batch_kld_z.item() / size
        kld_t += batch_kld_t.item() / size
        seq_words += torch.sum(lengths-1).item()
        bow_words += torch.sum(results.bow_targets)
    seq_ppl = math.exp(seq_loss * size / seq_words)
    bow_ppl = math.exp(bow_loss * size / bow_words)
    return (seq_loss, bow_loss, kld_z, kld_t,
            seq_ppl, bow_ppl)
示例#2
0
def train_alt(data_iter, model, pad_id, optimizer, epoch):
    model.train()
    data_iter.init_epoch()
    size = min(len(data_iter.data()), args.epoch_size * args.batch_size)
    seq_loss = 0.0
    bow_loss = 0.0
    kld_z = 0.0
    kld_t = 0.0    
    seq_words = 0
    bow_words = 0
    for i, batch in enumerate(data_iter):
        if i == args.epoch_size:
            break
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        results = model(inputs, lengths-1, pad_id)
        batch_bow = bow_recon_loss(
            results.bow_outputs, results.bow_targets
        )
        batch_kld_t = total_kld(results.posterior_t,
                                results.prior_t).to(inputs.device)
        
        bow_loss += batch_bow.item() / size        
        kld_t += batch_kld_t.item() / size        
        bow_words += torch.sum(results.bow_targets)
        optimizer.zero_grad()
        kld_term = batch_kld_t
        loss = batch_bow + kld_term
        loss.backward()
        optimizer.step()
    data_iter.init_epoch()        
    for i, batch in enumerate(data_iter):
        if i == args.epoch_size:
            break
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        results = model(inputs, lengths-1, pad_id)
        batch_seq = seq_recon_loss(
            results.seq_outputs, targets, pad_id
        )
        batch_kld_z = total_kld(results.posterior_z)
        
        seq_loss += batch_seq.item() / size
        kld_z += batch_kld_z.item() / size
        seq_words += torch.sum(lengths-1).item()
        kld_weight = weight_schedule(args.epoch_size * (epoch - 1) + i) if args.kla else 1.
        optimizer.zero_grad()
        kld_term = batch_kld_z
        loss = batch_seq + kld_weight * kld_term
        loss.backward()
        optimizer.step()
    seq_ppl = math.exp(seq_loss * size / seq_words)
    bow_ppl = math.exp(bow_loss * size / bow_words)
    return (seq_loss, bow_loss, kld_z, kld_t,
            seq_ppl, bow_ppl)
示例#3
0
def evaluate(data_iter, model, pad_id):
    model.eval()
    data_iter.init_epoch()
    size = len(data_iter.data())
    seq_loss = 0.0
    bow_loss = 0.0
    kld = 0.0
    mi = 0.0
    tc = 0.0
    dwkl = 0.0
    seq_words = 0
    bow_words = 0
    log_c = 0
    mmd_loss = 0
    if args.multi:
        model = model.module
    for batch in data_iter:
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        (posterior, prior, z, seq_outputs, bow_outputs, bow_targets,
         log_copula, log_marginals) = model(inputs, lengths - 1, pad_id)
        batch_seq = seq_recon_loss(seq_outputs, targets, pad_id)
        batch_bow = bow_recon_loss(bow_outputs, bow_targets)
        # kld terms are averaged across the mini-batch
        batch_kld = total_kld(posterior, prior) / batch_size
        batch_mi, batch_tc, batch_dwkl = kld_decomp(posterior, prior, z)

        prior_samples = torch.randn(args.batch_size, z.size(-1)).to(z.device)
        mmd = compute_mmd(prior_samples, z)

        iw_nll = model.iw_nll(posterior, prior, inputs, targets, lengths - 1,
                              pad_id, args.nsamples)

        seq_loss += batch_seq.item() / size
        # bow_loss += batch_bow.item() / size
        bow_loss += iw_nll.item() * batch_size / size
        kld += batch_kld.item() * batch_size / size
        mi += batch_mi.item() * batch_size / size
        tc += batch_tc.item() * batch_size / size
        dwkl += batch_dwkl.item() * batch_size / size
        seq_words += torch.sum(lengths - 1).item()
        bow_words += torch.sum(bow_targets)
        log_c += torch.sum(log_copula).item() / size
        mmd_loss += mmd.item() / size
    # seq_ppl = math.exp(seq_loss * size / seq_words)
    seq_ppl = math.exp((seq_loss + kld - log_c) * size / seq_words)
    bow_ppl = math.exp(bow_loss * size / bow_words)
    return (seq_loss, bow_loss, kld, mi, tc, dwkl, seq_ppl, bow_ppl, log_c,
            mmd_loss)
示例#4
0
def train(data_iter, model, pad_id, optimizer, epoch):
    model.train()
    data_iter.init_epoch()
    batch_time = AverageMeter()
    size = min(len(data_iter.data()), args.epoch_size * args.batch_size)
    seq_loss = 0.0
    bow_loss = 0.0
    kld = 0.0
    mi = 0.0
    tc = 0.0
    dwkl = 0.0
    seq_words = 0
    bow_words = 0
    log_c = 0
    mmd_loss = 0
    end = time.time()
    # if args.multi:
    #     model = model.module
    for i, batch in enumerate(tqdm(data_iter)):
        if i == args.epoch_size:
            break
        texts, lengths = batch.text
        batch_size = texts.size(0)
        inputs = texts[:, :-1].clone()
        targets = texts[:, 1:].clone()
        posterior, prior, z, seq_outputs, bow_outputs, bow_targets, log_copula, log_marginals = model(
            inputs, lengths - 1, pad_id)
        # print('debug')
        # posterior, prior = model.get_dist()
        batch_seq = seq_recon_loss(seq_outputs, targets, pad_id)
        batch_bow = bow_recon_loss(bow_outputs, bow_targets)
        # kld terms are averaged across the mini-batch
        batch_kld = total_kld(posterior, prior) / batch_size
        batch_mi, batch_tc, batch_dwkl = kld_decomp(posterior, prior, z)

        prior_samples = torch.randn(args.batch_size, z.size(-1)).to(z.device)
        mmd = compute_mmd(prior_samples, z)

        kld_weight = weight_schedule(args.epoch_size * (epoch - 1) +
                                     i) if args.kla else 1.
        # kld_weight = min(1./10 * epoch, 1) if args.kla else 1
        if args.copa is True:
            # copula_weight = weight_schedule(args.epoch_size * (epoch- 1 )+ i)
            copula_weight = min(1. / args.epochs * epoch, 1)
            if args.copat is not None:
                copula_weight = copula_weight if copula_weight < args.copat else args.copat
        elif args.recopa is True:
            copula_weight = weight_schedule(args.epoch_size * (epoch - 1) + i)
            if args.copat is not None:
                copula_weight = 1 - copula_weight if copula_weight > args.copat else args.copat
            else:
                copula_weight = 1 - copula_weight
        else:
            copula_weight = args.cw
        optimizer.zero_grad()
        if args.decomp:
            kld_term = args.alpha * batch_mi + args.beta * batch_tc +\
                       args.gamma * batch_dwkl
        else:
            kld_term = batch_kld

        if args.copula is True:
            loss = (batch_seq + args.c * batch_bow
                    ) / batch_size + kld_weight * kld_term - copula_weight * (
                        log_copula.sum() +
                        1 * log_marginals.sum()) / batch_size
        else:
            loss = (batch_seq +
                    args.c * batch_bow) / batch_size + kld_weight * kld_term
        if args.mmd is True:
            loss += (2.5 - kld_weight) * mmd

        loss.backward()
        optimizer.step()

        seq_loss += batch_seq.item() / size
        bow_loss += batch_bow.item() / size
        kld += batch_kld.item() * batch_size / size
        mi += batch_mi.item() * batch_size / size
        tc += batch_tc.item() * batch_size / size
        dwkl += batch_dwkl.item() * batch_size / size
        seq_words += torch.sum(lengths - 1).item()
        bow_words += torch.sum(bow_targets)
        log_c += torch.sum(log_copula).item() / size
        mmd_loss += mmd.item() / size

        batch_time.update(time.time() - end)

    # seq_ppl = math.exp(seq_loss * size / seq_words)
    seq_ppl = math.exp((seq_loss + kld - log_c) * size / seq_words)
    # bow_ppl = math.exp(bow_loss * size / bow_words)
    bow_ppl = math.exp(bow_loss * size / seq_words)
    return (seq_loss, bow_loss, kld, mi, tc, dwkl, seq_ppl, bow_ppl, log_c,
            mmd_loss, batch_time)