z = nmt.compute_prior_states(input) # z = torch.zeros((1, 6, OPTS.latentdim)) mask = torch.ones((1, z.shape[1])) if torch.cuda.is_available(): mask = mask.cuda() z = z.cuda() init_z = z.clone() for _ in range(10): z, tokens = nmt.refine(z, mask, n_steps=1, step_size=0.5, return_tokens=True) # z[:, 0] = init_z[:, 0] # z[:, -1] = init_z[:, -1] line = tgt_vocab.decode(tokens[0]) print(" ".join(line)) raise SystemExit result_path = OPTS.result_path # Read data lines = open(test_tgt_corpus).readlines() trains_stop_stdout_monitor() with open(OPTS.result_path, "w") as outf: for i, line in enumerate(lines): # Make a batch tokens = tgt_vocab.encode("<s> {} </s>".format( line.strip()).split()) x = torch.tensor([tokens]) if torch.cuda.is_available(): x = x.cuda()
# Predict latent and target words from prior if OPTS.scorenet: targets = scorenet.translate(x, n_iter=OPTS.Trefine_steps, step_size=1.0) else: targets = nmt.translate(x, refine_steps=OPTS.Trefine_steps) target_tokens = targets[0].cpu()[0].numpy().tolist() if targets is None: target_tokens = [2, 2, 2] # Record decoding time end_time = time.time() decode_times.append((end_time - start_time) * 1000.) # Convert token IDs back to words target_tokens = [t for t in target_tokens if t > 2] target_words = tgt_vocab.decode(target_tokens) target_sent = " ".join(target_words) outf.write(target_sent + "\n") sys.stdout.write("\rtranslating: {:.1f}% ".format( float(i) * 100 / len(lines))) sys.stdout.flush() sys.stdout.write("\n") trains_restore_stdout_monitor() print("Average decoding time: {:.0f}ms, std: {:.0f}".format( np.mean(decode_times), np.std(decode_times))) # Translate multiple sentences in batch if OPTS.batch_test: # Translate using only one GPU if not is_root_node(): sys.exit()