Beispiel #1
0
def load_gpt2(sess,
              run_name="run1",
              checkpoint_dir="checkpoint",
              model_name=None,
              model_dir='models'):
    """Loads the model checkpoint or existing model into a TensorFlow session
    for repeated predictions.
    """

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    context = tf.compat.v1.placeholder(tf.int32, [1, None])
    output = model.model(hparams=hparams, X=context)

    ckpt = tf.train.latest_checkpoint(checkpoint_path)
    saver = tf.compat.v1.train.Saver(allow_empty=True)
    sess.run(tf.compat.v1.global_variables_initializer())

    if model_name:
        print('Loading pretrained model', ckpt)
    else:
        print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)
Beispiel #2
0
def load_gpt2(sess,
              models_dir='models/',
              run_name="117M",
              checkpoint_name=None):
    """Loads the model checkpoint into a TensorFlow session
    for repeated predictions.
    """

    CHECKPOINT_DIR = models_dir

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    context = tf.placeholder(tf.int32, [1, None])
    output = model.model(hparams=hparams, X=context)

    if checkpoint_name:
        ckpt = os.path.join(checkpoint_path, checkpoint_name)
    else:
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
    saver = tf.train.Saver(allow_empty=True)
    sess.run(tf.global_variables_initializer())

    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)
Beispiel #3
0
def get_logitsTF(sess,
                 checkpoint='latest',
                 run_name="run1",
                 checkpoint_dir="checkpoint",
                 model_name=None,
                 model_dir='models',
                 multi_gpu=False):
    """Loads the model checkpoint or existing model into a TensorFlow session
    for repeated predictions.
    """

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    context = tf.compat.v1.placeholder(tf.int32, [1, None])

    gpus = []
    if multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)

    return output['logits'].eval(session=sess)
def embed(sess,
          run_name='run1',
          destination_path="X.p",
          sentences=None,
          batch_size=1,
          layer_type="h",
          save=False,
          multiple_lists=False):

    if type(sentences) != list:
        prefix = [sentences]
    if multiple_lists:
        prefixes = sentences
    else:
        prefixes = [sentences]

    CHECKPOINT_DIR = 'checkpoint'
    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    embeddings = []
    context = tf.placeholder(tf.int32, [batch_size, None])
    lm_output = model.model(hparams=hparams,
                            X=context,
                            past=None,
                            reuse=tf.AUTO_REUSE,
                            emb=True)
    with tqdm(total=len(prefixes)) as pbar:
        for prefix in prefixes:
            pref_emb = []
            for p in prefix:
                context_tokens = enc.encode(p)
                e = sess.run(
                    lm_output[layer_type],
                    feed_dict={context: batch_size * [context_tokens]})
                pref_emb.append(e[0])
            embeddings.append(pref_emb)
            pbar.update(1)
    if save:
        f = open(destination_path, 'wb')
        pickle.dump(embeddings, f)
        f.close()
    if multiple_lists:
        return embeddings
    else:
        return embeddings[0]
Beispiel #5
0
def get_text_rankings(sess, string):
    '''takes string, returns pairs of (string, int)
    where the string is a part of the input and the int it it's gpt2 ranking with
    0 being most likely'''
    prefix = string
    batch_size = 1

    checkpoint_path = os.path.join('checkpoint', 'run1')
    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    context_tokens = [50256] + enc.encode(prefix)
    np.random.seed(42)
    tf.compat.v1.set_random_seed(42)

    past, prev, output, all_logits = sample_sequence(
        hparams=hparams,
        length=len(context_tokens) - 1,
        start_token=None,
        batch_size=batch_size,
        context=context,
        temperature=1,
        top_k=0,
        top_p=0.0)
    output = output[:, 1:]

    pas, out, alt = sess.run(
        [past, output, all_logits],
        feed_dict={context: batch_size * [context_tokens]})

    text = enc.decode(out[0])
    print(text)
    print('generated tokens shape: ', out[0].shape)
    print('pas shape: ', pas.shape)

    def find_ranking(arr, index):
        "finds the postions of the element at index if the array was sorted decreasing"
        return (arr > arr[index]).sum()

    string_rank_pairs = []
    for i, token in enumerate(context_tokens[1:]):
        logs = alt[0][i]
        token_string = enc.decode([token])
        token_ranking = find_ranking(logs, token)
        string_rank_pairs.append((token_string, token_ranking))

    return string_rank_pairs
Beispiel #6
0
def get_logits(sess,
               run_name='run1',
               checkpoint_dir='checkpoint',
               model_name=None,
               model_dir='models',
               prefix="<|endoftext|>",
               all=False):

    batch_size = 1

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if prefix:
        context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
        context_tokens = enc.encode(prefix)
        context_tokens = context_tokens[-1023:]

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    output = step(hparams, context)

    out = sess.run(output, feed_dict={context: batch_size * [context_tokens]})

    if all:
        return out['logits'][
            0, :, :]  # all logits starting from the second token, n logits for n tokens
    return out['logits'][0, -1, :]  # logits for next token
def predict(sess,
            text,
            run_name='run1',
            checkpoint_dir='checkpoint',
            model_name=None,
            model_dir='models',
            sample_dir='samples',
            seed=None,
            temperature=1,
            top_k=2):

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()

    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    context = tf.compat.v1.placeholder(tf.int32, [1, None])
    context_tokens = enc.encode(text)

    np.random.seed(seed)
    tf.compat.v1.set_random_seed(seed)

    start_time = time.time()

    proba_t = sample.sample_sequence(hparams=hparams,
                                     context=context,
                                     temperature=temperature)

    # print("total time 1: {}".format(time.time() - start_time))
    start_time = time.time()

    proba = sess.run(proba_t, feed_dict={context: [context_tokens]})

    # print("total time 2: {}".format(time.time() - start_time))

    top_k_idxs = np.flip(np.argsort(proba))[:top_k]

    top_k_tokens = enc.decode(top_k_idxs).split()
    top_k_proba = proba[top_k_idxs]

    return top_k_tokens, top_k_proba
Beispiel #8
0
def load_gpt2(
    sess,
    checkpoint="latest",
    run_name="run1",
    checkpoint_dir="checkpoint",
    model_name=None,
    model_dir="models",
    multi_gpu=False,
):
    """Loads the model checkpoint or existing model into a TensorFlow session
    for repeated predictions.
    """

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, "hparams.json")) as f:
        hparams.override_from_dict(json.load(f))

    context = tf.compat.v1.placeholder(tf.int32, [1, None])

    gpus = []
    if multi_gpu:
        gpus = get_available_gpus()

    _ = model.model(hparams=hparams, X=context, gpus=gpus)

    if checkpoint == "latest":
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
    else:
        ckpt = os.path.join(checkpoint_path, checkpoint)

    saver = tf.compat.v1.train.Saver(allow_empty=True)
    sess.run(tf.compat.v1.global_variables_initializer())

    if model_name:
        print("Loading pretrained model", ckpt)
    else:
        print("Loading checkpoint", ckpt)
    saver.restore(sess, ckpt)
def my_load_gpt2(_sess,
                 _checkpoint='latest',
                 _run_name="run1",
                 _checkpoint_dir="checkpoint",
                 _model_name=None,
                 _model_dir='models',
                 _multi_gpu=False):
    """
    Loads the model checkpoint or existing model into a TensorFlow session for repeated predictions.
    """

    if _model_name:
        checkpoint_path = os.path.join(_model_dir, _model_name)
    else:
        checkpoint_path = os.path.join(_checkpoint_dir, _run_name)
    print(f"\ncheckpoint_path = {checkpoint_path}\n")

    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    context = tf.compat.v1.placeholder(tf.int32, [1, None])

    gpus = []
    if _multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)

    if _checkpoint == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
    else:
        ckpt = os.path.join(checkpoint_path, _checkpoint)

    saver = tf.compat.v1.train.Saver(allow_empty=True)
    _sess.run(tf.compat.v1.global_variables_initializer())

    if _model_name:
        print(f"\nLoading pretrained model :: {ckpt}\n")
    else:
        print(f"\nLoading checkpoint :: {ckpt}\n")
    saver.restore(_sess, ckpt)
Beispiel #10
0
def generate(sess,
             return_as_list=False,
             truncate=None,
             destination_path=None,
             sample_delim='=' * 20 + '\n',
             prefix=None,
             model_name='117M',
             seed=None,
             nsamples=1,
             batch_size=1,
             length=1023,
             temperature=0.7,
             top_k=0,
             run_name='run1'):
    """Generates text from a model loaded into memory.

    Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
    """

    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    if nsamples == 1:
        sample_delim = ''

    if prefix:
        context = tf.placeholder(tf.int32, [batch_size, None])

    CHECKPOINT_DIR = 'checkpoint'
    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    np.random.seed(seed)
    tf.set_random_seed(seed)

    output = sample.sample_sequence(
        hparams=hparams,
        length=length,
        start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
        context=context if prefix else None,
        batch_size=batch_size,
        temperature=temperature,
        top_k=top_k)[:, 1:]

    if destination_path:
        f = open(destination_path, 'w')
    if prefix:
        context_tokens = enc.encode(prefix)
    generated = 0
    gen_texts = []
    while generated < nsamples:
        if not prefix:
            out = sess.run(output)
        else:
            out = sess.run(output,
                           feed_dict={context: batch_size * [context_tokens]})
        for i in range(batch_size):
            generated += 1
            gen_text = enc.decode(out[i])
            if prefix:
                gen_text = prefix[0] + gen_text
            if truncate:
                trunc_text = re.search(r'(.*?)(?:{})'.format(truncate),
                                       gen_text, re.S)
                if trunc_text:
                    gen_text = trunc_text.group(1)
            if destination_path:
                f.write("{}\n{}".format(gen_text, sample_delim))
            if not return_as_list and not destination_path:
                print("{}\n{}".format(gen_text, sample_delim))
            gen_texts.append(gen_text)

    if destination_path:
        f.close()

    if return_as_list:
        return gen_texts
Beispiel #11
0
def generate(sess,
             run_name='run1',
             return_as_list=False,
             truncate=None,
             destination_path=None,
             sample_delim='=' * 20 + '\n',
             prefix=None,
             seed=None,
             nsamples=1,
             batch_size=1,
             length=1023,
             temperature=0.7,
             top_k=0,
             top_p=0.0,
             include_prefix=True,
             split_context=0.5):
    """Generates text from a model loaded into memory.

    Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
    """

    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    if nsamples == 1:
        sample_delim = ''

    if prefix == '':
        prefix = None

    if not length:
        assert truncate is not None, "If generating a non-fixed length \
        sample, must have a truncation term."

    CHECKPOINT_DIR = 'checkpoint'
    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
        
    context = tf.placeholder(tf.int32, [batch_size, None])
    if prefix:
        context_tokens = [enc.encode(prefix)] * batch_size
    else:
        context_tokens = [[enc.encoder['<|endoftext|>']] for _ in range(batch_size)]
    np.random.seed(seed)
    tf.set_random_seed(seed)

    if destination_path:
        f = open(destination_path, 'w')
    generated = 0
    gen_texts = []
    while generated < nsamples:
        gen_text = [''] * batch_size
        truncated = [False] * batch_size
        total_tokens = 0
        
        while False in truncated:
            num_tokens = 1023 - (len(context_tokens[0]))
            output = sample.sample_sequence(
                hparams=hparams,
                length=min(length if length else 1023, num_tokens),
                context=context,
                batch_size=batch_size,
                temperature=temperature, top_k=top_k, top_p=top_p
            )[:, 1:]

            out = sess.run(output, feed_dict={
                    context: context_tokens
                })
                
            total_tokens += num_tokens
            
            for i in range(batch_size):
                text = enc.decode(out[i])
                if prefix:
                    text = enc.decode(context_tokens[i][:1]) + text
                if truncate or all(gen_text):
                    context_tokens[i] = out[i][int(len(out[i])*(1-split_context)):]
                    if gen_text[i] != '':
                        split = re.split('[.!?]', gen_text[i])
                        text = text.partition(list(filter(None, split))[-1])[-1]
                    
                    if truncate:
                        truncate_esc = re.escape(truncate)
                        if prefix and not include_prefix:
                            prefix_esc = re.escape(prefix)
                            pattern = '(?:{})(.*?)(?:{})'.format(prefix_esc,
                                                                 truncate_esc)
                        else:
                            pattern = '(.*?)(?:{})'.format(truncate_esc)

                        trunc_text = re.search(pattern, text, re.S)
                        if trunc_text:
                            text = trunc_text.group(1)

                if not truncated[i]:
                    gen_text[i] += text.lstrip('\n')
                if trunc_text or (length is not None and total_tokens >= length-1):
                    truncated[i] = True

        for gen in gen_text:
            if destination_path:
                f.write("{}\n{}".format(gen, sample_delim))
            if not return_as_list and not destination_path:
                print("{}\n{}".format(gen, sample_delim), end='')
            gen_texts.append(gen)

        generated += batch_size  

    if destination_path:
        f.close()

    if return_as_list:
        return gen_texts
Beispiel #12
0
def get_hparams(checkpoint_path):
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, "hparams.json")) as f:
        hparams.override_from_dict(json.load(f))
    return hparams
Beispiel #13
0
def finetune(
    sess,
    dataset,
    steps=-1,
    model_name="124M",
    model_dir="models",
    combine=50000,
    batch_size=1,
    learning_rate=0.0001,
    accumulate_gradients=5,
    restore_from="latest",
    run_name="run1",
    checkpoint_dir="checkpoint",
    sample_every=100,
    sample_length=1023,
    sample_num=1,
    multi_gpu=False,
    save_every=1000,
    print_every=1,
    max_checkpoints=1,
    use_memory_saving_gradients=False,
    only_train_transformer_layers=False,
    optimizer="adam",
    overwrite=False,
):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    # assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger."

    SAMPLE_DIR = "samples"

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except FileExistsError:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ["hparams.json", "encoder.json", "vocab.bpe"]:
        try:
            shutil.copyfile(
                os.path.join(model_dir, model_name, file),
                os.path.join(checkpoint_path, file),
            )
        except FileNotFoundError as fnf_error:
            print(
                "You need to download the GPT-2 model first via download_gpt2()"
            )
            raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, "hparams.json")) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if model_name not in ["117M", "124M"]:
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    gpus = []

    if multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output["logits"][:, :-1]))

    tf_sample = sample.sample_sequence(
        hparams=hparams,
        length=sample_length,
        context=context,
        batch_size=batch_size,
        temperature=1.0,
        top_k=40,
    )

    all_vars = [
        v for v in tf.compat.v1.trainable_variables() if "model" in v.name
    ]
    train_vars = ([v for v in all_vars if "/h" in v.name]
                  if only_train_transformer_layers else all_vars)

    if optimizer == "adam":
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == "sgd":
        opt = tf.compat.v1.train.GradientDescentOptimizer(
            learning_rate=learning_rate)

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar("loss", opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar("loss", loss)

    summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path)

    saver = tf.compat.v1.train.Saver(var_list=all_vars,
                                     max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == "latest":
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == "fresh":
        ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print("Loading checkpoint", ckpt)
    saver.restore(sess, ckpt)

    print("Loading dataset...")
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print("dataset has", data_sampler.total_size, "tokens")
    print("Training...")

    counter = 1
    counter_path = os.path.join(checkpoint_path, "counter")
    if os.path.exists(counter_path) and restore_from == "latest":
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, "r") as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print("Saving",
              os.path.join(checkpoint_path, "model-{}").format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, "model"),
                   global_step=counter - 1)
        with open(counter_path, "w") as fp:
            fp.write(str(counter - 1) + "\n")

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = "======== SAMPLE {} ========\n{}\n".format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             "samples-{}").format(counter), "w") as fp:
            fp.write("\n".join(all_text))

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == "latest":
        for file in files:
            if file.startswith("model") or file.startswith("events"):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    if steps:
        steps = int(steps)

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if counter % print_every == 0:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    "[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}"
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1],
                    ))

            counter += 1
    except KeyboardInterrupt:
        print("interrupted")
        save()
Beispiel #14
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='124M',
             model_dir='models',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             checkpoint_dir='checkpoint',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             optimizer='adam',
             overwrite=False,
             mixed_precision=False):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        try:
            shutil.copyfile(os.path.join(model_dir, model_name, file),
                            os.path.join(checkpoint_path, file))
        except FileNotFoundError as fnf_error:
            print(
                "You need to download the GPT-2 model first via download_gpt2()"
            )
            raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if model_name not in ['117M', '124M']:
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    output = model.model(hparams=hparams, X=context)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    tf_sample = sample.sample_sequence(hparams=hparams,
                                       length=sample_length,
                                       context=context,
                                       batch_size=batch_size,
                                       temperature=1.0,
                                       top_k=40)

    all_vars = [
        v for v in tf.compat.v1.trainable_variables() if 'model' in v.name
    ]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars

    if optimizer == 'adam':
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == 'sgd':
        opt = tf.compat.v1.train.GradientDescentOptimizer(
            learning_rate=learning_rate)
    elif optimizer == 'sm3':
        opt = SM3Optimizer(learning_rate=learning_rate, momentum=0.9)
    elif optimizer == 'adafactor':
        opt = AdafactorOptimizer(learning_rate=learning_rate)

    def mp_check_tf_version():
        # check TensorFlow >= 1.14
        tf_version_list = tf.__version__.split(".")
        if int(tf_version_list[0]) < 2:
            return int(tf_version_list[1]) >= 14

    def mp_check_tensor_core_gpu_present():
        from tensorflow.python.client import device_lib
        # check Compute Capability >= 7.0
        local_device_protos = device_lib.list_local_devices()
        for line in local_device_protos:
            if "compute capability" in str(line):
                compute_capability = float(
                    line.physical_device_desc.split("compute capability: ")
                    [-1])
                if compute_capability >= 7.0:
                    return True

    if mixed_precision and mp_check_tf_version(
    ) and mp_check_tensor_core_gpu_present():
        if isinstance(opt, tf.keras.optimizers.Optimizer) or isinstance(
                opt, tf.compat.v1.train.Optimizer):
            opt = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
                opt)
        else:
            mixed_precision = False
    else:
        mixed_precision = False

    print('Mixed precision enabled: %s' % mixed_precision)

    # Fix warning: https://github.com/tensorflow/tensorflow/blob/233d3d/tensorflow/python/training/experimental/mixed_precision.py#L355-L357
    if sess is None:
        sess = start_tf_sess()

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar('loss', loss)

    summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path)

    saver = tf.compat.v1.train.Saver(var_list=all_vars,
                                     max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print('Saving',
              os.path.join(checkpoint_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = '======== SAMPLE {} ========\n{}\n'.format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    if steps:
        steps = int(steps)

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if counter % print_every == 0:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()
Beispiel #15
0
def generate(sess,
             run_name='run1',
             checkpoint_dir='checkpoint',
             model_name=None,
             model_dir='models',
             sample_dir='samples',
             return_as_list=False,
             truncate=None,
             destination_path=None,
             sample_delim='=' * 20 + '\n',
             prefix=None,
             seed=None,
             nsamples=1,
             batch_size=1,
             length=1023,
             temperature=0.7,
             top_k=0,
             top_p=0.0,
             include_prefix=True):
    """Generates text from a model loaded into memory.

    Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
    """

    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    if nsamples == 1:
        sample_delim = ''

    if prefix == '':
        prefix = None

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if prefix:
        context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
        context_tokens = enc.encode(prefix)

    np.random.seed(seed)
    tf.compat.v1.set_random_seed(seed)

    output = sample.sample_sequence(
        hparams=hparams,
        length=min(length, 1023 - (len(context_tokens) if prefix else 0)),
        start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
        context=context if prefix else None,
        batch_size=batch_size,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p)[:, 1:]

    if destination_path:
        f = open(destination_path, 'w')
    generated = 0
    gen_texts = []
    while generated < nsamples:
        if not prefix:
            out = sess.run(output)
        else:
            out = sess.run(output,
                           feed_dict={context: batch_size * [context_tokens]})
        for i in range(batch_size):
            generated += 1
            gen_text = enc.decode(out[i])
            if prefix:
                gen_text = enc.decode(context_tokens[:1]) + gen_text
            if truncate:
                truncate_esc = re.escape(truncate)
                if prefix and not include_prefix:
                    prefix_esc = re.escape(prefix)
                    pattern = '(?:{})(.*?)(?:{})'.format(
                        prefix_esc, truncate_esc)
                else:
                    pattern = '(.*?)(?:{})'.format(truncate_esc)

                trunc_text = re.search(pattern, gen_text, re.S)
                if trunc_text:
                    gen_text = trunc_text.group(1)
            gen_text = gen_text.lstrip('\n')
            if destination_path:
                f.write("{}\n{}".format(gen_text, sample_delim))
            if not return_as_list and not destination_path:
                print("{}\n{}".format(gen_text, sample_delim), end='')
            gen_texts.append(gen_text)

    if destination_path:
        f.close()

    if return_as_list:
        return gen_texts
Beispiel #16
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='124M',
             model_dir='models',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             checkpoint_dir='checkpoint',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             multi_gpu=False,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             optimizer='adam',
             overwrite=False,
             val_dataset=None,
             val_batch_size=2,
             val_batch_count=40,
             val_every=0
             ):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        try:
            shutil.copyfile(os.path.join(model_dir, model_name, file),
                            os.path.join(checkpoint_path, file))
        except FileNotFoundError as fnf_error:
            print("You need to download the GPT-2 model first via download_gpt2()")
            raise(fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    gpus = []

    if multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

	# validation code
    if val_every > 0:
        val_context = tf.placeholder(tf.int32, [val_batch_size, None])
        val_output = model.model(hparams=hparams, X=val_context, reuse=True) # added reuse=True
        val_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
        val_loss_summary = tf.summary.scalar('val_loss', val_loss)


    tf_sample = sample.sample_sequence(
        hparams=hparams,
        length=sample_length,
        context=context,
        batch_size=batch_size,
        temperature=1.0,
        top_k=40)

    all_vars = [v for v in tf.compat.v1.trainable_variables() if 'model' in v.name]
    train_vars = [v for v in all_vars if '/h' in v.name] if only_train_transformer_layers else all_vars

    if optimizer == 'adam':
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == 'sgd':
        opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate)

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit("Memory saving gradients are not implemented for gradient accumulation yet.")
        opt = AccumulatingOptimizer(
            opt=opt,
            var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar('loss', loss)

    summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path)

    saver = tf.compat.v1.train.Saver(
        var_list=all_vars,
        max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(
            os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    
    # validation code
    if val_every > 0:
        if val_dataset:
            val_chunks = load_dataset(enc, val_dataset, combine)
        else:
            val_chunks = chunks    
    
    
    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')
    
    # validation code
    if val_every > 0:
        # Sample from validation set once with fixed seed to make
        # it deterministic during training as well as across runs.
        val_data_sampler = Sampler(val_chunks, seed=1)
        val_batches = [[val_data_sampler.sample(1024) for _ in range(val_batch_size)]
                       for _ in range(val_batch_count)]

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print(
            'Saving',
            os.path.join(checkpoint_path,
                         'model-{}').format(counter-1))
        saver.save(
            sess,
            os.path.join(checkpoint_path, 'model'),
            global_step=counter-1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter-1) + '\n')

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(
                tf_sample,
                feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = '======== SAMPLE {} ========\n{}\n'.format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))
    
    # validation code        
    def validation():
        print('Calculating validation loss...')
        losses = []
        for batch in tqdm(val_batches):
            losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
        v_val_loss = np.mean(losses)
        v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss})
        summary_log.add_summary(v_summary, counter)
        summary_log.flush()
        print(
            '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'
            .format(
                counter=counter,
                time=time.time() - start_time,
                loss=v_val_loss))

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    if steps:
        steps = int(steps)
    
    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()
            
            # validation code    
            if val_every > 0 and (counter % val_every == 0 or counter == 1):
                validation()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(
                        opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if (counter % print_every == 0) or counter == 1:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()
Beispiel #17
0
def generate(sess,
             run_name='run1',
             checkpoint_dir='checkpoint',
             model_name=None,
             model_dir='models',
             sample_dir='samples',
             return_as_list=False,
             truncate=None,
             destination_path=None,
             sample_delim='=' * 20 + '\n',
             prefix=None,
             seed=None,
             nsamples=1,
             batch_size=1,
             length=1023,
             temperature=0.7,
             top_k=0,
             top_p=0.0,
             include_prefix=True,
             split_context=0.5):
    """Generates text from a model loaded into memory.

    Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
    """

    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    if nsamples == 1:
        sample_delim = ''

    if prefix == '':
        prefix = None

    if not length:
        assert truncate is not None, "If generating a non-fixed length \
                sample, must have a truncation term."

    assert 0 < split_context < 1

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])

    if prefix:
        prefix_enc = enc.encode(prefix)

    np.random.seed(seed)
    tf.compat.v1.set_random_seed(seed)

    output = sample.sample_sequence(
        hparams=hparams,
        length=min(length, 1023 - (len(prefix_enc) if prefix else 0)),
        start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
        context=context if prefix else None,
        batch_size=batch_size,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p)[:, 1:]

    split_length = int(1023 * split_context)
    split_output_length = min(length, 1023 - split_length)
    split_output = sample.sample_sequence(
        hparams=hparams,
        length=split_output_length,
        start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
        context=context if prefix else None,
        batch_size=batch_size,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p)[:, 1:]

    if destination_path:
        f = open(destination_path, 'w')
    generated = 0
    gen_texts = []
    while generated < nsamples:
        gen_text = [np.array([])] * batch_size
        truncated = [False] * batch_size

        if prefix:
            context_tokens = [prefix_enc] * batch_size
        else:
            context_tokens = [[enc.encoder['<|endoftext|>']]] * batch_size

        total_tokens = len(context_tokens[0])
        generated_once = False

        while False in truncated:
            num_tokens = 1023 - (len(context_tokens[0]))
            if generated_once:
                new_split_output_length = min(length - total_tokens,
                                              1023 - split_length)
                if new_split_output_length != split_output_length:
                    split_output = sample.sample_sequence(
                        hparams=hparams,
                        length=new_split_output_length,
                        start_token=enc.encoder['<|endoftext|>']
                        if not prefix else None,
                        context=context if prefix else None,
                        batch_size=batch_size,
                        temperature=temperature,
                        top_k=top_k,
                        top_p=top_p)[:, 1:]
                out = sess.run(split_output,
                               feed_dict={context: context_tokens})

            else:
                out = sess.run(output, feed_dict={context: context_tokens})

            total_tokens += num_tokens
            for i in range(batch_size):
                text = out[i]
                trunc_text = ""
                if prefix:
                    text = np.append(context_tokens[i][:1], text)
                if truncate or all(gen_text):
                    context_tokens[i] = out[i][(1023 - split_length - 1):]
                    if generated_once:
                        text = out[i][split_length:]

                    if truncate:
                        to_trunc = enc.decode(text)
                        truncate_esc = re.escape(truncate)
                        if prefix and not include_prefix:
                            prefix_esc = re.escape(prefix)
                            pattern = '(?:{})(.*?)(?:{})'.format(
                                prefix_esc, truncate_esc)
                        else:
                            pattern = '(.*?)(?:{})'.format(truncate_esc)

                        trunc_text = re.search(pattern, to_trunc, re.S)
                        if trunc_text:
                            text = enc.encode(trunc_text.group(1))
                            # better to re-encode here then decode every generation cycle, I think

                if not truncated[i]:
                    gen_text[i] = np.concatenate((gen_text[i], text),
                                                 axis=None)
                    if trunc_text or (length is not None
                                      and total_tokens >= length - 1):
                        truncated[i] = True
                        gen = enc.decode(gen_text[i]).lstrip('\n')
                        if destination_path:
                            f.write("{}\n{}".format(gen, sample_delim))
                        if not return_as_list and not destination_path:
                            print("{}\n{}".format(gen, sample_delim), end='')
                        gen_texts.append(gen)
            generated_once = True

        generated += batch_size

    if destination_path:
        f.close()

    if return_as_list:
        return gen_texts
Beispiel #18
0
def generate(sess,
             run_name='run1',
             return_as_list=False,
             truncate=None,
             destination_path=None,
             sample_delim='=' * 20 + '\n',
             prefix=None,
             seed=None,
             nsamples=1,
             batch_size=1,
             length=1023,
             temperature=0.7,
             top_k=0,
             top_p=0.0,
             include_prefix=True,
             split_context=0.5,
             batch_prefix=None):
    """Generates text from a model loaded into memory.

    Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
    """
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    if nsamples == 1:
        sample_delim = ''
    
    assert not (prefix and batch_prefix)
    assert not batch_prefix or len(batch_prefix) == batch_size
    if prefix == '':
        prefix = None
    if batch_prefix == []: 
        batch_prefix = None

    if not length:
        assert truncate is not None, "If generating a non-fixed length \
        sample, must have a truncation term."

    CHECKPOINT_DIR = 'checkpoint'
    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
        
    context = tf.placeholder(tf.int32, [batch_size, None])
    if prefix:
        context_tokens = [enc.encode(prefix)] * batch_size
    elif batch_prefix: 
        # context_tokens = [enc.encode(batch_prefix[0])] * batch_size
        context_tokens = [enc.encode(pre) for pre in batch_prefix]
        # print([enc.decode(c) for c in context_tokens])
        # assert all(len(context_tokens[0]) == len(c) for c in context_tokens)
        ml = max([len(p) for p in context_tokens])
        for p in context_tokens: 
           while len(p) < ml: 
               p.insert(0, 0)
        # padding front :)
    else:
        context_tokens = [[enc.encoder['<|endoftext|>']] for _ in range(batch_size)]
    np.random.seed(seed)
    tf.set_random_seed(seed)

    if destination_path:
        f = open(destination_path, 'w')
    generated = 0
    gen_texts = []
    while generated < nsamples:
        # gen_text = [''] * batch_size
        gen_text = [[]] * batch_size
        truncated = [False] * batch_size
        total_tokens = 0

        while False in truncated:
            num_tokens = 1023 - (len(context_tokens[0]))
            output = sample.sample_sequence(
                hparams=hparams,
                length=min(length if length else 1023, num_tokens),
                context=context,
                batch_size=batch_size,
                temperature=temperature, top_k=top_k, top_p=top_p
            )[:, 1:]

            out = sess.run(output, feed_dict={
                    context: context_tokens
                })
                
            total_tokens += num_tokens
            
            for i in range(batch_size):
                if truncated[i]: 
                    continue
                text = out[i]
                trunc_text = "" #added to patch to fix unassigned variable
                if prefix or batch_prefix:
                    text = np.append(context_tokens[i][:1], text)
                if truncate or all(gen_text):
                    context_tokens[i] = out[i][int(len(out[i])*(1-split_context)):]
                    if gen_text[i]:
                        text = out[i][int(len(out[i])*(split_context)):]
                        
                        # OLD, leaving for now for xref
                        # split = re.split('[.!?]', gen_text[i])
                        # text = text.partition(list(filter(None, split))[-1])[-1]
                        # ok so the idea is, split up gen_text, which is cumulative
                        # then you get the latest in gen_text, and find where it is in the new text
                        # then you split the new text on that, and get everything after it. 
                        # but it seems like it sometimes chooses an empty string...
                        # yo just leave it as tokens until the end tho and use count
                   
                    if truncate:
                        to_trunc = enc.decode(text)
                        truncate_esc = re.escape(truncate)
                        if prefix and not include_prefix:
                            prefix_esc = re.escape(prefix)
                            pattern = '(?:{})(.*?)(?:{})'.format(prefix_esc,
                                                                 truncate_esc)
                        else:
                            pattern = '(.*?)(?:{})'.format(truncate_esc)

                        trunc_text = re.search(pattern, to_trunc, re.S)
                        if trunc_text:
                            text = enc.encode(trunc_text.group(1)) #inefficient, but let's just get this working for now

                if not truncated[i]:
                    gen_text[i] += [text] #.lstrip('\n')
                if trunc_text or (length is not None and total_tokens >= length-1):
                    truncated[i] = True

        for gen in gen_text:
            gen = [enc.decode(g).lstrip('\n') for g in gen]
            gen = ''.join(gen)
            if destination_path:
                f.write("{}\n{}".format(gen, sample_delim))
            if not return_as_list and not destination_path:
                print("{}\n{}".format(gen, sample_delim), end='')
            gen_texts.append(gen)

        generated += batch_size  

    if destination_path:
        f.close()

    if return_as_list:
        return gen_texts
Beispiel #19
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='117M',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             model_load=False):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    CHECKPOINT_DIR = 'checkpoint'
    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    if not model_load:
        for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
            shutil.copyfile(os.path.join('models', model_name, file),
                            os.path.join(checkpoint_path, file))

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if model_name != '117M':
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.placeholder(tf.int32, [batch_size, None])
    output = model.model(hparams=hparams, X=context)
    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    tf_sample = sample.sample_sequence(hparams=hparams,
                                       length=sample_length,
                                       context=context,
                                       batch_size=batch_size,
                                       temperature=1.0,
                                       top_k=40)

    all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars
    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(
            opt=tf.train.AdamOptimizer(learning_rate=learning_rate),
            var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.summary.scalar('loss', opt_apply)
    else:
        opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(loss, train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.summary.scalar('loss', loss)

    summary_log = tf.summary.FileWriter(checkpoint_path)

    saver = tf.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints)
    sess.run(tf.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    if model_load:
        return

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print('Saving',
              os.path.join(checkpoint_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = '======== SAMPLE {} ========\n{}\n'.format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if counter % print_every == 0:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()
Beispiel #20
0
flags.DEFINE_integer(
    'num_tpu_cores',
    8,
    'Only used if `use_tpu` is True. Total number of TPU cores to use.',
)

tf.flags.DEFINE_string('master', None, '[Optional] TensorFlow master URL.')

flags.DEFINE_bool('do_train', False, 'Whether to run training.')

flags.DEFINE_bool('do_eval', False, 'Whether to run eval on the dev set.')

flags.DEFINE_bool('use_tpu', True, 'Whether to use TPU or GPU/CPU.')

# https://storage.googleapis.com/gpt-2/models/117M/hparams.json
hparams = model.default_hparams()
with open('base-hparams.json') as f:
    hparams.override_from_dict(json.load(f))


def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
    assignment_map = {}
    initialized_variable_names = {}

    name_to_variable = collections.OrderedDict()
    for var in tvars:
        name = var.name
        m = re.match('^(.*):\\d+$', name)
        if m is not None:
            name = m.group(1)
        name_to_variable[name] = var
Beispiel #21
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='124M',
             model_dir='models',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             checkpoint_dir='checkpoint',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             multi_gpu=False,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             optimizer='adam',
             overwrite=False):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    assert model_name not in [
        '774M', '1558M'
    ] or multi_gpu, "[ERROR] Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger."

    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        try:
            shutil.copyfile(os.path.join(model_dir, model_name, file),
                            os.path.join(checkpoint_path, file))
        except FileNotFoundError as fnf_error:
            print(
                "[ERROR] You need to download the GPT-2 model first via download_gpt2()"
            )
            raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError(
            "[ERROR] Can't get samples longer than window size: %s" %
            hparams.n_ctx)

    if model_name not in ['117M', '124M']:
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    gpus = []

    if multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    tf_sample = sample.sample_sequence(hparams=hparams,
                                       length=sample_length,
                                       context=context,
                                       batch_size=batch_size,
                                       temperature=1.0,
                                       top_k=40)

    all_vars = [
        v for v in tf.compat.v1.trainable_variables() if 'model' in v.name
    ]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars

    if optimizer == 'adam':
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == 'sgd':
        opt = tf.compat.v1.train.GradientDescentOptimizer(
            learning_rate=learning_rate)

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "[INFO] Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar('loss', loss)

    summaries = [summary_loss]
    summary_op = tf.summary.merge(summaries)
    summary_writer = tf.compat.v1.summary.FileWriter(checkpoint_path,
                                                     sess.graph)

    saver = tf.compat.v1.train.Saver(var_list=all_vars,
                                     max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('[INFO] Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('[INFO] Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print('[INFO] dataset has', data_sampler.total_size, 'tokens')
    print('[INFO] Training...')

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print('[INFO] Saving',
              os.path.join(checkpoint_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def generate_samples(header=True):
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                if header:
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))

        return text

    def sample_batch(sample_size=1024):
        return [data_sampler.sample(sample_size) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    best_loss = 1000.00

    if steps:
        steps = int(steps)

    try:
        while True:
            t = time.time() - start_time

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_op))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_op),
                    feed_dict={context: sample_batch()})

            monitoring.push_metric('loss', v_loss, counter, t)
            summary_writer.add_summary(v_summary, counter)

            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
                sample_text = generate_samples(False)
                summary_sample = tf.compat.v1.summary.text(
                    'generated_sample', tf.convert_to_tensor(sample_text))
                text = sess.run(summary_sample)
                summary_writer.add_summary(text, counter)

                monitoring.push_text('generated_text', sample_text, counter,
                                     v_loss, t)

            if v_loss < best_loss and v_loss > 1.00:
                best_loss = v_loss
                loss_text = generate_samples(False)
                loss_sample_text = "===== LOSS : " + str(
                    best_loss) + "=====" + loss_text
                summary_loss_sample = tf.compat.v1.summary.text(
                    'best_loss_sample', tf.convert_to_tensor(loss_sample_text))
                loss_text_ = sess.run(summary_loss_sample)
                summary_writer.add_summary(loss_text_, counter)
                monitoring.push_text('best_loss_text', loss_text, counter,
                                     best_loss, t)
                monitoring.push_metric('best_loss', best_loss, counter, t)

            if counter % print_every == 0:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                monitoring.push_metric('avg_loss', avg_loss[0] / avg_loss[1],
                                       counter, t)
                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=t,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('[STOP] interrupted')
        save()
Beispiel #22
0
def one_lr_cycle(sess,
                 dataset,
                 steps=10000,
                 model_name='117M',
                 combine=50000,
                 batch_size=1,
                 intial_lr=1e-10,
                 final_lr=1,
                 accumulate_gradients=5,
                 restore_from='fresh',
                 run_name='run1',
                 max_checkpoints=1,
                 use_memory_saving_gradients=False,
                 only_train_transformer_layers=False,
                 overwrite=False):
    """Does one LR half-cycle from initial to final over steps iterations using CLR algorithm
    https://github.com/bckenstler/CLR

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    CHECKPOINT_DIR = 'checkpoint'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        if file not in files:
            try:
                shutil.copyfile(os.path.join('models', model_name, file),
                                os.path.join(checkpoint_path, file))
            except FileNotFoundError as fnf_error:
                print(
                    "You need to download the GPT-2 model first via download_gpt2()"
                )
                raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if model_name != '117M':
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.placeholder(tf.int32, [batch_size, None])
    output = model.model(hparams=hparams, X=context)
    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    current_iter = 0
    learning_rate = tf.placeholder(tf.float32, shape=[])

    def get_lr():
        cycle = np.floor(1 + current_iter / (2 * steps))
        x = np.abs(current_iter / steps - 2 * cycle + 1)
        lr = intial_lr + (final_lr - intial_lr) * np.maximum(
            0, (1 - x))  # * scale_fn(x)
        return lr

    all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars
    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(
            opt=tf.train.AdamOptimizer(learning_rate=learning_rate),
            var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.summary.scalar('loss', opt_apply)
    else:
        opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(loss, train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.summary.scalar('loss', loss)

    summary_log = tf.summary.FileWriter(checkpoint_path)

    saver = tf.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints)
    sess.run(tf.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))

    start_time = time.time()

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                return
            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute,
                             feed_dict={
                                 context: sample_batch(),
                                 learning_rate: get_lr()
                             })
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={
                        context: sample_batch(),
                        learning_rate: get_lr()
                    })

            summary_log.add_summary(v_summary, counter)

            print('[{counter} | {time:2.2f}] loss={loss:3.14f} lr={lr:2.14f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_loss,
                         lr=get_lr()))

            counter += 1
            current_iter += 1
    except KeyboardInterrupt:
        print('interrupted')
Beispiel #23
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    logger = tf.get_logger()
    logger.propagate = False

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    hparams = model.default_hparams()

    with tf.gfile.GFile(FLAGS.config_file) as f:
        hparams.override_from_dict(json.load(f))

    tf.gfile.MakeDirs(FLAGS.output_dir)
    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.gfile.Glob(input_pattern))

    # tf.logging.info("*** Input Files ***")
    # for input_file in input_files:
    #     tf.logging.info("  %s" % input_file)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host,
        ),
    )

    model_fn = model_fn_builder(
        hparams=hparams,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=FLAGS.num_train_steps,
        num_warmup_steps=FLAGS.num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        optimizer=FLAGS.optimizer,
        poly_power=FLAGS.poly_power,
        start_warmup_step=FLAGS.start_warmup_step,
        use_memory_saving_gradients=FLAGS.use_memory_saving_gradients)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
    )

    if FLAGS.do_train:
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Batch size = %d", FLAGS.batch_size)
        train_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            is_training=True,
        )
        estimator.train(input_fn=train_input_fn,
                        max_steps=FLAGS.num_train_steps)

    if FLAGS.do_eval:
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        eval_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            is_training=False,
        )
        result = estimator.evaluate(input_fn=eval_input_fn,
                                    steps=FLAGS.max_eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Beispiel #24
0
def model_fn_builder(
    init_checkpoint,
    learning_rate,
    num_train_steps,
    num_warmup_steps,
    config,
):
    hparams = model.default_hparams()
    with tf.gfile.GFile(config, "r") as reader:
        text = reader.read()
        config = json.loads(text)
        hparams.override_from_dict(config)

    def model_fn(features, labels, mode, params):
        tf.logging.info('*** Features ***')
        for name in sorted(features.keys()):
            tf.logging.info('  name = %s, shape = %s' %
                            (name, features[name].shape))

        input_ids = features['input_ids']

        is_training = mode == tf.estimator.ModeKeys.TRAIN

        output = model.model(hparams=hparams, X=input_ids)
        loss = tf.reduce_mean(
            input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=input_ids[:, 1:], logits=output['logits'][:, :-1]))

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None

        if init_checkpoint:
            (
                assignment_map,
                initialized_variable_names,
            ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint)

            def tpu_scaffold():
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                return tf.train.Scaffold()

            scaffold_fn = tpu_scaffold

        tf.logging.info('**** Trainable Variables ****')
        for var in tvars:
            init_string = ''
            if var.name in initialized_variable_names:
                init_string = ', *INIT_FROM_CKPT*'
            tf.logging.info('  name = %s, shape = %s%s', var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, True)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
            )
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(loss, input_ids, output):
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(input_ids, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions,
                )
                next_sentence_mean_loss = tf.metrics.mean(values=loss)

                return {
                    'next_sentence_accuracy': next_sentence_accuracy,
                    'next_sentence_loss': next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [loss, input_ids, output])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn,
            )
        else:
            raise ValueError('Only TRAIN and EVAL modes are supported: %s' %
                             (mode))

        return output_spec

    return model_fn