示例#1
0
def load_gpt2(sess,
              run_name="run1",
              checkpoint_dir="checkpoint",
              model_name=None,
              model_dir='models',
              multi_gpu=False,
              scope=None):
    """Loads the model checkpoint or existing model into a TensorFlow session
    for repeated predictions.
    """
    if scope != None:
        with tf.compat.v1.variable_scope(scope):
            if model_name:
                checkpoint_path = os.path.join(model_dir, model_name)
            else:
                checkpoint_path = os.path.join(checkpoint_dir, run_name)

            hparams = model.default_hparams()
            with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
                hparams.override_from_dict(json.load(f))

            context = tf.compat.v1.placeholder(tf.int32, [1, None])

            gpus = []
            if multi_gpu:
                gpus = get_available_gpus()

            output = model.model(hparams=hparams, X=context, gpus=gpus)

            ckpt = tf.train.latest_checkpoint(checkpoint_path)
            saver = tf.compat.v1.train.Saver(allow_empty=True)
            sess.run(tf.compat.v1.global_variables_initializer())

            if model_name:
                print('Loading pretrained model', ckpt)
            else:
                print('Loading checkpoint', ckpt)
            saver.restore(sess, ckpt)
    else:
        if model_name:
            checkpoint_path = os.path.join(model_dir, model_name)
        else:
            checkpoint_path = os.path.join(checkpoint_dir, run_name)

        hparams = model.default_hparams()
        with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
            hparams.override_from_dict(json.load(f))

        context = tf.compat.v1.placeholder(tf.int32, [1, None])

        gpus = []
        if multi_gpu:
            gpus = get_available_gpus()

        output = model.model(hparams=hparams, X=context, gpus=gpus)

        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        saver = tf.compat.v1.train.Saver(allow_empty=True)
        sess.run(tf.compat.v1.global_variables_initializer())

        if model_name:
            print('Loading pretrained model', ckpt)
        else:
            print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)
示例#2
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='124M',
             model_dir='models',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             checkpoint_dir='checkpoint',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             multi_gpu=False,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             optimizer='adam',
             overwrite=False):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    # assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger."

    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        try:
            shutil.copyfile(os.path.join(model_dir, model_name, file),
                            os.path.join(checkpoint_path, file))
        except FileNotFoundError as fnf_error:
            print(
                "You need to download the GPT-2 model first via download_gpt2()"
            )
            raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if model_name not in ['117M', '124M']:
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    gpus = []

    if multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    tf_sample = sample.sample_sequence(hparams=hparams,
                                       length=sample_length,
                                       context=context,
                                       batch_size=batch_size,
                                       temperature=1.0,
                                       top_k=40)

    all_vars = [
        v for v in tf.compat.v1.trainable_variables() if 'model' in v.name
    ]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars

    if optimizer == 'adam':
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == 'sgd':
        opt = tf.compat.v1.train.GradientDescentOptimizer(
            learning_rate=learning_rate)

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar('loss', loss)

    summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path)

    saver = tf.compat.v1.train.Saver(var_list=all_vars,
                                     max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print('Saving',
              os.path.join(checkpoint_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = '======== SAMPLE {} ========\n{}\n'.format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    if steps:
        steps = int(steps)

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if counter % print_every == 0:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()
示例#3
0
def one_lr_cycle(sess,
                 dataset,
                 steps=10000,
                 model_name='117M',
                 combine=50000,
                 batch_size=1,
                 intial_lr=1e-10,
                 final_lr=1,
                 accumulate_gradients=5,
                 restore_from='fresh',
                 run_name='run1',
                 max_checkpoints=1,
                 use_memory_saving_gradients=False,
                 only_train_transformer_layers=False,
                 overwrite=False):
    """Does one LR half-cycle from initial to final over steps iterations using CLR algorithm
    https://github.com/bckenstler/CLR

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    CHECKPOINT_DIR = 'checkpoint'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        if file not in files:
            try:
                shutil.copyfile(os.path.join('models', model_name, file),
                                os.path.join(checkpoint_path, file))
            except FileNotFoundError as fnf_error:
                print(
                    "You need to download the GPT-2 model first via download_gpt2()"
                )
                raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if model_name != '117M':
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.placeholder(tf.int32, [batch_size, None])
    output = model.model(hparams=hparams, X=context)
    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    current_iter = 0
    learning_rate = tf.placeholder(tf.float32, shape=[])

    def get_lr():
        cycle = np.floor(1 + current_iter / (2 * steps))
        x = np.abs(current_iter / steps - 2 * cycle + 1)
        lr = intial_lr + (final_lr - intial_lr) * np.maximum(
            0, (1 - x))  # * scale_fn(x)
        return lr

    all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars
    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(
            opt=tf.train.AdamOptimizer(learning_rate=learning_rate),
            var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.summary.scalar('loss', opt_apply)
    else:
        opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(loss, train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.summary.scalar('loss', loss)

    summary_log = tf.summary.FileWriter(checkpoint_path)

    saver = tf.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints)
    sess.run(tf.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))

    start_time = time.time()

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                return
            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute,
                             feed_dict={
                                 context: sample_batch(),
                                 learning_rate: get_lr()
                             })
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={
                        context: sample_batch(),
                        learning_rate: get_lr()
                    })

            summary_log.add_summary(v_summary, counter)

            print('[{counter} | {time:2.2f}] loss={loss:3.14f} lr={lr:2.14f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_loss,
                         lr=get_lr()))

            counter += 1
            current_iter += 1
    except KeyboardInterrupt:
        print('interrupted')
示例#4
0
    def model_fn(features, labels, mode, params):
        tf.logging.info('*** Features ***')
        for name in sorted(features.keys()):
            tf.logging.info('  name = %s, shape = %s' %
                            (name, features[name].shape))

        input_ids = features['input_ids']

        is_training = mode == tf.estimator.ModeKeys.TRAIN

        output = model.model(hparams=hparams, X=input_ids)
        loss = tf.reduce_mean(
            input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=input_ids[:, 1:], logits=output['logits'][:, :-1]))

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        (
            assignment_map,
            initialized_variable_names,
        ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint)

        def tpu_scaffold():
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
            return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
        tf.logging.info('**** Trainable Variables ****')
        for var in tvars:
            init_string = ''
            if var.name in initialized_variable_names:
                init_string = ', *INIT_FROM_CKPT*'
            tf.logging.info('  name = %s, shape = %s%s', var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, True)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
            )
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(loss, input_ids, output):
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(input_ids, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions,
                )
                next_sentence_mean_loss = tf.metrics.mean(values=loss)

                return {
                    'next_sentence_accuracy': next_sentence_accuracy,
                    'next_sentence_loss': next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [loss, input_ids, output])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn,
            )
        else:
            raise ValueError('Only TRAIN and EVAL modes are supported: %s' %
                             (mode))

        return output_spec
示例#5
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='117M',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             model_load=False):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    CHECKPOINT_DIR = 'checkpoint'
    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    if not model_load:
        for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
            shutil.copyfile(os.path.join('models', model_name, file),
                            os.path.join(checkpoint_path, file))

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    context = tf.placeholder(tf.int32, [batch_size, None])
    output = model.model(hparams=hparams, X=context)
    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    tf_sample = sample.sample_sequence(hparams=hparams,
                                       length=sample_length,
                                       context=context,
                                       batch_size=batch_size,
                                       temperature=1.0,
                                       top_k=40)

    train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
    if accumulate_gradients > 1:
        opt = AccumulatingOptimizer(
            opt=tf.train.AdamOptimizer(learning_rate=learning_rate),
            var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.summary.scalar('loss', opt_apply)
    else:
        opt_apply = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(loss, var_list=train_vars)
        summary_loss = tf.summary.scalar('loss', loss)

    summary_log = tf.summary.FileWriter(checkpoint_path)

    saver = tf.train.Saver(var_list=train_vars, max_to_keep=max_checkpoints)
    sess.run(tf.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    if model_load:
        return

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path):
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1

    def save():
        maketree(checkpoint_path)
        print('Saving',
              os.path.join(checkpoint_path, 'model-{}').format(counter))
        saver.save(sess,
                   os.path.join(checkpoint_path, 'model'),
                   global_step=counter)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter) + '\n')

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = '======== SAMPLE {} ========\n{}\n'.format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    try:
        while True:
            if counter == steps:
                save()
                return
            if counter % save_every == 0:
                save()
            if counter % sample_every == 0:
                generate_samples()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if counter % print_every == 0:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()
示例#6
0
    def model_fn(features, labels, mode, params):
        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]

        output = model.model(hparams=hparams, X=input_ids)
        loss = tf.reduce_mean(
            input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=input_ids[:, 1:], logits=output["logits"][:, :-1]))

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (
                assignment_map,
                initialized_variable_names,
            ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                use_tpu,
                optimizer,
                poly_power,
                start_warmup_step,
                use_memory_saving_gradients=use_memory_saving_gradients)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
            )
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(loss):
                """Evaluation metric Fn which runs on CPU."""
                perplexity = tf.exp(tf.reduce_mean(loss))
                bpc = tf.reduce_mean(loss) / tf.constant(math.log(2))
                return {
                    "perplexity": tf.metrics.mean(perplexity),
                    "bpc": tf.metrics.mean(bpc),
                }

            if FLAGS.use_tpu:
                with tf.colocate_with(loss):
                    loss = tf.contrib.tpu.cross_replica_sum(loss) \
                              / FLAGS.num_tpu_cores
            metric_loss = tf.tile(tf.reshape(loss, [1, 1]),
                                  [FLAGS.eval_batch_size, 1])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metrics=(metric_fn, [metric_loss]),
                scaffold_fn=scaffold_fn)

            # eval_metrics = (metric_fn, {"loss":loss})
            # output_spec = tf.contrib.tpu.TPUEstimatorSpec(
            #     mode=mode,
            #     loss=loss,
            #     eval_metrics=eval_metrics,
            #     scaffold_fn=scaffold_fn,
            # )
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" %
                             (mode))

        return output_spec
示例#7
0
def finetune(
    sess,
    dataset,
    steps=-1,
    model_name="124M",
    model_dir="models",
    combine=50000,
    batch_size=1,
    learning_rate=0.0001,
    accumulate_gradients=5,
    restore_from="latest",
    run_name="run1",
    checkpoint_dir="checkpoint",
    sample_every=100,
    sample_length=1023,
    sample_num=1,
    sample_prefix="",
    multi_gpu=False,
    save_every=1000,
    print_every=1,
    sample_dir="samples",
    max_checkpoints=1,
    use_memory_saving_gradients=False,
    only_train_transformer_layers=False,
    optimizer="adafactor",
    overwrite=False,
):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    os.makedirs(checkpoint_path, exist_ok=True)
    files = os.listdir(checkpoint_path)
    for file_name in ["hparams.json", "encoder.json", "vocab.bpe"]:
        try:
            shutil.copyfile(
                os.path.join(model_dir, model_name, file_name),
                os.path.join(checkpoint_path, file_name),
            )
        except FileNotFoundError as fnf_error:
            raise RuntimeError(
                "You need to download the GPT-2 model first via download_gpt2()"
            ) from fnf_error

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, "hparams.json")) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx
        )

    if model_name > "124M":
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])

    if multi_gpu:
        gpus = get_available_gpus()
    else:
        gpus = []

    output = model.model(hparams=hparams, X=context, gpus=gpus)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output["logits"][:, :-1]
        )
    )

    tf_sample = sample.sample_sequence(
        hparams=hparams,
        length=sample_length,
        context=context,
        batch_size=batch_size,
        temperature=1.0,
        top_k=40,
    )

    all_vars = [v for v in tf.compat.v1.trainable_variables() if "model" in v.name]
    train_vars = (
        [v for v in all_vars if "/h" in v.name]
        if only_train_transformer_layers
        else all_vars
    )

    if optimizer == "adam":
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == "sgd":
        opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate)
    elif optimizer == "adafactor":
        params = {}
        params["decay_type"] = "adam"
        params["beta1"] = 0.0
        params["beta2"] = 0.999
        if params["decay_type"] == "adam":
            decay_rate = adafactor_decay_rate_adam(params["beta2"])
        elif params["decay_type"] == "pow":
            decay_rate = adafactor_decay_rate_pow(params["decay_exponent"])
        else:
            raise ValueError("unknown optimizer_adafactor_decay_type")

        if not "weight_decay" in params.keys():
            opt = AdafactorOptimizer(
                learning_rate=learning_rate,
                decay_rate=decay_rate,
                beta1=params["beta1"],
                name="Adafactor",
            )
        else:
            AdafactorWOptimizer = tf.contrib.opt.extend_with_decoupled_weight_decay(
                AdafactorOptimizer
            )

            opt = AdafactorWOptimizer(
                weight_decay=params["weight_decay"] * learning_rate,
                learning_rate=learning_rate,
                decay_rate=decay_rate,
                beta1=params["beta1"],
                name="AdafactorW",
            )
    else:
        raise ValueError(f"Unknown optimizer {optimizer}")

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar("loss", opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar("loss", loss)

    summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path)

    saver = tf.compat.v1.train.Saver(var_list=all_vars, max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == "latest":
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    elif restore_from == "fresh":
        ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print("Loading checkpoint", ckpt)
    saver.restore(sess, ckpt)

    print("Loading dataset...")
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)
    print(f"Dataset has {data_sampler.total_size} tokens.")
    if not data_sampler.total_size:
        raise ValueError("Dataset is empty.")

    print("Training...")

    counter = 1
    counter_path = os.path.join(checkpoint_path, "counter")
    if os.path.exists(counter_path) and restore_from == "latest":
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, "r") as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        os.makedirs(checkpoint_path, exist_ok=True)
        print("Saving", os.path.join(checkpoint_path, "model-{}").format(counter - 1))
        saver.save(
            sess, os.path.join(checkpoint_path, "model"), global_step=counter - 1
        )
        with open(counter_path, "w") as fp:
            fp.write(str(counter - 1) + "\n")

    def generate_samples():
        if sample_prefix:
            context_tokens = enc.encode(sample_prefix)
        else:
            context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(
                tf_sample, feed_dict={context: batch_size * [context_tokens]}
            )
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = "======== SAMPLE {} ========\n{}\n".format(index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        os.makedirs(os.path.join(sample_dir, run_name), exist_ok=True)
        with open(
            os.path.join(sample_dir, run_name, "samples-{}").format(counter), "w"
        ) as fp:
            fp.write("\n".join(all_text))

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == "latest":
        for file in files:
            if file.startswith("model") or file.startswith("events"):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    if steps:
        steps = int(steps)

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                save()
                return
            if (counter - 1) % save_every == 0 and counter > 1:
                save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss), feed_dict={context: sample_batch()}
                )

            summary_log.add_summary(v_summary, counter)

            if counter % print_every == 0:
                avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0)

                print(
                    "[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}".format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1],
                    )
                )

            counter += 1
    except KeyboardInterrupt:
        print("interrupted")
        save()