Example #1
0
def main(_):
    config = flags.FLAGS
    if config.mode == "get_vocab":
        get_vocab(config)
    elif config.mode == "prepare":
        prepare(config)
    elif config.mode == "train":
        train(config)
    elif config.mode == "train_rl":
        train_rl(config)
    elif config.mode == "train_qpp":
        train_qpp(config)
    elif config.mode == "train_qap":
        train_qap(config)
    elif config.mode == "train_qqp_qap":
        train_qqp_qap(config)
    elif config.mode == "test":
        test(config)
    else:
        print("Unknown mode")
        exit(0)
Example #2
0
def main(argv):

    if FLAGS.strategy == "random":
        strategy = RandomSamplingStrategy(temperature=FLAGS.temperature)
    elif FLAGS.strategy == "top-k":
        strategy = TopKSamplingStrategy(k=FLAGS.k,
                                        temperature=FLAGS.temperature)
    else:
        raise RuntimeError("Unsupported strategy '{}'".format(FLAGS.strategy))

    # Vocab
    vocab = get_vocab(str(Path(FLAGS.vocab)))

    # Load model
    transformer_decoder = transformer.TransformerOnlyDecoder()

    # Global step and epoch counters
    global_step = tf.Variable(0,
                              name="global_step",
                              trainable=False,
                              dtype=tf.int64)

    # Restore from checkpoint
    checkpoint_path = Path(FLAGS.checkpoint_path)
    ckpt = tf.train.Checkpoint(transformer_decoder=transformer_decoder,
                               global_step=global_step)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              str(checkpoint_path),
                                              max_to_keep=5)
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print("Restored checkpoint from: {}".format(
            ckpt_manager.latest_checkpoint))
    else:
        raise RuntimeError("Couldn't load from checkpoint")

    while True:
        seed_text = input("Seed text:\n")
        decoded = decode(seed_text,
                         vocab,
                         transformer_decoder,
                         strategy,
                         max_len=FLAGS.max_len)
        print(decoded)
Example #3
0
def main(argv):
    vocab_size = get_vocab(Path(flags.FLAGS.vocab)).vocab_size

    # Model
    transformer_decoder = transformer.TransformerOnlyDecoder(vocab_size)

    # Optimizer
    optimizer, learning_rate = get_optimizer()

    # Counters
    global_step = tf.Variable(0, name="global_step", trainable=False, dtype=tf.int64)
    num_examples_processed = tf.Variable(0, name="num_examples_processed", trainable=False, dtype=tf.int64)

    # Checkpointing
    checkpoint_path = Path(flags.FLAGS.checkpoint_path)
    ckpt = tf.train.Checkpoint(transformer_decoder=transformer_decoder, optimizer=optimizer,
                               global_step=global_step, num_examples_processed=num_examples_processed)
    ckpt_manager = tf.train.CheckpointManager(ckpt, str(checkpoint_path), max_to_keep=5)
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print("Restored checkpoint from: {}".format(ckpt_manager.latest_checkpoint))

    # Tensorboard events
    train_log_dir = str(checkpoint_path / "events")
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)

    # Training dataset
    ds = get_dataset(Path(flags.FLAGS.data), hp.get("max_tokens"), hp.get("max_seq_len"), hp.get("shuffle_buffer"),
                     skip=global_step.numpy())

    try:
        train_loop(ds, transformer_decoder, global_step, num_examples_processed, ckpt_manager, optimizer,
                   learning_rate, train_summary_writer, flags.FLAGS.checkpoint_every, flags.FLAGS.summarize_every,
                   flags.FLAGS.continuous)
    except KeyboardInterrupt:
        pass
Example #4
0
def evaluate(vocab_path: Path,
             checkpoint_path: Path,
             dataset_path: Path,
             batch_size: int,
             take: int = None,
             shuffle_buffer: int = None):
    # Vocab
    vocab = get_vocab(str(vocab_path))

    # Load model
    transformer_decoder = transformer.TransformerOnlyDecoder(vocab.vocab_size)

    # Global step counter
    global_step = tf.Variable(0,
                              name="global_step",
                              trainable=False,
                              dtype=tf.int64)

    # Restore from checkpoint
    ckpt = tf.train.Checkpoint(transformer_decoder=transformer_decoder,
                               global_step=global_step)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              str(checkpoint_path),
                                              max_to_keep=5)
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print("Restored checkpoint from: {}".format(
            ckpt_manager.latest_checkpoint))
    else:
        raise RuntimeError("Couldn't load from checkpoint")

    # Dataset
    ds = get_dataset(str(dataset_path),
                     batch_size=batch_size,
                     take=take,
                     shuffle_buffer=shuffle_buffer)

    # Metrics
    token_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        "token_accuracy")
    log_ppl = tf.keras.metrics.Mean("log_perplexity")

    eval_step_signature = [tf.TensorSpec(shape=(None, None), dtype=tf.int64)]

    @tf.function(input_signature=eval_step_signature,
                 experimental_relax_shapes=True)
    def evaluation_step(batch):
        batch_inp = batch[:, :-1]
        batch_tar = batch[:, 1:]

        # Apply model
        mask = transformer.create_masks(batch_inp)
        logits, _ = transformer_decoder(batch_inp, False,
                                        mask)  # TODO: Visualise attentions

        # Update metrics
        padding_mask = tf.math.logical_not(tf.math.equal(batch_tar, 0))
        token_accuracy(batch_tar, logits, sample_weight=padding_mask)
        log_ppl(
            tf.nn.sparse_softmax_cross_entropy_with_logits(batch_tar, logits) /
            tf.math.log(2.0),
            sample_weight=padding_mask)

    for batch in ds:
        evaluation_step(batch)

    # Decode some examples
    gt_examples = []
    random_sampling_examples = []
    top_5_sampling_examples = []
    for example in get_dataset(str(dataset_path), batch_size=1,
                               take=None).shuffle(1000, seed=42).take(5):
        # Use the first 4 tokens as seed
        gt_examples.append(vocab.decode(example[0].numpy()))
        random_sampling_examples.append(
            vocab.decode(
                decode_encoded(example[0][:4].numpy(), transformer_decoder,
                               vocab.end_idx, RandomSamplingStrategy())))
        top_5_sampling_examples.append(
            vocab.decode(
                decode_encoded(example[0][:4].numpy(), transformer_decoder,
                               vocab.end_idx, TopKSamplingStrategy(5))))

    # Tensorboard events
    eval_log_dir = str(checkpoint_path / (dataset_path.stem + "_eval"))
    eval_summary_writer = tf.summary.create_file_writer(eval_log_dir)

    with eval_summary_writer.as_default():
        tf.summary.scalar("token_accuracy", token_accuracy.result(),
                          global_step.numpy())
        tf.summary.scalar("log_perplexity", log_ppl.result(),
                          global_step.numpy())

        # Write decoded examples..
        for i, (gt_example, rand_ex, top_5_ex) in enumerate(
                zip(gt_examples, random_sampling_examples,
                    top_5_sampling_examples)):
            tf.summary.text(
                "decoded_example_{}".format(i + 1),
                tf.convert_to_tensor(
                    render_markdown(gt_example, rand_ex, top_5_ex)),
                global_step.numpy())

    return {
        "token_accuracy": float(token_accuracy.result().numpy()),
        "log_perplexity": float(log_ppl.result().numpy())
    }
Example #5
0
import sg_utils
import cap_eval_utils

imset = 'train'
coco_caps = COCO('../data/captions_train2014.json');

# mapping to output final statistics
mapping = {'NNS': 'NN', 'NNP': 'NN', 'NNPS': 'NN', 'NN': 'NN', \
  'VB': 'VB', 'VBD': 'VB', 'VBN': 'VB', 'VBZ': 'VB', 'VBP': 'VB', 'VBP': 'VB', 'VBG': 'VB', \
  'JJR': 'JJ', 'JJS': 'JJ', 'JJ': 'JJ', 'DT': 'DT', 'PRP': 'PRP', 'PRP$': 'PRP', 'IN': 'IN'};
    
# punctuations to be removed from the sentences
punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
  ".", "?", "!", ",", ":", "-", "--", "...", ";"] 

vocab = preprocess.get_vocab(imset, coco_caps, punctuations, mapping);

sg_utils.save_variables('vocab_' + imset + '.pkl', \
  [vocab[x] for x in vocab.keys()], \
  vocab.keys(), \
  overwrite = True);


##
N_WORDS = 1000;
vocab = preprocess.get_vocab_top_k(vocab, N_WORDS)
image_ids = coco_caps.getImgIds()
counts = preprocess.get_vocab_counts(image_ids, coco_caps, 5, vocab)
P = np.zeros((N_WORDS, 1), dtype = np.float); 
R = np.zeros((N_WORDS, 1), dtype = np.float); 
for i, w in enumerate(vv['words']): 
import sg_utils
import cap_eval_utils

imset = 'train'
coco_caps = COCO('../data/captions_train2014.json')

# mapping to output final statistics
mapping = {'NNS': 'NN', 'NNP': 'NN', 'NNPS': 'NN', 'NN': 'NN', \
  'VB': 'VB', 'VBD': 'VB', 'VBN': 'VB', 'VBZ': 'VB', 'VBP': 'VB', 'VBP': 'VB', 'VBG': 'VB', \
  'JJR': 'JJ', 'JJS': 'JJ', 'JJ': 'JJ', 'DT': 'DT', 'PRP': 'PRP', 'PRP$': 'PRP', 'IN': 'IN'}

# punctuations to be removed from the sentences
punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
  ".", "?", "!", ",", ":", "-", "--", "...", ";"]

vocab = preprocess.get_vocab(imset, coco_caps, punctuations, mapping)

sg_utils.save_variables('vocab_' + imset + '.pkl', \
  [vocab[x] for x in vocab.keys()], \
  vocab.keys(), \
  overwrite = True)

##
N_WORDS = 1000
vocab = preprocess.get_vocab_top_k(vocab, N_WORDS)
image_ids = coco_caps.getImgIds()
counts = preprocess.get_vocab_counts(image_ids, coco_caps, 5, vocab)
P = np.zeros((N_WORDS, 1), dtype=np.float)
R = np.zeros((N_WORDS, 1), dtype=np.float)
for i, w in enumerate(vv['words']):
    P[i], R[i] = cap_eval_utils.human_agreement(counts[:, i], 5)
Example #7
0
#=======================================#
#        Preprocessing Parameters       #
#=======================================#

n = 9  #2                    # Number of words used in prediction
min_occurences = 10  #1       # Minimum number of occurences of a word for it to occur in vocabulary
batch_size = 32  #1

#=======================================#
#             Preprocessing             #
#=======================================#

lotr_full_text = preprocess.load_full_text()

word_to_id, id_to_word = preprocess.get_vocab(lotr_full_text, min_occurences)

lotr_full_ids = [word_to_id[word] for word in lotr_full_text]

training_dataset = preprocess.get_tensor_dataset(lotr_full_ids, n)
training_loader = DataLoader(training_dataset,
                             batch_size=batch_size,
                             drop_last=True,
                             shuffle=True)

#=======================================#
#           Network Parameters          #
#=======================================#

# Size parameters
vocab_size = len(word_to_id) + 1