def evaluate(data_iter, model, pad_id): model.eval() data_iter.init_epoch() size = len(data_iter.data()) seq_loss = 0.0 bow_loss = 0.0 kld_z = 0.0 kld_t = 0.0 seq_words = 0 bow_words = 0 for batch in data_iter: texts, lengths = batch.text batch_size = texts.size(0) inputs = texts[:, :-1].clone() targets = texts[:, 1:].clone() results = model(inputs, lengths-1, pad_id) batch_seq = seq_recon_loss( results.seq_outputs, targets, pad_id ) batch_bow = bow_recon_loss( results.bow_outputs, results.bow_targets ) batch_kld_z = total_kld(results.posterior_z) batch_kld_t = total_kld(results.posterior_t, results.prior_t).to(inputs.device) seq_loss += batch_seq.item() / size bow_loss += batch_bow.item() / size kld_z += batch_kld_z.item() / size kld_t += batch_kld_t.item() / size seq_words += torch.sum(lengths-1).item() bow_words += torch.sum(results.bow_targets) seq_ppl = math.exp(seq_loss * size / seq_words) bow_ppl = math.exp(bow_loss * size / bow_words) return (seq_loss, bow_loss, kld_z, kld_t, seq_ppl, bow_ppl)
def train_alt(data_iter, model, pad_id, optimizer, epoch): model.train() data_iter.init_epoch() size = min(len(data_iter.data()), args.epoch_size * args.batch_size) seq_loss = 0.0 bow_loss = 0.0 kld_z = 0.0 kld_t = 0.0 seq_words = 0 bow_words = 0 for i, batch in enumerate(data_iter): if i == args.epoch_size: break texts, lengths = batch.text batch_size = texts.size(0) inputs = texts[:, :-1].clone() targets = texts[:, 1:].clone() results = model(inputs, lengths-1, pad_id) batch_bow = bow_recon_loss( results.bow_outputs, results.bow_targets ) batch_kld_t = total_kld(results.posterior_t, results.prior_t).to(inputs.device) bow_loss += batch_bow.item() / size kld_t += batch_kld_t.item() / size bow_words += torch.sum(results.bow_targets) optimizer.zero_grad() kld_term = batch_kld_t loss = batch_bow + kld_term loss.backward() optimizer.step() data_iter.init_epoch() for i, batch in enumerate(data_iter): if i == args.epoch_size: break texts, lengths = batch.text batch_size = texts.size(0) inputs = texts[:, :-1].clone() targets = texts[:, 1:].clone() results = model(inputs, lengths-1, pad_id) batch_seq = seq_recon_loss( results.seq_outputs, targets, pad_id ) batch_kld_z = total_kld(results.posterior_z) seq_loss += batch_seq.item() / size kld_z += batch_kld_z.item() / size seq_words += torch.sum(lengths-1).item() kld_weight = weight_schedule(args.epoch_size * (epoch - 1) + i) if args.kla else 1. optimizer.zero_grad() kld_term = batch_kld_z loss = batch_seq + kld_weight * kld_term loss.backward() optimizer.step() seq_ppl = math.exp(seq_loss * size / seq_words) bow_ppl = math.exp(bow_loss * size / bow_words) return (seq_loss, bow_loss, kld_z, kld_t, seq_ppl, bow_ppl)
def evaluate(data_iter, model, pad_id): model.eval() data_iter.init_epoch() size = len(data_iter.data()) seq_loss = 0.0 bow_loss = 0.0 kld = 0.0 mi = 0.0 tc = 0.0 dwkl = 0.0 seq_words = 0 bow_words = 0 log_c = 0 mmd_loss = 0 if args.multi: model = model.module for batch in data_iter: texts, lengths = batch.text batch_size = texts.size(0) inputs = texts[:, :-1].clone() targets = texts[:, 1:].clone() (posterior, prior, z, seq_outputs, bow_outputs, bow_targets, log_copula, log_marginals) = model(inputs, lengths - 1, pad_id) batch_seq = seq_recon_loss(seq_outputs, targets, pad_id) batch_bow = bow_recon_loss(bow_outputs, bow_targets) # kld terms are averaged across the mini-batch batch_kld = total_kld(posterior, prior) / batch_size batch_mi, batch_tc, batch_dwkl = kld_decomp(posterior, prior, z) prior_samples = torch.randn(args.batch_size, z.size(-1)).to(z.device) mmd = compute_mmd(prior_samples, z) iw_nll = model.iw_nll(posterior, prior, inputs, targets, lengths - 1, pad_id, args.nsamples) seq_loss += batch_seq.item() / size # bow_loss += batch_bow.item() / size bow_loss += iw_nll.item() * batch_size / size kld += batch_kld.item() * batch_size / size mi += batch_mi.item() * batch_size / size tc += batch_tc.item() * batch_size / size dwkl += batch_dwkl.item() * batch_size / size seq_words += torch.sum(lengths - 1).item() bow_words += torch.sum(bow_targets) log_c += torch.sum(log_copula).item() / size mmd_loss += mmd.item() / size # seq_ppl = math.exp(seq_loss * size / seq_words) seq_ppl = math.exp((seq_loss + kld - log_c) * size / seq_words) bow_ppl = math.exp(bow_loss * size / bow_words) return (seq_loss, bow_loss, kld, mi, tc, dwkl, seq_ppl, bow_ppl, log_c, mmd_loss)
def train(data_iter, model, pad_id, optimizer, epoch): model.train() data_iter.init_epoch() batch_time = AverageMeter() size = min(len(data_iter.data()), args.epoch_size * args.batch_size) seq_loss = 0.0 bow_loss = 0.0 kld = 0.0 mi = 0.0 tc = 0.0 dwkl = 0.0 seq_words = 0 bow_words = 0 log_c = 0 mmd_loss = 0 end = time.time() # if args.multi: # model = model.module for i, batch in enumerate(tqdm(data_iter)): if i == args.epoch_size: break texts, lengths = batch.text batch_size = texts.size(0) inputs = texts[:, :-1].clone() targets = texts[:, 1:].clone() posterior, prior, z, seq_outputs, bow_outputs, bow_targets, log_copula, log_marginals = model( inputs, lengths - 1, pad_id) # print('debug') # posterior, prior = model.get_dist() batch_seq = seq_recon_loss(seq_outputs, targets, pad_id) batch_bow = bow_recon_loss(bow_outputs, bow_targets) # kld terms are averaged across the mini-batch batch_kld = total_kld(posterior, prior) / batch_size batch_mi, batch_tc, batch_dwkl = kld_decomp(posterior, prior, z) prior_samples = torch.randn(args.batch_size, z.size(-1)).to(z.device) mmd = compute_mmd(prior_samples, z) kld_weight = weight_schedule(args.epoch_size * (epoch - 1) + i) if args.kla else 1. # kld_weight = min(1./10 * epoch, 1) if args.kla else 1 if args.copa is True: # copula_weight = weight_schedule(args.epoch_size * (epoch- 1 )+ i) copula_weight = min(1. / args.epochs * epoch, 1) if args.copat is not None: copula_weight = copula_weight if copula_weight < args.copat else args.copat elif args.recopa is True: copula_weight = weight_schedule(args.epoch_size * (epoch - 1) + i) if args.copat is not None: copula_weight = 1 - copula_weight if copula_weight > args.copat else args.copat else: copula_weight = 1 - copula_weight else: copula_weight = args.cw optimizer.zero_grad() if args.decomp: kld_term = args.alpha * batch_mi + args.beta * batch_tc +\ args.gamma * batch_dwkl else: kld_term = batch_kld if args.copula is True: loss = (batch_seq + args.c * batch_bow ) / batch_size + kld_weight * kld_term - copula_weight * ( log_copula.sum() + 1 * log_marginals.sum()) / batch_size else: loss = (batch_seq + args.c * batch_bow) / batch_size + kld_weight * kld_term if args.mmd is True: loss += (2.5 - kld_weight) * mmd loss.backward() optimizer.step() seq_loss += batch_seq.item() / size bow_loss += batch_bow.item() / size kld += batch_kld.item() * batch_size / size mi += batch_mi.item() * batch_size / size tc += batch_tc.item() * batch_size / size dwkl += batch_dwkl.item() * batch_size / size seq_words += torch.sum(lengths - 1).item() bow_words += torch.sum(bow_targets) log_c += torch.sum(log_copula).item() / size mmd_loss += mmd.item() / size batch_time.update(time.time() - end) # seq_ppl = math.exp(seq_loss * size / seq_words) seq_ppl = math.exp((seq_loss + kld - log_c) * size / seq_words) # bow_ppl = math.exp(bow_loss * size / bow_words) bow_ppl = math.exp(bow_loss * size / seq_words) return (seq_loss, bow_loss, kld, mi, tc, dwkl, seq_ppl, bow_ppl, log_c, mmd_loss, batch_time)