Пример #1
0
 def make_sampler(dataset, enc, seed, combine):
   if os.path.isdir(dataset) or dataset.endswith('.npz'):
     chunks = load_dataset(enc, dataset, combine)
     data_sampler = Sampler(chunks, seed=seed)
     print('dataset has', data_sampler.total_size, 'tokens', len(chunks), 'chunks')
   else:
     data_sampler = TextSampler(dataset, enc, seed=seed)
   return data_sampler
Пример #2
0
def train_main(dataset,
               model_name='1250M',
               seed=None,
               msg=True,
               batch_size=16,
               learning_rate=0.00002,
               sample_length=512,
               sample_num=1,
               sample_every=100,
               run_name='run1',
               restore_from='latest',
               save_every=1000,
               combine=50000):

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
        print('n_ctx: ', hparams.n_ctx, 'n_head: ', hparams.n_head, 'n_embd: ',
              hparams.n_embd, 'n_layer: ', hparams.n_layer)

    if sample_length is None:
        sample_length = hparams.n_ctx
    elif sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    # TF config

    config = tf.ConfigProto()
    #device_map = { 0:2, 0:3, 1:2, 1:3 }
    #config.gpu_options.visible_device_list = str(device_map[hvd.rank()])
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.allow_growth = True

    global_step = tf.Variable(0, trainable=False)

    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=0.9,
                                           top_k=40)

        #global_step = tf.Variable(0, trainable=False)
        counter = 1

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        #opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        # l4rz 11/10/2019
        decayed_lr = tf.train.exponential_decay(learning_rate,
                                                global_step,
                                                200,
                                                0.999,
                                                staircase=True)
        opt = tf.train.AdamOptimizer(decayed_lr)
        #opt = tf.train.GradientDescentOptimizer(decayed_lr)
        opt = hvd.DistributedOptimizer(opt)
        # this is original horovod
        #train_op = opt.minimize(loss, var_list=train_vars)
        # this is ours
        if (msg):
            print('Using memory saving gradients')
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            train_op = opt.apply_gradients(opt_grads, global_step=global_step)
        else:
            print('Not using memory saving gradients')
            #train_op = opt.minimize(loss, var_list=train_vars)
            # l4rz 11/10
            train_op = opt.minimize(loss,
                                    var_list=train_vars,
                                    global_step=global_step)
        # [1,2]<stderr>:TypeError: apply_gradients() missing 1 required positional argument: 'grads_and_vars'
        #summary_loss = tf.summary.scalar('loss', train_op)

        #_, lv = sess.run((train_op, loss), feed_dict={context: batch})

        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        print('Running hvd.broadcast_global_variables')
        bcast = hvd.broadcast_global_variables(0)
        print('Done')

        saver = tf.train.Saver(var_list=train_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)

        print('Running global_variables_initializer')
        sess.run(tf.global_variables_initializer())
        print('Done')

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
        # comment out when running for 1st time
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print(str(hvd.local_rank()), 'Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        # uncomment when running for first time INIT THE MODEL
        #print('tf.global_variables_initializer()')
        #sess.run(tf.global_variables_initializer())

        bcast.run()

        print(str(hvd.local_rank()), 'Loading dataset...')
        chunks = load_dataset(enc, dataset, combine)
        data_sampler = Sampler(chunks)
        print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size,
              'tokens')
        print(str(hvd.local_rank()), 'Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                       global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:

                batch = [data_sampler.sample(1024) for _ in range(batch_size)]

                _, lv = sess.run((train_op, loss), feed_dict={context: batch})

                avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0)

                if hvd.rank() == 0:
                    if counter % save_every == 0:
                        save()
                    if counter % sample_every == 0:
                        generate_samples()

                    print(
                        '[{counter} | {time:2.2f}] loss={loss:2.4f} avg={avg:2.4f} lr={lr:.2e}'
                        .format(counter=counter,
                                time=time.time() - start_time,
                                loss=lv,
                                avg=avg_loss[0] / avg_loss[1],
                                lr=decayed_lr.eval()))

                counter += 1

        except KeyboardInterrupt:
            print('interrupted')
            if hvd.rank() == 0:
                save()
Пример #3
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if args.model_name == '355M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=args.top_k,
                                           top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc,
                              args.dataset,
                              args.combine,
                              encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc,
                                          args.val_dataset,
                                          args.combine,
                                          encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text.encode('utf8'))
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(os.path.join(SAMPLE_DIR, args.run_name,
                                   'samples-{}').format(counter),
                      'w',
                      encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
Пример #4
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=40)

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        if args.accumulate_gradients > 1:
            opt = AccumulatingOptimizer(
                opt=tf.train.AdamOptimizer(learning_rate=args.learning_rate),
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            opt_apply = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate).minimize(loss,
                                                           var_list=train_vars)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=train_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc, args.dataset, args.combine)
        data_sampler = Sampler(chunks)
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
            print('uploading to Google Drive...')
            gdriveUp = os.popen(
                'cp -r /content/gpt-2/checkpoint/ /content/drive/My\\ Drive/' +
                GDRIVE_DIR).read()
            print(gdriveUp)

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summary_loss),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
Пример #5
0
def train_main(dataset,
               model_name='117M',
               seed=None,
               batch_size=2,
               sample_length=1023,
               sample_num=1,
               sample_every=4500,
               run_name='run1',
               restore_from='latest',
               save_every=2000,
               combine=50000):

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(
            os.path.join('chatbot_model', 'trained_models', model_name,
                         'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length is None:
        sample_length = hparams.n_ctx // 2
    elif sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    # TF config

    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=0.8,
                                           top_k=40)

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        opt = tf.train.AdamOptimizer()
        opt = hvd.DistributedOptimizer(opt)
        train_op = opt.minimize(loss, var_list=train_vars)

        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        bcast = hvd.broadcast_global_variables(0)

        saver = tf.train.Saver(var_list=train_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)

        sess.run(tf.global_variables_initializer())

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('chatbot_model', 'trained_models',
                                 model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('chatbot_model', 'trained_models', model_name))
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print(str(hvd.local_rank()), 'Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        bcast.run()

        print(str(hvd.local_rank()), 'Loading dataset...')
        chunks = load_dataset(enc, dataset, combine)
        data_sampler = Sampler(chunks)
        print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size,
              'tokens')
        print(str(hvd.local_rank()), 'Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                       global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:

                batch = [data_sampler.sample(1024) for _ in range(batch_size)]

                _, lv = sess.run((train_op, loss), feed_dict={context: batch})

                avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0)

                if hvd.rank() == 0:
                    if counter % save_every == 0:
                        save()
                    if counter % sample_every == 0:
                        generate_samples()

                    print(
                        '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                        .format(counter=counter,
                                time=time.time() - start_time,
                                loss=lv,
                                avg=avg_loss[0] / avg_loss[1]))

                counter += 1

        except KeyboardInterrupt:
            print('interrupted')
            if hvd.rank() == 0:
                save()
Пример #6
0
def main():
    args = parser.parse_args()
    try:
        logdir = os.path.join(CHECKPOINT_DIR, args.run_name)
        with open('logdir.txt', 'w') as z:
            z.write(logdir)
    except:
        pass
    enc = get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join(model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=40)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars
        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(
                opt=tf.train.AdamOptimizer(learning_rate=args.learning_rate),
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(os.path.join(model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(os.path.join(model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        #print('Loading dataset...')
        #chunks = load_dataset(enc, args.dataset, args.combine)
        #data_sampler = Sampler(chunks)
        print('Loading train dataset...')
        from_name, ques_name, to_name = name_parts(args.dataset)

        trn_chunks_from = load_dataset(
            enc, from_name, args.combine)  #if args.dataset else chunks
        #trn_chunks_ques = load_dataset(enc, ques_name, args.combine) if args.dataset else chunks
        trn_chunks_to = load_dataset(
            enc, to_name, args.combine)  #if args.dataset else chunks

        skip_delimeter = True
        char = '\t'
        trn_data_sampler_from = SamplerVal(trn_chunks_from,
                                           enc,
                                           char=char,
                                           skip_delimeter=skip_delimeter)
        #trn_data_sampler_ques = SamplerVal(trn_chunks_ques, enc, char=char, skip_delimeter=skip_delimeter)
        trn_data_sampler_to = SamplerVal(trn_chunks_to,
                                         enc,
                                         char=char,
                                         skip_delimeter=skip_delimeter)

        len_v = 0
        data_sampler = []
        for i in range(trn_data_sampler_from.total_size):
            v = (
                #enc.encode('\nQ: ') +
                trn_data_sampler_from.get(i) +
                #enc.encode('. \nA: ') +
                trn_data_sampler_to.get(i)  #  +
                #enc.encode('. ')
            )

            v = v[:HIDDEN_SIZE - 1]
            len_v += len(v)
            #data_sampler.extend(v) ##
            data_sampler.append(v)
            pass

        if len_v < HIDDEN_SIZE:
            mult = HIDDEN_SIZE // len_v + 1
            for i in range(mult):
                x = data_sampler[:]
                data_sampler.extend(x)
            data_sampler = Sampler([np.array(data_sampler)])

        #if not args.train_special and len_v >= HIDDEN_SIZE:
        #    data_sampler = Sampler([np.array(data_sampler)])

        if args.val_every > 0 and False:
            val_chunks = load_dataset(
                enc, args.val_dataset,
                args.combine) if args.val_dataset else chunks
        if not isinstance(data_sampler, list):
            print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

            #print(model_name, 'mn')
            GPT2_DIR_X = model_name
            cd = CHECKPOINT_DIR + "/" + args.run_name
            if not os.path.isfile(cd + '/' + 'encoder.json'):
                os.system("cp " + GPT2_DIR_X + '/' + 'encoder.json ' + cd +
                          '/.')
                os.system('cp ' + GPT2_DIR_X + "/" + 'vocab.bpe ' + cd + '/.')

        def generate_samples():
            print('Generating samples...')
            #context_tokens = data_sampler.sample(1)
            #context_tokens = data_sampler[0]
            context_tokens = trn_data_sampler_from.get(
                random.randint(0, trn_data_sampler_from.total_size))
            #print(enc.decode(context_tokens), len(context_tokens))
            #print(args.batch_size * [context_tokens])

            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            #z = [data_sampler.sample(1024) for _ in range(args.batch_size)]
            #print(len(data_sampler))
            #print(len(data_sampler[0]))
            z = [data_sampler[random.randint(0, args.batch_size)]]
            #print(enc.decode(z[0]))
            #print(z[1],'\n1' ,z[2],'\n2' ,z[3] ,len(data_sampler[0]))
            #exit()
            return z

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while counter != args.stop_after:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                    pass
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summary_loss),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('\ninterrupted')

        finally:
            save()
Пример #7
0
def main():
    
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    hparams.batch_size=args.batch_size
    hparams.seq_len=args.seq_len
    
    ##data_path
    args.train_data_path=args.data_dir+args.dataset+'/train.txt'
    args.eval_data_path=args.data_dir+args.dataset+'/dev.txt'
    args.test_data_path=args.data_dir+args.dataset+'/test.txt'
    args.eval_data_path=args.test_data_path                          ###Test mode only!
    args.gpt_save_path=args.gpt_save_dir+args.dataset+'/'
    args.dis_save_path=args.dis_save_dir+args.dataset+'/'
    
    args.gpt_sample_dir2=args.gpt_sample_dir+args.dataset+'/'
    args.dis_sample_dir2=args.dis_sample_dir+args.dataset+'/'
    
    args.log_path=args.log_dir+args.dataset+'/'
    maketree(args.gpt_save_dir)
    maketree(args.dis_save_dir)
    maketree(args.gpt_save_path)
    maketree(args.dis_save_path)
    maketree(args.gpt_sample_dir)
    maketree(args.dis_sample_dir)
    maketree(args.gpt_sample_dir2)
    maketree(args.dis_sample_dir2)
    
    maketree(args.log_dir)
    maketree(args.log_path)
    
    
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        scope_discri='distri'
        
        def get_dis_logit_and_prob_single_step(context, scope):
            with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
                context=tf.reshape(context, [-1, args.seq_len])
                emb=tf.get_variable(name='emb', initializer=tf.random.normal([hparams.n_vocab, 32], 0, 0.02))
                context_emb=tf.nn.embedding_lookup(emb, context)
                logit=dis(context_emb, scope=scope_discri)
                prob=tf.sigmoid(logit+1e-7)
            return logit, prob
        
        def get_dis_logit_and_prob(context, context_len, scope):
            ##Pay attention to context_len here. temporary changes!!!!!!!!!!!!!!!!!!!
            context_mask=(1-tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32))*1e3
            context_mask2=tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32)
            ones=tf.ones(shape=[tf.shape(context_len)[0], args.seq_len], dtype=tf.int32)*enc.encoder['<|endoftext|>']
            input_tensor_list=[]
            for i in range(1, args.seq_len):
                input_tensor_list.append(tf.concat([context[:, :i+1], ones[:,i+1:]], axis=1))
            input_tensor=tf.concat(input_tensor_list, axis=0)
            log_prob, _=get_dis_logit_and_prob_single_step(input_tensor, scope=scope)
            log_prob=tf.transpose(tf.reshape(log_prob, [args.seq_len-1, -1]))
            log_prob+=tf.cast(context_mask, tf.float32)
            log_prob_min=tf.reduce_min(log_prob, axis=1)
            prob_min=tf.exp(log_prob_min)
            return log_prob_min, prob_min, log_prob
        ##Build discriminator
        
        def build_dis_layer(scope):
            context_pos_discri = tf.placeholder(tf.int32, [None, args.seq_len])
            context_pos_discri_len = tf.placeholder(tf.int32, [None])
            context_neg_discri = tf.placeholder(tf.int32, [None, args.seq_len])
            context_neg_discri_len = tf.placeholder(tf.int32, [None])
            
            label_pos_discri=tf.ones([tf.shape(context_pos_discri_len)[0]], dtype=tf.float32)
            label_neg_discri=tf.zeros([tf.shape(context_neg_discri_len)[0]], dtype=tf.float32)
            logit_pos_discri, prob_pos_discri, mask=get_dis_logit_and_prob(context_pos_discri, context_pos_discri_len, scope=scope)
            logit_neg_discri, _, _=get_dis_logit_and_prob(context_neg_discri, context_neg_discri_len, scope=scope)
        
            loss_pre_pos_discri=tf.nn.sigmoid_cross_entropy_with_logits(labels=label_pos_discri, logits=logit_pos_discri)
            loss_pos_discri=tf.reduce_mean(loss_pre_pos_discri)
            loss_pre_neg_discri=tf.nn.sigmoid_cross_entropy_with_logits(labels=label_neg_discri, logits=logit_neg_discri)
            loss_neg_discri=tf.reduce_mean(loss_pre_neg_discri)
            loss_discri=(loss_pos_discri*args.pos_loss_weight+loss_neg_discri)/(1+args.pos_loss_weight)
        
            train_var_list_discri=[x for x in tf.global_variables() if scope in  x.name]
            train_op_discri=tf.train.AdamOptimizer().minimize(loss_discri, var_list=train_var_list_discri)
            var_list_discri=[x for x in tf.global_variables() if scope in  x.name]
            initializer_discri=tf.variables_initializer(var_list_discri)
            saver_discri=tf.train.Saver(var_list=var_list_discri, max_to_keep=1)
            print('discri: {} build succeed!'.format(scope))
            return context_pos_discri,context_pos_discri_len, context_neg_discri,context_neg_discri_len, loss_pos_discri, loss_neg_discri, loss_discri, train_op_discri, initializer_discri, saver_discri, prob_pos_discri, mask, logit_pos_discri
        
        class dis_class:
            def __init__(self, layer_num=1, scope=scope_discri):
                self.model=[]
                self.dis=np.zeros([layer_num], dtype=np.float32)
                print(layer_num)
                for i in range(layer_num):
                    layer={'scope': scope+str(i)}
                    layer['context_pos_discri'],layer['context_pos_discri_len'], layer['context_neg_discri'],layer['context_neg_discri_len'], layer['loss_pos_discri'], layer['loss_neg_discri'], layer['loss_discri'], layer['train_op_discri'], layer['initializer_discri'], layer['saver_discri'], layer['prob_pos_discri'], layer['mask'], layer['logit_pos_discri'] = build_dis_layer(scope+str(i))
                    self.model.append(layer)
            def prob(self, context, context_len, layer=-1):
                if layer==-1:
                    layer=len(self.model)
                prob_final=tf.ones(tf.shape(context)[0], dtype=tf.float32)
                for i in range(layer):
                    item=self.model[i]
                    scope=item['scope']
                    _, prob, _=get_dis_logit_and_prob(context, context_len, scope=scope)
                    prob_final*=prob
                return prob_final
            def log_prob_step(self, context, layer=-1):
                if layer==-1:
                    layer=len(self.model)
                prob_final=tf.ones(tf.shape(context)[0], dtype=tf.float32)
                log_prob_list=[]
                for i in range(layer):
                    item=self.model[i]
                    scope=item['scope']
                    log_prob, prob=get_dis_logit_and_prob_single_step(context, scope=scope)
                    log_prob_list.append(tf.expand_dims(log_prob, 1))
                log_prob_final=tf.concat(log_prob_list, axis=1)
                return log_prob_final
        
        Dis=dis_class(layer_num=args.layer_num)
        
        context = tf.placeholder(tf.int32, [None, None])
        context_len=tf.placeholder(tf.int32, [None])
        context_mask=tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32)
        context_in=context
        output = model.model(hparams=hparams, X=context_in)
        loss_tensor = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=context[:, 1:], logits=output['logits'][:, :-1])*context_mask
        
        loss=tf.reduce_sum(loss_tensor, axis=1)/(tf.reduce_sum(context_mask, axis=1)+1e-7)
        loss_sen=tf.reduce_sum(loss)
        loss=tf.reduce_mean(loss)
        
        
        if args.val_every > 0:
            def transform_np(x, lift=args.exponential_param):
                x=x-0.5
                x=x+np.abs(x)
                return lift*x**2
            def transform(x, lift=args.exponential_param):
                x=x-0.5
                x=x+tf.abs(x)
                return lift*x**2
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, args.seq_len])
            val_context_len=tf.placeholder(tf.int32, [args.batch_size])
            NLL_bias=tf.placeholder(tf.float32, [])
            val_context_mask=tf.sequence_mask(val_context_len-1, args.seq_len-1, dtype=tf.float32)
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss_tensor =tf.nn.sparse_softmax_cross_entropy_with_logits(labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])*val_context_mask
            val_context_prob_cut=Dis.prob(val_context, val_context_len)
            val_NLL_cut=tf.log(val_context_prob_cut+1e-7)
            
            val_loss=tf.reduce_sum(val_loss_tensor, axis=1)/(tf.reduce_sum(val_context_mask, axis=1)+1e-7)
            val_loss_cut=(tf.reduce_sum(val_loss_tensor, axis=1)+NLL_bias)/(tf.reduce_sum(val_context_mask, axis=1)+1e-7)-val_NLL_cut/tf.cast(val_context_len, tf.float32)
            
            val_loss_sum=tf.reduce_sum(val_loss_tensor, axis=1)
            val_loss_cut_sum=(tf.reduce_sum(val_loss_tensor, axis=1)+NLL_bias)-val_NLL_cut
            
            val_loss_mean=tf.reduce_mean(val_loss)
            val_loss_cut_mean=tf.reduce_mean(val_loss_cut)
            val_loss_summary = tf.summary.scalar('val_loss', val_loss_mean)


        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.seq_len,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p,
            start_token=enc.encoder['<|endoftext|>'])

        start_token=enc.encoder['<|endoftext|>']

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars, max_to_keep=1)
        
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        data_list, data_len = load_dataset(enc, args.train_data_path, args.seq_len)
        data_sampler = Sampler(data_list, data_len )
        if args.val_every > 0:
            val_data_list, val_data_len = load_dataset(enc, args.eval_data_path, args.seq_len)
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_data_list, val_data_len, seed=1)
            val_batches = [val_data_sampler.sample(args.batch_size) for _ in range(args.val_batch_count)]

        counter = 0
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
        
        
        def train_step_discri(layer_id=0, mask_train_epoch=0):
            pos_samples, pos_samples_len=data_sampler.sample(args.batch_size)
            neg_samples=generate_negative_sample(layer_id=layer_id)
            neg_samples_len=get_array_len(neg_samples)
            _, loss=sess.run([Dis.model[layer_id]['train_op_discri'], Dis.model[layer_id]['loss_discri']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples,Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len, Dis.model[layer_id]['context_neg_discri']: neg_samples, Dis.model[layer_id]['context_neg_discri_len']: neg_samples_len})
            return loss
        
        def generate_negative_samples(layer_id, generate_num=args.batch_size):
            result_list=[]
            generate_num_now=0
            samples_mem=[]
            while generate_num_now<generate_num:
                t=time.time()
                sample_id=generate_negative_sample(layer_id=layer_id)
                samples=[]
                t1=time.time()
                selected_id_list=np.arange(len(sample_id))
                t2=time.time()
                result_list.append(sample_id[selected_id_list])
                generate_num_now+=len(selected_id_list)
            return np.concatenate(result_list, axis=0)[:generate_num]
        
        def get_array_len(sample_array):
            lens=[]
            for item in sample_array:
                for i in range(1, len(item)):
                    if item[i]==enc.encoder['<|endoftext|>']:
                        break
                lens.append(i)
            return np.array(lens).astype(np.int32)
        
        def generate_discri_sample3(layer_id=-1, sample_size=10000, save_path='/mnt/cephfs_new_wj/mlnlp/miaoning/Experiment/gpt-2-sep/samples/discri/sample2.txt'):
            samples=[]
            while len(samples)<sample_size:
                sample_id=generate_negative_sample(layer_id)
                for i in range(len(sample_id)):
                    sample_tem=enc.decode(sample_id[i]).split('<|endoftext|>')[1].split('\n')[0]
                    samples.append(sample_tem)
                print(len(samples))
            with open(save_path, 'w') as g:
                g.write('\n'.join(samples))
        
        
        def eval_discri_NLL(layer_id=0):
            losses_pos=[]
            losses_neg=[]
            for batch in tqdm.tqdm(val_batches):
                pos_samples, pos_samples_len=batch
                neg_samples=generate_negative_sample(layer_id=layer_id)
                neg_samples_len=get_array_len(neg_samples)
                loss_pos, mask=sess.run([Dis.model[layer_id]['loss_pos_discri'], Dis.model[layer_id]['mask']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples, Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len})
                #print(mask)
                loss_neg=sess.run(Dis.model[layer_id]['loss_neg_discri'], feed_dict={Dis.model[layer_id]['context_neg_discri']: neg_samples, Dis.model[layer_id]['context_neg_discri_len']: neg_samples_len})
                losses_pos.append(loss_pos)
                losses_neg.append(loss_neg)
            return np.mean(losses_pos), np.mean(losses_neg)
        
        def get_discri_quantile(layer_id=0, quantile=0.85):
            logits_list=[]
            for batch in tqdm.tqdm(val_batches):
                pos_samples, pos_samples_len=batch
                logits, mask=sess.run([Dis.model[layer_id]['logit_pos_discri'], Dis.model[layer_id]['mask']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples, Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len})
                print(np.min(mask, axis=1)[:10])
                print(logits[:10])
                with open('mask.pkl', 'wb') as g:
                    pkl.dump(mask, g)
                logits_list.extend(list(logits))
                break
            with open('logits.pkl', 'wb') as g:
                pkl.dump(sorted(logits_list), g)
            #print(sorted(logits_list))
            print('finish')
            return sorted(logits_list)[int(len(logits_list)*(1-quantile))]
        
        def train_discri(train_step, eval_every, train_layer_list=list(range(len(Dis.model)))):
            #sess.run(initializer_discri)
            print('Start Discri training')
            train_losses=[]
            for layer_id in train_layer_list:
                flag=0
                for epoch in range(train_step):
                    if epoch % eval_every==0:
                        train_losses=np.mean(train_losses)
                        train_losses=[]
                    
                        eval_NLL_pos, eval_NLL_neg=eval_discri_NLL(layer_id)
                        eval_loss=(eval_NLL_pos*args.pos_loss_weight+eval_NLL_neg)/(args.pos_loss_weight+1)
                        print('layer_id:{} discri eval loss:{}'.format(layer_id, eval_loss))
                        print('layer_id:{} discri NLL pos: {}, discri NLL neg: {}'.format(layer_id, eval_NLL_pos, eval_NLL_neg))
                        print(epoch)
                        if epoch==0:
                            eval_loss_old=eval_loss
                        else:
                            print(eval_loss, eval_loss_old)
                            if eval_loss<eval_loss_old:
                                eval_loss_old=eval_loss
                                save_path=args.dis_save_path+str(layer_id)+'/'
                                if not os.path.isdir(save_path):
                                    os.mkdir(save_path)
                                Dis.model[layer_id]['saver_discri'].save(sess, save_path+'a')
                                print('model discri saved!')
                                flag=0
                            else:
                                if epoch>=200:
                                    flag+=1
                            if flag>=4:
                                break
                    train_loss=train_step_discri(layer_id)
                    print('layer_id:{} discri train loss:{}'.format(layer_id, train_loss))
                    train_losses.append(train_loss)
            return eval_loss_old
        
        tf_sample_0 = sample_link.sample_sequence(
                    hparams=hparams,
                    length=args.seq_len,
                    context=context,
                    batch_size=args.batch_size,
                    temperature=1.0,
                    top_k=args.top_k,
                    top_p=args.top_p,
                    start_token=enc.encoder['<|endoftext|>'])
        tf_sample_dict={}
        
        def generate_negative_sample(layer_id=0):
            ##output the filtered result of layer layer_id-1
            if layer_id==0:
                tf_sample=tf_sample_0
                sample = data_sampler.sample(args.batch_size)[0][:,0:1]
                out = sess.run(
                        tf_sample,
                        feed_dict={context: sample})[:,:args.seq_len]
                for i in range(len(out)):
                    flag=0
                    for j in range(len(out[i])):
                        if flag==2:
                            out[i][j]=start_token
                            continue
                        if out[i][j]==start_token:
                            flag+=1
                return out
            else:
                if layer_id==-1:
                    layer_id=len(Dis.model)
                if layer_id in tf_sample_dict:
                    tf_sample=tf_sample_dict[layer_id]
                else:
                    tf_sample = sample_link.sample_sequence_ISMC_threshold(
                        Dis=Dis,
                        layer=layer_id, 
                        hparams=hparams,
                        length=args.seq_len,
                        context=context,
                        batch_size=args.batch_size,
                        temperature=1.0,
                        top_k=args.top_k,
                        top_p=args.top_p,
                        start_token=enc.encoder['<|endoftext|>'])
                    tf_sample_dict[layer_id]=tf_sample
                
                sample = data_sampler.sample(args.batch_size)[0][:,0:1]
                
                out = sess.run(
                        tf_sample,
                        feed_dict={context: sample})[:,:args.seq_len]
                for i in range(len(out)):
                    flag=0
                    for j in range(len(out[i])):
                        if flag==2:
                            out[i][j]=start_token
                            continue
                        if out[i][j]==start_token:
                            flag+=1
                return out

        def validation():
            print('Calculating validation loss...')
            start_time=time.time()
            losses = []
            rates=[]
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss_mean, feed_dict={val_context: batch[0], val_context_len: batch[1]}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict={val_loss_mean: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'
                .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))
            return v_val_loss

        def validation_cut(NLL_bias_0=0):
            print('Calculating validation loss...')
            losses = []
            rates=[]
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss_cut_mean, feed_dict={val_context: batch[0], val_context_len: batch[1], NLL_bias:NLL_bias_0}))
            v_val_loss = np.mean(losses)
            print(
                '[{counter} | {time:2.2f}] validation cut loss = {loss:2.2f}'
                .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))
            return v_val_loss

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]
        
        def train_gpt():
            val_loss_old=10000.0
            avg_loss = (0.0, 0.0)
            start_time = time.time()
            counter=0
            while True:
                #pretraining
                if counter % args.save_every == 0:
                    pass
                    #save()
                if counter % args.sample_every == 0:
                    pass
                    #generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    val_loss_1=validation()
                    print(str(counter //args.val_every))
                    if val_loss_1>=val_loss_old:
                        print('pre-training ends!')
                        break
                    else:
                        val_loss_old=val_loss_1
                        saver.save(sess, args.gpt_save_path+'a')
                        print('save succeed!')

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        batch, batch_len=data_sampler.sample(args.batch_size)
                        sess.run(
                            opt_compute, feed_dict={context: batch, context_len:batch_len})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    batch, batch_len=data_sampler.sample(args.batch_size)
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: batch, context_len:batch_len})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.9 + v_loss,
                            avg_loss[1] * 0.9 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        class log_writer:
            def __init__(self, path):
                self.path=path
                with open(path, 'w') as g:
                    g.write('')
            def __call__(self, string, verbose=False):
                with open(self.path, 'a') as g:
                    g.write(string+'\n')
                if verbose:
                    print(string)
        
        try:
            if args.finetune:
                #Finetune GPT-2
                train_gpt() 
            if True:
                #Restore Finetuned model
                save_path=tf.train.latest_checkpoint(args.gpt_save_path)
                saver.restore(sess, save_path)
                print('Load gpt2 succeeded!')
            if args.evaluate_finetune:
                #Evaluate finetuning baseline
                print(validation())
            if args.evaluate_finetune:
                #Calculate reverse-ppl for finetuning baseline
                sample_path=args.gpt_sample_dir2+'sample.txt'
                generate_discri_sample3(layer_id=0, sample_size=3000, save_path=sample_path)
                rev_ppl=train.file_f(train_data_path=sample_path, val_data_path=args.eval_data_path)
                Log_writer=log_writer(args.log_path+'finetune')
                Log_writer('finetuning_rev_ppl: {}'.format(rev_ppl), verbose=True)
            ##Begin tailoring
            if True:
                Log_writer=log_writer(args.log_path+'discri')
                for layer in range(args.layer_num):
                    print(layer)
                    if args.train_tailor:
                        #Train ratio estimator
                        train_discri(500, 10, [layer])
                    if True:
                        #Restore ratio estimator
                        for layer_id in range(layer+1):
                            save_path=args.dis_save_path+str(layer_id)+'/'
                            print(save_path)
                            save_path=tf.train.latest_checkpoint(save_path)
                            print(save_path)
                            Dis.model[layer_id]['saver_discri'].restore(sess, save_path)
                    if False:
                        #Save quantile for analysis
                        with open(args.dis_sample_dir2+'quantile.pkl', 'rb') as f:
                            pkl.load(f)
                        print('Load dis model succeeded!')
                    if True:
                        if layer==0:
                            quantile=0.85
                        else:
                            quantile=0.9
                        Dis.dis[layer]=get_discri_quantile(layer, quantile)
                        with open(args.dis_sample_dir2+'quantile.pkl', 'wb') as g:
                            pkl.dump(Dis.dis, g)
                        print(Dis.dis)
                    if args.evaluate_tailor:
                        #Generate sample for ERS and calculate reverse-ppl
                        sample_path=args.dis_sample_dir2+'_sample_layer_'+str(layer)
                        generate_discri_sample3(layer_id=layer+1, sample_size=3000, save_path=sample_path)
                        rev_ppl=train.file_f(train_data_path=sample_path, val_data_path=args.eval_data_path)
                        Log_writer('layer: {}, dis_rev_ppl: {}'.format(layer, rev_ppl), verbose=True)
        except KeyboardInterrupt:
            print('interrupted')
Пример #8
0
def train(dataset,
          model_in_path,
          model_out_path,
          model_name='117M',
          steps=1000,
          combine=50000,
          batch_size=1,
          learning_rate=0.00002,
          accumulate_gradients=1,
          memory_saving_gradients=False,
          only_train_transformer_layers=False,
          optimizer='adam',
          noise=0.0,
          top_k=40,
          top_p=0.0,
          restore_from='latest',
          sample_every=100,
          sample_length=1023,
          sample_num=1,
          save_every=1000,
          val_dataset=None):
    # Reset the TF computation graph
    tf.reset_default_graph()
    # Get the checkpoint and sample directories
    #checkpoint_dir = os.path.dirname(model_path)
    #sample_dir = checkpoint_dir
    #run_name = os.path.basename(model_path)
    # Load the encoder
    enc = get_encoder(model_in_path)
    hparams = model.default_hparams()
    with open(os.path.join(model_in_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    # Size matters
    if model_name == '345M':
        memory_saving_gradients = True
        if optimizer == 'adam':
            only_train_transformer_layers = True

    # Configure TF
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    # Start the session
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        context_in = randomize(context, hparams, noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=1.0,
                                           top_k=top_k,
                                           top_p=top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if only_train_transformer_layers else all_vars

        if optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        elif optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=learning_rate)
        else:
            exit('Bad optimizer:', optimizer)

        if accumulate_gradients > 1:
            if memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            #os.path.join(checkpoint_dir, run_name)
            model_out_path)

        saver = tf.train.Saver(var_list=all_vars, max_to_keep=1)
        sess.run(tf.global_variables_initializer())

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                #os.path.join(checkpoint_dir, run_name)
                model_in_path)
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    model_in_path)  #os.path.join('models', model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                model_in_path)  #os.path.join('models', model_name))
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc, dataset, combine)
        data_sampler = Sampler(chunks)
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')
        counter = 1
        counter_path = os.path.join(
            model_in_path,
            'counter')  #os.path.join(checkpoint_dir, run_name, 'counter')
        if restore_from == 'latest' and os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            #maketree(os.path.join(checkpoint_dir, run_name))
            maketree(model_out_path)
            print(
                'Saving',
                #os.path.join(checkpoint_dir, run_name, 'model-{}').format(counter)
                os.path.join(model_out_path, 'model-{}').format(counter))
            saver.save(
                sess,
                #os.path.join(checkpoint_dir, run_name, 'model'),
                os.path.join(model_out_path, 'model'),
                global_step=counter)
            with open(os.path.join(model_out_path, 'counter'), 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            #maketree(os.path.join(sample_dir, run_name))
            maketree(model_out_path)
            with open(
                    os.path.join(model_out_path, 'samples-{}').format(counter),
                    'w') as fp:
                fp.write('\n'.join(all_text))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        stop = steps + counter

        try:
            while counter < stop + 1:
                if counter % save_every == 0:
                    save()
                '''
                if counter % sample_every == 0:
                    generate_samples()
                '''

                if accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
            print('done!')
            save()
        except KeyboardInterrupt:
            print('interrupted')
            save()
Пример #9
0
def main():
    args = parser.parse_args()
    folder_id = get_id(args.gdir)
    #xmpp = SendMsgBot(jid, password, to, "Starting GPT-2")
    #xmpp.register_plugin('xep_0030') # Service Discovery
    #xmpp.register_plugin('xep_0199') # XMPP Ping
    #xmpp.connect()
    #threading = Thread(target=xmpp.process, daemon=True).start()
    download_checkpoint(folder_id)
    #send_m('checkpoint downloaded')
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        # if args.optimizer == 'adam':
            # args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(os.path.join(CHECKPOINT_DIR, args.run_name))
        saver = tf.train.Saver(var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        #send_m('Loading  ' + str(ckpt))
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        #send_m('Loading dataset...')
        #chunks = load_dataset(enc, args.dataset, args.combine)
        ds_path = f'{CHECKPOINT_DIR}//run1//{args.dataset}'
        chunks = load_dataset(enc, ds_path, args.combine)
        data_sampler = Sampler(chunks)
        print(f'{ds_path} has', data_sampler.total_size, 'tokens')
        if args.val_every > 0:
            val_chunks = load_dataset(enc, args.val_dataset, args.combine) if args.val_dataset else chunks
        if args.enc:
            print(colored(f'Trying writing Data.npz encoded from this dataset to {args.enc}', 'red'))
            np.savez_compressed(args.enc, *chunks)
            upload_npz(args.enc, folder_id)
        #send_m(f'{args.dataset} has ' + str(data_sampler.total_size) + ' tokens' + '     Start training...')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
            save_gdisk(counter, folder_id)

        def generate_samples():
            print('Generating samples...')
            #send_m('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            #send_m(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'
                    .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()
        last_time = time.time()
        cur_counter, min_loss = 1, 2.0
        print(colored(f'Model  >>> {args.gdir}\nLearning rate is {args.learning_rate}', 'blue'))
        print(colored(f'model optimizer >>> {args.optimizer}\nRestricted to train only transformer layer={args.only_train_transformer_layers}', 'blue'))
        #send_m(f'Model  >>> {args.model_name}\nLearning rate is {args.learning_rate}')
        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                    if check_quota():
                        return() # exit train
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute, feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run((opt_apply, loss, summaries), feed_dict={context: sample_batch()})
                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)
                a_loss = avg_loss[0] / avg_loss[1]
                time_all = int((time.time() - start_time) / 60)
                time_iter = time.time() - last_time
                stats = f'[{counter} | {cur_counter} | {time_all}m | {time_iter:2.2f}s] loss={v_loss:2.2f} avg={a_loss:2.2f}'
                if not(cur_counter % 50):
                    print(colored(stats, 'red' if a_loss > min_loss else 'yellow'))
                    if a_loss < min_loss:
                        min_loss = a_loss
                    #send_m(stats)
                last_time = time.time()
                counter += 1
                cur_counter += 1
        except Exception as e:
            #send_m('Stoped  ' + str(e.__class__))
            print('Stoped', e.__class__)
Пример #10
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    with tf.Session() as sess:
        # Fully static shape required to make memory accounting in
        # twremat accurate.
        train_context = tf.placeholder(tf.int32, [args.batch_size, 1024])
        train_context_in = randomize(train_context, hparams, args.noise)
        train_output = model.model(hparams=hparams, X=train_context_in)
        train_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=train_context[:, 1:],
                logits=train_output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        sample_context = tf.placeholder(tf.int32, [args.batch_size, None])
        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=sample_context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=args.top_k,
                                           top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            print('Using Adam optimizer', file=sys.stderr)
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            print('Using SGD optimizer', file=sys.stderr)
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.memory_saving_gradients:
            if tf.VERSION >= '2':
                exit(
                    'Memory saving gradients are not supported in tensorflow 2.x'
                )
            import memory_saving_gradients
            opt_grads = memory_saving_gradients.gradients(
                train_loss, train_vars)
        elif args.twremat:
            import tfremat
            opt_grads = tf.gradients(train_loss, train_vars)
            (train_loss, opt_grads) = tfremat.tf_remat(
                (train_loss, opt_grads), memlimit=args.twremat_memlimit)
        else:
            opt_grads = tf.gradients(train_loss, train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.summary.scalar('loss', train_loss)

        # if args.twremat:
        #     import tfremat
        #     # Applying tfremat to opt_apply has more accurate
        #     # accounting but is a bit iffier since side effecting ops
        #     # have more restrictions for correctness. If in doubt
        #     # revert back to version using opt_grads above.
        #     (opt_apply, train_loss, summary_loss) = (
        #         tfremat.tf_remat((opt_apply, train_loss, summary_loss), memlimit=args.twremat_memlimit))

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc,
                              args.dataset,
                              args.combine,
                              encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc,
                                          args.val_dataset,
                                          args.combine,
                                          encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(tf_sample,
                               feed_dict={
                                   sample_context:
                                   args.batch_size * [context_tokens]
                               })
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(os.path.join(SAMPLE_DIR, args.run_name,
                                   'samples-{}').format(counter),
                      'w',
                      encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        # print('Evaluating grads..')
        # tf2.profiler.experimental.start('logdir')
        # sess.run((opt_apply, train_loss, summaries), feed_dict={train_context: sample_batch()})
        # tf2.profiler.experimental.stop()
        # print('Succeeded')
        # exit()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, train_loss, summaries),
                    feed_dict={train_context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
Пример #11
0
def main():
    args = parser.parse_args()
    enc = get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join(model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF

    acc_total = 0
    acc_over_time = []
    loss_avg_over_time = []

    if args.val_every > 0:

        # val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
        val_context = tf.placeholder(np.int32, [1, None])

        val_output = model.model(hparams=hparams, X=val_context)
        val_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=val_context[:,
                                   1:], logits=val_output['logits'][:, :-1]))
        val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample_val = sample.sample_sequence(
            hparams=hparams,
            length=1,  #args.sample_length,
            context=val_context,
            batch_size=1,  #args.batch_size,
            temperature=10.001,
            top_k=1)

    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=40)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars
        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(
                opt=tf.train.AdamOptimizer(learning_rate=args.learning_rate),
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(os.path.join(model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(os.path.join(model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading train dataset...')
        from_name, ques_name, to_name = name_parts(
            args.dataset)  #'../data/train.from')

        trn_chunks_from = load_dataset(
            enc, from_name, args.combine) if args.val_dataset else chunks
        trn_chunks_ques = load_dataset(
            enc, ques_name, args.combine) if args.val_dataset else chunks
        trn_chunks_to = load_dataset(
            enc, to_name, args.combine) if args.val_dataset else chunks

        skip_delimeter = True
        trn_data_sampler_from = SamplerVal(trn_chunks_from,
                                           enc,
                                           skip_delimeter=skip_delimeter)
        trn_data_sampler_ques = SamplerVal(trn_chunks_ques,
                                           enc,
                                           skip_delimeter=skip_delimeter)
        trn_data_sampler_to = SamplerVal(trn_chunks_to,
                                         enc,
                                         skip_delimeter=skip_delimeter)

        data_sampler = []
        for i in range(trn_data_sampler_from.total_size):
            v = (
                trn_data_sampler_from.get(i) + trn_data_sampler_ques.get(i) +
                enc.encode('. ') + trn_data_sampler_to.get(i)  # +
                #enc.encode('<|endoftext|>')
            )
            # v += [enc.encode(' ')[0] for _ in range(HIDDEN_SIZE - len(v) )]
            if len(v) >= HIDDEN_SIZE - GENERATE_SIZE:
                continue
            v = v[:HIDDEN_SIZE - 1]
            data_sampler.append(v)
            pass

        #chunks = load_dataset(enc, args.dataset, args.combine)
        if not args.train_special:
            data_sampler = Sampler([np.array(data_sampler)])

        if args.val_every > 0:
            print('Loading validation dataset...')
            #val_chunks = load_dataset(enc, args.val_dataset, args.combine) if args.val_dataset else chunks

            from_name, ques_name, to_name = name_parts(args.val_dataset)

            val_chunks_from = load_dataset(
                enc, from_name, args.combine) if args.val_dataset else chunks
            val_chunks_ques = load_dataset(
                enc, ques_name, args.combine) if args.val_dataset else chunks
            val_chunks_to = load_dataset(
                enc, to_name, args.combine) if args.val_dataset else chunks

        if not args.train_special:
            print('train dataset has', data_sampler.total_size, 'tokens')
        else:
            print('train dataset has', len(data_sampler), 'tokens')
        print('Training...')

        if args.val_every > 0:

            val_data_sampler_from = SamplerVal(val_chunks_from, enc)
            val_data_sampler_ques = SamplerVal(val_chunks_ques, enc)
            val_data_sampler_to = SamplerVal(val_chunks_to, enc)

            if args.val_batch_count == -1:
                args.val_batch_count = val_data_sampler_from.total_size

            val_batches = []
            for i in range(args.val_batch_count):
                v = (val_data_sampler_from.get(i) +
                     val_data_sampler_ques.get(i) + enc.encode('. ')
                     )  #+ val_data_sampler_to.get(i)

                #v += [enc.encode(' ')[0] for _ in range(HIDDEN_SIZE - len(v) )]
                if len(v) >= HIDDEN_SIZE - GENERATE_SIZE:
                    continue
                v = v[:HIDDEN_SIZE]
                val_batches.append(v)
                pass

        print('val dataset has', len(val_batches), 'tokens')
        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        txt_file_path = os.path.join(CHECKPOINT_DIR, args.run_name,
                                     args.run_name + '.summary.txt')

        def save_summary(message=None):
            if message is None:
                txt = ''
                fmt = '{valid:2.2f}'
                if not os.path.exists(txt_file_path):
                    a = vars(args)
                    txt += 'Summary for ' + args.run_name + '\n'
                    txt += str(datetime.datetime.now()) + '\n\n'
                    txt += json.dumps(a) + '\n'
                    txt += '-----\n'
                    pass
                txt += str(datetime.datetime.now()) + '\n'

                txt += 'acc: ' + ', '.join(
                    [fmt.format(valid=i) for i in acc_over_time]) + '\n'
                txt += 'loss: ' + ', '.join(
                    [fmt.format(valid=i) for i in loss_avg_over_time]) + '\n'
                txt += 'counter: ' + str(counter) + '\n'
                txt += 'time elapsed: ' + str(time.time() - start_time) + '\n'
                txt += '-----\n'
            else:
                txt = message
            print(txt)
            with open(txt_file_path, 'a') as f:
                f.write(txt + '\n')

        def save():
            if args.test:
                return

            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        '''
        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))
        '''

        def print_status(word=None,
                         acc_total_in=0,
                         size=0,
                         v_loss_in=0.0,
                         shorten=False):
            v_loss = v_loss_in
            acc_out = 0
            acc_total = 0
            loss_out = 0.0
            if word is None:
                word = 'progress'
            if acc_total_in != 0 and size != 0:
                acc_out = acc_total_in / size * 100
                acc_total = size
                pass
            if avg_loss[1] == 0.0 or avg_loss[0] == 0.0:
                loss_out = 0.0
                v_loss = 0.0
                pass
            elif not np.isnan(
                    avg_loss[0]) or True:  # and not np.isnan(avg_loss[1]):
                loss_out = avg_loss[0] / avg_loss[1]
            print(word + ' [' + args.run_name + ']' +
                  ' [{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_loss,
                         avg=loss_out),
                  'acc=' + str(acc_out),
                  end=' ')
            print('total=' + str(acc_total), end=' ')
            if len(acc_over_time) > 0 and not shorten:
                print('last-acc=' + str(acc_over_time[-1]))
            else:
                print()
            pass

        def sample_batch(counter=0, randomize=False, pad_start=False):
            #print(enc.encode('<|endoftext|>'), 'eot')
            #print(data_sampler.sample(1024))
            if not args.train_special:
                return [
                    data_sampler.sample(HIDDEN_SIZE)[0]
                    for _ in range(args.batch_size)
                ]
            else:
                num = 0
                z = []
                while (len(z) > HIDDEN_SIZE or len(z) == 0) and num <= 5:
                    if randomize:
                        r = random.randint(1, 4)
                    else:
                        r = 0
                    #print('train special', r)
                    if pad_start:
                        pad = HIDDEN_SIZE - (len(data_sampler[counter]) - r)
                    else:
                        pad = 0

                    if randomize:
                        z = [[enc.encode(' ')[0]
                              for _ in range(pad)] + data_sampler[counter][:-r]
                             for _ in range(args.batch_size)]

                    if not randomize:
                        z = [data_sampler[counter]]
                    #print(enc.decode(z[0]))
                    num += 1
                    if num == 5:
                        z = z[len(z) - HIDDEN_SIZE:]
                        print('cannot get sample_batch')
                        break

                return z

        def validation_by_sample():
            print('Generating validation...')
            global acc_total
            if args.val_with_loss:
                losses = []
                for batch in tqdm.tqdm(val_batches):
                    batch = np.reshape(batch, [1, -1])
                    v = sess.run(val_loss, feed_dict={val_context: batch})
                    #print(v, 'v')
                    losses.append(v)
                v_val_loss = np.mean(losses)
                v_summary = sess.run(val_loss_summary,
                                     feed_dict={val_loss: v_val_loss})
                summary_log.add_summary(v_summary, counter)
                summary_log.flush()
                print(
                    '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                    format(counter=counter,
                           time=time.time() - start_time,
                           loss=v_val_loss))
            acc_total = 0
            generated = 0
            for _ in range(len(val_batches)):

                val_batches_in = val_batches[generated]
                val_batches_in = val_batches_in[:1024]
                context_tokens = np.reshape(val_batches_in, [1, -1])
                #print(val_batches_in)
                text_in = enc.decode(val_batches_in)
                #print(text_in)

                #print(context_tokens, 'ct1')
                for x in range(GENERATE_SIZE):

                    out = sess.run(tf_sample_val,
                                   feed_dict={val_context: context_tokens})
                    #print(out[0][-x:])
                    #print(enc.decode(out[0][-x:]))
                    context_tokens = out

                compare = enc.decode(
                    val_data_sampler_to.get(generated))  # + ' <|endoftext|>'
                compare = ' '.join(compare.split(' '))

                generated += 1

                text = enc.decode(out[0])

                text_returned = ''
                text_original = ''
                if text.startswith(text_in):
                    text_returned = text[len(text_in):]
                    #print('-',text_returned,'-')

                if args.train_special:
                    text_original = text
                    text = text_returned

                if text.strip().endswith('.'):  ## remove trailing period
                    text = text.strip()[:-1]

                if text.strip().endswith('<|endoftext|>'):
                    text = text.strip()[:-len('<|endoftext|>')]

                t_vals = text.split(' ')
                if '<' in t_vals[-1] or '>' in t_vals[-1]:
                    t_vals = t_vals[:-1]

                num = 0
                while t_vals[-1] == '' and num < 10:
                    t_vals = t_vals[:-1]
                    num += 1

                #print(t_vals)
                t_vals = [i for i in t_vals if i != '']
                #print(t_vals)

                text = ' '.join(t_vals)

                if compare.strip().endswith('.'):
                    compare = compare.strip()[:-1]

                if compare.strip().endswith('<|endoftext|>'):
                    compare = compare.strip()[:-len('<|endoftext|>')]

                notification = ''
                len_bar = 40
                if text.strip().lower().endswith(compare.strip().lower()):
                    acc_total += 1
                    notification = 'vv CORRECT vv'
                    len_bar = 40 - len(notification)
                elif text_returned.strip().lower().startswith(
                        compare.strip().lower()):
                    acc_total += 1
                    notification = 'vv CORRECT_INITIAL vv'
                    len_bar = 40 - len(notification)

                print(notification + "=" * len_bar + " SAMPLE " +
                      str(generated) + " " + "=" * len_bar + notification)
                if args.train_special:
                    print(text_original)
                else:
                    print(text)
                print_status('old values',
                             acc_total_in=acc_total,
                             size=generated)
            print("=" * 80)
            return acc_total
            pass

        avg_loss = (0.0, 0.0)
        start_time = time.time()
        count_success = 0
        count_success_with_skips = 0
        acc = 0.0

        try:
            if args.test:
                v_loss = 0.0
                dataset = re.sub('train', 'test', args.dataset)
                print(dataset)
                from_name, ques_name, to_name = name_parts(dataset)

                test_chunks_from = load_dataset(enc, from_name, args.combine)
                test_chunks_ques = load_dataset(enc, ques_name, args.combine)
                test_chunks_to = load_dataset(enc, to_name, args.combine)

                val_data_sampler_from = SamplerVal(test_chunks_from, enc)
                val_data_sampler_ques = SamplerVal(test_chunks_ques, enc)
                val_data_sampler_to = SamplerVal(test_chunks_to, enc)

                if args.val_batch_count == -1:
                    args.val_batch_count = val_data_sampler_from.total_size

                val_batches = []
                for i in range(args.val_batch_count):
                    v = (val_data_sampler_from.get(i) +
                         val_data_sampler_ques.get(i) + enc.encode('. ')
                         )  # + val_data_sampler_to.get(i)

                    # v += [enc.encode(' ')[0] for _ in range(HIDDEN_SIZE - len(v) )]
                    if len(v) >= HIDDEN_SIZE - GENERATE_SIZE:
                        continue
                    val_batches.append(v)

                acc_total = validation_by_sample()
                acc = acc_total / len(val_batches) * 100

                print(acc, 'test accuracy')
                save_summary('Accuracy with test set ' + str(acc) + '\n')
                exit()

            while counter != args.stop_after:
                #model_summary()

                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    #generate_samples()
                    pass
                if args.val_every > 0 and (counter % args.val_every
                                           == 0):  # or counter == 1):
                    acc_total = validation_by_sample()
                    acc = acc_total / len(val_batches) * 100

                    acc_over_time.append(acc)
                    if avg_loss[1] > 0.0:
                        loss_avg_over_time.append(avg_loss[0] / avg_loss[1])
                    else:
                        loss_avg_over_time.append(0)

                counter_in = counter % len(val_batches)
                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch(counter_in)})
                    (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summary_loss),
                        feed_dict={context: sample_batch(counter_in)})

                summary_log.add_summary(v_summary, counter)

                #if True:
                if not np.isnan(avg_loss[0]) and not np.isnan(avg_loss[1]):

                    avg_loss = (avg_loss[0] * 0.99 + v_loss,
                                avg_loss[1] * 0.99 + 1.0)

                if counter % args.val_every == 1:
                    if float(acc) == 100.0:
                        #save()

                        print('validation accuracy 100',
                              time.time() - start_time)
                        count_success += 1
                        count_success_with_skips += 1
                        if count_success >= 2 or count_success_with_skips >= 4:
                            #save_summary()

                            exit()
                    else:
                        count_success = 0

                print_status(acc_total_in=acc_total,
                             size=len(val_batches),
                             v_loss_in=v_loss,
                             shorten=True)

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
        finally:
            save()
            save_summary()
            print('save weights/summary and exit.')
Пример #12
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='124M',
             model_dir='models',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             checkpoint_dir='checkpoint',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             multi_gpu=False,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             optimizer='adam',
             overwrite=False,
             val_dataset=None,
             val_batch_size=2,
             val_batch_count=40,
             val_every=0):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    # assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger."

    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        try:
            shutil.copyfile(os.path.join(model_dir, model_name, file),
                            os.path.join(checkpoint_path, file))
        except FileNotFoundError as fnf_error:
            print(
                "You need to download the GPT-2 model first via download_gpt2()"
            )
            raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if model_name not in ['117M', '124M']:
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    gpus = []

    if multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    # validation code
    if val_every > 0:
        val_context = tf.placeholder(tf.int32, [val_batch_size, None])
        val_output = model.model(hparams=hparams, X=val_context,
                                 reuse=True)  # added reuse=True
        val_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=val_context[:,
                                   1:], logits=val_output['logits'][:, :-1]))
        val_loss_summary = tf.summary.scalar('val_loss', val_loss)

    tf_sample = sample.sample_sequence(hparams=hparams,
                                       length=sample_length,
                                       context=context,
                                       batch_size=batch_size,
                                       temperature=1.0,
                                       top_k=40)

    all_vars = [
        v for v in tf.compat.v1.trainable_variables() if 'model' in v.name
    ]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars

    if optimizer == 'adam':
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == 'sgd':
        opt = tf.compat.v1.train.GradientDescentOptimizer(
            learning_rate=learning_rate)

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar('loss', loss)

    summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path)

    saver = tf.compat.v1.train.Saver(var_list=all_vars,
                                     max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)

    # validation code
    if val_every > 0:
        if val_dataset:
            val_chunks = load_dataset(enc, val_dataset, combine)
        else:
            val_chunks = chunks

    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')

    # validation code
    if val_every > 0:
        # Sample from validation set once with fixed seed to make
        # it deterministic during training as well as across runs.
        val_data_sampler = Sampler(val_chunks, seed=1)
        val_batches = [[
            val_data_sampler.sample(1024) for _ in range(val_batch_size)
        ] for _ in range(val_batch_count)]

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print('Saving',
              os.path.join(checkpoint_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = '======== SAMPLE {} ========\n{}\n'.format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))

    # validation code
    def validation():
        print('Calculating validation loss...')
        losses = []
        for batch in tqdm(val_batches):
            losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
        v_val_loss = np.mean(losses)
        v_summary = sess.run(val_loss_summary,
                             feed_dict={val_loss: v_val_loss})
        summary_log.add_summary(v_summary, counter)
        summary_log.flush()
        print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.format(
            counter=counter, time=time.time() - start_time, loss=v_val_loss))
        return v_val_loss

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    #Trying out a change to finetune that saves only when validation loss decreases
    if steps:
        steps = int(steps)

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                #save()
                return
            # if (counter - 1) % save_every == 0 and counter > 1:
            #     save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()

            # validation code
            if val_every > 0 and counter == 1:
                v_val_loss = validation()
                save()
            elif val_every > 0 and counter == counter_base:
                v_val_loss = validation()
            elif val_every > 0 and (counter % val_every == 0):
                new_v_val_loss = validation()
                if new_v_val_loss < v_val_loss:
                    v_val_loss = new_v_val_loss
                    save()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if (counter % print_every == 0) or counter == 1:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()
Пример #13
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    if args.model_name == '774M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        #this line is to hopefully reduce memory usage (found on Twitter: https://twitter.com/BasedBlue/status/1169601983046672385?s=20)
        edgeindex = -1 * args.layers_to_train
        train_vars = all_vars[edgeindex:]
        print("Training", args.layers_to_train, "raw layers out of", len(all_vars))

        train_vars = [v for v in train_vars if '/h' in v.name] if args.only_train_transformer_layers else train_vars
        print("Training", len(train_vars), "net layers out of", len(all_vars))

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'adafactor':
            opt = AdafactorOptimizer(learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(
            var_list=all_vars,
            max_to_keep=5,
            keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w', encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def sample_batch():
            ret = [data_sampler.sample(1024) for _ in range(args.batch_size)]
            # print (enc.decode(ret[0]))
            return ret


        avg_loss = (0.0, 0.0)
        bval_loss = (0.0, 0.0)
        start_time = time.time()
        best_val_loss = 99
        missed_val_checkpoints = 0

        try:
            while counter < args.stop_after:
                if counter % args.sample_every == 0:
                    generate_samples()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(
                            opt_compute, feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.98 + v_loss,
                            avg_loss[1] * 0.98 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

                if args.val_every > 0 and counter % args.val_every == 0:
                    valbatch = [val_data_sampler.sample(1024) for _ in range(args.batch_size)]
                    valacc = sess.run(loss, feed_dict={context: valbatch})
                    bval_loss = (bval_loss[0] * 0.9 + valacc, bval_loss[1] * 0.9 + 1.0)
                    av_val_loss = bval_loss[0] / bval_loss[1]
                    av_train_loss = avg_loss[0] / avg_loss[1]
                    print(
                        '[{counter} | {time:2.2f}] VAL_loss={loss:2.4f} VAL_avg={avg:2.4f} best={best:2.4f}'
                        .format(
                            counter=counter,
                            time=time.time() - start_time,
                            loss=valacc,
                            avg=av_val_loss,
                            best=best_val_loss))
                    if counter >= args.save_every and counter % args.save_every == 0: # check for validation checkpoints every save_every iterations.
                        if av_val_loss < best_val_loss and av_val_loss > av_train_loss: # got a good one from validation, save a checkpoint (every save_every) -- but don't save before val loss goes above train loss
                            save()
                            best_val_loss = av_val_loss
                            missed_val_checkpoints = 0
                        else: # missed a validation checkpoint. tolerate like 10 of these.
                            if av_val_loss > av_train_loss: # don't count a missed checkpoint while val loss is under training loss
                                missed_val_checkpoints += 1
                    if missed_val_checkpoints > 19: # missed too many save opportunities, stop training
                        counter = args.stop_after + 1
                        print('stopping training due to val loss not improving.')

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')