def test_dual_chain_rewrite():
  """Runs regular chain gradient, makes sure memory usage makes sense."""


  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 5
  nodes1 = make_chain_tanh_constant(n, "a")
  nodes2 = make_chain_tanh_constant(n, "b")

  a0,b0 = nodes1[0], nodes2[0]
  a, b = nodes1[-1], nodes2[-1]

  grad = memory_saving_gradients.gradients([a+b], [a0, b0],
                                           checkpoints=[nodes1[2], nodes2[2]])

  sess = create_session()
  sessrun(tf.global_variables_initializer())

  sessrun([grad[0].op, grad[1].op])

  peak_memory = cpu_peak()
  # normal usage comes from 2*n nodes + default ygrad node + 2 gradient nodes
  # here we save two 2 units of memory by dropping 2 activations (a1/b1) temporarily
  # also, this moves "peak memory" scenario lower down the chain
  # where the final addition node activations are no longer needed (another -1)
  expected_peak = (2*(n-1)+1)*10**6 
  util.report_memory(peak_memory, expected_peak)

  # since two independent chains, some variability in node scheduling
  # allow 1MB slack
  if not REMOVE_ASSERTS:
    assert (peak_memory - expected_peak) < 4.1e6, "Difference too large."
def test_chain_rewrite(linearize=False):
  """Take chain of length 5, save 2 nodes, make sure 2 units of RAM is
  saved."""

  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 5

  a0, a1, a2, a3, a4 = make_chain_tanh(n)
  grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a1,a3])[0]
  expected_peak = (n+1-2)*10**6  # subtract 2 since we recompute 2

  sess = create_session()
  sessrun(tf.global_variables_initializer())

  sessrun(grad.op)
  if linearize:
    linearize_lib.linearize()

  peak_memory = cpu_peak()
  util.report_memory(peak_memory, expected_peak)

  if not REMOVE_ASSERTS:
    assert (peak_memory - expected_peak) < 1e6+10000, "Difference too large."
def test_resnet_rewrite(linearize=False):
  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 6

  nodes = make_resnet(n)
  a0 = nodes[0]
  a = nodes[-1]

  checkpoints = [nodes[3], nodes[5]] # ['a03_add:0', 'a05_add:0']
  grad = memory_saving_gradients.gradients([a], [a0], checkpoints=[nodes[2]])[0]
  if linearize:
    added = linearize_lib.linearize(grad.op)


  sess = create_session()
  sessrun(tf.global_variables_initializer())

  sessrun(grad.op)


  peak_memory = cpu_peak()
  # 1 for activation of each tanh node + 1 for initial backprop node
  # + 1 temporary memory for computing the adds,
  # -1 for discarding, then recomputing a1_tanh
  expected_peak = (n-1)*10**6 
  util.report_memory(peak_memory, expected_peak)

  if not REMOVE_ASSERTS:
    assert (peak_memory - expected_peak) < 1.1*10**6, "Difference too large."
Ejemplo n.º 4
0
def test_resnet_rewrite(linearize=False):
    tf.reset_default_graph()
    tf_dev = tf.device('/cpu:0')
    tf_dev.__enter__()

    n = 6

    nodes = make_resnet(n)
    a0 = nodes[0]
    a = nodes[-1]

    checkpoints = [nodes[3], nodes[5]]  # ['a03_add:0', 'a05_add:0']
    grad = memory_saving_gradients.gradients([a], [a0],
                                             checkpoints=[nodes[2]])[0]
    if linearize:
        added = linearize_lib.linearize(grad.op)

    sess = create_session()
    sessrun(tf.global_variables_initializer())

    sessrun(grad.op)

    peak_memory = cpu_peak()
    # 1 for activation of each tanh node + 1 for initial backprop node
    # + 1 temporary memory for computing the adds,
    # -1 for discarding, then recomputing a1_tanh
    expected_peak = (n - 1) * 10**6
    util.report_memory(peak_memory, expected_peak)

    if not REMOVE_ASSERTS:
        assert (peak_memory -
                expected_peak) < 1.1 * 10**6, "Difference too large."
Ejemplo n.º 5
0
def test_chain_rewrite(linearize=False):
    """Take chain of length 5, save 2 nodes, make sure 2 units of RAM is
  saved."""

    tf.reset_default_graph()
    tf_dev = tf.device('/cpu:0')
    tf_dev.__enter__()

    n = 5

    a0, a1, a2, a3, a4 = make_chain_tanh(n)
    grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a1,
                                                                      a3])[0]
    expected_peak = (n + 1 - 2) * 10**6  # subtract 2 since we recompute 2

    sess = create_session()
    sessrun(tf.global_variables_initializer())

    sessrun(grad.op)
    if linearize:
        linearize_lib.linearize()

    peak_memory = cpu_peak()
    util.report_memory(peak_memory, expected_peak)

    if not REMOVE_ASSERTS:
        assert (peak_memory -
                expected_peak) < 1e6 + 10000, "Difference too large."
Ejemplo n.º 6
0
def test_dual_chain_rewrite():
    """Runs regular chain gradient, makes sure memory usage makes sense."""

    tf.reset_default_graph()
    tf_dev = tf.device('/cpu:0')
    tf_dev.__enter__()

    n = 5
    nodes1 = make_chain_tanh_constant(n, "a")
    nodes2 = make_chain_tanh_constant(n, "b")

    a0, b0 = nodes1[0], nodes2[0]
    a, b = nodes1[-1], nodes2[-1]

    grad = memory_saving_gradients.gradients(
        [a + b], [a0, b0], checkpoints=[nodes1[2], nodes2[2]])

    sess = create_session()
    sessrun(tf.global_variables_initializer())

    sessrun([grad[0].op, grad[1].op])

    peak_memory = cpu_peak()
    # normal usage comes from 2*n nodes + default ygrad node + 2 gradient nodes
    # here we save two 2 units of memory by dropping 2 activations (a1/b1) temporarily
    # also, this moves "peak memory" scenario lower down the chain
    # where the final addition node activations are no longer needed (another -1)
    expected_peak = (2 * (n - 1) + 1) * 10**6
    util.report_memory(peak_memory, expected_peak)

    # since two independent chains, some variability in node scheduling
    # allow 1MB slack
    if not REMOVE_ASSERTS:
        assert (peak_memory - expected_peak) < 4.1e6, "Difference too large."
def test_chain_rewrite_save_last():
  """Take chain of length 5, save last node. This saved no memory, and is 
  and edge case that should raise exception by rewriter."""

  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 5

  a0, a1, a2, a3, a4 = make_chain_tanh(n)
  try:
      grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a4])[0]
  except Exception:
      return
  else:
    if not REMOVE_ASSERTS:
      assert "Should've been 'no checkpoints nodes found' exception"
Ejemplo n.º 8
0
def test_chain_rewrite_save_last():
  """Take chain of length 5, save last node. This saved no memory, and is 
  and edge case that should raise exception by rewriter."""

  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 5

  a0, a1, a2, a3, a4 = make_chain_tanh(n)
  try:
      grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a4])[0]
  except Exception:
      return
  else:
    if not REMOVE_ASSERTS:
      assert "Should've been 'no checkpoints nodes found' exception"
def test_chain_rewrite_save_first():
  """Take chain of length 5, save first node."""

  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 5

  a0, a1, a2, a3, a4 = make_chain_tanh_constant(n)
  grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a1, a3])[0]
  expected_peak = (n+1-2)*10**6 

  sess = create_session()
  sessrun(tf.global_variables_initializer())

  sessrun(grad.op)

  peak_memory = cpu_peak()
  util.report_memory(peak_memory, expected_peak)

  if not REMOVE_ASSERTS:
    assert (peak_memory - expected_peak) < 1.1e6, "Difference too large."
Ejemplo n.º 10
0
def test_chain_rewrite_save_one_before_last():
    """Take chain of length 5, save first node."""

    tf.reset_default_graph()
    tf_dev = tf.device('/cpu:0')
    tf_dev.__enter__()

    n = 5

    a0, a1, a2, a3, a4 = make_chain_tanh_constant(n)
    grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a2])[0]
    expected_peak = (n + 1 - 2) * 10**6

    sess = create_session()
    sessrun(tf.global_variables_initializer())

    sessrun(grad.op)

    peak_memory = cpu_peak()
    util.report_memory(peak_memory, expected_peak)

    if not REMOVE_ASSERTS:
        assert (peak_memory - expected_peak) < 1.1e6, "Difference too large."
Ejemplo n.º 11
0
 def gradients_collection(ys, xs, grad_ys=None, **kwargs):
     return memory_saving_gradients.gradients(ys,
                                              xs,
                                              grad_ys,
                                              checkpoints='collection',
                                              **kwargs)
Ejemplo n.º 12
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    if args.model_name == '774M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        #this line is to hopefully reduce memory usage (found on Twitter: https://twitter.com/BasedBlue/status/1169601983046672385?s=20)
        edgeindex = -1 * args.layers_to_train
        train_vars = all_vars[edgeindex:]
        print("Training", args.layers_to_train, "raw layers out of", len(all_vars))

        train_vars = [v for v in train_vars if '/h' in v.name] if args.only_train_transformer_layers else train_vars
        print("Training", len(train_vars), "net layers out of", len(all_vars))

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'adafactor':
            opt = AdafactorOptimizer(learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(
            var_list=all_vars,
            max_to_keep=5,
            keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w', encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def sample_batch():
            ret = [data_sampler.sample(1024) for _ in range(args.batch_size)]
            # print (enc.decode(ret[0]))
            return ret


        avg_loss = (0.0, 0.0)
        bval_loss = (0.0, 0.0)
        start_time = time.time()
        best_val_loss = 99
        missed_val_checkpoints = 0

        try:
            while counter < args.stop_after:
                if counter % args.sample_every == 0:
                    generate_samples()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(
                            opt_compute, feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.98 + v_loss,
                            avg_loss[1] * 0.98 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

                if args.val_every > 0 and counter % args.val_every == 0:
                    valbatch = [val_data_sampler.sample(1024) for _ in range(args.batch_size)]
                    valacc = sess.run(loss, feed_dict={context: valbatch})
                    bval_loss = (bval_loss[0] * 0.9 + valacc, bval_loss[1] * 0.9 + 1.0)
                    av_val_loss = bval_loss[0] / bval_loss[1]
                    av_train_loss = avg_loss[0] / avg_loss[1]
                    print(
                        '[{counter} | {time:2.2f}] VAL_loss={loss:2.4f} VAL_avg={avg:2.4f} best={best:2.4f}'
                        .format(
                            counter=counter,
                            time=time.time() - start_time,
                            loss=valacc,
                            avg=av_val_loss,
                            best=best_val_loss))
                    if counter >= args.save_every and counter % args.save_every == 0: # check for validation checkpoints every save_every iterations.
                        if av_val_loss < best_val_loss and av_val_loss > av_train_loss: # got a good one from validation, save a checkpoint (every save_every) -- but don't save before val loss goes above train loss
                            save()
                            best_val_loss = av_val_loss
                            missed_val_checkpoints = 0
                        else: # missed a validation checkpoint. tolerate like 10 of these.
                            if av_val_loss > av_train_loss: # don't count a missed checkpoint while val loss is under training loss
                                missed_val_checkpoints += 1
                    if missed_val_checkpoints > 19: # missed too many save opportunities, stop training
                        counter = args.stop_after + 1
                        print('stopping training due to val loss not improving.')

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
 def grads(ys, xs, grad_ys=None, **kwargs):
   return memory_saving_gradients.gradients(ys, xs, grad_ys,
                                            checkpoints='speed', **kwargs)
Ejemplo n.º 14
0
def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss):

    # == Create class with static fields and methods
    class m(object):
        pass
    m.sess = sess
    m.feeds = feeds
    m.lr = lr

    # === Loss and optimizer
    loss_train, stats_train = f_loss(train_iterator, True)
    all_params = tf.trainable_variables()
    if hps.gradient_checkpointing == 1:
        from memory_saving_gradients import gradients
        gs = gradients(loss_train, all_params)
    else:
        gs = tf.gradients(loss_train, all_params)

    optimizer = {'adam': optim.adam, 'adamax': optim.adamax,
                 'adam2': optim.adam2}[hps.optimizer]

    train_op, polyak_swap_op, ema = optimizer(
        all_params, gs, alpha=lr, hps=hps)
    if hps.direct_iterator:
        m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1]
    else:
        def _train(_lr):
            _x, _y = train_iterator()
            return sess.run([train_op, stats_train], {feeds['x']: _x,
                                                      feeds['y']: _y, lr: _lr})[1]
        m.train = _train

    m.polyak_swap = lambda: sess.run(polyak_swap_op)

    # === Testing
    loss_test, stats_test = f_loss(test_iterator, False, reuse=True)
    if hps.direct_iterator:
        m.test = lambda: sess.run(stats_test)
    else:
        def _test():
            _x, _y = test_iterator()
            return sess.run(stats_test, {feeds['x']: _x,
                                         feeds['y']: _y})
        m.test = _test

    # === Saving and restoring
    saver = tf.train.Saver()
    saver_ema = tf.train.Saver(ema.variables_to_restore())
    m.save_ema = lambda path: saver_ema.save(
        sess, path, write_meta_graph=False)
    m.save = lambda path: saver.save(sess, path, write_meta_graph=False)
    m.restore = lambda path: saver.restore(sess, path)

    # === Initialize the parameters
    if hps.restore_path != '':
        m.restore(hps.restore_path)
    else:
        with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
            results_init = f_loss(None, True, reuse=True)
        sess.run(tf.global_variables_initializer())
        sess.run(results_init, {feeds['x']: data_init['x'],
                                feeds['y']: data_init['y']})
    sess.run(hvd.broadcast_global_variables(0))

    return m
Ejemplo n.º 15
0
def model_fn(features, labels, mode, params):

    del labels

    cfg = params['cfg']
    model = models.model(cfg)
    y = features['y']

    if mode == tf.estimator.ModeKeys.PREDICT:
        ###########
        # PREDICT #
        ###########
        predictions = {'generated_images': model.sample(y, temp=0.75)}
        return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                              predictions=predictions)

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    real_images = features['real_images']

    f_loss, eps = model.f_loss(real_images, y)

    if mode == tf.estimator.ModeKeys.TRAIN:
        #########
        # TRAIN #
        #########

        f_loss = tf.reduce_mean(f_loss)

        with tf.variable_scope('Regularization'):
            for v in tf.trainable_variables():
                if 'invconv' in v.name:
                    det = tf.matrix_determinant(v * tf.transpose(v))
                    f_loss += tf.square(det - 1)

            if cfg.use_l2_regularization:
                for v in tf.trainable_variables():
                    if 'actnorm' not in v.name:
                        f_loss += cfg.l2_regularization_factor * tf.nn.l2_loss(
                            v)

        if not cfg.use_tpu and cfg.report_histograms:
            for v in tf.trainable_variables():
                tf.summary.histogram(v.name.replace(':', '_'), v)

        global_step = tf.train.get_or_create_global_step()
        rate = tf.minimum(tf.cast(global_step, tf.float32) / 2000.0, 1.0)
        #lr = int(real_images.get_shape()[0]) * cfg.lr
        lr = cfg.lr * rate
        #from AMSGrad import AMSGrad
        optimizer = tf.train.AdamOptimizer(learning_rate=lr,
                                           beta1=cfg.beta1,
                                           epsilon=cfg.adam_eps)

        tf.summary.scalar('lr', lr)

        if cfg.use_tpu:
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            with tf.variable_scope('TrainOps'):
                if cfg.memory_saving_gradients:
                    from memory_saving_gradients import gradients
                    gs = gradients(f_loss, tf.trainable_variables())
                else:
                    gs = tf.gradients(f_loss, tf.trainable_variables())
                if cfg.use_gradient_clipping:
                    gs = [tf.clip_by_value(g, -100., 100.) for g in gs]
                grads_and_vars = list(zip(gs, tf.trainable_variables()))
                train_op = optimizer.apply_gradients(grads_and_vars)
                increment_step = tf.assign_add(
                    tf.train.get_or_create_global_step(), 1)
                joint_op = tf.group([train_op, increment_step])

            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  loss=f_loss,
                                                  train_op=joint_op)

    elif mode == tf.estimator.ModeKeys.EVAL:
        ########
        # EVAL #
        ########
        def _eval_metric_fn(f_loss):
            return {'f_loss': tf.metrics.mean(f_loss)}

        return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                              loss=tf.reduce_mean(f_loss),
                                              eval_metrics=(_eval_metric_fn,
                                                            [f_loss]))

    raise ValueError('Invalid mode provided to model_fn')
Ejemplo n.º 16
0
def train_main(dataset,
               model_name='1250M',
               seed=None,
               msg=True,
               batch_size=16,
               learning_rate=0.00002,
               sample_length=512,
               sample_num=1,
               sample_every=100,
               run_name='run1',
               restore_from='latest',
               save_every=1000,
               combine=50000):

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
        print('n_ctx: ', hparams.n_ctx, 'n_head: ', hparams.n_head, 'n_embd: ',
              hparams.n_embd, 'n_layer: ', hparams.n_layer)

    if sample_length is None:
        sample_length = hparams.n_ctx
    elif sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    # TF config

    config = tf.ConfigProto()
    #device_map = { 0:2, 0:3, 1:2, 1:3 }
    #config.gpu_options.visible_device_list = str(device_map[hvd.rank()])
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.allow_growth = True

    global_step = tf.Variable(0, trainable=False)

    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=0.9,
                                           top_k=40)

        #global_step = tf.Variable(0, trainable=False)
        counter = 1

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        #opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        # l4rz 11/10/2019
        decayed_lr = tf.train.exponential_decay(learning_rate,
                                                global_step,
                                                200,
                                                0.999,
                                                staircase=True)
        opt = tf.train.AdamOptimizer(decayed_lr)
        #opt = tf.train.GradientDescentOptimizer(decayed_lr)
        opt = hvd.DistributedOptimizer(opt)
        # this is original horovod
        #train_op = opt.minimize(loss, var_list=train_vars)
        # this is ours
        if (msg):
            print('Using memory saving gradients')
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            train_op = opt.apply_gradients(opt_grads, global_step=global_step)
        else:
            print('Not using memory saving gradients')
            #train_op = opt.minimize(loss, var_list=train_vars)
            # l4rz 11/10
            train_op = opt.minimize(loss,
                                    var_list=train_vars,
                                    global_step=global_step)
        # [1,2]<stderr>:TypeError: apply_gradients() missing 1 required positional argument: 'grads_and_vars'
        #summary_loss = tf.summary.scalar('loss', train_op)

        #_, lv = sess.run((train_op, loss), feed_dict={context: batch})

        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        print('Running hvd.broadcast_global_variables')
        bcast = hvd.broadcast_global_variables(0)
        print('Done')

        saver = tf.train.Saver(var_list=train_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)

        print('Running global_variables_initializer')
        sess.run(tf.global_variables_initializer())
        print('Done')

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
        # comment out when running for 1st time
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print(str(hvd.local_rank()), 'Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        # uncomment when running for first time INIT THE MODEL
        #print('tf.global_variables_initializer()')
        #sess.run(tf.global_variables_initializer())

        bcast.run()

        print(str(hvd.local_rank()), 'Loading dataset...')
        chunks = load_dataset(enc, dataset, combine)
        data_sampler = Sampler(chunks)
        print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size,
              'tokens')
        print(str(hvd.local_rank()), 'Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                       global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:

                batch = [data_sampler.sample(1024) for _ in range(batch_size)]

                _, lv = sess.run((train_op, loss), feed_dict={context: batch})

                avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0)

                if hvd.rank() == 0:
                    if counter % save_every == 0:
                        save()
                    if counter % sample_every == 0:
                        generate_samples()

                    print(
                        '[{counter} | {time:2.2f}] loss={loss:2.4f} avg={avg:2.4f} lr={lr:.2e}'
                        .format(counter=counter,
                                time=time.time() - start_time,
                                loss=lv,
                                avg=avg_loss[0] / avg_loss[1],
                                lr=decayed_lr.eval()))

                counter += 1

        except KeyboardInterrupt:
            print('interrupted')
            if hvd.rank() == 0:
                save()
Ejemplo n.º 17
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if args.model_name == '355M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=args.top_k,
                                           top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc,
                              args.dataset,
                              args.combine,
                              encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc,
                                          args.val_dataset,
                                          args.combine,
                                          encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text.encode('utf8'))
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(os.path.join(SAMPLE_DIR, args.run_name,
                                   'samples-{}').format(counter),
                      'w',
                      encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
Ejemplo n.º 18
0
 def gradients_auto(ys, xs, grad_ys=None, **kwargs):
     return memory_saving_gradients.gradients(ys,
                                              xs,
                                              grad_ys,
                                              checkpoints='memory',
                                              **kwargs)
Ejemplo n.º 19
0
def main():
    args = parser.parse_args()
    try:
        logdir = os.path.join(CHECKPOINT_DIR, args.run_name)
        with open('logdir.txt', 'w') as z:
            z.write(logdir)
    except:
        pass
    enc = get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join(model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=40)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars
        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(
                opt=tf.train.AdamOptimizer(learning_rate=args.learning_rate),
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(os.path.join(model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(os.path.join(model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        #print('Loading dataset...')
        #chunks = load_dataset(enc, args.dataset, args.combine)
        #data_sampler = Sampler(chunks)
        print('Loading train dataset...')
        from_name, ques_name, to_name = name_parts(args.dataset)

        trn_chunks_from = load_dataset(
            enc, from_name, args.combine)  #if args.dataset else chunks
        #trn_chunks_ques = load_dataset(enc, ques_name, args.combine) if args.dataset else chunks
        trn_chunks_to = load_dataset(
            enc, to_name, args.combine)  #if args.dataset else chunks

        skip_delimeter = True
        char = '\t'
        trn_data_sampler_from = SamplerVal(trn_chunks_from,
                                           enc,
                                           char=char,
                                           skip_delimeter=skip_delimeter)
        #trn_data_sampler_ques = SamplerVal(trn_chunks_ques, enc, char=char, skip_delimeter=skip_delimeter)
        trn_data_sampler_to = SamplerVal(trn_chunks_to,
                                         enc,
                                         char=char,
                                         skip_delimeter=skip_delimeter)

        len_v = 0
        data_sampler = []
        for i in range(trn_data_sampler_from.total_size):
            v = (
                #enc.encode('\nQ: ') +
                trn_data_sampler_from.get(i) +
                #enc.encode('. \nA: ') +
                trn_data_sampler_to.get(i)  #  +
                #enc.encode('. ')
            )

            v = v[:HIDDEN_SIZE - 1]
            len_v += len(v)
            #data_sampler.extend(v) ##
            data_sampler.append(v)
            pass

        if len_v < HIDDEN_SIZE:
            mult = HIDDEN_SIZE // len_v + 1
            for i in range(mult):
                x = data_sampler[:]
                data_sampler.extend(x)
            data_sampler = Sampler([np.array(data_sampler)])

        #if not args.train_special and len_v >= HIDDEN_SIZE:
        #    data_sampler = Sampler([np.array(data_sampler)])

        if args.val_every > 0 and False:
            val_chunks = load_dataset(
                enc, args.val_dataset,
                args.combine) if args.val_dataset else chunks
        if not isinstance(data_sampler, list):
            print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

            #print(model_name, 'mn')
            GPT2_DIR_X = model_name
            cd = CHECKPOINT_DIR + "/" + args.run_name
            if not os.path.isfile(cd + '/' + 'encoder.json'):
                os.system("cp " + GPT2_DIR_X + '/' + 'encoder.json ' + cd +
                          '/.')
                os.system('cp ' + GPT2_DIR_X + "/" + 'vocab.bpe ' + cd + '/.')

        def generate_samples():
            print('Generating samples...')
            #context_tokens = data_sampler.sample(1)
            #context_tokens = data_sampler[0]
            context_tokens = trn_data_sampler_from.get(
                random.randint(0, trn_data_sampler_from.total_size))
            #print(enc.decode(context_tokens), len(context_tokens))
            #print(args.batch_size * [context_tokens])

            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            #z = [data_sampler.sample(1024) for _ in range(args.batch_size)]
            #print(len(data_sampler))
            #print(len(data_sampler[0]))
            z = [data_sampler[random.randint(0, args.batch_size)]]
            #print(enc.decode(z[0]))
            #print(z[1],'\n1' ,z[2],'\n2' ,z[3] ,len(data_sampler[0]))
            #exit()
            return z

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while counter != args.stop_after:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                    pass
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summary_loss),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('\ninterrupted')

        finally:
            save()
Ejemplo n.º 20
0
def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator,
                      data_init, lr, f_loss):

    # == Create class with static fields and methods
    class m(object):
        pass

    m.sess = sess
    m.feeds = feeds
    m.lr = lr

    # === Loss and optimizer
    loss_train, stats_train = f_loss(train_iterator, True)
    all_params = tf.trainable_variables()
    if hps.gradient_checkpointing == 1:
        from memory_saving_gradients import gradients
        gs = gradients(loss_train, all_params)
    else:
        gs = tf.gradients(loss_train, all_params)

    optimizer = {
        'adam': optim.adam,
        'adamax': optim.adamax,
        'adam2': optim.adam2
    }[hps.optimizer]

    train_op, polyak_swap_op, ema = optimizer(all_params,
                                              gs,
                                              alpha=lr,
                                              hps=hps)
    if hps.direct_iterator:
        m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1]
    else:

        def _train(_lr):
            _x, _y = train_iterator()
            return sess.run([train_op, stats_train], {
                feeds['x']: _x,
                feeds['y']: _y,
                lr: _lr
            })[1]

        m.train = _train

    m.polyak_swap = lambda: sess.run(polyak_swap_op)

    # === Testing
    loss_test, stats_test = f_loss(test_iterator, False, reuse=True)
    if hps.direct_iterator:
        m.test = lambda: sess.run(stats_test)
    else:

        def _test():
            _x, _y = test_iterator()
            return sess.run(stats_test, {feeds['x']: _x, feeds['y']: _y})

        m.test = _test

    # === Saving and restoring
    saver = tf.train.Saver()
    saver_ema = tf.train.Saver(ema.variables_to_restore())
    m.save_ema = lambda path: saver_ema.save(
        sess, path, write_meta_graph=False)
    m.save = lambda path: saver.save(sess, path, write_meta_graph=False)
    m.restore = lambda path: saver.restore(sess, path)

    # === Initialize the parameters
    if hps.restore_path != '':
        m.restore(hps.restore_path)
    else:
        with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
            results_init = f_loss(None, True, reuse=True)
        sess.run(tf.global_variables_initializer())
        sess.run(results_init, {
            feeds['x']: data_init['x'],
            feeds['y']: data_init['y']
        })
    # sess.run(hvd.broadcast_global_variables(0))

    return m
Ejemplo n.º 21
0
def train_main(dataset,
               valset,
               model_name='774M',
               seed=None,
               batch_size=1,
               batch_length=1024,
               sample_length=1023,
               sample_num=1,
               sample_every=100,
               run_name='run1',
               restore_from='latest',
               stop_after=None,
               learning_rate=0.001,
               beta1=0.9,
               beta2=0.999,
               epsilon=1e-08,
               save_every=1000,
               layers_to_train=144):

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length is None:
        sample_length = hparams.n_ctx // 2
    elif sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=1.0,
                                           top_k=40)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        #this line is to hopefully reduce memory usage (found on Twitter: https://twitter.com/BasedBlue/status/1169601983046672385?s=20)
        train_vars = all_vars[-layers_to_train:]
        print("Training", layers_to_train, "layers out of", len(all_vars))

        decay_rate = adafactor_decay_rate_adam(beta2)
        opt = AdafactorOptimizer(learning_rate=learning_rate,
                                 decay_rate=decay_rate,
                                 beta1=beta1,
                                 name="Adafactor")
        opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.summary.scalar('loss', loss)

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc, dataset)
        data_sampler = Sampler(chunks)
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        print('Loading valset...')
        val_chunks = load_dataset(enc, valset)
        val_data_sampler = Sampler(val_chunks)
        print('valset has', val_data_sampler.total_size, 'tokens')
        print('Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                       global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        val_loss = (0.0, 0.0)
        start_time = time.time()
        best_val_loss = 99
        missed_val_checkpoints = 0

        try:
            while counter < stop_after:
                #if counter % save_every == 0:
                #    save()
                if counter % sample_every == 0:
                    generate_samples()

                batch = [
                    data_sampler.sample(batch_length)
                    for _ in range(batch_size)
                ]

                _, lv = sess.run((opt_apply, loss), feed_dict={context: batch})

                avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.4f} avg={avg:2.4f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=lv,
                            avg=avg_loss[0] / avg_loss[1]))

                if counter % 5 == 0:
                    valbatch = [
                        val_data_sampler.sample(batch_length)
                        for _ in range(batch_size)
                    ]
                    valacc = sess.run(loss, feed_dict={context: valbatch})
                    val_loss = (val_loss[0] * 0.99 + valacc,
                                val_loss[1] * 0.99 + 1.0)
                    av_val_loss = val_loss[0] / val_loss[1]
                    print(
                        '[{counter} | {time:2.2f}] VAL_loss={loss:2.4f} VAL_avg={avg:2.4f} best={best:2.4f}'
                        .format(counter=counter,
                                time=time.time() - start_time,
                                loss=valacc,
                                avg=av_val_loss,
                                best=best_val_loss))
                    if counter >= save_every and counter % save_every == 0:  # check for validation checkpoints every save_every iterations.
                        if av_val_loss < best_val_loss:  # got a good one from validation, save a checkpoint (every save_every)
                            save()
                            best_val_loss = av_val_loss
                            missed_val_checkpoints = 0
                        else:  # missed a validation checkpoint. tolerate like 10 of these.
                            missed_val_checkpoints += 1
                    if missed_val_checkpoints > 9:  # missed too many save opportunities, stop training
                        counter = stop_after + 1
                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
 def gradients_collection(ys, xs, grad_ys=None, **kwargs):
   return memory_saving_gradients.gradients(ys, xs, grad_ys,
                                            checkpoints='collection', **kwargs)
Ejemplo n.º 23
0
        def compute_gradients(optimizer,
                              loss,
                              var_list=None,
                              gate_gradients=Optimizer.GATE_OP,
                              aggregation_method=None,
                              colocate_gradients_with_ops=False,
                              grad_loss=None):
            if callable(loss):
                from tensorflow.python.eager import backprop
                with backprop.GradientTape() as tape:
                    if var_list is not None:
                        tape.watch(var_list)
                    loss_value = loss()

                    # Scale loss if using a "mean" loss reduction and multiple towers.
                    # Have to be careful to call distribute_lib.get_loss_reduction()
                    # *after* loss() is evaluated, so we know what loss reduction it uses.
                    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
                    if (distribute_lib.get_loss_reduction() ==
                            variable_scope.VariableAggregation.MEAN):
                        num_towers = distribution_strategy_context.get_distribution_strategy(
                        ).num_towers
                        if num_towers > 1:
                            loss_value *= (1. / num_towers)

                if var_list is None:
                    var_list = tape.watched_variables()
                # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
                # to be executed.
                with ops.control_dependencies([loss_value]):
                    grads = tape.gradient(loss_value, var_list, grad_loss)
                return list(zip(grads, var_list))

            # Non-callable/Tensor loss case
            if context.executing_eagerly():
                raise RuntimeError(
                    "`loss` passed to Optimizer.compute_gradients should "
                    "be a function when eager execution is enabled.")

            # Scale loss if using a "mean" loss reduction and multiple towers.
            if (distribute_lib.get_loss_reduction() ==
                    variable_scope.VariableAggregation.MEAN):
                num_towers = distribution_strategy_context.get_distribution_strategy(
                ).num_towers
                if num_towers > 1:
                    loss *= (1. / num_towers)

            if gate_gradients not in [
                    Optimizer.GATE_NONE, Optimizer.GATE_OP,
                    Optimizer.GATE_GRAPH
            ]:
                raise ValueError(
                    "gate_gradients must be one of: Optimizer.GATE_NONE, "
                    "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
                    gate_gradients)
            optimizer._assert_valid_dtypes([loss])
            if grad_loss is not None:
                optimizer._assert_valid_dtypes([grad_loss])
            if var_list is None:
                var_list = (variables.trainable_variables() +
                            ops.get_collection(
                                ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
            else:
                var_list = nest.flatten(var_list)
            # pylint: disable=protected-access
            var_list += ops.get_collection(
                ops.GraphKeys._STREAMING_MODEL_PORTS)
            # pylint: enable=protected-access
            from tensorflow.python.training.optimizer import _get_processor
            processors = [_get_processor(v) for v in var_list]
            if not var_list:
                raise ValueError("No variables to optimize.")
            var_refs = [p.target() for p in processors]
            # original gradients computation
            # grads = tf.gradients(
            #     loss, var_refs, grad_ys=grad_loss,
            #     gate_gradients=(gate_gradients == Optimizer.GATE_OP),
            #     aggregation_method=aggregation_method,
            #     colocate_gradients_with_ops=colocate_gradients_with_ops)
            # using gradient check-pointing
            from memory_saving_gradients import gradients
            # setting outputs of different networks
            tensors_to_checkpoint = self.get_tensors_to_checkpoint()

            # just specifying memory as parameter fails
            grads = gradients(
                loss,
                var_refs,
                grad_ys=grad_loss,
                gate_gradients=(gate_gradients == Optimizer.GATE_OP),
                aggregation_method=aggregation_method,
                colocate_gradients_with_ops=colocate_gradients_with_ops,
                checkpoints='speed')

            if gate_gradients == Optimizer.GATE_GRAPH:
                grads = control_flow_ops.tuple(grads)
            grads_and_vars = list(zip(grads, var_list))
            optimizer._assert_valid_dtypes([
                v for g, v in grads_and_vars
                if g is not None and v.dtype != dtypes.resource
            ])
            return grads_and_vars
Ejemplo n.º 24
0
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, prune_config_flag):
  """Creates an optimizer training op."""
  global_step = tf.train.get_or_create_global_step()

  learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)

  # Implements linear decay of the learning rate.
  learning_rate = tf.train.polynomial_decay(
      learning_rate,
      global_step,
      num_train_steps,
      end_learning_rate=0.0,
      power=1.0,
      cycle=False)

  # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
  # learning rate will be `global_step/num_warmup_steps * init_lr`.
  if num_warmup_steps:
    global_steps_int = tf.cast(global_step, tf.int32)
    warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

    global_steps_float = tf.cast(global_steps_int, tf.float32)
    warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

    warmup_percent_done = global_steps_float / warmup_steps_float
    warmup_learning_rate = init_lr * warmup_percent_done

    is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
    learning_rate = (
        (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)

  # It is recommended that you use this optimizer for fine tuning, since this
  # is how the model was trained (note that the Adam m/v variables are NOT
  # loaded from init_checkpoint.)
  optimizer = AdamWeightDecayOptimizer(
      learning_rate=learning_rate,
      weight_decay_rate=0.01,
      beta_1=0.9,
      beta_2=0.999,
      epsilon=1e-6,
      exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

  if use_tpu:
    optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

  # memory_saving_gradients.DEBUG_LOGGING = True
  tvars = tf.trainable_variables()
  if os.getenv('DISABLE_GRAD_CHECKPOINT'):
    grads = tf.gradients(loss, tvars)
  else:
    grads = memory_saving_gradients.gradients(loss, tvars, checkpoints='memory')

  # This is how the model was pre-trained.
  (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

  train_op = optimizer.apply_gradients(
      zip(grads, tvars), global_step=global_step)

  # Pruning mask update ops
  if prune_config_flag:
    tf.logging.info(f'Pruning with configs {prune_config_flag}')
    prune_config =  get_pruning_hparams().parse(prune_config_flag)
    prune = Pruning(prune_config, global_step=global_step)
    mask_update_op = prune.conditional_mask_update_op()
    prune.add_pruning_summaries()
  else:
    tf.logging.info('No pruning config provided, skipping pruning')
    mask_update_op = tf.no_op()

  # Normally the global step update is done inside of `apply_gradients`.
  # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
  # a different optimizer, you should probably take this line out.
  new_global_step = global_step + 1
  train_op = tf.group(train_op, mask_update_op, [global_step.assign(new_global_step)])
  return train_op
Ejemplo n.º 25
0
def main():
    
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    hparams.batch_size=args.batch_size
    hparams.seq_len=args.seq_len
    
    ##data_path
    args.train_data_path=args.data_dir+args.dataset+'/train.txt'
    args.eval_data_path=args.data_dir+args.dataset+'/dev.txt'
    args.test_data_path=args.data_dir+args.dataset+'/test.txt'
    args.eval_data_path=args.test_data_path                          ###Test mode only!
    args.gpt_save_path=args.gpt_save_dir+args.dataset+'/'
    args.dis_save_path=args.dis_save_dir+args.dataset+'/'
    
    args.gpt_sample_dir2=args.gpt_sample_dir+args.dataset+'/'
    args.dis_sample_dir2=args.dis_sample_dir+args.dataset+'/'
    
    args.log_path=args.log_dir+args.dataset+'/'
    maketree(args.gpt_save_dir)
    maketree(args.dis_save_dir)
    maketree(args.gpt_save_path)
    maketree(args.dis_save_path)
    maketree(args.gpt_sample_dir)
    maketree(args.dis_sample_dir)
    maketree(args.gpt_sample_dir2)
    maketree(args.dis_sample_dir2)
    
    maketree(args.log_dir)
    maketree(args.log_path)
    
    
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        scope_discri='distri'
        
        def get_dis_logit_and_prob_single_step(context, scope):
            with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
                context=tf.reshape(context, [-1, args.seq_len])
                emb=tf.get_variable(name='emb', initializer=tf.random.normal([hparams.n_vocab, 32], 0, 0.02))
                context_emb=tf.nn.embedding_lookup(emb, context)
                logit=dis(context_emb, scope=scope_discri)
                prob=tf.sigmoid(logit+1e-7)
            return logit, prob
        
        def get_dis_logit_and_prob(context, context_len, scope):
            ##Pay attention to context_len here. temporary changes!!!!!!!!!!!!!!!!!!!
            context_mask=(1-tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32))*1e3
            context_mask2=tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32)
            ones=tf.ones(shape=[tf.shape(context_len)[0], args.seq_len], dtype=tf.int32)*enc.encoder['<|endoftext|>']
            input_tensor_list=[]
            for i in range(1, args.seq_len):
                input_tensor_list.append(tf.concat([context[:, :i+1], ones[:,i+1:]], axis=1))
            input_tensor=tf.concat(input_tensor_list, axis=0)
            log_prob, _=get_dis_logit_and_prob_single_step(input_tensor, scope=scope)
            log_prob=tf.transpose(tf.reshape(log_prob, [args.seq_len-1, -1]))
            log_prob+=tf.cast(context_mask, tf.float32)
            log_prob_min=tf.reduce_min(log_prob, axis=1)
            prob_min=tf.exp(log_prob_min)
            return log_prob_min, prob_min, log_prob
        ##Build discriminator
        
        def build_dis_layer(scope):
            context_pos_discri = tf.placeholder(tf.int32, [None, args.seq_len])
            context_pos_discri_len = tf.placeholder(tf.int32, [None])
            context_neg_discri = tf.placeholder(tf.int32, [None, args.seq_len])
            context_neg_discri_len = tf.placeholder(tf.int32, [None])
            
            label_pos_discri=tf.ones([tf.shape(context_pos_discri_len)[0]], dtype=tf.float32)
            label_neg_discri=tf.zeros([tf.shape(context_neg_discri_len)[0]], dtype=tf.float32)
            logit_pos_discri, prob_pos_discri, mask=get_dis_logit_and_prob(context_pos_discri, context_pos_discri_len, scope=scope)
            logit_neg_discri, _, _=get_dis_logit_and_prob(context_neg_discri, context_neg_discri_len, scope=scope)
        
            loss_pre_pos_discri=tf.nn.sigmoid_cross_entropy_with_logits(labels=label_pos_discri, logits=logit_pos_discri)
            loss_pos_discri=tf.reduce_mean(loss_pre_pos_discri)
            loss_pre_neg_discri=tf.nn.sigmoid_cross_entropy_with_logits(labels=label_neg_discri, logits=logit_neg_discri)
            loss_neg_discri=tf.reduce_mean(loss_pre_neg_discri)
            loss_discri=(loss_pos_discri*args.pos_loss_weight+loss_neg_discri)/(1+args.pos_loss_weight)
        
            train_var_list_discri=[x for x in tf.global_variables() if scope in  x.name]
            train_op_discri=tf.train.AdamOptimizer().minimize(loss_discri, var_list=train_var_list_discri)
            var_list_discri=[x for x in tf.global_variables() if scope in  x.name]
            initializer_discri=tf.variables_initializer(var_list_discri)
            saver_discri=tf.train.Saver(var_list=var_list_discri, max_to_keep=1)
            print('discri: {} build succeed!'.format(scope))
            return context_pos_discri,context_pos_discri_len, context_neg_discri,context_neg_discri_len, loss_pos_discri, loss_neg_discri, loss_discri, train_op_discri, initializer_discri, saver_discri, prob_pos_discri, mask, logit_pos_discri
        
        class dis_class:
            def __init__(self, layer_num=1, scope=scope_discri):
                self.model=[]
                self.dis=np.zeros([layer_num], dtype=np.float32)
                print(layer_num)
                for i in range(layer_num):
                    layer={'scope': scope+str(i)}
                    layer['context_pos_discri'],layer['context_pos_discri_len'], layer['context_neg_discri'],layer['context_neg_discri_len'], layer['loss_pos_discri'], layer['loss_neg_discri'], layer['loss_discri'], layer['train_op_discri'], layer['initializer_discri'], layer['saver_discri'], layer['prob_pos_discri'], layer['mask'], layer['logit_pos_discri'] = build_dis_layer(scope+str(i))
                    self.model.append(layer)
            def prob(self, context, context_len, layer=-1):
                if layer==-1:
                    layer=len(self.model)
                prob_final=tf.ones(tf.shape(context)[0], dtype=tf.float32)
                for i in range(layer):
                    item=self.model[i]
                    scope=item['scope']
                    _, prob, _=get_dis_logit_and_prob(context, context_len, scope=scope)
                    prob_final*=prob
                return prob_final
            def log_prob_step(self, context, layer=-1):
                if layer==-1:
                    layer=len(self.model)
                prob_final=tf.ones(tf.shape(context)[0], dtype=tf.float32)
                log_prob_list=[]
                for i in range(layer):
                    item=self.model[i]
                    scope=item['scope']
                    log_prob, prob=get_dis_logit_and_prob_single_step(context, scope=scope)
                    log_prob_list.append(tf.expand_dims(log_prob, 1))
                log_prob_final=tf.concat(log_prob_list, axis=1)
                return log_prob_final
        
        Dis=dis_class(layer_num=args.layer_num)
        
        context = tf.placeholder(tf.int32, [None, None])
        context_len=tf.placeholder(tf.int32, [None])
        context_mask=tf.sequence_mask(context_len-1, args.seq_len-1, dtype=tf.float32)
        context_in=context
        output = model.model(hparams=hparams, X=context_in)
        loss_tensor = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=context[:, 1:], logits=output['logits'][:, :-1])*context_mask
        
        loss=tf.reduce_sum(loss_tensor, axis=1)/(tf.reduce_sum(context_mask, axis=1)+1e-7)
        loss_sen=tf.reduce_sum(loss)
        loss=tf.reduce_mean(loss)
        
        
        if args.val_every > 0:
            def transform_np(x, lift=args.exponential_param):
                x=x-0.5
                x=x+np.abs(x)
                return lift*x**2
            def transform(x, lift=args.exponential_param):
                x=x-0.5
                x=x+tf.abs(x)
                return lift*x**2
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, args.seq_len])
            val_context_len=tf.placeholder(tf.int32, [args.batch_size])
            NLL_bias=tf.placeholder(tf.float32, [])
            val_context_mask=tf.sequence_mask(val_context_len-1, args.seq_len-1, dtype=tf.float32)
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss_tensor =tf.nn.sparse_softmax_cross_entropy_with_logits(labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])*val_context_mask
            val_context_prob_cut=Dis.prob(val_context, val_context_len)
            val_NLL_cut=tf.log(val_context_prob_cut+1e-7)
            
            val_loss=tf.reduce_sum(val_loss_tensor, axis=1)/(tf.reduce_sum(val_context_mask, axis=1)+1e-7)
            val_loss_cut=(tf.reduce_sum(val_loss_tensor, axis=1)+NLL_bias)/(tf.reduce_sum(val_context_mask, axis=1)+1e-7)-val_NLL_cut/tf.cast(val_context_len, tf.float32)
            
            val_loss_sum=tf.reduce_sum(val_loss_tensor, axis=1)
            val_loss_cut_sum=(tf.reduce_sum(val_loss_tensor, axis=1)+NLL_bias)-val_NLL_cut
            
            val_loss_mean=tf.reduce_mean(val_loss)
            val_loss_cut_mean=tf.reduce_mean(val_loss_cut)
            val_loss_summary = tf.summary.scalar('val_loss', val_loss_mean)


        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.seq_len,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p,
            start_token=enc.encoder['<|endoftext|>'])

        start_token=enc.encoder['<|endoftext|>']

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars, max_to_keep=1)
        
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        data_list, data_len = load_dataset(enc, args.train_data_path, args.seq_len)
        data_sampler = Sampler(data_list, data_len )
        if args.val_every > 0:
            val_data_list, val_data_len = load_dataset(enc, args.eval_data_path, args.seq_len)
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_data_list, val_data_len, seed=1)
            val_batches = [val_data_sampler.sample(args.batch_size) for _ in range(args.val_batch_count)]

        counter = 0
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
        
        
        def train_step_discri(layer_id=0, mask_train_epoch=0):
            pos_samples, pos_samples_len=data_sampler.sample(args.batch_size)
            neg_samples=generate_negative_sample(layer_id=layer_id)
            neg_samples_len=get_array_len(neg_samples)
            _, loss=sess.run([Dis.model[layer_id]['train_op_discri'], Dis.model[layer_id]['loss_discri']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples,Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len, Dis.model[layer_id]['context_neg_discri']: neg_samples, Dis.model[layer_id]['context_neg_discri_len']: neg_samples_len})
            return loss
        
        def generate_negative_samples(layer_id, generate_num=args.batch_size):
            result_list=[]
            generate_num_now=0
            samples_mem=[]
            while generate_num_now<generate_num:
                t=time.time()
                sample_id=generate_negative_sample(layer_id=layer_id)
                samples=[]
                t1=time.time()
                selected_id_list=np.arange(len(sample_id))
                t2=time.time()
                result_list.append(sample_id[selected_id_list])
                generate_num_now+=len(selected_id_list)
            return np.concatenate(result_list, axis=0)[:generate_num]
        
        def get_array_len(sample_array):
            lens=[]
            for item in sample_array:
                for i in range(1, len(item)):
                    if item[i]==enc.encoder['<|endoftext|>']:
                        break
                lens.append(i)
            return np.array(lens).astype(np.int32)
        
        def generate_discri_sample3(layer_id=-1, sample_size=10000, save_path='/mnt/cephfs_new_wj/mlnlp/miaoning/Experiment/gpt-2-sep/samples/discri/sample2.txt'):
            samples=[]
            while len(samples)<sample_size:
                sample_id=generate_negative_sample(layer_id)
                for i in range(len(sample_id)):
                    sample_tem=enc.decode(sample_id[i]).split('<|endoftext|>')[1].split('\n')[0]
                    samples.append(sample_tem)
                print(len(samples))
            with open(save_path, 'w') as g:
                g.write('\n'.join(samples))
        
        
        def eval_discri_NLL(layer_id=0):
            losses_pos=[]
            losses_neg=[]
            for batch in tqdm.tqdm(val_batches):
                pos_samples, pos_samples_len=batch
                neg_samples=generate_negative_sample(layer_id=layer_id)
                neg_samples_len=get_array_len(neg_samples)
                loss_pos, mask=sess.run([Dis.model[layer_id]['loss_pos_discri'], Dis.model[layer_id]['mask']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples, Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len})
                #print(mask)
                loss_neg=sess.run(Dis.model[layer_id]['loss_neg_discri'], feed_dict={Dis.model[layer_id]['context_neg_discri']: neg_samples, Dis.model[layer_id]['context_neg_discri_len']: neg_samples_len})
                losses_pos.append(loss_pos)
                losses_neg.append(loss_neg)
            return np.mean(losses_pos), np.mean(losses_neg)
        
        def get_discri_quantile(layer_id=0, quantile=0.85):
            logits_list=[]
            for batch in tqdm.tqdm(val_batches):
                pos_samples, pos_samples_len=batch
                logits, mask=sess.run([Dis.model[layer_id]['logit_pos_discri'], Dis.model[layer_id]['mask']], feed_dict={Dis.model[layer_id]['context_pos_discri']: pos_samples, Dis.model[layer_id]['context_pos_discri_len']: pos_samples_len})
                print(np.min(mask, axis=1)[:10])
                print(logits[:10])
                with open('mask.pkl', 'wb') as g:
                    pkl.dump(mask, g)
                logits_list.extend(list(logits))
                break
            with open('logits.pkl', 'wb') as g:
                pkl.dump(sorted(logits_list), g)
            #print(sorted(logits_list))
            print('finish')
            return sorted(logits_list)[int(len(logits_list)*(1-quantile))]
        
        def train_discri(train_step, eval_every, train_layer_list=list(range(len(Dis.model)))):
            #sess.run(initializer_discri)
            print('Start Discri training')
            train_losses=[]
            for layer_id in train_layer_list:
                flag=0
                for epoch in range(train_step):
                    if epoch % eval_every==0:
                        train_losses=np.mean(train_losses)
                        train_losses=[]
                    
                        eval_NLL_pos, eval_NLL_neg=eval_discri_NLL(layer_id)
                        eval_loss=(eval_NLL_pos*args.pos_loss_weight+eval_NLL_neg)/(args.pos_loss_weight+1)
                        print('layer_id:{} discri eval loss:{}'.format(layer_id, eval_loss))
                        print('layer_id:{} discri NLL pos: {}, discri NLL neg: {}'.format(layer_id, eval_NLL_pos, eval_NLL_neg))
                        print(epoch)
                        if epoch==0:
                            eval_loss_old=eval_loss
                        else:
                            print(eval_loss, eval_loss_old)
                            if eval_loss<eval_loss_old:
                                eval_loss_old=eval_loss
                                save_path=args.dis_save_path+str(layer_id)+'/'
                                if not os.path.isdir(save_path):
                                    os.mkdir(save_path)
                                Dis.model[layer_id]['saver_discri'].save(sess, save_path+'a')
                                print('model discri saved!')
                                flag=0
                            else:
                                if epoch>=200:
                                    flag+=1
                            if flag>=4:
                                break
                    train_loss=train_step_discri(layer_id)
                    print('layer_id:{} discri train loss:{}'.format(layer_id, train_loss))
                    train_losses.append(train_loss)
            return eval_loss_old
        
        tf_sample_0 = sample_link.sample_sequence(
                    hparams=hparams,
                    length=args.seq_len,
                    context=context,
                    batch_size=args.batch_size,
                    temperature=1.0,
                    top_k=args.top_k,
                    top_p=args.top_p,
                    start_token=enc.encoder['<|endoftext|>'])
        tf_sample_dict={}
        
        def generate_negative_sample(layer_id=0):
            ##output the filtered result of layer layer_id-1
            if layer_id==0:
                tf_sample=tf_sample_0
                sample = data_sampler.sample(args.batch_size)[0][:,0:1]
                out = sess.run(
                        tf_sample,
                        feed_dict={context: sample})[:,:args.seq_len]
                for i in range(len(out)):
                    flag=0
                    for j in range(len(out[i])):
                        if flag==2:
                            out[i][j]=start_token
                            continue
                        if out[i][j]==start_token:
                            flag+=1
                return out
            else:
                if layer_id==-1:
                    layer_id=len(Dis.model)
                if layer_id in tf_sample_dict:
                    tf_sample=tf_sample_dict[layer_id]
                else:
                    tf_sample = sample_link.sample_sequence_ISMC_threshold(
                        Dis=Dis,
                        layer=layer_id, 
                        hparams=hparams,
                        length=args.seq_len,
                        context=context,
                        batch_size=args.batch_size,
                        temperature=1.0,
                        top_k=args.top_k,
                        top_p=args.top_p,
                        start_token=enc.encoder['<|endoftext|>'])
                    tf_sample_dict[layer_id]=tf_sample
                
                sample = data_sampler.sample(args.batch_size)[0][:,0:1]
                
                out = sess.run(
                        tf_sample,
                        feed_dict={context: sample})[:,:args.seq_len]
                for i in range(len(out)):
                    flag=0
                    for j in range(len(out[i])):
                        if flag==2:
                            out[i][j]=start_token
                            continue
                        if out[i][j]==start_token:
                            flag+=1
                return out

        def validation():
            print('Calculating validation loss...')
            start_time=time.time()
            losses = []
            rates=[]
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss_mean, feed_dict={val_context: batch[0], val_context_len: batch[1]}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict={val_loss_mean: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'
                .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))
            return v_val_loss

        def validation_cut(NLL_bias_0=0):
            print('Calculating validation loss...')
            losses = []
            rates=[]
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss_cut_mean, feed_dict={val_context: batch[0], val_context_len: batch[1], NLL_bias:NLL_bias_0}))
            v_val_loss = np.mean(losses)
            print(
                '[{counter} | {time:2.2f}] validation cut loss = {loss:2.2f}'
                .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))
            return v_val_loss

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]
        
        def train_gpt():
            val_loss_old=10000.0
            avg_loss = (0.0, 0.0)
            start_time = time.time()
            counter=0
            while True:
                #pretraining
                if counter % args.save_every == 0:
                    pass
                    #save()
                if counter % args.sample_every == 0:
                    pass
                    #generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    val_loss_1=validation()
                    print(str(counter //args.val_every))
                    if val_loss_1>=val_loss_old:
                        print('pre-training ends!')
                        break
                    else:
                        val_loss_old=val_loss_1
                        saver.save(sess, args.gpt_save_path+'a')
                        print('save succeed!')

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        batch, batch_len=data_sampler.sample(args.batch_size)
                        sess.run(
                            opt_compute, feed_dict={context: batch, context_len:batch_len})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    batch, batch_len=data_sampler.sample(args.batch_size)
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: batch, context_len:batch_len})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.9 + v_loss,
                            avg_loss[1] * 0.9 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        class log_writer:
            def __init__(self, path):
                self.path=path
                with open(path, 'w') as g:
                    g.write('')
            def __call__(self, string, verbose=False):
                with open(self.path, 'a') as g:
                    g.write(string+'\n')
                if verbose:
                    print(string)
        
        try:
            if args.finetune:
                #Finetune GPT-2
                train_gpt() 
            if True:
                #Restore Finetuned model
                save_path=tf.train.latest_checkpoint(args.gpt_save_path)
                saver.restore(sess, save_path)
                print('Load gpt2 succeeded!')
            if args.evaluate_finetune:
                #Evaluate finetuning baseline
                print(validation())
            if args.evaluate_finetune:
                #Calculate reverse-ppl for finetuning baseline
                sample_path=args.gpt_sample_dir2+'sample.txt'
                generate_discri_sample3(layer_id=0, sample_size=3000, save_path=sample_path)
                rev_ppl=train.file_f(train_data_path=sample_path, val_data_path=args.eval_data_path)
                Log_writer=log_writer(args.log_path+'finetune')
                Log_writer('finetuning_rev_ppl: {}'.format(rev_ppl), verbose=True)
            ##Begin tailoring
            if True:
                Log_writer=log_writer(args.log_path+'discri')
                for layer in range(args.layer_num):
                    print(layer)
                    if args.train_tailor:
                        #Train ratio estimator
                        train_discri(500, 10, [layer])
                    if True:
                        #Restore ratio estimator
                        for layer_id in range(layer+1):
                            save_path=args.dis_save_path+str(layer_id)+'/'
                            print(save_path)
                            save_path=tf.train.latest_checkpoint(save_path)
                            print(save_path)
                            Dis.model[layer_id]['saver_discri'].restore(sess, save_path)
                    if False:
                        #Save quantile for analysis
                        with open(args.dis_sample_dir2+'quantile.pkl', 'rb') as f:
                            pkl.load(f)
                        print('Load dis model succeeded!')
                    if True:
                        if layer==0:
                            quantile=0.85
                        else:
                            quantile=0.9
                        Dis.dis[layer]=get_discri_quantile(layer, quantile)
                        with open(args.dis_sample_dir2+'quantile.pkl', 'wb') as g:
                            pkl.dump(Dis.dis, g)
                        print(Dis.dis)
                    if args.evaluate_tailor:
                        #Generate sample for ERS and calculate reverse-ppl
                        sample_path=args.dis_sample_dir2+'_sample_layer_'+str(layer)
                        generate_discri_sample3(layer_id=layer+1, sample_size=3000, save_path=sample_path)
                        rev_ppl=train.file_f(train_data_path=sample_path, val_data_path=args.eval_data_path)
                        Log_writer('layer: {}, dis_rev_ppl: {}'.format(layer, rev_ppl), verbose=True)
        except KeyboardInterrupt:
            print('interrupted')
Ejemplo n.º 26
0
def finetune(sess,
             dataset,
             steps=-1,
             model_name='124M',
             model_dir='models',
             combine=50000,
             batch_size=1,
             learning_rate=0.0001,
             accumulate_gradients=5,
             restore_from='latest',
             run_name='run1',
             checkpoint_dir='checkpoint',
             sample_every=100,
             sample_length=1023,
             sample_num=1,
             multi_gpu=False,
             save_every=1000,
             print_every=1,
             max_checkpoints=1,
             use_memory_saving_gradients=False,
             only_train_transformer_layers=False,
             optimizer='adam',
             overwrite=False,
             val_dataset=None,
             val_batch_size=2,
             val_batch_count=40,
             val_every=0):
    """Finetunes the model on the given dataset.

    Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
    See that file for parameter definitions.
    """

    # assert model_name not in ['774M', '1558M'] or multi_gpu, "Currently, a modern single GPU cannot finetune the 774M GPT-2 model or larger."

    SAMPLE_DIR = 'samples'

    checkpoint_path = os.path.join(checkpoint_dir, run_name)

    def maketree(path):
        try:
            os.makedirs(path)
        except:
            pass

    maketree(checkpoint_path)
    files = [f for f in os.listdir(checkpoint_path)]
    for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
        try:
            shutil.copyfile(os.path.join(model_dir, model_name, file),
                            os.path.join(checkpoint_path, file))
        except FileNotFoundError as fnf_error:
            print(
                "You need to download the GPT-2 model first via download_gpt2()"
            )
            raise (fnf_error)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if model_name not in ['117M', '124M']:
        use_memory_saving_gradients = True
        only_train_transformer_layers = True
        accumulate_gradients = 1

    context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
    gpus = []

    if multi_gpu:
        gpus = get_available_gpus()

    output = model.model(hparams=hparams, X=context, gpus=gpus)
    loss = tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:], logits=output['logits'][:, :-1]))

    # validation code
    if val_every > 0:
        val_context = tf.placeholder(tf.int32, [val_batch_size, None])
        val_output = model.model(hparams=hparams, X=val_context,
                                 reuse=True)  # added reuse=True
        val_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=val_context[:,
                                   1:], logits=val_output['logits'][:, :-1]))
        val_loss_summary = tf.summary.scalar('val_loss', val_loss)

    tf_sample = sample.sample_sequence(hparams=hparams,
                                       length=sample_length,
                                       context=context,
                                       batch_size=batch_size,
                                       temperature=1.0,
                                       top_k=40)

    all_vars = [
        v for v in tf.compat.v1.trainable_variables() if 'model' in v.name
    ]
    train_vars = [v for v in all_vars if '/h' in v.name
                  ] if only_train_transformer_layers else all_vars

    if optimizer == 'adam':
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer == 'sgd':
        opt = tf.compat.v1.train.GradientDescentOptimizer(
            learning_rate=learning_rate)

    if accumulate_gradients > 1:
        if use_memory_saving_gradients:
            exit(
                "Memory saving gradients are not implemented for gradient accumulation yet."
            )
        opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
        summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
    else:
        if use_memory_saving_gradients:
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
        else:
            opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar('loss', loss)

    summary_log = tf.compat.v1.summary.FileWriter(checkpoint_path)

    saver = tf.compat.v1.train.Saver(var_list=all_vars,
                                     max_to_keep=max_checkpoints)
    sess.run(tf.compat.v1.global_variables_initializer())

    if restore_from == 'latest':
        ckpt = tf.train.latest_checkpoint(checkpoint_path)
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join(model_dir, model_name))
    elif restore_from == 'fresh':
        ckpt = tf.train.latest_checkpoint(os.path.join(model_dir, model_name))
    else:
        ckpt = tf.train.latest_checkpoint(restore_from)
    print('Loading checkpoint', ckpt)
    saver.restore(sess, ckpt)

    print('Loading dataset...')
    chunks = load_dataset(enc, dataset, combine)
    data_sampler = Sampler(chunks)

    # validation code
    if val_every > 0:
        if val_dataset:
            val_chunks = load_dataset(enc, val_dataset, combine)
        else:
            val_chunks = chunks

    print('dataset has', data_sampler.total_size, 'tokens')
    print('Training...')

    # validation code
    if val_every > 0:
        # Sample from validation set once with fixed seed to make
        # it deterministic during training as well as across runs.
        val_data_sampler = Sampler(val_chunks, seed=1)
        val_batches = [[
            val_data_sampler.sample(1024) for _ in range(val_batch_size)
        ] for _ in range(val_batch_count)]

    counter = 1
    counter_path = os.path.join(checkpoint_path, 'counter')
    if os.path.exists(counter_path) and restore_from == 'latest':
        # Load the step number if we're resuming a run
        # Add 1 so we don't immediately try to save again
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        maketree(checkpoint_path)
        print('Saving',
              os.path.join(checkpoint_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(checkpoint_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def generate_samples():
        context_tokens = data_sampler.sample(1)
        all_text = []
        index = 0
        while index < sample_num:
            out = sess.run(tf_sample,
                           feed_dict={context: batch_size * [context_tokens]})
            for i in range(min(sample_num - index, batch_size)):
                text = enc.decode(out[i])
                text = '======== SAMPLE {} ========\n{}\n'.format(
                    index + 1, text)
                all_text.append(text)
                index += 1
        print(text)
        maketree(os.path.join(SAMPLE_DIR, run_name))
        with open(
                os.path.join(SAMPLE_DIR, run_name,
                             'samples-{}').format(counter), 'w') as fp:
            fp.write('\n'.join(all_text))

    # validation code
    def validation():
        print('Calculating validation loss...')
        losses = []
        for batch in tqdm(val_batches):
            losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
        v_val_loss = np.mean(losses)
        v_summary = sess.run(val_loss_summary,
                             feed_dict={val_loss: v_val_loss})
        summary_log.add_summary(v_summary, counter)
        summary_log.flush()
        print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.format(
            counter=counter, time=time.time() - start_time, loss=v_val_loss))
        return v_val_loss

    def sample_batch():
        return [data_sampler.sample(1024) for _ in range(batch_size)]

    if overwrite and restore_from == 'latest':
        for file in files:
            if file.startswith('model') or file.startswith('events'):
                os.remove(os.path.join(checkpoint_path, file))
        save()

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    #Trying out a change to finetune that saves only when validation loss decreases
    if steps:
        steps = int(steps)

    try:
        while True:
            if steps > 0 and counter == (counter_base + steps):
                #save()
                return
            # if (counter - 1) % save_every == 0 and counter > 1:
            #     save()
            if (counter - 1) % sample_every == 0 and counter > 1:
                generate_samples()

            # validation code
            if val_every > 0 and counter == 1:
                v_val_loss = validation()
                save()
            elif val_every > 0 and counter == counter_base:
                v_val_loss = validation()
            elif val_every > 0 and (counter % val_every == 0):
                new_v_val_loss = validation()
                if new_v_val_loss < v_val_loss:
                    v_val_loss = new_v_val_loss
                    save()

            if accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(accumulate_gradients):
                    sess.run(opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summary_loss),
                    feed_dict={context: sample_batch()})

            summary_log.add_summary(v_summary, counter)

            if (counter % print_every == 0) or counter == 1:
                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()
Ejemplo n.º 27
0
def train(dataset,
          model_in_path,
          model_out_path,
          model_name='117M',
          steps=1000,
          combine=50000,
          batch_size=1,
          learning_rate=0.00002,
          accumulate_gradients=1,
          memory_saving_gradients=False,
          only_train_transformer_layers=False,
          optimizer='adam',
          noise=0.0,
          top_k=40,
          top_p=0.0,
          restore_from='latest',
          sample_every=100,
          sample_length=1023,
          sample_num=1,
          save_every=1000,
          val_dataset=None):
    # Reset the TF computation graph
    tf.reset_default_graph()
    # Get the checkpoint and sample directories
    #checkpoint_dir = os.path.dirname(model_path)
    #sample_dir = checkpoint_dir
    #run_name = os.path.basename(model_path)
    # Load the encoder
    enc = get_encoder(model_in_path)
    hparams = model.default_hparams()
    with open(os.path.join(model_in_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    # Size matters
    if model_name == '345M':
        memory_saving_gradients = True
        if optimizer == 'adam':
            only_train_transformer_layers = True

    # Configure TF
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    # Start the session
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        context_in = randomize(context, hparams, noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=1.0,
                                           top_k=top_k,
                                           top_p=top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if only_train_transformer_layers else all_vars

        if optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        elif optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=learning_rate)
        else:
            exit('Bad optimizer:', optimizer)

        if accumulate_gradients > 1:
            if memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            #os.path.join(checkpoint_dir, run_name)
            model_out_path)

        saver = tf.train.Saver(var_list=all_vars, max_to_keep=1)
        sess.run(tf.global_variables_initializer())

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                #os.path.join(checkpoint_dir, run_name)
                model_in_path)
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    model_in_path)  #os.path.join('models', model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                model_in_path)  #os.path.join('models', model_name))
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc, dataset, combine)
        data_sampler = Sampler(chunks)
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')
        counter = 1
        counter_path = os.path.join(
            model_in_path,
            'counter')  #os.path.join(checkpoint_dir, run_name, 'counter')
        if restore_from == 'latest' and os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            #maketree(os.path.join(checkpoint_dir, run_name))
            maketree(model_out_path)
            print(
                'Saving',
                #os.path.join(checkpoint_dir, run_name, 'model-{}').format(counter)
                os.path.join(model_out_path, 'model-{}').format(counter))
            saver.save(
                sess,
                #os.path.join(checkpoint_dir, run_name, 'model'),
                os.path.join(model_out_path, 'model'),
                global_step=counter)
            with open(os.path.join(model_out_path, 'counter'), 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            #maketree(os.path.join(sample_dir, run_name))
            maketree(model_out_path)
            with open(
                    os.path.join(model_out_path, 'samples-{}').format(counter),
                    'w') as fp:
                fp.write('\n'.join(all_text))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        stop = steps + counter

        try:
            while counter < stop + 1:
                if counter % save_every == 0:
                    save()
                '''
                if counter % sample_every == 0:
                    generate_samples()
                '''

                if accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
            print('done!')
            save()
        except KeyboardInterrupt:
            print('interrupted')
            save()
Ejemplo n.º 28
0
def main():
    # initialize data loaders for train/test splits
    if args.data_set == 'imagenet' and args.class_conditional:
        raise("We currently don't have labels for the small imagenet data set")
    if args.data_set == 'cifar':
        import data.cifar10_data as cifar10_data
        DataLoader = cifar10_data.DataLoader
    elif args.data_set == 'imagenet':
        import data.imagenet_data as imagenet_data
        DataLoader = imagenet_data.DataLoader
    else:
        raise("unsupported dataset")
    train_data = DataLoader(args.data_dir, 'train', args.batch_size * args.nr_gpu, rng=rng, shuffle=True, return_labels=args.class_conditional)
    test_data = DataLoader(args.data_dir, 'test', args.batch_size * args.nr_gpu, shuffle=False, return_labels=args.class_conditional)
    obs_shape = train_data.get_observation_size() # e.g. a tuple (32,32,3)
    assert len(obs_shape) == 3, 'assumed right now'

    # data place holders
    x_init = tf.placeholder(tf.float32, shape=(args.init_batch_size,) + obs_shape)
    xs = [tf.placeholder(tf.float32, shape=(args.batch_size, ) + obs_shape) for i in range(args.nr_gpu)]

    # if the model is class-conditional we'll set up label placeholders + one-hot encodings 'h' to condition on
    if args.class_conditional:
        num_labels = train_data.get_num_labels()
        y_init = tf.placeholder(tf.int32, shape=(args.init_batch_size,))
        h_init = tf.one_hot(y_init, num_labels)
        y_sample = np.split(np.mod(np.arange(args.batch_size*args.nr_gpu), num_labels), args.nr_gpu)
        h_sample = [tf.one_hot(tf.Variable(y_sample[i], trainable=False), num_labels) for i in range(args.nr_gpu)]
        ys = [tf.placeholder(tf.int32, shape=(args.batch_size,)) for i in range(args.nr_gpu)]
        hs = [tf.one_hot(ys[i], num_labels) for i in range(args.nr_gpu)]
    else:
        h_init = None
        h_sample = [None] * args.nr_gpu
        hs = h_sample

    # create the model
    model_opt = { 'nr_resnet': args.nr_resnet, 'nr_filters': args.nr_filters, 'nr_logistic_mix': args.nr_logistic_mix, 'resnet_nonlinearity': args.resnet_nonlinearity}
    model = tf.make_template('model', model_spec)

    # run once for data dependent initialization of parameters
    data_dependent_init = model(x_init, h_init, init=True, dropout_p=args.dropout_p, **model_opt)

    # keep track of moving average
    all_params = tf.trainable_variables()
    ema = tf.train.ExponentialMovingAverage(decay=args.polyak_decay)
    maintain_averages_op = tf.group(ema.apply(all_params))
    ema_params = [ema.average(p) for p in all_params]

    # get loss gradients over multiple GPUs + sampling
    grads = []
    loss_gen = []
    loss_gen_test = []
    new_x_gen = []
    for i in range(args.nr_gpu):
        with tf.device('/gpu:%d' % i):
            if args.graph_cloning and i>0:
                # already defined the graph once, use it again via template rather than redefining again
                in_ = [xs[i]] + tf.global_variables()
                res = gpu_template.apply(in_)
                loss_train, loss_test, sx = res[:3]
                grad = res[3:]

                loss_gen.append(loss_train)
                loss_gen_test.append(loss_test)
                new_x_gen.append(sx)
                grads.append(grad)

            else:
                # train
                out = model(xs[i], hs[i], ema=None, dropout_p=args.dropout_p, **model_opt)
                loss_gen.append(nn.discretized_mix_logistic_loss(tf.stop_gradient(xs[i]), out))

                # gradients
                grads.append(gradients(loss_gen[i], all_params))

                # test
                out = model(xs[i], hs[i], ema=ema, dropout_p=0., **model_opt)
                loss_gen_test.append(nn.discretized_mix_logistic_loss(xs[i], out))

                # sample
                out = model(xs[i], h_sample[i], ema=ema, dropout_p=0, **model_opt)
                new_x_gen.append(nn.sample_from_discretized_mix_logistic(out, args.nr_logistic_mix))

                if args.graph_cloning:
                    in_ = [xs[0]] + tf.global_variables()
                    out_ = [loss_gen[0], loss_gen_test[0], new_x_gen[0]] + grads[0]
                    gpu_template = GraphTemplate(in_, outputs=out_)

    # add losses and gradients together and get training updates
    tf_lr = tf.placeholder(tf.float32, shape=[])
    with tf.device('/gpu:0'):
        for i in range(1,args.nr_gpu):
            loss_gen[0] += loss_gen[i]
            loss_gen_test[0] += loss_gen_test[i]
            for j in range(len(grads[0])):
                grads[0][j] += grads[i][j]
        # training op
        optimizer = tf.group(nn.adam_updates(all_params, grads[0], lr=tf_lr, mom1=0.95, mom2=0.9995), maintain_averages_op)

    # convert loss to bits/dim
    bits_per_dim = loss_gen[0]/(args.nr_gpu*np.log(2.)*np.prod(obs_shape)*args.batch_size)
    bits_per_dim_test = loss_gen_test[0]/(args.nr_gpu*np.log(2.)*np.prod(obs_shape)*args.batch_size)

    # sample from the model
    def sample_from_model(sess):
        x_gen = [np.zeros((args.batch_size,) + obs_shape, dtype=np.float32) for i in range(args.nr_gpu)]
        for yi in range(obs_shape[0]):
            for xi in range(obs_shape[1]):
                new_x_gen_np = sess.run(new_x_gen, {xs[i]: x_gen[i] for i in range(args.nr_gpu)})
                for i in range(args.nr_gpu):
                    x_gen[i][:,yi,xi,:] = new_x_gen_np[i][:,yi,xi,:]
        return np.concatenate(x_gen, axis=0)

    # turn numpy inputs into feed_dict for use with tensorflow
    def make_feed_dict(data, init=False):
        if type(data) is tuple:
            x,y = data
        else:
            x = data
            y = None
        x = np.cast[np.float32]((x - 127.5) / 127.5) # input to pixelCNN is scaled from uint8 [0,255] to float in range [-1,1]
        if init:
            feed_dict = {x_init: x}
            if y is not None:
                feed_dict.update({y_init: y})
        else:
            x = np.split(x, args.nr_gpu)
            feed_dict = {xs[i]: x[i] for i in range(args.nr_gpu)}
            if y is not None:
                y = np.split(y, args.nr_gpu)
                feed_dict.update({ys[i]: y[i] for i in range(args.nr_gpu)})
        return feed_dict

    # //////////// perform training //////////////
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    print('starting training')
    test_bpd = []
    lr = args.learning_rate
    saver = tf.train.Saver()
    with tf.Session() as sess:
        for epoch in range(args.max_epochs):
            begin = time.time()

            # init
            if epoch == 0:
                feed_dict = make_feed_dict(train_data.next(args.init_batch_size), init=True) # manually retrieve exactly init_batch_size examples
                train_data.reset()  # rewind the iterator back to 0 to do one full epoch
                print('initializing the model...')
                sess.run(tf.global_variables_initializer())
                sess.run(data_dependent_init, feed_dict)

            # train for one epoch
            train_losses = []
            counter = 0
            for d in train_data:
                counter+=1
                feed_dict = make_feed_dict(d)
                # forward/backward/update model on each gpu
                lr *= args.lr_decay
                feed_dict.update({ tf_lr: lr })
                l,_ = sess.run([bits_per_dim, optimizer], feed_dict)
                print(counter, l)
                train_losses.append(l)
                if counter>50:
                    if l>6.5:
                        assert False, "Test failed, expected loss 6.28537 at iteration 50"
                    else:
                        print("Test passed, loss %f (expected %f)"%(l, 6.28537))
                        sys.exit()
            train_loss_gen = np.mean(train_losses)

            # compute likelihood over test data
            test_losses = []
            for d in test_data:
                feed_dict = make_feed_dict(d)
                l = sess.run(bits_per_dim_test, feed_dict)
                test_losses.append(l)
            test_loss_gen = np.mean(test_losses)
            test_bpd.append(test_loss_gen)

            # log progress to console
            print("Iteration %d, time = %ds, train bits_per_dim = %.4f, test bits_per_dim = %.4f" % (epoch, time.time()-begin, train_loss_gen, test_loss_gen))
            sys.stdout.flush()

            if epoch % args.save_interval == 0:

                # generate samples from the model
                sample_x = []
                for i in range(args.num_samples):
                    sample_x.append(sample_from_model(sess))
                sample_x = np.concatenate(sample_x,axis=0)
                #img_tile = plotting.img_tile(sample_x[:100], aspect_ratio=1.0, border_color=1.0, stretch=True)
                #img = plotting.plot_img(img_tile, title=args.data_set + ' samples')
                #plotting.plt.savefig(os.path.join(args.save_dir,'%s_sample%d.png' % (args.data_set, epoch)))
                #plotting.plt.close('all')
                np.savez(os.path.join(args.save_dir,'%s_sample%d.npz' % (args.data_set, epoch)), sample_x)

                # save params
                saver.save(sess, args.save_dir + '/params_' + args.data_set + '.ckpt')
                np.savez(args.save_dir + '/test_bpd_' + args.data_set + '.npz', test_bpd=np.array(test_bpd))
Ejemplo n.º 29
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    hparams.res_dropout = args.dropout
    hparams.attn_dropout = args.dropout
    epsilon = -1e10
    if args.dtype == 'float32':
        hparams.dtype = tf.float32
    elif args.dtype == 'float16':
        hparams.dtype = tf.float16
        epsilon = -65500
    elif args.dtype == 'bfloat16':
        hparams.dtype = tf.bfloat16
        epsilon = -65500
    else:
        print('Unknown dtype', args.dtype)
    if args.float16:
        hparams.dtype = tf.bfloat16
        epsilon = -65500

    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
    if args.n_ctx >= 0:
        hparams.n_ctx=args.n_ctx
    if args.n_embd >= 0:
        hparams.n_embd=args.n_embd
    if args.n_head >= 0:
        hparams.n_head=args.n_head
    if args.n_layer >= 0:
        hparams.n_layer=args.n_layer

    if args.sample_length < 0:
        args.sample_length = hparams.n_ctx - 1
    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)
    if args.sample_ctx < 0:
      args.sample_ctx = hparams.n_ctx

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    if args.allow_growth:
        config.gpu_options.allow_growth = True
    if args.disable_layout_optimizer:
        config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tflex.Session(config=config, init_tpu=args.init_tpu) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)


        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p,
            epsilon=epsilon)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars

        parameter_count = sum([np.prod(v.shape.as_list()) for v in train_vars])
        print("This model is using %d parameters (%.2fM)" % (parameter_count, parameter_count/(1024.0*1024.0)))

        with tf.variable_scope(tf.get_variable_scope().name, reuse=tf.AUTO_REUSE):
            global_step = tflex.get_variable('global_step') or tf.get_variable('global_step', shape=(), dtype=tf.int32, trainable=False)
            current_step = args.learning_rate_initial_step
            global_step.load(current_step, session=sess)
            if args.learning_rate_cos:
                lr = tflex_sgdr.sgdr_decay_with_warmup(args.learning_rate, global_step,
                    warmup_steps=args.learning_rate_warmup, initial_period_steps=args.learning_rate_period, learning_rate_min=args.learning_rate_min)
            else:
                lr = tflex.get_variable('learn_rate') or tf.get_variable('learn_rate', shape=(), dtype=tf.float32, trainable=False)
                lr.load(args.learning_rate, session=sess)

        def update_lr(rate=None, step=None):
          if not args.learning_rate_cos:
            if step is None:
              step = global_step.eval(session=sess)
            if rate is None:
              rate = args.learning_rate
            if callable(rate):
              rate = rate(step)
            lr.load(rate, session=sess)
          return lr.eval(session=sess)

        @tflex.register_command
        def set_learning_rate():
          print("Current learn rate: %0.8f" % update_lr())
          print("New learn rate?")
          rate = input('')
          if not rate:
            print("Empty input; not changing anything.")
          else:
            try:
              rate = float(rate)
            except:
              print("Invalid input; must be a float")
          print("Setting learn rate to %0.8f" % rate)
          args.learning_rate = rate

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=lr)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=lr)
        elif args.optimizer == 'ada':
            import tensor2tensor.utils.optimize
            from tensor2tensor.utils import hparam
            import tensor2tensor.models.research
            from tensor2tensor.utils import registry
            ada_hparams = registry.hparams('afx_mimic_adam')
            ada_hparams.optimizer_adafactor_beta1 = 0.0
            ada_hparams.optimizer_adafactor_factored = True
            opt = tensor2tensor.utils.optimize.adafactor(learning_rate=lr, hparams=ada_hparams)
        else:
            exit('Bad optimizer:', args.optimizer)

        #if tpu_addr:
        #    # https://pulsejet.github.io/blog/posts/tpu-without-estimator/
        #    from tensorflow.contrib.tpu.python.tpu import tpu_function
        #    tpu_function.get_tpu_context().set_number_of_shards(8)
        #    opt = tf.contrib.tpu.CrossShardOptimizer(opt)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', lr)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        if args.save_graph:
            summary_log.add_graph(tf.get_default_graph())

        saver = tflex.Saver(
            var_list=all_vars,
            max_to_keep=args.max_to_keep,
            keep_checkpoint_every_n_hours=100000,
            reshape=args.truncate_weights)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tflex.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tflex.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tflex.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tflex.latest_checkpoint(args.restore_from)
        print('Loading snapshot %s...' % ckpt)
        t0 = time.time()
        if not args.fresh_model:
            saver.restore(sess, ckpt)
        t1 = time.time()
        print('Loaded in %f seconds' % (t1 - t0))

        def make_sampler(dataset, enc, seed, combine):
          if os.path.isdir(dataset) or dataset.endswith('.npz'):
            chunks = load_dataset(enc, dataset, combine)
            data_sampler = Sampler(chunks, seed=seed)
            print('dataset has', data_sampler.total_size, 'tokens', len(chunks), 'chunks')
          else:
            data_sampler = TextSampler(dataset, enc, seed=seed)
          return data_sampler

        print('Loading dataset...')
        seed = None if args.seed < 0 else args.seed
        data_sampler = make_sampler(dataset=args.dataset, enc=enc, seed=seed, combine=args.combine)
        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_dataset = args.val_dataset if args.val_dataset else args.dataset
            val_data_sampler = make_sampler(dataset=val_dataset, enc=enc, seed=1, combine=args.combine)
            val_batches = [[val_data_sampler.sample(hparams.n_ctx) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        print('Training...')
        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        @tflex.register_command
        def get_tarfile_name(checkpoint_folder):
            """Converts a folder path into a filename for a .tar archive"""
            tarfile_name = checkpoint_folder.replace(os.path.sep, '_') + '.tar'

            return tarfile_name


        def copy_checkpoint_to_gdrive(run_name='run1', copy_folder=False):
            """Copies the checkpoint folder to a mounted Google Drive."""
            #is_mounted()

            checkpoint_folder = os.path.join('checkpoint', run_name)

            if copy_folder:
                shutil.copytree(checkpoint_folder, "/content/drive/My Drive/" + checkpoint_folder)
            else:
                file_path = get_tarfile_name(checkpoint_folder)

                # Reference: https://stackoverflow.com/a/17081026
                with tarfile.open(file_path, 'w') as tar:
                    tar.add(checkpoint_folder)

                shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path)

        @tflex.register_command
        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            t0 = time.time()
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            t1 = time.time()
            print('Saved in %f seconds' % (t1 - t0))
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
            #copy_checkpoint_to_gdrive()

        @tflex.register_command
        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    print(text)
                    all_text.append(text)
                    index += 1
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        @tflex.register_command
        def validation():
            if args.val_every <= 0:
              return
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '{stamp} [{counter} | {time:2.4f}] validation loss = {loss:2.4f}'
                .format(
                    stamp=timestamp(),
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))

        start_time = time.time()

        def elapsed():
            return time.time() - start_time

        def say(msg):
            print('{stamp} [{counter} | {time:2.4f}] {msg}'.format(counter=counter, time=elapsed(), msg=msg, stamp=timestamp()))

        def sample_batch():
            #return [data_sampler.sample(args.sample_ctx) for _ in range(args.batch_size)]
            #say('Sampling batch...')
            r = []
            times = []
            for _ in range(args.batch_size):
                start = time.time()
                sample = data_sampler.sample(args.sample_ctx)
                end = time.time()
                elapsed = (end - start)
                r += [sample]
                times += [elapsed]
            total = sum(times)
            avg = total / len(times)
            #say('Sampled %d batches in %.4f seconds (avg per batch: %.4f)' % (args.batch_size, total, avg))
            return r

        prev_time = time.time()
        avg_loss = (0.0, 0.0)

        if args.debug_before_training:
            import pdb
            pdb.set_trace()

        last_saved_time = elapsed()
        while True:
            try:
                now = elapsed()
                if args.save_time > 0 and (((now - last_saved_time) / 60.0) >= args.save_time):
                    save()
                    last_saved_time = now
                elif args.save_every > 0 and (counter % args.save_every == 0):
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    validation()

                v_rate = update_lr()

                if args.accumulate_gradients > 1:
                    #say('Running opt_reset...')
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        batch = sample_batch()
                        say('Running opt_compute...')
                        sess.run(opt_compute, feed_dict={context: batch})
                    say('Running opt_apply...')
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    batch = sample_batch()
                    say('Running opt_apply...')
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: batch})

                if args.float16:
                    v_loss = tf.to_float(v_loss).eval()

                summary_log.add_summary(v_summary, counter)
                summary_log.flush()

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                now = time.time()
                print('{stamp} [{counter} | {time:2.4f} | {delta:2.2f}s | {ops:2.6f}tokens/s] loss={loss:2.4f} avg={avg:2.4f} rate={rate:0.7f} step={step}'
                    .format(
                        stamp=timestamp(),
                        counter=counter,
                        time=now - start_time,
                        delta=now - prev_time,
                        ops=args.sample_ctx * args.batch_size / (now - prev_time),
                        rate=v_rate,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1],
                        step=current_step,
                        ))

                counter += 1
                current_step += 1
                global_step.load(current_step, session=sess)

                tflex.check_commands_with_args(
                    session=sess,
                    stamp=timestamp(),
                    counter=counter,
                    time=now - start_time,
                    delta=now - prev_time,
                    ops=args.batch_size / (now - prev_time),
                    rate=v_rate,
                    loss=v_loss,
                    avg=avg_loss[0] / avg_loss[1],
                    avg_loss=avg_loss,
                    step=current_step,
                    train_vars=train_vars,
                    all_vars=all_vars,
                    args=args,
                    data_sampler=data_sampler,
                    ckpt=ckpt,
                    saver=saver,
                    )
                if tflex.should_quit():
                  break

                prev_time = now
                if args.debug_print_all_vars:
                    print('all variables:')
                    print('name/shape/parameter_count')
                    param_count = 0
                    for x in tf.all_variables():
                        shape = x.shape.as_list()
                        count = np.prod(shape)
                        print(x.name, shape, count)
                        param_count += count
                    print('Total parameters:', param_count)
                    args.debug_print_all_vars = False

                if args.debug_print_trainable_vars:
                    print('trainable variables:')
                    print('name/shape/parameter_count')
                    param_count = 0
                    for x in tf.trainable_variables():
                        shape = x.shape.as_list()
                        count = np.prod(shape)
                        print(x.name, shape, count)
                        param_count += count
                    print('Total parameters:', param_count)
                    args.debug_print_trainable_vars = False
            except KeyboardInterrupt:
                print('interrupted')
                if args.save_on_ctrlc:
                    save()
                if args.debug_on_ctrlc:
                    import pdb
                    pdb.set_trace()
                else:
                    break
 def gradients_auto(ys, xs, grad_ys=None, **kwargs):
   return memory_saving_gradients.gradients(ys, xs, grad_ys,
                                            checkpoints='memory', **kwargs)
Ejemplo n.º 31
0
 def compute_gradients(loss, var_list, checkpoints='collection'):
     # https://github.com/openai/gradient-checkpointing
     from memory_saving_gradients import gradients
     grads = gradients(loss, var_list, checkpoints=checkpoints)
     grads_vars = list(zip(grads, var_list))
     return grads_vars
Ejemplo n.º 32
0
def main():
    args = parser.parse_args()
    enc = get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join(model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF

    acc_total = 0
    acc_over_time = []
    loss_avg_over_time = []

    if args.val_every > 0:

        # val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
        val_context = tf.placeholder(np.int32, [1, None])

        val_output = model.model(hparams=hparams, X=val_context)
        val_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=val_context[:,
                                   1:], logits=val_output['logits'][:, :-1]))
        val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample_val = sample.sample_sequence(
            hparams=hparams,
            length=1,  #args.sample_length,
            context=val_context,
            batch_size=1,  #args.batch_size,
            temperature=10.001,
            top_k=1)

    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=40)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars
        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(
                opt=tf.train.AdamOptimizer(learning_rate=args.learning_rate),
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(os.path.join(model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(os.path.join(model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading train dataset...')
        from_name, ques_name, to_name = name_parts(
            args.dataset)  #'../data/train.from')

        trn_chunks_from = load_dataset(
            enc, from_name, args.combine) if args.val_dataset else chunks
        trn_chunks_ques = load_dataset(
            enc, ques_name, args.combine) if args.val_dataset else chunks
        trn_chunks_to = load_dataset(
            enc, to_name, args.combine) if args.val_dataset else chunks

        skip_delimeter = True
        trn_data_sampler_from = SamplerVal(trn_chunks_from,
                                           enc,
                                           skip_delimeter=skip_delimeter)
        trn_data_sampler_ques = SamplerVal(trn_chunks_ques,
                                           enc,
                                           skip_delimeter=skip_delimeter)
        trn_data_sampler_to = SamplerVal(trn_chunks_to,
                                         enc,
                                         skip_delimeter=skip_delimeter)

        data_sampler = []
        for i in range(trn_data_sampler_from.total_size):
            v = (
                trn_data_sampler_from.get(i) + trn_data_sampler_ques.get(i) +
                enc.encode('. ') + trn_data_sampler_to.get(i)  # +
                #enc.encode('<|endoftext|>')
            )
            # v += [enc.encode(' ')[0] for _ in range(HIDDEN_SIZE - len(v) )]
            if len(v) >= HIDDEN_SIZE - GENERATE_SIZE:
                continue
            v = v[:HIDDEN_SIZE - 1]
            data_sampler.append(v)
            pass

        #chunks = load_dataset(enc, args.dataset, args.combine)
        if not args.train_special:
            data_sampler = Sampler([np.array(data_sampler)])

        if args.val_every > 0:
            print('Loading validation dataset...')
            #val_chunks = load_dataset(enc, args.val_dataset, args.combine) if args.val_dataset else chunks

            from_name, ques_name, to_name = name_parts(args.val_dataset)

            val_chunks_from = load_dataset(
                enc, from_name, args.combine) if args.val_dataset else chunks
            val_chunks_ques = load_dataset(
                enc, ques_name, args.combine) if args.val_dataset else chunks
            val_chunks_to = load_dataset(
                enc, to_name, args.combine) if args.val_dataset else chunks

        if not args.train_special:
            print('train dataset has', data_sampler.total_size, 'tokens')
        else:
            print('train dataset has', len(data_sampler), 'tokens')
        print('Training...')

        if args.val_every > 0:

            val_data_sampler_from = SamplerVal(val_chunks_from, enc)
            val_data_sampler_ques = SamplerVal(val_chunks_ques, enc)
            val_data_sampler_to = SamplerVal(val_chunks_to, enc)

            if args.val_batch_count == -1:
                args.val_batch_count = val_data_sampler_from.total_size

            val_batches = []
            for i in range(args.val_batch_count):
                v = (val_data_sampler_from.get(i) +
                     val_data_sampler_ques.get(i) + enc.encode('. ')
                     )  #+ val_data_sampler_to.get(i)

                #v += [enc.encode(' ')[0] for _ in range(HIDDEN_SIZE - len(v) )]
                if len(v) >= HIDDEN_SIZE - GENERATE_SIZE:
                    continue
                v = v[:HIDDEN_SIZE]
                val_batches.append(v)
                pass

        print('val dataset has', len(val_batches), 'tokens')
        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        txt_file_path = os.path.join(CHECKPOINT_DIR, args.run_name,
                                     args.run_name + '.summary.txt')

        def save_summary(message=None):
            if message is None:
                txt = ''
                fmt = '{valid:2.2f}'
                if not os.path.exists(txt_file_path):
                    a = vars(args)
                    txt += 'Summary for ' + args.run_name + '\n'
                    txt += str(datetime.datetime.now()) + '\n\n'
                    txt += json.dumps(a) + '\n'
                    txt += '-----\n'
                    pass
                txt += str(datetime.datetime.now()) + '\n'

                txt += 'acc: ' + ', '.join(
                    [fmt.format(valid=i) for i in acc_over_time]) + '\n'
                txt += 'loss: ' + ', '.join(
                    [fmt.format(valid=i) for i in loss_avg_over_time]) + '\n'
                txt += 'counter: ' + str(counter) + '\n'
                txt += 'time elapsed: ' + str(time.time() - start_time) + '\n'
                txt += '-----\n'
            else:
                txt = message
            print(txt)
            with open(txt_file_path, 'a') as f:
                f.write(txt + '\n')

        def save():
            if args.test:
                return

            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        '''
        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))
        '''

        def print_status(word=None,
                         acc_total_in=0,
                         size=0,
                         v_loss_in=0.0,
                         shorten=False):
            v_loss = v_loss_in
            acc_out = 0
            acc_total = 0
            loss_out = 0.0
            if word is None:
                word = 'progress'
            if acc_total_in != 0 and size != 0:
                acc_out = acc_total_in / size * 100
                acc_total = size
                pass
            if avg_loss[1] == 0.0 or avg_loss[0] == 0.0:
                loss_out = 0.0
                v_loss = 0.0
                pass
            elif not np.isnan(
                    avg_loss[0]) or True:  # and not np.isnan(avg_loss[1]):
                loss_out = avg_loss[0] / avg_loss[1]
            print(word + ' [' + args.run_name + ']' +
                  ' [{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_loss,
                         avg=loss_out),
                  'acc=' + str(acc_out),
                  end=' ')
            print('total=' + str(acc_total), end=' ')
            if len(acc_over_time) > 0 and not shorten:
                print('last-acc=' + str(acc_over_time[-1]))
            else:
                print()
            pass

        def sample_batch(counter=0, randomize=False, pad_start=False):
            #print(enc.encode('<|endoftext|>'), 'eot')
            #print(data_sampler.sample(1024))
            if not args.train_special:
                return [
                    data_sampler.sample(HIDDEN_SIZE)[0]
                    for _ in range(args.batch_size)
                ]
            else:
                num = 0
                z = []
                while (len(z) > HIDDEN_SIZE or len(z) == 0) and num <= 5:
                    if randomize:
                        r = random.randint(1, 4)
                    else:
                        r = 0
                    #print('train special', r)
                    if pad_start:
                        pad = HIDDEN_SIZE - (len(data_sampler[counter]) - r)
                    else:
                        pad = 0

                    if randomize:
                        z = [[enc.encode(' ')[0]
                              for _ in range(pad)] + data_sampler[counter][:-r]
                             for _ in range(args.batch_size)]

                    if not randomize:
                        z = [data_sampler[counter]]
                    #print(enc.decode(z[0]))
                    num += 1
                    if num == 5:
                        z = z[len(z) - HIDDEN_SIZE:]
                        print('cannot get sample_batch')
                        break

                return z

        def validation_by_sample():
            print('Generating validation...')
            global acc_total
            if args.val_with_loss:
                losses = []
                for batch in tqdm.tqdm(val_batches):
                    batch = np.reshape(batch, [1, -1])
                    v = sess.run(val_loss, feed_dict={val_context: batch})
                    #print(v, 'v')
                    losses.append(v)
                v_val_loss = np.mean(losses)
                v_summary = sess.run(val_loss_summary,
                                     feed_dict={val_loss: v_val_loss})
                summary_log.add_summary(v_summary, counter)
                summary_log.flush()
                print(
                    '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                    format(counter=counter,
                           time=time.time() - start_time,
                           loss=v_val_loss))
            acc_total = 0
            generated = 0
            for _ in range(len(val_batches)):

                val_batches_in = val_batches[generated]
                val_batches_in = val_batches_in[:1024]
                context_tokens = np.reshape(val_batches_in, [1, -1])
                #print(val_batches_in)
                text_in = enc.decode(val_batches_in)
                #print(text_in)

                #print(context_tokens, 'ct1')
                for x in range(GENERATE_SIZE):

                    out = sess.run(tf_sample_val,
                                   feed_dict={val_context: context_tokens})
                    #print(out[0][-x:])
                    #print(enc.decode(out[0][-x:]))
                    context_tokens = out

                compare = enc.decode(
                    val_data_sampler_to.get(generated))  # + ' <|endoftext|>'
                compare = ' '.join(compare.split(' '))

                generated += 1

                text = enc.decode(out[0])

                text_returned = ''
                text_original = ''
                if text.startswith(text_in):
                    text_returned = text[len(text_in):]
                    #print('-',text_returned,'-')

                if args.train_special:
                    text_original = text
                    text = text_returned

                if text.strip().endswith('.'):  ## remove trailing period
                    text = text.strip()[:-1]

                if text.strip().endswith('<|endoftext|>'):
                    text = text.strip()[:-len('<|endoftext|>')]

                t_vals = text.split(' ')
                if '<' in t_vals[-1] or '>' in t_vals[-1]:
                    t_vals = t_vals[:-1]

                num = 0
                while t_vals[-1] == '' and num < 10:
                    t_vals = t_vals[:-1]
                    num += 1

                #print(t_vals)
                t_vals = [i for i in t_vals if i != '']
                #print(t_vals)

                text = ' '.join(t_vals)

                if compare.strip().endswith('.'):
                    compare = compare.strip()[:-1]

                if compare.strip().endswith('<|endoftext|>'):
                    compare = compare.strip()[:-len('<|endoftext|>')]

                notification = ''
                len_bar = 40
                if text.strip().lower().endswith(compare.strip().lower()):
                    acc_total += 1
                    notification = 'vv CORRECT vv'
                    len_bar = 40 - len(notification)
                elif text_returned.strip().lower().startswith(
                        compare.strip().lower()):
                    acc_total += 1
                    notification = 'vv CORRECT_INITIAL vv'
                    len_bar = 40 - len(notification)

                print(notification + "=" * len_bar + " SAMPLE " +
                      str(generated) + " " + "=" * len_bar + notification)
                if args.train_special:
                    print(text_original)
                else:
                    print(text)
                print_status('old values',
                             acc_total_in=acc_total,
                             size=generated)
            print("=" * 80)
            return acc_total
            pass

        avg_loss = (0.0, 0.0)
        start_time = time.time()
        count_success = 0
        count_success_with_skips = 0
        acc = 0.0

        try:
            if args.test:
                v_loss = 0.0
                dataset = re.sub('train', 'test', args.dataset)
                print(dataset)
                from_name, ques_name, to_name = name_parts(dataset)

                test_chunks_from = load_dataset(enc, from_name, args.combine)
                test_chunks_ques = load_dataset(enc, ques_name, args.combine)
                test_chunks_to = load_dataset(enc, to_name, args.combine)

                val_data_sampler_from = SamplerVal(test_chunks_from, enc)
                val_data_sampler_ques = SamplerVal(test_chunks_ques, enc)
                val_data_sampler_to = SamplerVal(test_chunks_to, enc)

                if args.val_batch_count == -1:
                    args.val_batch_count = val_data_sampler_from.total_size

                val_batches = []
                for i in range(args.val_batch_count):
                    v = (val_data_sampler_from.get(i) +
                         val_data_sampler_ques.get(i) + enc.encode('. ')
                         )  # + val_data_sampler_to.get(i)

                    # v += [enc.encode(' ')[0] for _ in range(HIDDEN_SIZE - len(v) )]
                    if len(v) >= HIDDEN_SIZE - GENERATE_SIZE:
                        continue
                    val_batches.append(v)

                acc_total = validation_by_sample()
                acc = acc_total / len(val_batches) * 100

                print(acc, 'test accuracy')
                save_summary('Accuracy with test set ' + str(acc) + '\n')
                exit()

            while counter != args.stop_after:
                #model_summary()

                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    #generate_samples()
                    pass
                if args.val_every > 0 and (counter % args.val_every
                                           == 0):  # or counter == 1):
                    acc_total = validation_by_sample()
                    acc = acc_total / len(val_batches) * 100

                    acc_over_time.append(acc)
                    if avg_loss[1] > 0.0:
                        loss_avg_over_time.append(avg_loss[0] / avg_loss[1])
                    else:
                        loss_avg_over_time.append(0)

                counter_in = counter % len(val_batches)
                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch(counter_in)})
                    (v_loss, v_summary) = sess.run((opt_apply, summary_loss))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summary_loss),
                        feed_dict={context: sample_batch(counter_in)})

                summary_log.add_summary(v_summary, counter)

                #if True:
                if not np.isnan(avg_loss[0]) and not np.isnan(avg_loss[1]):

                    avg_loss = (avg_loss[0] * 0.99 + v_loss,
                                avg_loss[1] * 0.99 + 1.0)

                if counter % args.val_every == 1:
                    if float(acc) == 100.0:
                        #save()

                        print('validation accuracy 100',
                              time.time() - start_time)
                        count_success += 1
                        count_success_with_skips += 1
                        if count_success >= 2 or count_success_with_skips >= 4:
                            #save_summary()

                            exit()
                    else:
                        count_success = 0

                print_status(acc_total_in=acc_total,
                             size=len(val_batches),
                             v_loss_in=v_loss,
                             shorten=True)

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
        finally:
            save()
            save_summary()
            print('save weights/summary and exit.')
Ejemplo n.º 33
0
def abstract_model_xy(sess, hps, feeds, train_iterators, test_iterators, data_inits, lr, f_loss):

    # == Create class with static fields and methods
    class m(object):
        pass
    m.sess = sess
    m.feeds = feeds
    m.lr = lr

    # === Loss and optimizer
    if hps.joint_train:
        (loss_train_A, stats_train_A, eps_flatten_A, loss_train_B, stats_train_B, eps_flatten_B) \
            = f_loss(train_iterators, is_training=True)
    else:
        (loss_train_A, stats_train_A, loss_train_B, stats_train_B) \
            = f_loss(train_iterators, is_training=True)

    all_params = tf.trainable_variables()

    # Get train data op
    def get_train_data():
        x_A, y_A = train_iterators['A']()
        x_B, y_B = train_iterators['B']()
        return x_A, y_A, x_B, y_B
    m.get_train_data = get_train_data

    # A
    with tf.variable_scope('optim_A'):
        params_A = [param for param in all_params if 'A/' in param.name]
        if hps.gradient_checkpointing == 1:
            from memory_saving_gradients import gradients
            gs_A = gradients(loss_train_A, params_A)
        else:
            gs_A = tf.gradients(loss_train_A, params_A)
        m.optimizer_A = optim.Optimizer()
        train_op_A, polyak_swap_op_A, ema_A = m.optimizer_A.adamax(
            params_A, gs_A, alpha=lr, hps=hps)
        if hps.direct_iterator:
            m.train_A = lambda _lr: sess.run(
                [train_op_A, stats_train_A], {lr: _lr})[1]
        else:
            def _train_A(_lr, _x_A, _y_A, _x_B, _y_B):
                return sess.run([train_op_A, stats_train_A], {feeds['x_A']: _x_A,
                                                              feeds['y_A']: _y_A,
                                                              feeds['x_B']: _x_B,
                                                              feeds['y_B']: _y_B,
                                                              lr: _lr})[1]
            m.train_A = _train_A
        m.polyak_swap_A = lambda: sess.run(polyak_swap_op_A)
    # B
    with tf.variable_scope('optim_B'):
        params_B = [param for param in all_params if 'B/' in param.name]
        if hps.gradient_checkpointing == 1:
            from memory_saving_gradients import gradients
            gs_B = gradients(loss_train_B, params_B)
        else:
            gs_B = tf.gradients(loss_train_B, params_B)
        m.optimizer_B = optim.Optimizer()
        train_op_B, polyak_swap_op_B, ema_B = m.optimizer_B.adamax(
            params_B, gs_B, alpha=lr, hps=hps)
        if hps.direct_iterator:
            m.train_B = lambda _lr: sess.run(
                [train_op_B, stats_train_B], {lr: _lr})[1]
        else:
            def _train_B(_lr, _x_A, _y_A, _x_B, _y_B):
                return sess.run([train_op_B, stats_train_B], {feeds['x_A']: _x_A,
                                                              feeds['y_A']: _y_A,
                                                              feeds['x_B']: _x_B,
                                                              feeds['y_B']: _y_B,
                                                              lr: _lr})[1]
            m.train_B = _train_B
        m.polyak_swap_B = lambda: sess.run(polyak_swap_op_B)

    def _train(_lr, _x_A, _y_A, _x_B, _y_B):
        return sess.run([train_op_A, train_op_B, stats_train_A, stats_train_B],
                        {feeds['x_A']: _x_A, feeds['y_A']: _y_A,
                         feeds['x_B']: _x_B, feeds['y_B']: _y_B,
                         lr: _lr})[-2:]
    m.train = _train

    # === Testing
    loss_test_A, stats_test_A, loss_test_B, stats_test_B = f_loss(
        test_iterators, False, reuse=True)
    if hps.direct_iterator:
        m.test_A = lambda: sess.run(stats_test_A)
        m.test_B = lambda: sess.run(stats_test_B)
    else:
        # Get test data op
        def get_test_data():
            x_A, y_A = test_iterators['A']()
            x_B, y_B = test_iterators['B']()
            return x_A, y_A, x_B, y_B
        m.get_test_data = get_test_data

        def _test_A(_x_A, _y_A, _x_B, _y_B):
            return sess.run(stats_test_A, {feeds['x_A']: _x_A,
                                           feeds['y_A']: _y_A,
                                           feeds['x_B']: _x_B,
                                           feeds['y_B']: _y_B})

        def _test_B(_x_A, _y_A, _x_B, _y_B):
            return sess.run(stats_test_B, {feeds['x_A']: _x_A,
                                           feeds['y_A']: _y_A,
                                           feeds['x_B']: _x_B,
                                           feeds['y_B']: _y_B})
        m.test_A = _test_A
        m.test_B = _test_B

    # === Saving and restoring
    with tf.variable_scope('saver_A'):
        saver_A = tf.train.Saver()
        saver_ema_A = tf.train.Saver(ema_A.variables_to_restore())
        m.save_ema_A = lambda path_A: saver_ema_A.save(
            sess, path_A, write_meta_graph=False)
        m.save_A = lambda path_A: saver_A.save(
            sess, path_A, write_meta_graph=False)
        m.restore_A = lambda path_A: saver_A.restore(sess, path_A)

    with tf.variable_scope('saver_B'):
        saver_B = tf.train.Saver()
        saver_ema_B = tf.train.Saver(ema_B.variables_to_restore())
        m.save_ema_B = lambda path_B: saver_ema_B.save(
            sess, path_B, write_meta_graph=False)
        m.save_B = lambda path_B: saver_B.save(
            sess, path_B, write_meta_graph=False)
        m.restore_B = lambda path_B: saver_B.restore(sess, path_B)
        print("After saver")

    # === Initialize the parameters
    if hps.restore_path_A != '':
        m.restore_A(hps.restore_path_A)
    if hps.restore_path_B != '':
        m.restore_B(hps.restore_path_B)
    if hps.restore_path_A == '' and hps.restore_path_B == '':
        with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
            results_init = f_loss(None, False, reuse=True, init=True)

        all_params = tf.global_variables()
        params_A = [param for param in all_params if 'A/' in param.name]
        params_B = [param for param in all_params if 'B/' in param.name]
        sess.run(tf.variables_initializer(params_A))
        sess.run(tf.variables_initializer(params_B))
        feeds_dict = {feeds['x_A']: data_inits['A']['x'],
                      feeds['y_A']: data_inits['A']['y'],
                      feeds['x_B']: data_inits['B']['x'],
                      feeds['y_B']: data_inits['B']['y']}
        sess.run(results_init, feeds_dict)
    sess.run(hvd.broadcast_global_variables(0))

    return m
Ejemplo n.º 34
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    with tf.Session() as sess:
        # Fully static shape required to make memory accounting in
        # twremat accurate.
        train_context = tf.placeholder(tf.int32, [args.batch_size, 1024])
        train_context_in = randomize(train_context, hparams, args.noise)
        train_output = model.model(hparams=hparams, X=train_context_in)
        train_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=train_context[:, 1:],
                logits=train_output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        sample_context = tf.placeholder(tf.int32, [args.batch_size, None])
        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=sample_context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=args.top_k,
                                           top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            print('Using Adam optimizer', file=sys.stderr)
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            print('Using SGD optimizer', file=sys.stderr)
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.memory_saving_gradients:
            if tf.VERSION >= '2':
                exit(
                    'Memory saving gradients are not supported in tensorflow 2.x'
                )
            import memory_saving_gradients
            opt_grads = memory_saving_gradients.gradients(
                train_loss, train_vars)
        elif args.twremat:
            import tfremat
            opt_grads = tf.gradients(train_loss, train_vars)
            (train_loss, opt_grads) = tfremat.tf_remat(
                (train_loss, opt_grads), memlimit=args.twremat_memlimit)
        else:
            opt_grads = tf.gradients(train_loss, train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.summary.scalar('loss', train_loss)

        # if args.twremat:
        #     import tfremat
        #     # Applying tfremat to opt_apply has more accurate
        #     # accounting but is a bit iffier since side effecting ops
        #     # have more restrictions for correctness. If in doubt
        #     # revert back to version using opt_grads above.
        #     (opt_apply, train_loss, summary_loss) = (
        #         tfremat.tf_remat((opt_apply, train_loss, summary_loss), memlimit=args.twremat_memlimit))

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc,
                              args.dataset,
                              args.combine,
                              encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc,
                                          args.val_dataset,
                                          args.combine,
                                          encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(tf_sample,
                               feed_dict={
                                   sample_context:
                                   args.batch_size * [context_tokens]
                               })
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(os.path.join(SAMPLE_DIR, args.run_name,
                                   'samples-{}').format(counter),
                      'w',
                      encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        # print('Evaluating grads..')
        # tf2.profiler.experimental.start('logdir')
        # sess.run((opt_apply, train_loss, summaries), feed_dict={train_context: sample_batch()})
        # tf2.profiler.experimental.stop()
        # print('Succeeded')
        # exit()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, train_loss, summaries),
                    feed_dict={train_context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
 def grads(ys, xs, grad_ys=None, **kwargs):
     return memory_saving_gradients.gradients(ys,
                                              xs,
                                              grad_ys,
                                              checkpoints='speed',
                                              **kwargs)
Ejemplo n.º 36
0
def main():
    args = parser.parse_args()
    folder_id = get_id(args.gdir)
    #xmpp = SendMsgBot(jid, password, to, "Starting GPT-2")
    #xmpp.register_plugin('xep_0030') # Service Discovery
    #xmpp.register_plugin('xep_0199') # XMPP Ping
    #xmpp.connect()
    #threading = Thread(target=xmpp.process, daemon=True).start()
    download_checkpoint(folder_id)
    #send_m('checkpoint downloaded')
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        # if args.optimizer == 'adam':
            # args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(os.path.join(CHECKPOINT_DIR, args.run_name))
        saver = tf.train.Saver(var_list=all_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        #send_m('Loading  ' + str(ckpt))
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        #send_m('Loading dataset...')
        #chunks = load_dataset(enc, args.dataset, args.combine)
        ds_path = f'{CHECKPOINT_DIR}//run1//{args.dataset}'
        chunks = load_dataset(enc, ds_path, args.combine)
        data_sampler = Sampler(chunks)
        print(f'{ds_path} has', data_sampler.total_size, 'tokens')
        if args.val_every > 0:
            val_chunks = load_dataset(enc, args.val_dataset, args.combine) if args.val_dataset else chunks
        if args.enc:
            print(colored(f'Trying writing Data.npz encoded from this dataset to {args.enc}', 'red'))
            np.savez_compressed(args.enc, *chunks)
            upload_npz(args.enc, folder_id)
        #send_m(f'{args.dataset} has ' + str(data_sampler.total_size) + ' tokens' + '     Start training...')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
            save_gdisk(counter, folder_id)

        def generate_samples():
            print('Generating samples...')
            #send_m('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            #send_m(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'
                    .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()
        last_time = time.time()
        cur_counter, min_loss = 1, 2.0
        print(colored(f'Model  >>> {args.gdir}\nLearning rate is {args.learning_rate}', 'blue'))
        print(colored(f'model optimizer >>> {args.optimizer}\nRestricted to train only transformer layer={args.only_train_transformer_layers}', 'blue'))
        #send_m(f'Model  >>> {args.model_name}\nLearning rate is {args.learning_rate}')
        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                    if check_quota():
                        return() # exit train
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute, feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run((opt_apply, loss, summaries), feed_dict={context: sample_batch()})
                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)
                a_loss = avg_loss[0] / avg_loss[1]
                time_all = int((time.time() - start_time) / 60)
                time_iter = time.time() - last_time
                stats = f'[{counter} | {cur_counter} | {time_all}m | {time_iter:2.2f}s] loss={v_loss:2.2f} avg={a_loss:2.2f}'
                if not(cur_counter % 50):
                    print(colored(stats, 'red' if a_loss > min_loss else 'yellow'))
                    if a_loss < min_loss:
                        min_loss = a_loss
                    #send_m(stats)
                last_time = time.time()
                counter += 1
                cur_counter += 1
        except Exception as e:
            #send_m('Stoped  ' + str(e.__class__))
            print('Stoped', e.__class__)