def main(args): def interpolate(start, end, steps): interpolation = np.zeros((start.shape[0], steps + 2)) for dim, (s, e) in enumerate(zip(start, end)): interpolation[dim] = np.linspace(s, e, steps + 2) return interpolation.T def idx2word(sent_list, i2w, pad_idx): sent = [] for s in sent_list: sent.append(" ".join([i2w[str(int(idx))] \ for idx in s if int(idx) is not pad_idx])) return sent with open(args.data_dir + '/vocab.json', 'r') as file: vocab = json.load(file) w2i, i2w = vocab['w2i'], vocab['i2w'] #Load model model = SVAE( vocab_size=len(w2i), embed_dim=args.embedding_dimension, hidden_dim=args.hidden_dimension, latent_dim=args.latent_dimension, teacher_forcing=False, dropout=args.dropout, n_direction=(2 if args.bidirectional else 1), n_parallel=args.n_layer, max_src_len=args.max_src_length, #influence in inference stage max_tgt_len=args.max_tgt_length, sos_idx=w2i['<sos>'], eos_idx=w2i['<eos>'], pad_idx=w2i['<pad>'], unk_idx=w2i['<unk>'], ) path = os.path.join('checkpoint', args.load_checkpoint) if not os.path.exists(path): raise FileNotFoundError(path) model.load_state_dict(torch.load(path)) print("Model loaded from %s" % (path)) if torch.cuda.is_available(): model = model.cuda() model.eval() samples, z = model.inference(n=args.num_samples) print('----------SAMPLES----------') print(*idx2word(sent_list=samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n') z1 = torch.randn([args.latent_dimension]).numpy() z2 = torch.randn([args.latent_dimension]).numpy() z = torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float() samples, _ = model.inference(z=z) print('-------INTERPOLATION-------') print(*idx2word(sent_list=samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
def main(args): splits = ['train', 'valid'] + (['dev'] if args.test else []) print(args) #Load dataset datasets = OrderedDict() for split in splits: datasets[split] = seq_data(data_dir=args.data_dir, split=split, mt=args.mt, create_data=args.create_data, max_src_len=args.max_src_length, max_tgt_len=args.max_tgt_length, min_occ=args.min_occ) print('Data OK') #Load model model = SVAE( vocab_size=datasets['train'].vocab_size, embed_dim=args.embedding_dimension, hidden_dim=args.hidden_dimension, latent_dim=args.latent_dimension, #word_drop=args.word_dropout, teacher_forcing=args.teacher_forcing, dropout=args.dropout, n_direction=args.bidirectional, n_parallel=args.n_layer, attn=args.attention, max_src_len=args.max_src_length, #influence in inference stage max_tgt_len=args.max_tgt_length, sos_idx=datasets['train'].sos_idx, eos_idx=datasets['train'].eos_idx, pad_idx=datasets['train'].pad_idx, unk_idx=datasets['train'].unk_idx) if args.fasttext: prt = torch.load(args.data_dir + '/prt_fasttext.model') model.load_prt(prt) print('Model OK') if torch.cuda.is_available(): model = model.cuda() device = model.device #Training phase with validation(earlystopping) tracker = Tracker(patience=10, verbose=True) #record training history & es function optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) step = 0 for epoch in range(args.epochs): for split in splits: data_loader = DataLoader(dataset=datasets[split], batch_size=args.n_batch, shuffle=(split == 'train'), num_workers=cpu_count(), pin_memory=torch.cuda.is_available()) if split == 'train': model.train() else: model.eval() #Executing for i, data in enumerate(data_loader): src, srclen, tgt, tgtlen = \ data['src'], data['srclen'], data['tgt'], data['tgtlen'] #FP logits, (mu, logv, z), generations = model(src, srclen, tgt, tgtlen, split) #FP for groundtruth #h_pred, h_tgt = model.forward_gt(generations, tgt, tgtlen) #LOSS(weighted) NLL, KL, KL_W = model.loss(logits, tgt.to(device), data['tgtlen'], mu, logv, step, args.k, args.x0, args.af) #GLOBAL = model.global_loss(h_pred, h_tgt) GLOBAL = 0 loss = (NLL + KL * KL_W + GLOBAL) / data['src'].size(0) #BP & OPTIM if split == 'train': optimizer.zero_grad() loss.backward() optimizer.step() step += 1 #RECORD & RESULT(batch) if i % 50 == 0 or i + 1 == len(data_loader): #NLL.data = torch.cuda.FloatTensor([NLL.data]) #KL.data = torch.cuda.FloatTensor([KL.data]) print( "{} Phase - Batch {}/{}, Loss: {}, NLL: {}, KL: {}, KL-W: {}, G: {}" .format(split.upper(), i, len(data_loader) - 1, loss, NLL, KL, KL_W, GLOBAL)) tracker._elbo(torch.Tensor([loss])) if split == 'valid': tracker.record(tgt, generations, datasets['train'].i2w, datasets['train'].pad_idx, datasets['train'].eos_idx, datasets['train'].unk_idx, z) #SAVING & RESULT(epoch) if split == 'valid': tracker.dumps(epoch, args.dump_file) #dump the predicted text. else: tracker._save_checkpoint( epoch, args.model_file, model.state_dict()) #save the checkpooint print("{} Phase - Epoch {} , Mean ELBO: {}".format( split.upper(), epoch, torch.mean(tracker.elbo))) tracker._purge()