Exemplo n.º 1
0
def predict(sess,
            text,
            run_name='run1',
            checkpoint_dir='checkpoint',
            model_name=None,
            model_dir='models',
            sample_dir='samples',
            seed=None,
            temperature=1,
            top_k=2):

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()

    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    context = tf.compat.v1.placeholder(tf.int32, [1, None])
    context_tokens = enc.encode(text)

    np.random.seed(seed)
    tf.compat.v1.set_random_seed(seed)

    start_time = time.time()

    proba_t = sample.sample_sequence(hparams=hparams,
                                     context=context,
                                     temperature=temperature)

    # print("total time 1: {}".format(time.time() - start_time))
    start_time = time.time()

    proba = sess.run(proba_t, feed_dict={context: [context_tokens]})

    # print("total time 2: {}".format(time.time() - start_time))

    top_k_idxs = np.flip(np.argsort(proba))[:top_k]

    top_k_tokens = enc.decode(top_k_idxs).split()
    top_k_proba = proba[top_k_idxs]

    return top_k_tokens, top_k_proba
Exemplo n.º 2
0
def generate(sess,
             run_name='run1',
             checkpoint_dir='checkpoint',
             model_name=None,
             model_dir='models',
             sample_dir='samples',
             return_as_list=False,
             truncate=None,
             destination_path=None,
             sample_delim='=' * 20 + '\n',
             prefix=None,
             seed=None,
             nsamples=1,
             batch_size=1,
             length=1023,
             temperature=0.7,
             top_k=0,
             top_p=0.0,
             include_prefix=True):
    """Generates text from a model loaded into memory.

    Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
    """

    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    if nsamples == 1:
        sample_delim = ''

    if prefix == '':
        prefix = None

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if prefix:
        context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
        context_tokens = enc.encode(prefix)

    np.random.seed(seed)
    tf.compat.v1.set_random_seed(seed)

    output = sample.sample_sequence(
        hparams=hparams,
        length=min(length, 1023 - (len(context_tokens) if prefix else 0)),
        start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
        context=context if prefix else None,
        batch_size=batch_size,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p)[:, 1:]

    if destination_path:
        f = open(destination_path, 'w')
    generated = 0
    gen_texts = []
    while generated < nsamples:
        if not prefix:
            out = sess.run(output)
        else:
            out = sess.run(output,
                           feed_dict={context: batch_size * [context_tokens]})
        for i in range(batch_size):
            generated += 1
            gen_text = enc.decode(out[i])
            if prefix:
                gen_text = enc.decode(context_tokens[:1]) + gen_text
            if truncate:
                truncate_esc = re.escape(truncate)
                if prefix and not include_prefix:
                    prefix_esc = re.escape(prefix)
                    pattern = '(?:{})(.*?)(?:{})'.format(
                        prefix_esc, truncate_esc)
                else:
                    pattern = '(.*?)(?:{})'.format(truncate_esc)

                trunc_text = re.search(pattern, gen_text, re.S)
                if trunc_text:
                    gen_text = trunc_text.group(1)
            gen_text = gen_text.lstrip('\n')
            if destination_path:
                f.write("{}\n{}".format(gen_text, sample_delim))
            if not return_as_list and not destination_path:
                print("{}\n{}".format(gen_text, sample_delim), end='')
            gen_texts.append(gen_text)

    if destination_path:
        f.close()

    if return_as_list:
        return gen_texts
Exemplo n.º 3
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='124M',
             model_dir='models',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             checkpoint_dir='checkpoint',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             multi_gpu=False,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             optimizer='adam',
             overwrite=False):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    assert model_name not in [
        '774M', '1558M'
    ] or multi_gpu, "[ERROR] Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger."

    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        try:
            shutil.copyfile(os.path.join(model_dir, model_name, file),
                            os.path.join(checkpoint_path, file))
        except FileNotFoundError as fnf_error:
            print(
                "[ERROR] You need to download the GPT-2 model first via download_gpt2()"
            )
            raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError(
            "[ERROR] Can't get samples longer than window size: %s" %
            hparams.n_ctx)

    if model_name not in ['117M', '124M']:
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    gpus = []

    if multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    tf_sample = sample.sample_sequence(hparams=hparams,
                                       length=sample_length,
                                       context=context,
                                       batch_size=batch_size,
                                       temperature=1.0,
                                       top_k=40)

    all_vars = [
        v for v in tf.compat.v1.trainable_variables() if 'model' in v.name
    ]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars

    if optimizer == 'adam':
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == 'sgd':
        opt = tf.compat.v1.train.GradientDescentOptimizer(
            learning_rate=learning_rate)

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "[INFO] Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar('loss', loss)

    summaries = [summary_loss]
    summary_op = tf.summary.merge(summaries)
    summary_writer = tf.compat.v1.summary.FileWriter(checkpoint_path,
                                                     sess.graph)

    saver = tf.compat.v1.train.Saver(var_list=all_vars,
                                     max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('[INFO] Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('[INFO] Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print('[INFO] dataset has', data_sampler.total_size, 'tokens')
    print('[INFO] Training...')

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print('[INFO] Saving',
              os.path.join(checkpoint_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def generate_samples(header=True):
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                if header:
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))

        return text

    def sample_batch(sample_size=1024):
        return [data_sampler.sample(sample_size) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    best_loss = 1000.00

    if steps:
        steps = int(steps)

    try:
        while True:
            t = time.time() - start_time

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_op))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_op),
                    feed_dict={context: sample_batch()})

            monitoring.push_metric('loss', v_loss, counter, t)
            summary_writer.add_summary(v_summary, counter)

            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
                sample_text = generate_samples(False)
                summary_sample = tf.compat.v1.summary.text(
                    'generated_sample', tf.convert_to_tensor(sample_text))
                text = sess.run(summary_sample)
                summary_writer.add_summary(text, counter)

                monitoring.push_text('generated_text', sample_text, counter,
                                     v_loss, t)

            if v_loss < best_loss and v_loss > 1.00:
                best_loss = v_loss
                loss_text = generate_samples(False)
                loss_sample_text = "===== LOSS : " + str(
                    best_loss) + "=====" + loss_text
                summary_loss_sample = tf.compat.v1.summary.text(
                    'best_loss_sample', tf.convert_to_tensor(loss_sample_text))
                loss_text_ = sess.run(summary_loss_sample)
                summary_writer.add_summary(loss_text_, counter)
                monitoring.push_text('best_loss_text', loss_text, counter,
                                     best_loss, t)
                monitoring.push_metric('best_loss', best_loss, counter, t)

            if counter % print_every == 0:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                monitoring.push_metric('avg_loss', avg_loss[0] / avg_loss[1],
                                       counter, t)
                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=t,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('[STOP] interrupted')
        save()
Exemplo n.º 4
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='117M',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             model_load=False):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    CHECKPOINT_DIR = 'checkpoint'
    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    if not model_load:
        for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
            shutil.copyfile(os.path.join('models', model_name, file),
                            os.path.join(checkpoint_path, file))

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if model_name != '117M':
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.placeholder(tf.int32, [batch_size, None])
    output = model.model(hparams=hparams, X=context)
    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    tf_sample = sample.sample_sequence(hparams=hparams,
                                       length=sample_length,
                                       context=context,
                                       batch_size=batch_size,
                                       temperature=1.0,
                                       top_k=40)

    all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars
    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(
            opt=tf.train.AdamOptimizer(learning_rate=learning_rate),
            var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.summary.scalar('loss', opt_apply)
    else:
        opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(loss, train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.summary.scalar('loss', loss)

    summary_log = tf.summary.FileWriter(checkpoint_path)

    saver = tf.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints)
    sess.run(tf.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    if model_load:
        return

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print('Saving',
              os.path.join(checkpoint_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = '======== SAMPLE {} ========\n{}\n'.format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if counter % print_every == 0:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()
Exemplo n.º 5
0
def generate(
    sess,
    run_name="run1",
    checkpoint_dir="checkpoint",
    model_name=None,
    model_dir="models",
    sample_dir="samples",
    return_as_list=False,
    truncate=None,
    destination_path=None,
    sample_delim="=" * 20 + "\n",
    prefix=None,
    seed=None,
    nsamples=1,
    batch_size=1,
    length=1023,
    temperature=0.7,
    top_k=0,
    top_p=0.0,
    include_prefix=True,
    enc=None,
    hparams=None,
):
    """Generates text from a model loaded into memory.

    Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
    """

    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    if nsamples == 1:
        sample_delim = ""

    if prefix == "":
        prefix = None

    checkpoint_path = get_checkpoint_path(run_name, checkpoint_dir, model_name,
                                          checkpoint_dir)

    if enc is None:
        enc = get_encoder(checkpoint_path)
    if hparams is None:
        hparams = get_hparams(checkpoint_path)

    if prefix:
        context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
        context_tokens = enc.encode(prefix)

    np.random.seed(seed)
    tf.compat.v1.set_random_seed(seed)

    output = sample.sample_sequence(
        hparams=hparams,
        length=min(length, 1023 - (len(context_tokens) if prefix else 0)),
        start_token=enc.encoder["<|endoftext|>"] if not prefix else None,
        context=context if prefix else None,
        batch_size=batch_size,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    )[:, 1:]

    if destination_path:
        f = open(destination_path, "w")
    generated = 0
    gen_texts = []
    while generated < nsamples:
        if not prefix:
            out = sess.run(output)
        else:
            out = sess.run(output,
                           feed_dict={context: batch_size * [context_tokens]})
        for i in range(batch_size):
            generated += 1
            gen_text = enc.decode(out[i])
            if prefix:
                gen_text = enc.decode(context_tokens[:1]) + gen_text
            if truncate:
                truncate_esc = re.escape(truncate)
                if prefix and not include_prefix:
                    prefix_esc = re.escape(prefix)
                    pattern = "(?:{})(.*?)(?:{})".format(
                        prefix_esc, truncate_esc)
                else:
                    pattern = "(.*?)(?:{})".format(truncate_esc)

                trunc_text = re.search(pattern, gen_text, re.S)
                if trunc_text:
                    gen_text = trunc_text.group(1)
            gen_text = gen_text.lstrip("\n")
            if destination_path:
                f.write("{}\n{}".format(gen_text, sample_delim))
            if not return_as_list and not destination_path:
                print("{}\n{}".format(gen_text, sample_delim), end="")
            gen_texts.append(gen_text)

    if destination_path:
        f.close()

    if return_as_list:
        return gen_texts
Exemplo n.º 6
0
def generate(sess,
             run_name='run1',
             return_as_list=False,
             truncate=None,
             destination_path=None,
             sample_delim='=' * 20 + '\n',
             prefix=None,
             seed=None,
             nsamples=1,
             batch_size=1,
             length=1023,
             temperature=0.7,
             top_k=0,
             top_p=0.0,
             include_prefix=True,
             split_context=0.5):
    """Generates text from a model loaded into memory.

    Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
    """

    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    if nsamples == 1:
        sample_delim = ''

    if prefix == '':
        prefix = None

    if not length:
        assert truncate is not None, "If generating a non-fixed length \
        sample, must have a truncation term."

    CHECKPOINT_DIR = 'checkpoint'
    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
        
    context = tf.placeholder(tf.int32, [batch_size, None])
    if prefix:
        context_tokens = [enc.encode(prefix)] * batch_size
    else:
        context_tokens = [[enc.encoder['<|endoftext|>']] for _ in range(batch_size)]
    np.random.seed(seed)
    tf.set_random_seed(seed)

    if destination_path:
        f = open(destination_path, 'w')
    generated = 0
    gen_texts = []
    while generated < nsamples:
        gen_text = [''] * batch_size
        truncated = [False] * batch_size
        total_tokens = 0
        
        while False in truncated:
            num_tokens = 1023 - (len(context_tokens[0]))
            output = sample.sample_sequence(
                hparams=hparams,
                length=min(length if length else 1023, num_tokens),
                context=context,
                batch_size=batch_size,
                temperature=temperature, top_k=top_k, top_p=top_p
            )[:, 1:]

            out = sess.run(output, feed_dict={
                    context: context_tokens
                })
                
            total_tokens += num_tokens
            
            for i in range(batch_size):
                text = enc.decode(out[i])
                if prefix:
                    text = enc.decode(context_tokens[i][:1]) + text
                if truncate or all(gen_text):
                    context_tokens[i] = out[i][int(len(out[i])*(1-split_context)):]
                    if gen_text[i] != '':
                        split = re.split('[.!?]', gen_text[i])
                        text = text.partition(list(filter(None, split))[-1])[-1]
                    
                    if truncate:
                        truncate_esc = re.escape(truncate)
                        if prefix and not include_prefix:
                            prefix_esc = re.escape(prefix)
                            pattern = '(?:{})(.*?)(?:{})'.format(prefix_esc,
                                                                 truncate_esc)
                        else:
                            pattern = '(.*?)(?:{})'.format(truncate_esc)

                        trunc_text = re.search(pattern, text, re.S)
                        if trunc_text:
                            text = trunc_text.group(1)

                if not truncated[i]:
                    gen_text[i] += text.lstrip('\n')
                if trunc_text or (length is not None and total_tokens >= length-1):
                    truncated[i] = True

        for gen in gen_text:
            if destination_path:
                f.write("{}\n{}".format(gen, sample_delim))
            if not return_as_list and not destination_path:
                print("{}\n{}".format(gen, sample_delim), end='')
            gen_texts.append(gen)

        generated += batch_size  

    if destination_path:
        f.close()

    if return_as_list:
        return gen_texts
Exemplo n.º 7
0
def generate(sess,
             return_as_list=False,
             truncate=None,
             destination_path=None,
             sample_delim='=' * 20 + '\n',
             prefix=None,
             model_name='117M',
             seed=None,
             nsamples=1,
             batch_size=1,
             length=1023,
             temperature=0.7,
             top_k=0,
             run_name='run1'):
    """Generates text from a model loaded into memory.

    Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
    """

    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    if nsamples == 1:
        sample_delim = ''

    if prefix:
        context = tf.placeholder(tf.int32, [batch_size, None])

    CHECKPOINT_DIR = 'checkpoint'
    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    np.random.seed(seed)
    tf.set_random_seed(seed)

    output = sample.sample_sequence(
        hparams=hparams,
        length=length,
        start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
        context=context if prefix else None,
        batch_size=batch_size,
        temperature=temperature,
        top_k=top_k)[:, 1:]

    if destination_path:
        f = open(destination_path, 'w')
    if prefix:
        context_tokens = enc.encode(prefix)
    generated = 0
    gen_texts = []
    while generated < nsamples:
        if not prefix:
            out = sess.run(output)
        else:
            out = sess.run(output,
                           feed_dict={context: batch_size * [context_tokens]})
        for i in range(batch_size):
            generated += 1
            gen_text = enc.decode(out[i])
            if prefix:
                gen_text = prefix[0] + gen_text
            if truncate:
                trunc_text = re.search(r'(.*?)(?:{})'.format(truncate),
                                       gen_text, re.S)
                if trunc_text:
                    gen_text = trunc_text.group(1)
            if destination_path:
                f.write("{}\n{}".format(gen_text, sample_delim))
            if not return_as_list and not destination_path:
                print("{}\n{}".format(gen_text, sample_delim))
            gen_texts.append(gen_text)

    if destination_path:
        f.close()

    if return_as_list:
        return gen_texts
Exemplo n.º 8
0
def finetune(
    sess,
    dataset,
    steps=-1,
    model_name="124M",
    model_dir="models",
    combine=50000,
    batch_size=1,
    learning_rate=0.0001,
    accumulate_gradients=5,
    restore_from="latest",
    run_name="run1",
    checkpoint_dir="checkpoint",
    sample_every=100,
    sample_length=1023,
    sample_num=1,
    multi_gpu=False,
    save_every=1000,
    print_every=1,
    max_checkpoints=1,
    use_memory_saving_gradients=False,
    only_train_transformer_layers=False,
    optimizer="adam",
    overwrite=False,
):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    # assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger."

    SAMPLE_DIR = "samples"

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except FileExistsError:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ["hparams.json", "encoder.json", "vocab.bpe"]:
        try:
            shutil.copyfile(
                os.path.join(model_dir, model_name, file),
                os.path.join(checkpoint_path, file),
            )
        except FileNotFoundError as fnf_error:
            print(
                "You need to download the GPT-2 model first via download_gpt2()"
            )
            raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, "hparams.json")) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if model_name not in ["117M", "124M"]:
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    gpus = []

    if multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output["logits"][:, :-1]))

    tf_sample = sample.sample_sequence(
        hparams=hparams,
        length=sample_length,
        context=context,
        batch_size=batch_size,
        temperature=1.0,
        top_k=40,
    )

    all_vars = [
        v for v in tf.compat.v1.trainable_variables() if "model" in v.name
    ]
    train_vars = ([v for v in all_vars if "/h" in v.name]
                  if only_train_transformer_layers else all_vars)

    if optimizer == "adam":
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == "sgd":
        opt = tf.compat.v1.train.GradientDescentOptimizer(
            learning_rate=learning_rate)

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar("loss", opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar("loss", loss)

    summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path)

    saver = tf.compat.v1.train.Saver(var_list=all_vars,
                                     max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == "latest":
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == "fresh":
        ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print("Loading checkpoint", ckpt)
    saver.restore(sess, ckpt)

    print("Loading dataset...")
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print("dataset has", data_sampler.total_size, "tokens")
    print("Training...")

    counter = 1
    counter_path = os.path.join(checkpoint_path, "counter")
    if os.path.exists(counter_path) and restore_from == "latest":
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, "r") as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print("Saving",
              os.path.join(checkpoint_path, "model-{}").format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, "model"),
                   global_step=counter - 1)
        with open(counter_path, "w") as fp:
            fp.write(str(counter - 1) + "\n")

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = "======== SAMPLE {} ========\n{}\n".format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             "samples-{}").format(counter), "w") as fp:
            fp.write("\n".join(all_text))

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == "latest":
        for file in files:
            if file.startswith("model") or file.startswith("events"):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    if steps:
        steps = int(steps)

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if counter % print_every == 0:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    "[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}"
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1],
                    ))

            counter += 1
    except KeyboardInterrupt:
        print("interrupted")
        save()
Exemplo n.º 9
0
def generate(sess,
             run_name='run1',
             checkpoint_dir='checkpoint',
             model_name=None,
             model_dir='models',
             sample_dir='samples',
             return_as_list=False,
             truncate=None,
             destination_path=None,
             sample_delim='=' * 20 + '\n',
             prefix=None,
             seed=None,
             nsamples=1,
             batch_size=1,
             length=1023,
             temperature=0.7,
             top_k=0,
             top_p=0.0,
             include_prefix=True,
             split_context=0.5):
    """Generates text from a model loaded into memory.

    Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
    """

    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    if nsamples == 1:
        sample_delim = ''

    if prefix == '':
        prefix = None

    if not length:
        assert truncate is not None, "If generating a non-fixed length \
                sample, must have a truncation term."

    assert 0 < split_context < 1

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])

    if prefix:
        prefix_enc = enc.encode(prefix)

    np.random.seed(seed)
    tf.compat.v1.set_random_seed(seed)

    output = sample.sample_sequence(
        hparams=hparams,
        length=min(length, 1023 - (len(prefix_enc) if prefix else 0)),
        start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
        context=context if prefix else None,
        batch_size=batch_size,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p)[:, 1:]

    split_length = int(1023 * split_context)
    split_output_length = min(length, 1023 - split_length)
    split_output = sample.sample_sequence(
        hparams=hparams,
        length=split_output_length,
        start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
        context=context if prefix else None,
        batch_size=batch_size,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p)[:, 1:]

    if destination_path:
        f = open(destination_path, 'w')
    generated = 0
    gen_texts = []
    while generated < nsamples:
        gen_text = [np.array([])] * batch_size
        truncated = [False] * batch_size

        if prefix:
            context_tokens = [prefix_enc] * batch_size
        else:
            context_tokens = [[enc.encoder['<|endoftext|>']]] * batch_size

        total_tokens = len(context_tokens[0])
        generated_once = False

        while False in truncated:
            num_tokens = 1023 - (len(context_tokens[0]))
            if generated_once:
                new_split_output_length = min(length - total_tokens,
                                              1023 - split_length)
                if new_split_output_length != split_output_length:
                    split_output = sample.sample_sequence(
                        hparams=hparams,
                        length=new_split_output_length,
                        start_token=enc.encoder['<|endoftext|>']
                        if not prefix else None,
                        context=context if prefix else None,
                        batch_size=batch_size,
                        temperature=temperature,
                        top_k=top_k,
                        top_p=top_p)[:, 1:]
                out = sess.run(split_output,
                               feed_dict={context: context_tokens})

            else:
                out = sess.run(output, feed_dict={context: context_tokens})

            total_tokens += num_tokens
            for i in range(batch_size):
                text = out[i]
                trunc_text = ""
                if prefix:
                    text = np.append(context_tokens[i][:1], text)
                if truncate or all(gen_text):
                    context_tokens[i] = out[i][(1023 - split_length - 1):]
                    if generated_once:
                        text = out[i][split_length:]

                    if truncate:
                        to_trunc = enc.decode(text)
                        truncate_esc = re.escape(truncate)
                        if prefix and not include_prefix:
                            prefix_esc = re.escape(prefix)
                            pattern = '(?:{})(.*?)(?:{})'.format(
                                prefix_esc, truncate_esc)
                        else:
                            pattern = '(.*?)(?:{})'.format(truncate_esc)

                        trunc_text = re.search(pattern, to_trunc, re.S)
                        if trunc_text:
                            text = enc.encode(trunc_text.group(1))
                            # better to re-encode here then decode every generation cycle, I think

                if not truncated[i]:
                    gen_text[i] = np.concatenate((gen_text[i], text),
                                                 axis=None)
                    if trunc_text or (length is not None
                                      and total_tokens >= length - 1):
                        truncated[i] = True
                        gen = enc.decode(gen_text[i]).lstrip('\n')
                        if destination_path:
                            f.write("{}\n{}".format(gen, sample_delim))
                        if not return_as_list and not destination_path:
                            print("{}\n{}".format(gen, sample_delim), end='')
                        gen_texts.append(gen)
            generated_once = True

        generated += batch_size

    if destination_path:
        f.close()

    if return_as_list:
        return gen_texts
Exemplo n.º 10
0
def generate(sess,
             run_name='run1',
             return_as_list=False,
             truncate=None,
             destination_path=None,
             sample_delim='=' * 20 + '\n',
             prefix=None,
             seed=None,
             nsamples=1,
             batch_size=1,
             length=1023,
             temperature=0.7,
             top_k=0,
             top_p=0.0,
             include_prefix=True,
             split_context=0.5,
             batch_prefix=None):
    """Generates text from a model loaded into memory.

    Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
    """
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    if nsamples == 1:
        sample_delim = ''
    
    assert not (prefix and batch_prefix)
    assert not batch_prefix or len(batch_prefix) == batch_size
    if prefix == '':
        prefix = None
    if batch_prefix == []: 
        batch_prefix = None

    if not length:
        assert truncate is not None, "If generating a non-fixed length \
        sample, must have a truncation term."

    CHECKPOINT_DIR = 'checkpoint'
    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
        
    context = tf.placeholder(tf.int32, [batch_size, None])
    if prefix:
        context_tokens = [enc.encode(prefix)] * batch_size
    elif batch_prefix: 
        # context_tokens = [enc.encode(batch_prefix[0])] * batch_size
        context_tokens = [enc.encode(pre) for pre in batch_prefix]
        # print([enc.decode(c) for c in context_tokens])
        # assert all(len(context_tokens[0]) == len(c) for c in context_tokens)
        ml = max([len(p) for p in context_tokens])
        for p in context_tokens: 
           while len(p) < ml: 
               p.insert(0, 0)
        # padding front :)
    else:
        context_tokens = [[enc.encoder['<|endoftext|>']] for _ in range(batch_size)]
    np.random.seed(seed)
    tf.set_random_seed(seed)

    if destination_path:
        f = open(destination_path, 'w')
    generated = 0
    gen_texts = []
    while generated < nsamples:
        # gen_text = [''] * batch_size
        gen_text = [[]] * batch_size
        truncated = [False] * batch_size
        total_tokens = 0

        while False in truncated:
            num_tokens = 1023 - (len(context_tokens[0]))
            output = sample.sample_sequence(
                hparams=hparams,
                length=min(length if length else 1023, num_tokens),
                context=context,
                batch_size=batch_size,
                temperature=temperature, top_k=top_k, top_p=top_p
            )[:, 1:]

            out = sess.run(output, feed_dict={
                    context: context_tokens
                })
                
            total_tokens += num_tokens
            
            for i in range(batch_size):
                if truncated[i]: 
                    continue
                text = out[i]
                trunc_text = "" #added to patch to fix unassigned variable
                if prefix or batch_prefix:
                    text = np.append(context_tokens[i][:1], text)
                if truncate or all(gen_text):
                    context_tokens[i] = out[i][int(len(out[i])*(1-split_context)):]
                    if gen_text[i]:
                        text = out[i][int(len(out[i])*(split_context)):]
                        
                        # OLD, leaving for now for xref
                        # split = re.split('[.!?]', gen_text[i])
                        # text = text.partition(list(filter(None, split))[-1])[-1]
                        # ok so the idea is, split up gen_text, which is cumulative
                        # then you get the latest in gen_text, and find where it is in the new text
                        # then you split the new text on that, and get everything after it. 
                        # but it seems like it sometimes chooses an empty string...
                        # yo just leave it as tokens until the end tho and use count
                   
                    if truncate:
                        to_trunc = enc.decode(text)
                        truncate_esc = re.escape(truncate)
                        if prefix and not include_prefix:
                            prefix_esc = re.escape(prefix)
                            pattern = '(?:{})(.*?)(?:{})'.format(prefix_esc,
                                                                 truncate_esc)
                        else:
                            pattern = '(.*?)(?:{})'.format(truncate_esc)

                        trunc_text = re.search(pattern, to_trunc, re.S)
                        if trunc_text:
                            text = enc.encode(trunc_text.group(1)) #inefficient, but let's just get this working for now

                if not truncated[i]:
                    gen_text[i] += [text] #.lstrip('\n')
                if trunc_text or (length is not None and total_tokens >= length-1):
                    truncated[i] = True

        for gen in gen_text:
            gen = [enc.decode(g).lstrip('\n') for g in gen]
            gen = ''.join(gen)
            if destination_path:
                f.write("{}\n{}".format(gen, sample_delim))
            if not return_as_list and not destination_path:
                print("{}\n{}".format(gen, sample_delim), end='')
            gen_texts.append(gen)

        generated += batch_size  

    if destination_path:
        f.close()

    if return_as_list:
        return gen_texts
Exemplo n.º 11
0
context_none = tf.compat.v1.placeholder(tf.int32, [None, None])
gpus = []
gpus = get_available_gpus()

output = model.model(hparams=hparams, X=context_none, gpus=gpus)
loss = tf.reduce_mean(
    input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=context_none[:, 1:], logits=output['logits'][:, :-1]))

context = tf.compat.v1.placeholder(tf.int32, [1, None])

tf_sample = sample.sample_sequence(
    hparams=hparams,
    length=sample_length,
    context=context,
    batch_size=1,
    temperature=1.0,
    top_k=40,
)

all_vars = [v for v in tf.compat.v1.trainable_variables() if 'model' in v.name]
train_vars = ([v for v in all_vars if '/h' in v.name]
              if only_train_transformer_layers else all_vars)
opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

if accumulate_gradients > 1:
    if use_memory_saving_gradients:
        exit(
            'Memory saving gradients are not implemented for gradient accumulation yet.'
        )
    opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
Exemplo n.º 12
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='124M',
             model_dir='models',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             checkpoint_dir='checkpoint',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             multi_gpu=False,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             optimizer='adam',
             overwrite=False,
             val_dataset=None,
             val_batch_size=2,
             val_batch_count=40,
             val_every=0
             ):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        try:
            shutil.copyfile(os.path.join(model_dir, model_name, file),
                            os.path.join(checkpoint_path, file))
        except FileNotFoundError as fnf_error:
            print("You need to download the GPT-2 model first via download_gpt2()")
            raise(fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    gpus = []

    if multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

	# validation code
    if val_every > 0:
        val_context = tf.placeholder(tf.int32, [val_batch_size, None])
        val_output = model.model(hparams=hparams, X=val_context, reuse=True) # added reuse=True
        val_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
        val_loss_summary = tf.summary.scalar('val_loss', val_loss)


    tf_sample = sample.sample_sequence(
        hparams=hparams,
        length=sample_length,
        context=context,
        batch_size=batch_size,
        temperature=1.0,
        top_k=40)

    all_vars = [v for v in tf.compat.v1.trainable_variables() if 'model' in v.name]
    train_vars = [v for v in all_vars if '/h' in v.name] if only_train_transformer_layers else all_vars

    if optimizer == 'adam':
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == 'sgd':
        opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate)

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit("Memory saving gradients are not implemented for gradient accumulation yet.")
        opt = AccumulatingOptimizer(
            opt=opt,
            var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar('loss', loss)

    summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path)

    saver = tf.compat.v1.train.Saver(
        var_list=all_vars,
        max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(
            os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    
    # validation code
    if val_every > 0:
        if val_dataset:
            val_chunks = load_dataset(enc, val_dataset, combine)
        else:
            val_chunks = chunks    
    
    
    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')
    
    # validation code
    if val_every > 0:
        # Sample from validation set once with fixed seed to make
        # it deterministic during training as well as across runs.
        val_data_sampler = Sampler(val_chunks, seed=1)
        val_batches = [[val_data_sampler.sample(1024) for _ in range(val_batch_size)]
                       for _ in range(val_batch_count)]

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print(
            'Saving',
            os.path.join(checkpoint_path,
                         'model-{}').format(counter-1))
        saver.save(
            sess,
            os.path.join(checkpoint_path, 'model'),
            global_step=counter-1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter-1) + '\n')

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(
                tf_sample,
                feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = '======== SAMPLE {} ========\n{}\n'.format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))
    
    # validation code        
    def validation():
        print('Calculating validation loss...')
        losses = []
        for batch in tqdm(val_batches):
            losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
        v_val_loss = np.mean(losses)
        v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss})
        summary_log.add_summary(v_summary, counter)
        summary_log.flush()
        print(
            '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'
            .format(
                counter=counter,
                time=time.time() - start_time,
                loss=v_val_loss))

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    if steps:
        steps = int(steps)
    
    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()
            
            # validation code    
            if val_every > 0 and (counter % val_every == 0 or counter == 1):
                validation()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(
                        opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if (counter % print_every == 0) or counter == 1:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()
Exemplo n.º 13
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='124M',
             model_dir='models',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             checkpoint_dir='checkpoint',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             optimizer='adam',
             overwrite=False,
             mixed_precision=False):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        try:
            shutil.copyfile(os.path.join(model_dir, model_name, file),
                            os.path.join(checkpoint_path, file))
        except FileNotFoundError as fnf_error:
            print(
                "You need to download the GPT-2 model first via download_gpt2()"
            )
            raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if model_name not in ['117M', '124M']:
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    output = model.model(hparams=hparams, X=context)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    tf_sample = sample.sample_sequence(hparams=hparams,
                                       length=sample_length,
                                       context=context,
                                       batch_size=batch_size,
                                       temperature=1.0,
                                       top_k=40)

    all_vars = [
        v for v in tf.compat.v1.trainable_variables() if 'model' in v.name
    ]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars

    if optimizer == 'adam':
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == 'sgd':
        opt = tf.compat.v1.train.GradientDescentOptimizer(
            learning_rate=learning_rate)
    elif optimizer == 'sm3':
        opt = SM3Optimizer(learning_rate=learning_rate, momentum=0.9)
    elif optimizer == 'adafactor':
        opt = AdafactorOptimizer(learning_rate=learning_rate)

    def mp_check_tf_version():
        # check TensorFlow >= 1.14
        tf_version_list = tf.__version__.split(".")
        if int(tf_version_list[0]) < 2:
            return int(tf_version_list[1]) >= 14

    def mp_check_tensor_core_gpu_present():
        from tensorflow.python.client import device_lib
        # check Compute Capability >= 7.0
        local_device_protos = device_lib.list_local_devices()
        for line in local_device_protos:
            if "compute capability" in str(line):
                compute_capability = float(
                    line.physical_device_desc.split("compute capability: ")
                    [-1])
                if compute_capability >= 7.0:
                    return True

    if mixed_precision and mp_check_tf_version(
    ) and mp_check_tensor_core_gpu_present():
        if isinstance(opt, tf.keras.optimizers.Optimizer) or isinstance(
                opt, tf.compat.v1.train.Optimizer):
            opt = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
                opt)
        else:
            mixed_precision = False
    else:
        mixed_precision = False

    print('Mixed precision enabled: %s' % mixed_precision)

    # Fix warning: https://github.com/tensorflow/tensorflow/blob/233d3d/tensorflow/python/training/experimental/mixed_precision.py#L355-L357
    if sess is None:
        sess = start_tf_sess()

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar('loss', loss)

    summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path)

    saver = tf.compat.v1.train.Saver(var_list=all_vars,
                                     max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print('Saving',
              os.path.join(checkpoint_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = '======== SAMPLE {} ========\n{}\n'.format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    if steps:
        steps = int(steps)

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if counter % print_every == 0:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()