def main():
    args = parser.parse_args()
    enc = encoder_sp.get_encoder(args.model_name)
    print('Reading files')
    chunks = load_dataset(enc,
                          args.in_text,
                          args.combine,
                          encoding=args.encoding)
    print('Writing', args.out_npz)
    np.savez_compressed(args.out_npz, *chunks)
def train_main(dataset,
               model_name='1250M',
               seed=None,
               msg=True,
               batch_size=16,
               learning_rate=0.00002,
               sample_length=512,
               sample_num=1,
               sample_every=100,
               run_name='run1',
               restore_from='latest',
               save_every=1000,
               combine=50000):

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
        print('n_ctx: ', hparams.n_ctx, 'n_head: ', hparams.n_head, 'n_embd: ',
              hparams.n_embd, 'n_layer: ', hparams.n_layer)

    if sample_length is None:
        sample_length = hparams.n_ctx
    elif sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    # TF config

    config = tf.ConfigProto()
    #device_map = { 0:2, 0:3, 1:2, 1:3 }
    #config.gpu_options.visible_device_list = str(device_map[hvd.rank()])
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.allow_growth = True

    global_step = tf.Variable(0, trainable=False)

    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=0.9,
                                           top_k=40)

        #global_step = tf.Variable(0, trainable=False)
        counter = 1

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        #opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        # l4rz 11/10/2019
        decayed_lr = tf.train.exponential_decay(learning_rate,
                                                global_step,
                                                200,
                                                0.999,
                                                staircase=True)
        opt = tf.train.AdamOptimizer(decayed_lr)
        #opt = tf.train.GradientDescentOptimizer(decayed_lr)
        opt = hvd.DistributedOptimizer(opt)
        # this is original horovod
        #train_op = opt.minimize(loss, var_list=train_vars)
        # this is ours
        if (msg):
            print('Using memory saving gradients')
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            train_op = opt.apply_gradients(opt_grads, global_step=global_step)
        else:
            print('Not using memory saving gradients')
            #train_op = opt.minimize(loss, var_list=train_vars)
            # l4rz 11/10
            train_op = opt.minimize(loss,
                                    var_list=train_vars,
                                    global_step=global_step)
        # [1,2]<stderr>:TypeError: apply_gradients() missing 1 required positional argument: 'grads_and_vars'
        #summary_loss = tf.summary.scalar('loss', train_op)

        #_, lv = sess.run((train_op, loss), feed_dict={context: batch})

        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        print('Running hvd.broadcast_global_variables')
        bcast = hvd.broadcast_global_variables(0)
        print('Done')

        saver = tf.train.Saver(var_list=train_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)

        print('Running global_variables_initializer')
        sess.run(tf.global_variables_initializer())
        print('Done')

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
        # comment out when running for 1st time
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print(str(hvd.local_rank()), 'Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        # uncomment when running for first time INIT THE MODEL
        #print('tf.global_variables_initializer()')
        #sess.run(tf.global_variables_initializer())

        bcast.run()

        print(str(hvd.local_rank()), 'Loading dataset...')
        chunks = load_dataset(enc, dataset, combine)
        data_sampler = Sampler(chunks)
        print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size,
              'tokens')
        print(str(hvd.local_rank()), 'Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                       global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:

                batch = [data_sampler.sample(1024) for _ in range(batch_size)]

                _, lv = sess.run((train_op, loss), feed_dict={context: batch})

                avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0)

                if hvd.rank() == 0:
                    if counter % save_every == 0:
                        save()
                    if counter % sample_every == 0:
                        generate_samples()

                    print(
                        '[{counter} | {time:2.2f}] loss={loss:2.4f} avg={avg:2.4f} lr={lr:.2e}'
                        .format(counter=counter,
                                time=time.time() - start_time,
                                loss=lv,
                                avg=avg_loss[0] / avg_loss[1],
                                lr=decayed_lr.eval()))

                counter += 1

        except KeyboardInterrupt:
            print('interrupted')
            if hvd.rank() == 0:
                save()
def sample_model(
    model_name='1250M',
    seed=None,
    nsamples=0,
    batch_size=1,
    length=None,
    temperature=1,
    top_k=0,
    top_p=0.0
):
    """
    Run the sample_model
    :model_name=117M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to
     reproduce results
    :nsamples=0 : Number of samples to return, if 0, continues to
     generate samples indefinately.
    :batch_size=1 : Number of batches (only affects speed/memory).
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
    :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
     overriding top_k if set to a value > 0. A good setting is 0.9.
    """
    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        np.random.seed(seed)
        tf.set_random_seed(seed)

        output = sample.sample_sequence(
            hparams=hparams, length=length,
            start_token=enc.encoder['<|endoftext|>'],
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p
        )[:, 1:]

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
        saver.restore(sess, ckpt)

        generated = 0
        while nsamples == 0 or generated < nsamples:
            out = sess.run(output)
            for i in range(batch_size):
                generated += batch_size
                text = enc.decode(out[i])
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                print(text)
def train_main(dataset,
               model_name='117M',
               seed=None,
               batch_size=2,
               sample_length=1023,
               sample_num=1,
               sample_every=4500,
               run_name='run1',
               restore_from='latest',
               save_every=2000,
               combine=50000):

    enc = encoder_sp.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length is None:
        sample_length = hparams.n_ctx // 2
    elif sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    # TF config

    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=sample_length,
            context=context,
            batch_size=batch_size,
            temperature=0.8,
            top_k=40)

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        opt = tf.train.AdamOptimizer()
        opt = hvd.DistributedOptimizer(opt)
        train_op = opt.minimize(loss, var_list=train_vars)

        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        bcast = hvd.broadcast_global_variables(0)

        saver = tf.train.Saver(
            var_list=train_vars,
            max_to_keep=5,
            keep_checkpoint_every_n_hours=2)

        sess.run(tf.global_variables_initializer())


        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print(str(hvd.local_rank()), 'Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        bcast.run()

        print(str(hvd.local_rank()), 'Loading dataset...')
        chunks = load_dataset(enc, dataset, combine)
        data_sampler = Sampler(chunks)
        print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size, 'tokens')
        print(str(hvd.local_rank()), 'Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample, feed_dict={context: batch_size*[context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:

                batch = [data_sampler.sample(1024) for _ in range(batch_size)]

                _, lv = sess.run((train_op, loss), feed_dict={context: batch})

                avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0)

                if hvd.rank() == 0:
                    if counter % save_every == 0:
                        save()
                    if counter % sample_every == 0:
                        generate_samples()

                    print(
                        '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                        .format(
                            counter=counter,
                            time=time.time() - start_time,
                            loss=lv,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1

        except KeyboardInterrupt:
            print('interrupted')
            if hvd.rank() == 0:
                save()
def interact_model(
    model_name='1250M',
    seed=None,
    nsamples=1,
    batch_size=1,
    length=None,
    temperature=0.8,
    top_k=40,
    run_name='run1',
):
    """
    Interactively run the model
    :model_name=117M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :nsamples=1 : Number of samples to return total
    :batch_size=1 : Number of batches (only affects speed/memory).  Must divide nsamples.
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
    """
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(hparams=hparams,
                                        length=length,
                                        context=context,
                                        batch_size=batch_size,
                                        temperature=temperature,
                                        top_k=top_k)

        saver = tf.train.Saver()
        #ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name, 'checkpoint/%s' % run_name))
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
        saver.restore(sess, ckpt)

        while True:
            raw_text = input("Model prompt >>> ")
            while not raw_text:
                print('Prompt should not be empty!')
                raw_text = input("Model prompt >>> ")
            context_tokens = enc.encode(raw_text)
            generated = 0
            for _ in range(nsamples // batch_size):
                out = sess.run(output,
                               feed_dict={
                                   context:
                                   [context_tokens for _ in range(batch_size)]
                               })[:, len(context_tokens):]
                for i in range(batch_size):
                    generated += 1
                    text = enc.decode(out[i])
                    print("=" * 40 + " SAMPLE " + str(generated) + " " +
                          "=" * 40)
                    print(text)
            print("=" * 80)
def main():
    args = parser.parse_args()
    enc = encoder_sp.get_encoder(args.model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=args.top_k,
                                           top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc,
                              args.dataset,
                              args.combine,
                              encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc,
                                          args.val_dataset,
                                          args.combine,
                                          encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(os.path.join(SAMPLE_DIR, args.run_name,
                                   'samples-{}').format(counter),
                      'w',
                      encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
def generate_with_model(
        model_name='117M',
        seed=None,
        nsamples=1,
        batch_size=1,
        length=None,
        temperature=1,
        top_k=0,
        top_p=0.0
):
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder_sp.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p
        )
        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
        saver.restore(sess, ckpt)
        root = etree.parse(
            r'C:\Users\kiva0319\IdeaProjects\hrdmd1803\Strong-Paraphrase-Generation-2020\raw_data\paraphrases.xml')

        root = root.getroot()
        corpus = etree.SubElement(root, "corpus")
        topic_save_percentage = 0.7
        number_generated_examples_at_stage = 10
        minimum_acceptable_similarity = 0.8
        max_intersection_similarity = 0.6
        indexs = list(range(len(root[1])))
        random.Random(3).shuffle(indexs)
        lda_model = LdaModel.load(LDA_MODEL_NAME)
        id2word = lda_model.id2word
        rouge = Rouge()

        for i in tqdm(indexs):
            element = root[1][i]
            id = element[0].text
            id_1 = element[1].text
            id_2 = element[2].text
            title_1 = element[3].text
            title_2 = element[4].text
            jaccard = element[5].text
            clas = element[6].text
            text_1 = "none"
            with open(
                    "C:/Users/kiva0319/IdeaProjects/hrdmd1803/Strong-Paraphrase-Generation-2020/download/v1/" + id_1 + ".txt",
                    'r', encoding="utf-8") as file:
                text = file.read()
                if len(text) < 50:
                    print("bad file id =", id_1)
                    continue
                text_1 = text
            paragraphs = text_1.split("\n\n")

            topic_to_paragraphs = {i: [] for i in range(200)}
            topic_met_num = {i: 0 for i in range(200)}

            for paragraph_num, paragraph in enumerate(paragraphs):
                corpus = [id2word.doc2bow(paragraph.split(" "))]
                row = lda_model[corpus][0]
                row = sorted(row, key=lambda x: (x[1]), reverse=True)
                for j in range(3):
                    topic_num = safe_list_get(row, j)
                    if topic_num >= 0:
                        topic_met_num[topic_num] += 1
                        topic_to_paragraphs[topic_num].append(paragraph_num)
            topic_count = 0
            main_topic = 0
            selected_paragraphs = set()
            paragraphs_grouped_by_topic = []

            topic_limit = int(float(len(topic_met_num)) * topic_save_percentage)
            for topic_num in sorted(topic_met_num, key=topic_met_num.get, reverse=True):
                if topic_count == 0:
                    main_topic = topic_num
                if topic_count > topic_limit:
                    break

                paragraphs_grouped_by_topic[topic_count] = []
                for tn in topic_to_paragraphs[topic_num]:
                    if tn not in selected_paragraphs:
                        paragraphs_grouped_by_topic[topic_count].append(paragraphs[tn])
                        selected_paragraphs.add(tn)
                topic_count += 1

            text_by_topic = []

            for topic_count, paragraphs in enumerate(paragraphs_grouped_by_topic):
                text_by_topic[topic_count] = " ".join(paragraphs)
            sentnces_group_by_topic = []

            for text in text_by_topic:
                ordinary_sentences = []
                sentnces_group = []
                group_count = 0
                sentnces = text.strip().split(". ")

                for sentence in sentnces:
                    if len(re.findall(r'\"(.+?)\"', sentence)) > 0:
                        sentnces_group[group_count] = [sentence]
                        group_count += 1
                    elif len(re.findall(r'[1-3][0-9]{3}', sentence)) > 0:
                        sentnces_group[group_count] = [sentence]
                        group_count += 1
                    elif len(re.findall(r"[A-Z][a-z]+", sentence[1:])) > 0:
                        sentnces_group[group_count] = [sentence]
                        group_count += 1
                    elif len(re.findall(r'[?]', text)) > 0:
                        sentnces_group[group_count] = [sentence]
                        group_count += 1
                    else:
                        ordinary_sentences.append(sentence)
                sentnces_group[group_count] = ordinary_sentences
                text_groups = []

                for group in sentnces_group:
                    text_groups.append(". ".join(group))
                sentnces_group_by_topic.append(text_groups)
            result = []

            for text_groups in sentnces_group_by_topic:
                for raw_text in text_groups:
                    context_tokens = enc.encode(raw_text)
                    samples = []
                    min_rl = 2
                    best_sample = raw_text
                    for s_num in range(number_generated_examples_at_stage):
                        vec_1 = lda_model[id2word.doc2bow(raw_text.split(" "))]
                        vec_2 = lda_model[id2word.doc2bow(best_sample.split(" "))]
                        if (cossim(vec_1, vec_2)) < minimum_acceptable_similarity:
                            break
                        for _ in range(nsamples // batch_size):
                            out = sess.run(output, feed_dict={
                                context: [context_tokens for _ in range(batch_size)]
                            })[:, len(context_tokens):]
                            for i in range(batch_size):
                                text = enc.decode(out[i])
                                samples.append(text)
                        for sample in samples:
                            sc = rouge.get_scores(raw_text, sample)[0]
                            r = sc['rouge-l']['f']
                            if r < min_rl:
                                min_rl = r
                                best_sample = sample
                    if rouge.get_scores(raw_text, best_sample)[0]['rouge-1']['f'] < max_intersection_similarity:
                        result.append(best_sample)
            random.shuffle(result)
            # print(" ".join(result))
            return " ".join(result)