Пример #1
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if args.model_name == '355M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=args.top_k,
                                           top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc,
                              args.dataset,
                              args.combine,
                              encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc,
                                          args.val_dataset,
                                          args.combine,
                                          encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text.encode('utf8'))
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(os.path.join(SAMPLE_DIR, args.run_name,
                                   'samples-{}').format(counter),
                      'w',
                      encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
Пример #2
0
def train_main(dataset,
               model_name='1250M',
               seed=None,
               msg=True,
               batch_size=16,
               learning_rate=0.00002,
               sample_length=512,
               sample_num=1,
               sample_every=100,
               run_name='run1',
               restore_from='latest',
               save_every=1000,
               combine=50000):

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
        print('n_ctx: ', hparams.n_ctx, 'n_head: ', hparams.n_head, 'n_embd: ',
              hparams.n_embd, 'n_layer: ', hparams.n_layer)

    if sample_length is None:
        sample_length = hparams.n_ctx
    elif sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    # TF config

    config = tf.ConfigProto()
    #device_map = { 0:2, 0:3, 1:2, 1:3 }
    #config.gpu_options.visible_device_list = str(device_map[hvd.rank()])
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.allow_growth = True

    global_step = tf.Variable(0, trainable=False)

    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=0.9,
                                           top_k=40)

        #global_step = tf.Variable(0, trainable=False)
        counter = 1

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        #opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        # l4rz 11/10/2019
        decayed_lr = tf.train.exponential_decay(learning_rate,
                                                global_step,
                                                200,
                                                0.999,
                                                staircase=True)
        opt = tf.train.AdamOptimizer(decayed_lr)
        #opt = tf.train.GradientDescentOptimizer(decayed_lr)
        opt = hvd.DistributedOptimizer(opt)
        # this is original horovod
        #train_op = opt.minimize(loss, var_list=train_vars)
        # this is ours
        if (msg):
            print('Using memory saving gradients')
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            train_op = opt.apply_gradients(opt_grads, global_step=global_step)
        else:
            print('Not using memory saving gradients')
            #train_op = opt.minimize(loss, var_list=train_vars)
            # l4rz 11/10
            train_op = opt.minimize(loss,
                                    var_list=train_vars,
                                    global_step=global_step)
        # [1,2]<stderr>:TypeError: apply_gradients() missing 1 required positional argument: 'grads_and_vars'
        #summary_loss = tf.summary.scalar('loss', train_op)

        #_, lv = sess.run((train_op, loss), feed_dict={context: batch})

        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        print('Running hvd.broadcast_global_variables')
        bcast = hvd.broadcast_global_variables(0)
        print('Done')

        saver = tf.train.Saver(var_list=train_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)

        print('Running global_variables_initializer')
        sess.run(tf.global_variables_initializer())
        print('Done')

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
        # comment out when running for 1st time
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print(str(hvd.local_rank()), 'Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        # uncomment when running for first time INIT THE MODEL
        #print('tf.global_variables_initializer()')
        #sess.run(tf.global_variables_initializer())

        bcast.run()

        print(str(hvd.local_rank()), 'Loading dataset...')
        chunks = load_dataset(enc, dataset, combine)
        data_sampler = Sampler(chunks)
        print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size,
              'tokens')
        print(str(hvd.local_rank()), 'Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                       global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:

                batch = [data_sampler.sample(1024) for _ in range(batch_size)]

                _, lv = sess.run((train_op, loss), feed_dict={context: batch})

                avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0)

                if hvd.rank() == 0:
                    if counter % save_every == 0:
                        save()
                    if counter % sample_every == 0:
                        generate_samples()

                    print(
                        '[{counter} | {time:2.2f}] loss={loss:2.4f} avg={avg:2.4f} lr={lr:.2e}'
                        .format(counter=counter,
                                time=time.time() - start_time,
                                loss=lv,
                                avg=avg_loss[0] / avg_loss[1],
                                lr=decayed_lr.eval()))

                counter += 1

        except KeyboardInterrupt:
            print('interrupted')
            if hvd.rank() == 0:
                save()
Пример #3
0
def main():
    args = parser.parse_args()
    try:
        logdir = os.path.join(CHECKPOINT_DIR, args.run_name)
        with open('logdir.txt', 'w') as z:
            z.write(logdir)
    except:
        pass
    enc = get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join(model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=40)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars
        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(
                opt=tf.train.AdamOptimizer(learning_rate=args.learning_rate),
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(os.path.join(model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(os.path.join(model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        #print('Loading dataset...')
        #chunks = load_dataset(enc, args.dataset, args.combine)
        #data_sampler = Sampler(chunks)
        print('Loading train dataset...')
        from_name, ques_name, to_name = name_parts(args.dataset)

        trn_chunks_from = load_dataset(
            enc, from_name, args.combine)  #if args.dataset else chunks
        #trn_chunks_ques = load_dataset(enc, ques_name, args.combine) if args.dataset else chunks
        trn_chunks_to = load_dataset(
            enc, to_name, args.combine)  #if args.dataset else chunks

        skip_delimeter = True
        char = '\t'
        trn_data_sampler_from = SamplerVal(trn_chunks_from,
                                           enc,
                                           char=char,
                                           skip_delimeter=skip_delimeter)
        #trn_data_sampler_ques = SamplerVal(trn_chunks_ques, enc, char=char, skip_delimeter=skip_delimeter)
        trn_data_sampler_to = SamplerVal(trn_chunks_to,
                                         enc,
                                         char=char,
                                         skip_delimeter=skip_delimeter)

        len_v = 0
        data_sampler = []
        for i in range(trn_data_sampler_from.total_size):
            v = (
                #enc.encode('\nQ: ') +
                trn_data_sampler_from.get(i) +
                #enc.encode('. \nA: ') +
                trn_data_sampler_to.get(i)  #  +
                #enc.encode('. ')
            )

            v = v[:HIDDEN_SIZE - 1]
            len_v += len(v)
            #data_sampler.extend(v) ##
            data_sampler.append(v)
            pass

        if len_v < HIDDEN_SIZE:
            mult = HIDDEN_SIZE // len_v + 1
            for i in range(mult):
                x = data_sampler[:]
                data_sampler.extend(x)
            data_sampler = Sampler([np.array(data_sampler)])

        #if not args.train_special and len_v >= HIDDEN_SIZE:
        #    data_sampler = Sampler([np.array(data_sampler)])

        if args.val_every > 0 and False:
            val_chunks = load_dataset(
                enc, args.val_dataset,
                args.combine) if args.val_dataset else chunks
        if not isinstance(data_sampler, list):
            print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

            #print(model_name, 'mn')
            GPT2_DIR_X = model_name
            cd = CHECKPOINT_DIR + "/" + args.run_name
            if not os.path.isfile(cd + '/' + 'encoder.json'):
                os.system("cp " + GPT2_DIR_X + '/' + 'encoder.json ' + cd +
                          '/.')
                os.system('cp ' + GPT2_DIR_X + "/" + 'vocab.bpe ' + cd + '/.')

        def generate_samples():
            print('Generating samples...')
            #context_tokens = data_sampler.sample(1)
            #context_tokens = data_sampler[0]
            context_tokens = trn_data_sampler_from.get(
                random.randint(0, trn_data_sampler_from.total_size))
            #print(enc.decode(context_tokens), len(context_tokens))
            #print(args.batch_size * [context_tokens])

            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            #z = [data_sampler.sample(1024) for _ in range(args.batch_size)]
            #print(len(data_sampler))
            #print(len(data_sampler[0]))
            z = [data_sampler[random.randint(0, args.batch_size)]]
            #print(enc.decode(z[0]))
            #print(z[1],'\n1' ,z[2],'\n2' ,z[3] ,len(data_sampler[0]))
            #exit()
            return z

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while counter != args.stop_after:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                    pass
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summary_loss),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('\ninterrupted')

        finally:
            save()
Пример #4
0
def train_main(dataset,
               model_name='117M',
               seed=None,
               batch_size=2,
               sample_length=1023,
               sample_num=1,
               sample_every=4500,
               run_name='run1',
               restore_from='latest',
               save_every=2000,
               combine=50000):

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(
            os.path.join('chatbot_model', 'trained_models', model_name,
                         'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length is None:
        sample_length = hparams.n_ctx // 2
    elif sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    # TF config

    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=0.8,
                                           top_k=40)

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        opt = tf.train.AdamOptimizer()
        opt = hvd.DistributedOptimizer(opt)
        train_op = opt.minimize(loss, var_list=train_vars)

        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        bcast = hvd.broadcast_global_variables(0)

        saver = tf.train.Saver(var_list=train_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)

        sess.run(tf.global_variables_initializer())

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('chatbot_model', 'trained_models',
                                 model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('chatbot_model', 'trained_models', model_name))
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print(str(hvd.local_rank()), 'Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        bcast.run()

        print(str(hvd.local_rank()), 'Loading dataset...')
        chunks = load_dataset(enc, dataset, combine)
        data_sampler = Sampler(chunks)
        print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size,
              'tokens')
        print(str(hvd.local_rank()), 'Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                       global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:

                batch = [data_sampler.sample(1024) for _ in range(batch_size)]

                _, lv = sess.run((train_op, loss), feed_dict={context: batch})

                avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0)

                if hvd.rank() == 0:
                    if counter % save_every == 0:
                        save()
                    if counter % sample_every == 0:
                        generate_samples()

                    print(
                        '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                        .format(counter=counter,
                                time=time.time() - start_time,
                                loss=lv,
                                avg=avg_loss[0] / avg_loss[1]))

                counter += 1

        except KeyboardInterrupt:
            print('interrupted')
            if hvd.rank() == 0:
                save()
Пример #5
0
def main():
    
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    hparams.batch_size=args.batch_size
    hparams.seq_len=args.seq_len
    
    ##data_path
    args.train_data_path=args.data_dir+args.dataset+'/train.txt'
    args.eval_data_path=args.data_dir+args.dataset+'/dev.txt'
    args.test_data_path=args.data_dir+args.dataset+'/test.txt'
    args.eval_data_path=args.test_data_path                          ###Test mode only!
    args.gpt_save_path=args.gpt_save_dir+args.dataset+'/'
    args.dis_save_path=args.dis_save_dir+args.dataset+'/'
    
    args.gpt_sample_dir2=args.gpt_sample_dir+args.dataset+'/'
    args.dis_sample_dir2=args.dis_sample_dir+args.dataset+'/'
    
    args.log_path=args.log_dir+args.dataset+'/'
    maketree(args.gpt_save_dir)
    maketree(args.dis_save_dir)
    maketree(args.gpt_save_path)
    maketree(args.dis_save_path)
    maketree(args.gpt_sample_dir)
    maketree(args.dis_sample_dir)
    maketree(args.gpt_sample_dir2)
    maketree(args.dis_sample_dir2)
    
    maketree(args.log_dir)
    maketree(args.log_path)
    
    
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        scope_discri='distri'
        
        def get_dis_logit_and_prob_single_step(context, scope):
            with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
                context=tf.reshape(context, [-1, args.seq_len])
                emb=tf.get_variable(name='emb', initializer=tf.random.normal([hparams.n_vocab, 32], 0, 0.02))
                context_emb=tf.nn.embedding_lookup(emb, context)
                logit=dis(context_emb, scope=scope_discri)
                prob=tf.sigmoid(logit+1e-7)
            return logit, prob
        
        def get_dis_logit_and_prob(context, context_len, scope):
            ##Pay attention to context_len here. temporary changes!!!!!!!!!!!!!!!!!!!
            context_mask=(1-tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32))*1e3
            context_mask2=tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32)
            ones=tf.ones(shape=[tf.shape(context_len)[0], args.seq_len], dtype=tf.int32)*enc.encoder['<|endoftext|>']
            input_tensor_list=[]
            for i in range(1, args.seq_len):
                input_tensor_list.append(tf.concat([context[:, :i+1], ones[:,i+1:]], axis=1))
            input_tensor=tf.concat(input_tensor_list, axis=0)
            log_prob, _=get_dis_logit_and_prob_single_step(input_tensor, scope=scope)
            log_prob=tf.transpose(tf.reshape(log_prob, [args.seq_len-1, -1]))
            log_prob+=tf.cast(context_mask, tf.float32)
            log_prob_min=tf.reduce_min(log_prob, axis=1)
            prob_min=tf.exp(log_prob_min)
            return log_prob_min, prob_min, log_prob
        ##Build discriminator
        
        def build_dis_layer(scope):
            context_pos_discri = tf.placeholder(tf.int32, [None, args.seq_len])
            context_pos_discri_len = tf.placeholder(tf.int32, [None])
            context_neg_discri = tf.placeholder(tf.int32, [None, args.seq_len])
            context_neg_discri_len = tf.placeholder(tf.int32, [None])
            
            label_pos_discri=tf.ones([tf.shape(context_pos_discri_len)[0]], dtype=tf.float32)
            label_neg_discri=tf.zeros([tf.shape(context_neg_discri_len)[0]], dtype=tf.float32)
            logit_pos_discri, prob_pos_discri, mask=get_dis_logit_and_prob(context_pos_discri, context_pos_discri_len, scope=scope)
            logit_neg_discri, _, _=get_dis_logit_and_prob(context_neg_discri, context_neg_discri_len, scope=scope)
        
            loss_pre_pos_discri=tf.nn.sigmoid_cross_entropy_with_logits(labels=label_pos_discri, logits=logit_pos_discri)
            loss_pos_discri=tf.reduce_mean(loss_pre_pos_discri)
            loss_pre_neg_discri=tf.nn.sigmoid_cross_entropy_with_logits(labels=label_neg_discri, logits=logit_neg_discri)
            loss_neg_discri=tf.reduce_mean(loss_pre_neg_discri)
            loss_discri=(loss_pos_discri*args.pos_loss_weight+loss_neg_discri)/(1+args.pos_loss_weight)
        
            train_var_list_discri=[x for x in tf.global_variables() if scope in  x.name]
            train_op_discri=tf.train.AdamOptimizer().minimize(loss_discri, var_list=train_var_list_discri)
            var_list_discri=[x for x in tf.global_variables() if scope in  x.name]
            initializer_discri=tf.variables_initializer(var_list_discri)
            saver_discri=tf.train.Saver(var_list=var_list_discri, max_to_keep=1)
            print('discri: {} build succeed!'.format(scope))
            return context_pos_discri,context_pos_discri_len, context_neg_discri,context_neg_discri_len, loss_pos_discri, loss_neg_discri, loss_discri, train_op_discri, initializer_discri, saver_discri, prob_pos_discri, mask, logit_pos_discri
        
        class dis_class:
            def __init__(self, layer_num=1, scope=scope_discri):
                self.model=[]
                self.dis=np.zeros([layer_num], dtype=np.float32)
                print(layer_num)
                for i in range(layer_num):
                    layer={'scope': scope+str(i)}
                    layer['context_pos_discri'],layer['context_pos_discri_len'], layer['context_neg_discri'],layer['context_neg_discri_len'], layer['loss_pos_discri'], layer['loss_neg_discri'], layer['loss_discri'], layer['train_op_discri'], layer['initializer_discri'], layer['saver_discri'], layer['prob_pos_discri'], layer['mask'], layer['logit_pos_discri'] = build_dis_layer(scope+str(i))
                    self.model.append(layer)
            def prob(self, context, context_len, layer=-1):
                if layer==-1:
                    layer=len(self.model)
                prob_final=tf.ones(tf.shape(context)[0], dtype=tf.float32)
                for i in range(layer):
                    item=self.model[i]
                    scope=item['scope']
                    _, prob, _=get_dis_logit_and_prob(context, context_len, scope=scope)
                    prob_final*=prob
                return prob_final
            def log_prob_step(self, context, layer=-1):
                if layer==-1:
                    layer=len(self.model)
                prob_final=tf.ones(tf.shape(context)[0], dtype=tf.float32)
                log_prob_list=[]
                for i in range(layer):
                    item=self.model[i]
                    scope=item['scope']
                    log_prob, prob=get_dis_logit_and_prob_single_step(context, scope=scope)
                    log_prob_list.append(tf.expand_dims(log_prob, 1))
                log_prob_final=tf.concat(log_prob_list, axis=1)
                return log_prob_final
        
        Dis=dis_class(layer_num=args.layer_num)
        
        context = tf.placeholder(tf.int32, [None, None])
        context_len=tf.placeholder(tf.int32, [None])
        context_mask=tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32)
        context_in=context
        output = model.model(hparams=hparams, X=context_in)
        loss_tensor = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=context[:, 1:], logits=output['logits'][:, :-1])*context_mask
        
        loss=tf.reduce_sum(loss_tensor, axis=1)/(tf.reduce_sum(context_mask, axis=1)+1e-7)
        loss_sen=tf.reduce_sum(loss)
        loss=tf.reduce_mean(loss)
        
        
        if args.val_every > 0:
            def transform_np(x, lift=args.exponential_param):
                x=x-0.5
                x=x+np.abs(x)
                return lift*x**2
            def transform(x, lift=args.exponential_param):
                x=x-0.5
                x=x+tf.abs(x)
                return lift*x**2
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, args.seq_len])
            val_context_len=tf.placeholder(tf.int32, [args.batch_size])
            NLL_bias=tf.placeholder(tf.float32, [])
            val_context_mask=tf.sequence_mask(val_context_len-1, args.seq_len-1, dtype=tf.float32)
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss_tensor =tf.nn.sparse_softmax_cross_entropy_with_logits(labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])*val_context_mask
            val_context_prob_cut=Dis.prob(val_context, val_context_len)
            val_NLL_cut=tf.log(val_context_prob_cut+1e-7)
            
            val_loss=tf.reduce_sum(val_loss_tensor, axis=1)/(tf.reduce_sum(val_context_mask, axis=1)+1e-7)
            val_loss_cut=(tf.reduce_sum(val_loss_tensor, axis=1)+NLL_bias)/(tf.reduce_sum(val_context_mask, axis=1)+1e-7)-val_NLL_cut/tf.cast(val_context_len, tf.float32)
            
            val_loss_sum=tf.reduce_sum(val_loss_tensor, axis=1)
            val_loss_cut_sum=(tf.reduce_sum(val_loss_tensor, axis=1)+NLL_bias)-val_NLL_cut
            
            val_loss_mean=tf.reduce_mean(val_loss)
            val_loss_cut_mean=tf.reduce_mean(val_loss_cut)
            val_loss_summary = tf.summary.scalar('val_loss', val_loss_mean)


        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.seq_len,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p,
            start_token=enc.encoder['<|endoftext|>'])

        start_token=enc.encoder['<|endoftext|>']

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars, max_to_keep=1)
        
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        data_list, data_len = load_dataset(enc, args.train_data_path, args.seq_len)
        data_sampler = Sampler(data_list, data_len )
        if args.val_every > 0:
            val_data_list, val_data_len = load_dataset(enc, args.eval_data_path, args.seq_len)
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_data_list, val_data_len, seed=1)
            val_batches = [val_data_sampler.sample(args.batch_size) for _ in range(args.val_batch_count)]

        counter = 0
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
        
        
        def train_step_discri(layer_id=0, mask_train_epoch=0):
            pos_samples, pos_samples_len=data_sampler.sample(args.batch_size)
            neg_samples=generate_negative_sample(layer_id=layer_id)
            neg_samples_len=get_array_len(neg_samples)
            _, loss=sess.run([Dis.model[layer_id]['train_op_discri'], Dis.model[layer_id]['loss_discri']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples,Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len, Dis.model[layer_id]['context_neg_discri']: neg_samples, Dis.model[layer_id]['context_neg_discri_len']: neg_samples_len})
            return loss
        
        def generate_negative_samples(layer_id, generate_num=args.batch_size):
            result_list=[]
            generate_num_now=0
            samples_mem=[]
            while generate_num_now<generate_num:
                t=time.time()
                sample_id=generate_negative_sample(layer_id=layer_id)
                samples=[]
                t1=time.time()
                selected_id_list=np.arange(len(sample_id))
                t2=time.time()
                result_list.append(sample_id[selected_id_list])
                generate_num_now+=len(selected_id_list)
            return np.concatenate(result_list, axis=0)[:generate_num]
        
        def get_array_len(sample_array):
            lens=[]
            for item in sample_array:
                for i in range(1, len(item)):
                    if item[i]==enc.encoder['<|endoftext|>']:
                        break
                lens.append(i)
            return np.array(lens).astype(np.int32)
        
        def generate_discri_sample3(layer_id=-1, sample_size=10000, save_path='/mnt/cephfs_new_wj/mlnlp/miaoning/Experiment/gpt-2-sep/samples/discri/sample2.txt'):
            samples=[]
            while len(samples)<sample_size:
                sample_id=generate_negative_sample(layer_id)
                for i in range(len(sample_id)):
                    sample_tem=enc.decode(sample_id[i]).split('<|endoftext|>')[1].split('\n')[0]
                    samples.append(sample_tem)
                print(len(samples))
            with open(save_path, 'w') as g:
                g.write('\n'.join(samples))
        
        
        def eval_discri_NLL(layer_id=0):
            losses_pos=[]
            losses_neg=[]
            for batch in tqdm.tqdm(val_batches):
                pos_samples, pos_samples_len=batch
                neg_samples=generate_negative_sample(layer_id=layer_id)
                neg_samples_len=get_array_len(neg_samples)
                loss_pos, mask=sess.run([Dis.model[layer_id]['loss_pos_discri'], Dis.model[layer_id]['mask']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples, Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len})
                #print(mask)
                loss_neg=sess.run(Dis.model[layer_id]['loss_neg_discri'], feed_dict={Dis.model[layer_id]['context_neg_discri']: neg_samples, Dis.model[layer_id]['context_neg_discri_len']: neg_samples_len})
                losses_pos.append(loss_pos)
                losses_neg.append(loss_neg)
            return np.mean(losses_pos), np.mean(losses_neg)
        
        def get_discri_quantile(layer_id=0, quantile=0.85):
            logits_list=[]
            for batch in tqdm.tqdm(val_batches):
                pos_samples, pos_samples_len=batch
                logits, mask=sess.run([Dis.model[layer_id]['logit_pos_discri'], Dis.model[layer_id]['mask']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples, Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len})
                print(np.min(mask, axis=1)[:10])
                print(logits[:10])
                with open('mask.pkl', 'wb') as g:
                    pkl.dump(mask, g)
                logits_list.extend(list(logits))
                break
            with open('logits.pkl', 'wb') as g:
                pkl.dump(sorted(logits_list), g)
            #print(sorted(logits_list))
            print('finish')
            return sorted(logits_list)[int(len(logits_list)*(1-quantile))]
        
        def train_discri(train_step, eval_every, train_layer_list=list(range(len(Dis.model)))):
            #sess.run(initializer_discri)
            print('Start Discri training')
            train_losses=[]
            for layer_id in train_layer_list:
                flag=0
                for epoch in range(train_step):
                    if epoch % eval_every==0:
                        train_losses=np.mean(train_losses)
                        train_losses=[]
                    
                        eval_NLL_pos, eval_NLL_neg=eval_discri_NLL(layer_id)
                        eval_loss=(eval_NLL_pos*args.pos_loss_weight+eval_NLL_neg)/(args.pos_loss_weight+1)
                        print('layer_id:{} discri eval loss:{}'.format(layer_id, eval_loss))
                        print('layer_id:{} discri NLL pos: {}, discri NLL neg: {}'.format(layer_id, eval_NLL_pos, eval_NLL_neg))
                        print(epoch)
                        if epoch==0:
                            eval_loss_old=eval_loss
                        else:
                            print(eval_loss, eval_loss_old)
                            if eval_loss<eval_loss_old:
                                eval_loss_old=eval_loss
                                save_path=args.dis_save_path+str(layer_id)+'/'
                                if not os.path.isdir(save_path):
                                    os.mkdir(save_path)
                                Dis.model[layer_id]['saver_discri'].save(sess, save_path+'a')
                                print('model discri saved!')
                                flag=0
                            else:
                                if epoch>=200:
                                    flag+=1
                            if flag>=4:
                                break
                    train_loss=train_step_discri(layer_id)
                    print('layer_id:{} discri train loss:{}'.format(layer_id, train_loss))
                    train_losses.append(train_loss)
            return eval_loss_old
        
        tf_sample_0 = sample_link.sample_sequence(
                    hparams=hparams,
                    length=args.seq_len,
                    context=context,
                    batch_size=args.batch_size,
                    temperature=1.0,
                    top_k=args.top_k,
                    top_p=args.top_p,
                    start_token=enc.encoder['<|endoftext|>'])
        tf_sample_dict={}
        
        def generate_negative_sample(layer_id=0):
            ##output the filtered result of layer layer_id-1
            if layer_id==0:
                tf_sample=tf_sample_0
                sample = data_sampler.sample(args.batch_size)[0][:,0:1]
                out = sess.run(
                        tf_sample,
                        feed_dict={context: sample})[:,:args.seq_len]
                for i in range(len(out)):
                    flag=0
                    for j in range(len(out[i])):
                        if flag==2:
                            out[i][j]=start_token
                            continue
                        if out[i][j]==start_token:
                            flag+=1
                return out
            else:
                if layer_id==-1:
                    layer_id=len(Dis.model)
                if layer_id in tf_sample_dict:
                    tf_sample=tf_sample_dict[layer_id]
                else:
                    tf_sample = sample_link.sample_sequence_ISMC_threshold(
                        Dis=Dis,
                        layer=layer_id, 
                        hparams=hparams,
                        length=args.seq_len,
                        context=context,
                        batch_size=args.batch_size,
                        temperature=1.0,
                        top_k=args.top_k,
                        top_p=args.top_p,
                        start_token=enc.encoder['<|endoftext|>'])
                    tf_sample_dict[layer_id]=tf_sample
                
                sample = data_sampler.sample(args.batch_size)[0][:,0:1]
                
                out = sess.run(
                        tf_sample,
                        feed_dict={context: sample})[:,:args.seq_len]
                for i in range(len(out)):
                    flag=0
                    for j in range(len(out[i])):
                        if flag==2:
                            out[i][j]=start_token
                            continue
                        if out[i][j]==start_token:
                            flag+=1
                return out

        def validation():
            print('Calculating validation loss...')
            start_time=time.time()
            losses = []
            rates=[]
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss_mean, feed_dict={val_context: batch[0], val_context_len: batch[1]}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict={val_loss_mean: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'
                .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))
            return v_val_loss

        def validation_cut(NLL_bias_0=0):
            print('Calculating validation loss...')
            losses = []
            rates=[]
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss_cut_mean, feed_dict={val_context: batch[0], val_context_len: batch[1], NLL_bias:NLL_bias_0}))
            v_val_loss = np.mean(losses)
            print(
                '[{counter} | {time:2.2f}] validation cut loss = {loss:2.2f}'
                .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))
            return v_val_loss

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]
        
        def train_gpt():
            val_loss_old=10000.0
            avg_loss = (0.0, 0.0)
            start_time = time.time()
            counter=0
            while True:
                #pretraining
                if counter % args.save_every == 0:
                    pass
                    #save()
                if counter % args.sample_every == 0:
                    pass
                    #generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    val_loss_1=validation()
                    print(str(counter //args.val_every))
                    if val_loss_1>=val_loss_old:
                        print('pre-training ends!')
                        break
                    else:
                        val_loss_old=val_loss_1
                        saver.save(sess, args.gpt_save_path+'a')
                        print('save succeed!')

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        batch, batch_len=data_sampler.sample(args.batch_size)
                        sess.run(
                            opt_compute, feed_dict={context: batch, context_len:batch_len})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    batch, batch_len=data_sampler.sample(args.batch_size)
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: batch, context_len:batch_len})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.9 + v_loss,
                            avg_loss[1] * 0.9 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        class log_writer:
            def __init__(self, path):
                self.path=path
                with open(path, 'w') as g:
                    g.write('')
            def __call__(self, string, verbose=False):
                with open(self.path, 'a') as g:
                    g.write(string+'\n')
                if verbose:
                    print(string)
        
        try:
            if args.finetune:
                #Finetune GPT-2
                train_gpt() 
            if True:
                #Restore Finetuned model
                save_path=tf.train.latest_checkpoint(args.gpt_save_path)
                saver.restore(sess, save_path)
                print('Load gpt2 succeeded!')
            if args.evaluate_finetune:
                #Evaluate finetuning baseline
                print(validation())
            if args.evaluate_finetune:
                #Calculate reverse-ppl for finetuning baseline
                sample_path=args.gpt_sample_dir2+'sample.txt'
                generate_discri_sample3(layer_id=0, sample_size=3000, save_path=sample_path)
                rev_ppl=train.file_f(train_data_path=sample_path, val_data_path=args.eval_data_path)
                Log_writer=log_writer(args.log_path+'finetune')
                Log_writer('finetuning_rev_ppl: {}'.format(rev_ppl), verbose=True)
            ##Begin tailoring
            if True:
                Log_writer=log_writer(args.log_path+'discri')
                for layer in range(args.layer_num):
                    print(layer)
                    if args.train_tailor:
                        #Train ratio estimator
                        train_discri(500, 10, [layer])
                    if True:
                        #Restore ratio estimator
                        for layer_id in range(layer+1):
                            save_path=args.dis_save_path+str(layer_id)+'/'
                            print(save_path)
                            save_path=tf.train.latest_checkpoint(save_path)
                            print(save_path)
                            Dis.model[layer_id]['saver_discri'].restore(sess, save_path)
                    if False:
                        #Save quantile for analysis
                        with open(args.dis_sample_dir2+'quantile.pkl', 'rb') as f:
                            pkl.load(f)
                        print('Load dis model succeeded!')
                    if True:
                        if layer==0:
                            quantile=0.85
                        else:
                            quantile=0.9
                        Dis.dis[layer]=get_discri_quantile(layer, quantile)
                        with open(args.dis_sample_dir2+'quantile.pkl', 'wb') as g:
                            pkl.dump(Dis.dis, g)
                        print(Dis.dis)
                    if args.evaluate_tailor:
                        #Generate sample for ERS and calculate reverse-ppl
                        sample_path=args.dis_sample_dir2+'_sample_layer_'+str(layer)
                        generate_discri_sample3(layer_id=layer+1, sample_size=3000, save_path=sample_path)
                        rev_ppl=train.file_f(train_data_path=sample_path, val_data_path=args.eval_data_path)
                        Log_writer('layer: {}, dis_rev_ppl: {}'.format(layer, rev_ppl), verbose=True)
        except KeyboardInterrupt:
            print('interrupted')
Пример #6
0
def main():
    args = parser.parse_args()
    folder_id = get_id(args.gdir)
    #xmpp = SendMsgBot(jid, password, to, "Starting GPT-2")
    #xmpp.register_plugin('xep_0030') # Service Discovery
    #xmpp.register_plugin('xep_0199') # XMPP Ping
    #xmpp.connect()
    #threading = Thread(target=xmpp.process, daemon=True).start()
    download_checkpoint(folder_id)
    #send_m('checkpoint downloaded')
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        # if args.optimizer == 'adam':
            # args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(os.path.join(CHECKPOINT_DIR, args.run_name))
        saver = tf.train.Saver(var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        #send_m('Loading  ' + str(ckpt))
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        #send_m('Loading dataset...')
        #chunks = load_dataset(enc, args.dataset, args.combine)
        ds_path = f'{CHECKPOINT_DIR}//run1//{args.dataset}'
        chunks = load_dataset(enc, ds_path, args.combine)
        data_sampler = Sampler(chunks)
        print(f'{ds_path} has', data_sampler.total_size, 'tokens')
        if args.val_every > 0:
            val_chunks = load_dataset(enc, args.val_dataset, args.combine) if args.val_dataset else chunks
        if args.enc:
            print(colored(f'Trying writing Data.npz encoded from this dataset to {args.enc}', 'red'))
            np.savez_compressed(args.enc, *chunks)
            upload_npz(args.enc, folder_id)
        #send_m(f'{args.dataset} has ' + str(data_sampler.total_size) + ' tokens' + '     Start training...')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
            save_gdisk(counter, folder_id)

        def generate_samples():
            print('Generating samples...')
            #send_m('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            #send_m(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'
                    .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()
        last_time = time.time()
        cur_counter, min_loss = 1, 2.0
        print(colored(f'Model  >>> {args.gdir}\nLearning rate is {args.learning_rate}', 'blue'))
        print(colored(f'model optimizer >>> {args.optimizer}\nRestricted to train only transformer layer={args.only_train_transformer_layers}', 'blue'))
        #send_m(f'Model  >>> {args.model_name}\nLearning rate is {args.learning_rate}')
        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                    if check_quota():
                        return() # exit train
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute, feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run((opt_apply, loss, summaries), feed_dict={context: sample_batch()})
                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)
                a_loss = avg_loss[0] / avg_loss[1]
                time_all = int((time.time() - start_time) / 60)
                time_iter = time.time() - last_time
                stats = f'[{counter} | {cur_counter} | {time_all}m | {time_iter:2.2f}s] loss={v_loss:2.2f} avg={a_loss:2.2f}'
                if not(cur_counter % 50):
                    print(colored(stats, 'red' if a_loss > min_loss else 'yellow'))
                    if a_loss < min_loss:
                        min_loss = a_loss
                    #send_m(stats)
                last_time = time.time()
                counter += 1
                cur_counter += 1
        except Exception as e:
            #send_m('Stoped  ' + str(e.__class__))
            print('Stoped', e.__class__)
Пример #7
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    with tf.Session() as sess:
        # Fully static shape required to make memory accounting in
        # twremat accurate.
        train_context = tf.placeholder(tf.int32, [args.batch_size, 1024])
        train_context_in = randomize(train_context, hparams, args.noise)
        train_output = model.model(hparams=hparams, X=train_context_in)
        train_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=train_context[:, 1:],
                logits=train_output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        sample_context = tf.placeholder(tf.int32, [args.batch_size, None])
        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=sample_context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=args.top_k,
                                           top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            print('Using Adam optimizer', file=sys.stderr)
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            print('Using SGD optimizer', file=sys.stderr)
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.memory_saving_gradients:
            if tf.VERSION >= '2':
                exit(
                    'Memory saving gradients are not supported in tensorflow 2.x'
                )
            import memory_saving_gradients
            opt_grads = memory_saving_gradients.gradients(
                train_loss, train_vars)
        elif args.twremat:
            import tfremat
            opt_grads = tf.gradients(train_loss, train_vars)
            (train_loss, opt_grads) = tfremat.tf_remat(
                (train_loss, opt_grads), memlimit=args.twremat_memlimit)
        else:
            opt_grads = tf.gradients(train_loss, train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.summary.scalar('loss', train_loss)

        # if args.twremat:
        #     import tfremat
        #     # Applying tfremat to opt_apply has more accurate
        #     # accounting but is a bit iffier since side effecting ops
        #     # have more restrictions for correctness. If in doubt
        #     # revert back to version using opt_grads above.
        #     (opt_apply, train_loss, summary_loss) = (
        #         tfremat.tf_remat((opt_apply, train_loss, summary_loss), memlimit=args.twremat_memlimit))

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc,
                              args.dataset,
                              args.combine,
                              encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc,
                                          args.val_dataset,
                                          args.combine,
                                          encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(tf_sample,
                               feed_dict={
                                   sample_context:
                                   args.batch_size * [context_tokens]
                               })
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(os.path.join(SAMPLE_DIR, args.run_name,
                                   'samples-{}').format(counter),
                      'w',
                      encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        # print('Evaluating grads..')
        # tf2.profiler.experimental.start('logdir')
        # sess.run((opt_apply, train_loss, summaries), feed_dict={train_context: sample_batch()})
        # tf2.profiler.experimental.stop()
        # print('Succeeded')
        # exit()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, train_loss, summaries),
                    feed_dict={train_context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
Пример #8
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='124M',
             model_dir='models',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             checkpoint_dir='checkpoint',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             multi_gpu=False,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             optimizer='adam',
             overwrite=False,
             val_dataset=None,
             val_batch_size=2,
             val_batch_count=40,
             val_every=0):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    # assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger."

    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        try:
            shutil.copyfile(os.path.join(model_dir, model_name, file),
                            os.path.join(checkpoint_path, file))
        except FileNotFoundError as fnf_error:
            print(
                "You need to download the GPT-2 model first via download_gpt2()"
            )
            raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if model_name not in ['117M', '124M']:
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    gpus = []

    if multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    # validation code
    if val_every > 0:
        val_context = tf.placeholder(tf.int32, [val_batch_size, None])
        val_output = model.model(hparams=hparams, X=val_context,
                                 reuse=True)  # added reuse=True
        val_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=val_context[:,
                                   1:], logits=val_output['logits'][:, :-1]))
        val_loss_summary = tf.summary.scalar('val_loss', val_loss)

    tf_sample = sample.sample_sequence(hparams=hparams,
                                       length=sample_length,
                                       context=context,
                                       batch_size=batch_size,
                                       temperature=1.0,
                                       top_k=40)

    all_vars = [
        v for v in tf.compat.v1.trainable_variables() if 'model' in v.name
    ]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars

    if optimizer == 'adam':
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == 'sgd':
        opt = tf.compat.v1.train.GradientDescentOptimizer(
            learning_rate=learning_rate)

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar('loss', loss)

    summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path)

    saver = tf.compat.v1.train.Saver(var_list=all_vars,
                                     max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)

    # validation code
    if val_every > 0:
        if val_dataset:
            val_chunks = load_dataset(enc, val_dataset, combine)
        else:
            val_chunks = chunks

    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')

    # validation code
    if val_every > 0:
        # Sample from validation set once with fixed seed to make
        # it deterministic during training as well as across runs.
        val_data_sampler = Sampler(val_chunks, seed=1)
        val_batches = [[
            val_data_sampler.sample(1024) for _ in range(val_batch_size)
        ] for _ in range(val_batch_count)]

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print('Saving',
              os.path.join(checkpoint_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = '======== SAMPLE {} ========\n{}\n'.format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))

    # validation code
    def validation():
        print('Calculating validation loss...')
        losses = []
        for batch in tqdm(val_batches):
            losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
        v_val_loss = np.mean(losses)
        v_summary = sess.run(val_loss_summary,
                             feed_dict={val_loss: v_val_loss})
        summary_log.add_summary(v_summary, counter)
        summary_log.flush()
        print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.format(
            counter=counter, time=time.time() - start_time, loss=v_val_loss))
        return v_val_loss

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    #Trying out a change to finetune that saves only when validation loss decreases
    if steps:
        steps = int(steps)

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                #save()
                return
            # if (counter - 1) % save_every == 0 and counter > 1:
            #     save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()

            # validation code
            if val_every > 0 and counter == 1:
                v_val_loss = validation()
                save()
            elif val_every > 0 and counter == counter_base:
                v_val_loss = validation()
            elif val_every > 0 and (counter % val_every == 0):
                new_v_val_loss = validation()
                if new_v_val_loss < v_val_loss:
                    v_val_loss = new_v_val_loss
                    save()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if (counter % print_every == 0) or counter == 1:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()
Пример #9
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    if args.model_name == '774M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        #this line is to hopefully reduce memory usage (found on Twitter: https://twitter.com/BasedBlue/status/1169601983046672385?s=20)
        edgeindex = -1 * args.layers_to_train
        train_vars = all_vars[edgeindex:]
        print("Training", args.layers_to_train, "raw layers out of", len(all_vars))

        train_vars = [v for v in train_vars if '/h' in v.name] if args.only_train_transformer_layers else train_vars
        print("Training", len(train_vars), "net layers out of", len(all_vars))

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'adafactor':
            opt = AdafactorOptimizer(learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(
            var_list=all_vars,
            max_to_keep=5,
            keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w', encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def sample_batch():
            ret = [data_sampler.sample(1024) for _ in range(args.batch_size)]
            # print (enc.decode(ret[0]))
            return ret


        avg_loss = (0.0, 0.0)
        bval_loss = (0.0, 0.0)
        start_time = time.time()
        best_val_loss = 99
        missed_val_checkpoints = 0

        try:
            while counter < args.stop_after:
                if counter % args.sample_every == 0:
                    generate_samples()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(
                            opt_compute, feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.98 + v_loss,
                            avg_loss[1] * 0.98 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

                if args.val_every > 0 and counter % args.val_every == 0:
                    valbatch = [val_data_sampler.sample(1024) for _ in range(args.batch_size)]
                    valacc = sess.run(loss, feed_dict={context: valbatch})
                    bval_loss = (bval_loss[0] * 0.9 + valacc, bval_loss[1] * 0.9 + 1.0)
                    av_val_loss = bval_loss[0] / bval_loss[1]
                    av_train_loss = avg_loss[0] / avg_loss[1]
                    print(
                        '[{counter} | {time:2.2f}] VAL_loss={loss:2.4f} VAL_avg={avg:2.4f} best={best:2.4f}'
                        .format(
                            counter=counter,
                            time=time.time() - start_time,
                            loss=valacc,
                            avg=av_val_loss,
                            best=best_val_loss))
                    if counter >= args.save_every and counter % args.save_every == 0: # check for validation checkpoints every save_every iterations.
                        if av_val_loss < best_val_loss and av_val_loss > av_train_loss: # got a good one from validation, save a checkpoint (every save_every) -- but don't save before val loss goes above train loss
                            save()
                            best_val_loss = av_val_loss
                            missed_val_checkpoints = 0
                        else: # missed a validation checkpoint. tolerate like 10 of these.
                            if av_val_loss > av_train_loss: # don't count a missed checkpoint while val loss is under training loss
                                missed_val_checkpoints += 1
                    if missed_val_checkpoints > 19: # missed too many save opportunities, stop training
                        counter = args.stop_after + 1
                        print('stopping training due to val loss not improving.')

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')