Пример #1
0
def interact_model(model_name='117M',
                   restore_from=None,
                   seed=None,
                   nsamples=1,
                   step=1,
                   length=64,
                   prompt="\n",
                   clear=None,
                   maxlen=-1,
                   temperature=1,
                   top_k=0,
                   top_p=0,
                   penalize=0):
    """
    Interactively run the model
    :model_name=117M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :nsamples=1 : Number of samples to return total
    :step=1 : Number of tokens to generate at a time
    :length=64 : Window size; use 1024 for maximum size per sample
    :prompt="\\n" : Prompt to start with. The default of "" prompts with an <|endoftext|> token.
    :clear=None : If this string is encountered, clear the context window.
    :maxlen=-1 : if this many tokens are generated without
     encountering --clear, then print it and clear the context window.
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
    :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
     overriding top_k if set to a value > 0. A good setting is 0.9.
    :penalize=0.0 : Float value controlling "used" penalty. Implements repetition
     reduction (similar to CTRL) if set to a value > 0. A decent setting might be 0.85
     with temperature 0.3 and top_k 40.
    """
    batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length > hparams.n_ctx:
        raise ValueError("Length can't be largeer than n_ctx: %s" %
                         hparams.n_ctx)
    if step > length:
        raise ValueError("Can't get samples longer than length: %s" % length)

    with tflex.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(hparams=hparams,
                                        length=step,
                                        context=context,
                                        batch_size=batch_size,
                                        temperature=temperature,
                                        top_k=top_k,
                                        top_p=top_p,
                                        penalize=penalize)

        saver = tflex.Saver(reshape=True)
        if restore_from is None:
            restore_from = os.path.join('models', model_name)
        ckpt = tflex.latest_checkpoint(restore_from)
        saver.restore(sess, ckpt)

        while True:
            tflex.check_commands()
            if tflex.should_quit():
                break
            try:
                with open(prompt) as f:
                    tflex.raw_text = f.read()
                if tflex.raw_text.endswith('\n'):
                    tflex.raw_text = tflex.raw_text[:-1]
                if tflex.raw_text.endswith('\r'):
                    tflex.raw_text = tflex.raw_text[:-1]
            except:
                tflex.raw_text = prompt
            tflex.raw_text = tflex.raw_text.replace('\\n', '\n')
            tflex.raw_text = tflex.raw_text.replace('\\t', '\t')
            #print(repr(tflex.raw_text))
            tflex.context_tokens = enc.encode(
                tflex.raw_text) if len(tflex.raw_text) > 0 else [50256]
            while len(tflex.context_tokens) > length - step - 1:
                tflex.context_tokens = tflex.context_tokens[1:]
            tflex.prompt_tokens = tflex.context_tokens[:]
            tflex.first = True
            tflex.backlog = []
            tflex.backlog_count = 0
            tflex.context_text = ""
            tflex.context_count = 0
            while True:
                for tokens in generate_result(
                        context_tokens=tflex.context_tokens,
                        enc=enc,
                        output=output,
                        context=context,
                        nsamples=1,
                        batch_size=batch_size,
                        sess=sess):
                    tflex.tokens = tokens
                    if tflex.first:
                        #clear_output(wait=True)
                        sys.stdout.write(enc.decode(tflex.context_tokens))
                        sys.stdout.flush()
                        tflex.first = False
                    tflex.backlog.extend(tflex.tokens)
                    tflex.backlog_count += 1
                    if is_ascii(enc.decode([tflex.backlog[-1]
                                            ])) or tflex.backlog_count > 16:
                        text = enc.decode(tflex.backlog)
                        result = text
                        if clear is not None:
                            result, *rest = text.split(clear)
                        sys.stdout.write(result)
                        sys.stdout.flush()
                        tflex.context_text += text
                        tflex.context_count += len(tflex.backlog)

                        def reset_context():
                            tflex.context_text = ""
                            tflex.context_count = 0
                            tflex.context_tokens = []
                            tflex.first = True
                            tflex.tokens = tflex.prompt_tokens[:]

                        tflex.reset_context = reset_context
                        if maxlen > 0 and tflex.context_count > maxlen or clear is not None and clear in tflex.context_text:
                            tflex.reset_context()
                        tflex.backlog = []
                        tflex.backlog_count = 0
                    tflex.check_commands()
                    tflex.context_tokens.extend(tflex.tokens)
                    while len(tflex.context_tokens) > length - step - 1:
                        tflex.context_tokens = tflex.context_tokens[1:]
Пример #2
0
with open(os.path.join('models', model_name, 'hparams.json')) as f:
    hparams.override_from_dict(json.load(f))
sess = tf.Session(graph=tf.Graph()).__enter__()
context = tf.placeholder(tf.int32, [4, None])
np.random.seed(None)
tf.set_random_seed(None)
output = sample.sample_sequence(
    hparams=hparams, length=length,
    context=context,
    batch_size=batch_size,
    temperature=temperature, top_k=top_k, top_p=top_p
)
saver = tflex.Saver()
if restore_from is None:
    restore_from = os.path.join('models', model_name)
ckpt = tflex.latest_checkpoint(restore_from)
saver.restore(sess, ckpt)


class Article(BaseModel):
    text: str


@app.post("/text_rewrite")
def text_rewrite(
    article: Article
):
    text = article.text
    result = []
    for sent in nlp(text).sents:
        sent_text = sent.text.strip()
Пример #3
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    hparams.res_dropout = args.dropout
    hparams.attn_dropout = args.dropout
    epsilon = -1e10
    if args.dtype == 'float32':
        hparams.dtype = tf.float32
    elif args.dtype == 'float16':
        hparams.dtype = tf.float16
        epsilon = -65500
    elif args.dtype == 'bfloat16':
        hparams.dtype = tf.bfloat16
        epsilon = -65500
    else:
        print('Unknown dtype', args.dtype)
    if args.float16:
        hparams.dtype = tf.bfloat16
        epsilon = -65500

    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
    if args.n_ctx >= 0:
        hparams.n_ctx=args.n_ctx
    if args.n_embd >= 0:
        hparams.n_embd=args.n_embd
    if args.n_head >= 0:
        hparams.n_head=args.n_head
    if args.n_layer >= 0:
        hparams.n_layer=args.n_layer

    if args.sample_length < 0:
        args.sample_length = hparams.n_ctx - 1
    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)
    if args.sample_ctx < 0:
      args.sample_ctx = hparams.n_ctx

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    if args.allow_growth:
        config.gpu_options.allow_growth = True
    if args.disable_layout_optimizer:
        config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tflex.Session(config=config, init_tpu=args.init_tpu) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)


        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p,
            epsilon=epsilon)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars

        parameter_count = sum([np.prod(v.shape.as_list()) for v in train_vars])
        print("This model is using %d parameters (%.2fM)" % (parameter_count, parameter_count/(1024.0*1024.0)))

        with tf.variable_scope(tf.get_variable_scope().name, reuse=tf.AUTO_REUSE):
            global_step = tflex.get_variable('global_step') or tf.get_variable('global_step', shape=(), dtype=tf.int32, trainable=False)
            current_step = args.learning_rate_initial_step
            global_step.load(current_step, session=sess)
            if args.learning_rate_cos:
                lr = tflex_sgdr.sgdr_decay_with_warmup(args.learning_rate, global_step,
                    warmup_steps=args.learning_rate_warmup, initial_period_steps=args.learning_rate_period, learning_rate_min=args.learning_rate_min)
            else:
                lr = tflex.get_variable('learn_rate') or tf.get_variable('learn_rate', shape=(), dtype=tf.float32, trainable=False)
                lr.load(args.learning_rate, session=sess)

        def update_lr(rate=None, step=None):
          if not args.learning_rate_cos:
            if step is None:
              step = global_step.eval(session=sess)
            if rate is None:
              rate = args.learning_rate
            if callable(rate):
              rate = rate(step)
            lr.load(rate, session=sess)
          return lr.eval(session=sess)

        @tflex.register_command
        def set_learning_rate():
          print("Current learn rate: %0.8f" % update_lr())
          print("New learn rate?")
          rate = input('')
          if not rate:
            print("Empty input; not changing anything.")
          else:
            try:
              rate = float(rate)
            except:
              print("Invalid input; must be a float")
          print("Setting learn rate to %0.8f" % rate)
          args.learning_rate = rate

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=lr)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=lr)
        elif args.optimizer == 'ada':
            import tensor2tensor.utils.optimize
            from tensor2tensor.utils import hparam
            import tensor2tensor.models.research
            from tensor2tensor.utils import registry
            ada_hparams = registry.hparams('afx_mimic_adam')
            ada_hparams.optimizer_adafactor_beta1 = 0.0
            ada_hparams.optimizer_adafactor_factored = True
            opt = tensor2tensor.utils.optimize.adafactor(learning_rate=lr, hparams=ada_hparams)
        else:
            exit('Bad optimizer:', args.optimizer)

        #if tpu_addr:
        #    # https://pulsejet.github.io/blog/posts/tpu-without-estimator/
        #    from tensorflow.contrib.tpu.python.tpu import tpu_function
        #    tpu_function.get_tpu_context().set_number_of_shards(8)
        #    opt = tf.contrib.tpu.CrossShardOptimizer(opt)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', lr)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        if args.save_graph:
            summary_log.add_graph(tf.get_default_graph())

        saver = tflex.Saver(
            var_list=all_vars,
            max_to_keep=args.max_to_keep,
            keep_checkpoint_every_n_hours=100000,
            reshape=args.truncate_weights)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tflex.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tflex.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tflex.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tflex.latest_checkpoint(args.restore_from)
        print('Loading snapshot %s...' % ckpt)
        t0 = time.time()
        if not args.fresh_model:
            saver.restore(sess, ckpt)
        t1 = time.time()
        print('Loaded in %f seconds' % (t1 - t0))

        def make_sampler(dataset, enc, seed, combine):
          if os.path.isdir(dataset) or dataset.endswith('.npz'):
            chunks = load_dataset(enc, dataset, combine)
            data_sampler = Sampler(chunks, seed=seed)
            print('dataset has', data_sampler.total_size, 'tokens', len(chunks), 'chunks')
          else:
            data_sampler = TextSampler(dataset, enc, seed=seed)
          return data_sampler

        print('Loading dataset...')
        seed = None if args.seed < 0 else args.seed
        data_sampler = make_sampler(dataset=args.dataset, enc=enc, seed=seed, combine=args.combine)
        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_dataset = args.val_dataset if args.val_dataset else args.dataset
            val_data_sampler = make_sampler(dataset=val_dataset, enc=enc, seed=1, combine=args.combine)
            val_batches = [[val_data_sampler.sample(hparams.n_ctx) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        print('Training...')
        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        @tflex.register_command
        def get_tarfile_name(checkpoint_folder):
            """Converts a folder path into a filename for a .tar archive"""
            tarfile_name = checkpoint_folder.replace(os.path.sep, '_') + '.tar'

            return tarfile_name


        def copy_checkpoint_to_gdrive(run_name='run1', copy_folder=False):
            """Copies the checkpoint folder to a mounted Google Drive."""
            #is_mounted()

            checkpoint_folder = os.path.join('checkpoint', run_name)

            if copy_folder:
                shutil.copytree(checkpoint_folder, "/content/drive/My Drive/" + checkpoint_folder)
            else:
                file_path = get_tarfile_name(checkpoint_folder)

                # Reference: https://stackoverflow.com/a/17081026
                with tarfile.open(file_path, 'w') as tar:
                    tar.add(checkpoint_folder)

                shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path)

        @tflex.register_command
        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            t0 = time.time()
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            t1 = time.time()
            print('Saved in %f seconds' % (t1 - t0))
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
            #copy_checkpoint_to_gdrive()

        @tflex.register_command
        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    print(text)
                    all_text.append(text)
                    index += 1
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        @tflex.register_command
        def validation():
            if args.val_every <= 0:
              return
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '{stamp} [{counter} | {time:2.4f}] validation loss = {loss:2.4f}'
                .format(
                    stamp=timestamp(),
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))

        start_time = time.time()

        def elapsed():
            return time.time() - start_time

        def say(msg):
            print('{stamp} [{counter} | {time:2.4f}] {msg}'.format(counter=counter, time=elapsed(), msg=msg, stamp=timestamp()))

        def sample_batch():
            #return [data_sampler.sample(args.sample_ctx) for _ in range(args.batch_size)]
            #say('Sampling batch...')
            r = []
            times = []
            for _ in range(args.batch_size):
                start = time.time()
                sample = data_sampler.sample(args.sample_ctx)
                end = time.time()
                elapsed = (end - start)
                r += [sample]
                times += [elapsed]
            total = sum(times)
            avg = total / len(times)
            #say('Sampled %d batches in %.4f seconds (avg per batch: %.4f)' % (args.batch_size, total, avg))
            return r

        prev_time = time.time()
        avg_loss = (0.0, 0.0)

        if args.debug_before_training:
            import pdb
            pdb.set_trace()

        last_saved_time = elapsed()
        while True:
            try:
                now = elapsed()
                if args.save_time > 0 and (((now - last_saved_time) / 60.0) >= args.save_time):
                    save()
                    last_saved_time = now
                elif args.save_every > 0 and (counter % args.save_every == 0):
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    validation()

                v_rate = update_lr()

                if args.accumulate_gradients > 1:
                    #say('Running opt_reset...')
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        batch = sample_batch()
                        say('Running opt_compute...')
                        sess.run(opt_compute, feed_dict={context: batch})
                    say('Running opt_apply...')
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    batch = sample_batch()
                    say('Running opt_apply...')
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: batch})

                if args.float16:
                    v_loss = tf.to_float(v_loss).eval()

                summary_log.add_summary(v_summary, counter)
                summary_log.flush()

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                now = time.time()
                print('{stamp} [{counter} | {time:2.4f} | {delta:2.2f}s | {ops:2.6f}tokens/s] loss={loss:2.4f} avg={avg:2.4f} rate={rate:0.7f} step={step}'
                    .format(
                        stamp=timestamp(),
                        counter=counter,
                        time=now - start_time,
                        delta=now - prev_time,
                        ops=args.sample_ctx * args.batch_size / (now - prev_time),
                        rate=v_rate,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1],
                        step=current_step,
                        ))

                counter += 1
                current_step += 1
                global_step.load(current_step, session=sess)

                tflex.check_commands_with_args(
                    session=sess,
                    stamp=timestamp(),
                    counter=counter,
                    time=now - start_time,
                    delta=now - prev_time,
                    ops=args.batch_size / (now - prev_time),
                    rate=v_rate,
                    loss=v_loss,
                    avg=avg_loss[0] / avg_loss[1],
                    avg_loss=avg_loss,
                    step=current_step,
                    train_vars=train_vars,
                    all_vars=all_vars,
                    args=args,
                    data_sampler=data_sampler,
                    ckpt=ckpt,
                    saver=saver,
                    )
                if tflex.should_quit():
                  break

                prev_time = now
                if args.debug_print_all_vars:
                    print('all variables:')
                    print('name/shape/parameter_count')
                    param_count = 0
                    for x in tf.all_variables():
                        shape = x.shape.as_list()
                        count = np.prod(shape)
                        print(x.name, shape, count)
                        param_count += count
                    print('Total parameters:', param_count)
                    args.debug_print_all_vars = False

                if args.debug_print_trainable_vars:
                    print('trainable variables:')
                    print('name/shape/parameter_count')
                    param_count = 0
                    for x in tf.trainable_variables():
                        shape = x.shape.as_list()
                        count = np.prod(shape)
                        print(x.name, shape, count)
                        param_count += count
                    print('Total parameters:', param_count)
                    args.debug_print_trainable_vars = False
            except KeyboardInterrupt:
                print('interrupted')
                if args.save_on_ctrlc:
                    save()
                if args.debug_on_ctrlc:
                    import pdb
                    pdb.set_trace()
                else:
                    break
Пример #4
0
def interact_model(model_name='117M',
                   restore_from=None,
                   seed=None,
                   nsamples=1,
                   batch_size=1,
                   length=None,
                   temperature=1,
                   top_k=0,
                   top_p=0.0,
                   penalize=0,
                   prompt=None):
    """
    Interactively run the model
    :model_name=117M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :nsamples=1 : Number of samples to return total
    :batch_size=1 : Number of batches (only affects speed/memory).  Must divide nsamples.
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
    :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
     overriding top_k if set to a value > 0. A good setting is 0.9.
    :penalize=0.0 : Float value controlling "used" penalty. Implements repetition
     reduction (similar to CTRL) if set to a value > 0. A decent setting might be 0.85
     with temperature 0.3 and top_k 40.
    """
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    with tflex.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(hparams=hparams,
                                        length=length,
                                        context=context,
                                        batch_size=batch_size,
                                        temperature=temperature,
                                        top_k=top_k,
                                        top_p=top_p,
                                        penalize=penalize)

        saver = tflex.Saver()
        if restore_from is None:
            restore_from = os.path.join('models', model_name)
        ckpt = tflex.latest_checkpoint(restore_from)
        saver.restore(sess, ckpt)

        while True:
            if prompt is not None:
                if os.path.isfile(prompt):
                    with open(prompt) as f:
                        raw_text = f.read()
                else:
                    raw_text = prompt
            else:
                raw_text = input("Model prompt >>> ")
                if not raw_text:
                    raw_text = "\n"
            if len(raw_text) > 1 and raw_text.endswith('\n'):
                raw_text = raw_text[:-1]
            print('Prompt:', repr(raw_text))
            context_tokens = enc.encode(raw_text)
            generated = 0
            for _ in range(nsamples // batch_size):
                out = sess.run(output,
                               feed_dict={
                                   context:
                                   [context_tokens for _ in range(batch_size)]
                               })[:, len(context_tokens):]
                for i in range(batch_size):
                    generated += 1
                    text = enc.decode(out[i])
                    print("=" * 40 + " SAMPLE " + str(generated) + " " +
                          "=" * 40)
                    sys.stdout.write(raw_text)
                    print(text)
                    sys.stdout.flush()
            print("=" * 80)
Пример #5
0
def text_rewrite(
        input_filepath,
        output_filepath,
        model_name='1558M',
        restore_from=None,
        seed=None,
        batch_size=4,
        length=70,
        temperature=1,
        top_k=0,
        top_p=1,
        init_tpu=False
):
    nlp = English()
    nlp.add_pipe(nlp.create_pipe('sentencizer'))
    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    sess_initer = tflex if init_tpu else tf

    with sess_initer.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p
        )
        saver = tflex.Saver()
        if restore_from is None:
            restore_from = os.path.join('models', model_name)
        ckpt = tflex.latest_checkpoint(restore_from)
        saver.restore(sess, ckpt)

        with open(input_filepath, encoding='utf-8') as f:
            text = f.read()

        result = []
        for sent in nlp(text).sents:
            sent_text = sent.text.strip()
            if not sent_text:
                continue
            model_input = "ORIGINAL_SENT: {} >>>>>".format(sent_text)
            context_tokens = enc.encode(model_input)
            print(model_input)
            print('*' * 80)
            out = sess.run(output, feed_dict={
                context: [context_tokens for _ in range(batch_size)]
            })[:, len(context_tokens):]

            texts = []
            for i in range(batch_size):
                text = enc.decode(out[i]) + "\n"
                print("SAMPLE {}".format(i))
                print(text)
                print('-' * 80)
                texts.append(text)
            rephrased = get_best_candidate(sent_text, texts)
            rephrased = add_new_line(sent.text, rephrased)
            result.append(rephrased)
            print("REPHRASED: {}".format(rephrased))
            print('=' * 80)

        with open(output_filepath, encoding='utf-8', mode='w') as f:
            f.write(" ".join(result))
Пример #6
0
def interact_model(model_name='117M',
                   restore_from=None,
                   seed=None,
                   nsamples=1,
                   step=1,
                   length=64,
                   prompt="\n",
                   clear=None,
                   maxlen=-1,
                   temperature=1,
                   top_k=0,
                   top_p=0,
                   penalize=0):
    """
    Interactively run the model
    :model_name=117M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :nsamples=1 : Number of samples to return total
    :step=1 : Number of tokens to generate at a time
    :length=64 : Window size; use 1024 for maximum size per sample
    :prompt="\\n" : Prompt to start with. The default of "" prompts with an <|endoftext|> token.
    :clear=None : If this string is encountered, clear the context window.
    :maxlen=-1 : if this many tokens are generated without
     encountering --clear, then print it and clear the context window.
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
    :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
     overriding top_k if set to a value > 0. A good setting is 0.9.
    :penalize=0.0 : Float value controlling "used" penalty. Implements repetition
     reduction (similar to CTRL) if set to a value > 0. A decent setting might be 0.85
     with temperature 0.3 and top_k 40.
    """
    batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length > hparams.n_ctx:
        raise ValueError("Length can't be largeer than n_ctx: %s" %
                         hparams.n_ctx)
    if step > length:
        raise ValueError("Can't get samples longer than length: %s" % length)

    with tflex.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(hparams=hparams,
                                        length=step,
                                        context=context,
                                        batch_size=batch_size,
                                        temperature=temperature,
                                        top_k=top_k,
                                        top_p=top_p,
                                        penalize=penalize)

        saver = tflex.Saver(reshape=True)
        if restore_from is None:
            restore_from = os.path.join('models', model_name)
        ckpt = tflex.latest_checkpoint(restore_from)
        saver.restore(sess, ckpt)
        saver2 = tf.train.Saver()
        counter = int(ckpt.split('-')[-1].split('.')[0])
        saver2.save(sess, os.path.join('saved', 'model'), global_step=counter)
def sample_model(model_name='117M',
                 restore_from=None,
                 seed=None,
                 nsamples=0,
                 batch_size=1,
                 length=None,
                 temperature=1,
                 top_k=0,
                 top_p=0.0,
                 penalize=0):
    """
    Run the sample_model
    :model_name=117M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to
     reproduce results
    :nsamples=0 : Number of samples to return, if 0, continues to
     generate samples indefinately.
    :batch_size=1 : Number of batches (only affects speed/memory).
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
    :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
     overriding top_k if set to a value > 0. A good setting is 0.9.
    :penalize=0.0 : Float value controlling "used" penalty. Implements repetition
     reduction (similar to CTRL) if set to a value > 0. A decent setting might be 0.85
     with temperature 0.3 and top_k 40.
    """
    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    with tflex.Session(graph=tf.Graph()) as sess:
        np.random.seed(seed)
        tf.set_random_seed(seed)

        output = sample.sample_sequence(
            hparams=hparams,
            length=length,
            start_token=enc.encoder['<|endoftext|>'],
            batch_size=batch_size,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            penalize=penalize)[:, 1:]

        saver = tflex.Saver()
        if restore_from is None:
            restore_from = os.path.join('models', model_name)
        ckpt = tflex.latest_checkpoint(restore_from)
        saver.restore(sess, ckpt)

        generated = 0
        while nsamples == 0 or generated < nsamples:
            out = sess.run(output)
            for i in range(batch_size):
                generated += 1
                text = enc.decode(out[i])
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                print(text)
Пример #8
0
def interact_model(model_name='117M',
                   asker=None,
                   responder=None,
                   restore_from=None,
                   seed=None,
                   length=None,
                   temperature=1,
                   top_k=0,
                   top_p=0.0,
                   penalize=0,
                   prompt=None):
    """
    Interactively chat with the model
    :model_name=117M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
    :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
     overriding top_k if set to a value > 0. A good setting is 0.9.
    :penalize=0.0 : Float value controlling "used" penalty. Implements repetition
     reduction (similar to CTRL) if set to a value > 0. A decent setting might be 0.85
     with temperature 0.3 and top_k 40.
    """
    if asker is None:
        raise Exception(
            "Add a name present in the training dataset that you will be chatting as"
        )
    if responder is None:
        raise Exception(
            "Add a name present in the training dataset that gpt will be chatting as"
        )

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    with tflex.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [1, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(hparams=hparams,
                                        length=length,
                                        context=context,
                                        batch_size=1,
                                        temperature=temperature,
                                        top_k=top_k,
                                        top_p=top_p,
                                        penalize=penalize)

        saver = tflex.Saver()
        if restore_from is None:
            restore_from = os.path.join('models', model_name)
        ckpt = tflex.latest_checkpoint(restore_from)
        saver.restore(sess, ckpt)

        input_ = ''
        time = 1924862493344
        while True:
            time = increase_time(time)
            input_ = input_ + f'({time}) {asker}: ' + input(f"{asker}: ")
            time = increase_time(time)
            input_ = input_ + f'\n ({time}) {responder}: '
            if len(input_) > 1 and input_.endswith('\n'):
                input_ = input_[:-1]
            context_tokens = enc.encode(input_)
            out = sess.run(output,
                           feed_dict={context:
                                      [context_tokens]})[:,
                                                         len(context_tokens):]
            enc.decode(out[0])
            text = enc.decode(out[0]).split(f') {asker}', 1)[0]
            print(f'\n ({time}) {responder}: ' + text.rsplit('(', 1)[0])
            input_ = input_ + text
            sys.stdout.flush()