def calc_ppl(sents, m): batches, _ = get_batches(sents, vocab, args.batch_size, device) total_nll = 0 with torch.no_grad(): for inputs, targets in batches: total_nll += model.nll_is(inputs, targets, m).sum().item() n_words = sum(len(s) + 1 for s in sents) # include <eos> return total_nll / len(sents), np.exp(total_nll / n_words)
def encode(sents): batches, order = get_batches(sents, vocab, batch_size, device) z = [] for inputs, _ in batches: mu, logvar = model.encode(inputs) zi = reparameterize(mu, logvar) z.append(zi.detach().cpu().numpy()) z = np.concatenate(z, axis=0) z_ = np.zeros_like(z) z_[np.array(order)] = z return z_
def encode(sents): assert args.enc == 'mu' or args.enc == 'z' # print("args.batch_size", args.batch_size) batches, order = get_batches(sents, vocab, args.batch_size, device) z = [] for inputs, _ in batches: mu, logvar, _, _ = model(inputs) if args.enc == 'mu': zi = mu else: zi = reparameterize(mu, logvar) # print("zi", zi.size()) z.append(zi.detach().cpu().numpy()) z = np.concatenate(z, axis=0) z_ = np.zeros_like(z) z_[np.array(order)] = z return z_
total_nll += model.nll_is(inputs, targets, m).sum().item() n_words = sum(len(s) + 1 for s in sents) # include <eos> return total_nll / len(sents), np.exp(total_nll / n_words) if __name__ == '__main__': args = parser.parse_args() vocab = Vocab(os.path.join(args.checkpoint, 'vocab.txt')) set_seed(args.seed) cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if cuda else "cpu") model = get_model(os.path.join(args.checkpoint, 'model.pt')) if args.evaluate: sents = load_sent(args.data) batches, _ = get_batches(sents, vocab, args.batch_size, device) meters = evaluate(model, batches) print(' '.join([ '{} {:.2f},'.format(k, meter.avg) for k, meter in meters.items() ])) if args.ppl: sents = load_sent(args.data) nll, ppl = calc_ppl(sents, args.m) print('NLL {:.2f}, PPL {:.2f}'.format(nll, ppl)) if args.sample: z = np.random.normal(size=(args.n, model.args.dim_z)).astype('f') sents = decode(z) write_sent(sents, os.path.join(args.checkpoint, args.output))
def main(args): if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) log_file = os.path.join(args.save_dir, 'log.txt') logging(str(args), log_file) # Prepare data train_sents = load_sent(args.train) logging( '# train sents {}, tokens {}'.format(len(train_sents), sum(len(s) for s in train_sents)), log_file) valid_sents = load_sent(args.valid) logging( '# valid sents {}, tokens {}'.format(len(valid_sents), sum(len(s) for s in valid_sents)), log_file) vocab_file = os.path.join(args.save_dir, 'vocab.txt') # if not os.path.isfile(vocab_file): # Vocab.build(train_sents, vocab_file, args.vocab_size) Vocab.build(train_sents, vocab_file, args.vocab_size) vocab = Vocab(vocab_file) logging('# vocab size {}'.format(vocab.size), log_file) set_seed(args.seed) cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device('cuda' if cuda else 'cpu') model = { 'dae': DAE, 'vae': VAE, 'aae': AAE }[args.model_type](vocab, args).to(device) if args.load_model: ckpt = torch.load(args.load_model) model.load_state_dict(ckpt['model']) model.flatten() logging( '# model parameters: {}'.format( sum(x.data.nelement() for x in model.parameters())), log_file) train_batches, _ = get_batches(train_sents, vocab, args.batch_size, device) valid_batches, _ = get_batches(valid_sents, vocab, args.batch_size, device) best_val_loss = None for epoch in range(args.epochs): start_time = time.time() logging('-' * 80, log_file) model.train() meters = collections.defaultdict(lambda: AverageMeter()) indices = list(range(len(train_batches))) random.shuffle(indices) for i, idx in enumerate(indices): inputs, targets = train_batches[idx] losses = model.autoenc(inputs, targets, is_train=True) losses['loss'] = model.loss(losses) model.step(losses) for k, v in losses.items(): meters[k].update(v.item()) if (i + 1) % args.log_interval == 0: log_output = '| epoch {:3d} | {:5d}/{:5d} batches |'.format( epoch + 1, i + 1, len(indices)) for k, meter in meters.items(): log_output += ' {} {:.2f},'.format(k, meter.avg) meter.clear() logging(log_output, log_file) valid_meters = evaluate(model, valid_batches) logging('-' * 80, log_file) log_output = '| end of epoch {:3d} | time {:5.0f}s | valid'.format( epoch + 1, time.time() - start_time) for k, meter in valid_meters.items(): log_output += ' {} {:.2f},'.format(k, meter.avg) if not best_val_loss or valid_meters['loss'].avg < best_val_loss: log_output += ' | saving model' ckpt = {'args': args, 'model': model.state_dict()} torch.save(ckpt, os.path.join(args.save_dir, 'model.pt')) best_val_loss = valid_meters['loss'].avg logging(log_output, log_file) logging('Done training', log_file)