Example #1
0
def _main(_):
    # Data
    batch_size = config.batch_size
    memory_size = config.memory_size
    terminating_learning_rate = config.terminating_learning_rate
    data = prepare_data(FLAGS.data_path)
    vocab_size = data["vocab_size"]
    print('vocab_size = {}'.format(vocab_size))

    inputs = tf.placeholder(tf.int32, [None, memory_size], name="inputs")
    targets = tf.placeholder(tf.int32, [None], name="targets")

    # Model architecture
    initializer = tf.random_normal_initializer(stddev=config.initialize_stddev)
    with tf.variable_scope("model", initializer=initializer):
        memnet = tx.modules.MemNetRNNLike(raw_memory_dim=vocab_size,
                                          hparams=config.memnet)
        queries = tf.fill([tf.shape(inputs)[0], config.dim],
                          config.query_constant)
        logits = memnet(inputs, queries)

    # Losses & train ops
    mle_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets,
                                                              logits=logits)
    mle_loss = tf.reduce_sum(mle_loss)

    # Use global_step to pass epoch, for lr decay
    lr = config.opt["optimizer"]["kwargs"]["learning_rate"]
    learning_rate = tf.placeholder(tf.float32, [], name="learning_rate")
    global_step = tf.Variable(0, dtype=tf.int32, name="global_step")
    increment_global_step = tf.assign_add(global_step, 1)
    train_op = tx.core.get_train_op(mle_loss,
                                    learning_rate=learning_rate,
                                    global_step=global_step,
                                    increment_global_step=False,
                                    hparams=config.opt)

    def _run_epoch(sess, data_iter, epoch, is_train=False):
        loss = 0.
        iters = 0

        fetches = {"mle_loss": mle_loss}
        if is_train:
            fetches["train_op"] = train_op

        mode = (tf.estimator.ModeKeys.TRAIN
                if is_train else tf.estimator.ModeKeys.EVAL)

        for _, (x, y) in enumerate(data_iter):
            batch_size = x.shape[0]
            feed_dict = {
                inputs: x,
                targets: y,
                learning_rate: lr,
                tx.global_mode(): mode,
            }

            rets = sess.run(fetches, feed_dict)
            loss += rets["mle_loss"]
            iters += batch_size

        ppl = np.exp(loss / iters)
        return ppl

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        try:
            saver.restore(sess, "ckpt/model.ckpt")
            print('restored checkpoint.')
        except:
            print('restore checkpoint failed.')

        last_valid_ppl = None
        heuristic_lr_decay = (hasattr(config, 'heuristic_lr_decay')
                              and config.heuristic_lr_decay)
        while True:
            if lr < terminating_learning_rate:
                break

            epoch = sess.run(global_step)
            if epoch >= config.num_epochs:
                print('Too many epochs!')
                break

            print('epoch: {} learning_rate: {:.6f}'.format(epoch, lr))

            # Train
            train_data_iter = ptb_iterator(data["train_text_id"], batch_size,
                                           memory_size)
            train_ppl = _run_epoch(sess, train_data_iter, epoch, is_train=True)
            print("Train Perplexity: {:.3f}".format(train_ppl))
            sess.run(increment_global_step)

            # checkpoint
            if epoch % 5 == 0:
                try:
                    saver.save(sess, "ckpt/model.ckpt")
                    print("saved checkpoint.")
                except:
                    print("save checkpoint failed.")

            # Valid
            valid_data_iter = ptb_iterator(data["valid_text_id"], batch_size,
                                           memory_size)
            valid_ppl = _run_epoch(sess, valid_data_iter, epoch)
            print("Valid Perplexity: {:.3f}".format(valid_ppl))

            # Learning rate decay
            if last_valid_ppl:
                if heuristic_lr_decay:
                    if valid_ppl > last_valid_ppl * config.heuristic_threshold:
                        lr /= 1. + (valid_ppl / last_valid_ppl \
                                    - config.heuristic_threshold) \
                                   * config.heuristic_rate
                    last_valid_ppl = last_valid_ppl \
                                     * (1 - config.heuristic_smooth_rate) \
                                     + valid_ppl * config.heuristic_smooth_rate
                else:
                    if valid_ppl > last_valid_ppl:
                        lr /= config.learning_rate_anneal_factor
                    last_valid_ppl = valid_ppl
            else:
                last_valid_ppl = valid_ppl
            print("last_valid_ppl: {:.6f}".format(last_valid_ppl))

        epoch = sess.run(global_step)
        print('Terminate after epoch ', epoch)

        # Test
        test_data_iter = ptb_iterator(data["test_text_id"], 1, memory_size)
        test_ppl = _run_epoch(sess, test_data_iter, 0)
        print("Test Perplexity: {:.3f}".format(test_ppl))
def _main(_):
    # Data
    batch_size = config.batch_size
    num_steps = config.num_steps
    data = prepare_data(FLAGS.data_path)
    vocab_size = data["vocab_size"]

    inputs = tf.placeholder(tf.int32, [batch_size, num_steps])
    targets = tf.placeholder(tf.int32, [batch_size, num_steps])

    # Model architecture
    initializer = tf.random_uniform_initializer(-config.init_scale,
                                                config.init_scale)
    with tf.variable_scope("model", initializer=initializer):
        embedder = tx.modules.WordEmbedder(vocab_size=vocab_size,
                                           hparams=config.emb)
        emb_inputs = embedder(inputs)
        if config.keep_prob < 1:
            emb_inputs = tf.nn.dropout(
                emb_inputs, tx.utils.switch_dropout(config.keep_prob))

        decoder = tx.modules.BasicRNNDecoder(vocab_size=vocab_size,
                                             hparams={"rnn_cell": config.cell})
        initial_state = decoder.zero_state(batch_size, tf.float32)
        outputs, final_state, seq_lengths = decoder(
            decoding_strategy="train_greedy",
            impute_finished=True,
            inputs=emb_inputs,
            sequence_length=[num_steps] * batch_size,
            initial_state=initial_state)

    # Losses & train ops
    mle_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
        labels=targets, logits=outputs.logits, sequence_length=seq_lengths)

    # Use global_step to pass epoch, for lr decay
    global_step = tf.placeholder(tf.int32)
    train_op = tx.core.get_train_op(mle_loss,
                                    global_step=global_step,
                                    increment_global_step=False,
                                    hparams=config.opt)

    def _run_epoch(sess, data_iter, epoch, is_train=False, verbose=False):
        start_time = time.time()
        loss = 0.
        iters = 0
        state = sess.run(initial_state)

        fetches = {
            "mle_loss": mle_loss,
            "final_state": final_state,
        }
        if is_train:
            fetches["train_op"] = train_op
            epoch_size = (len(data["train_text_id"]) // batch_size - 1)\
                // num_steps

        mode = (tf.estimator.ModeKeys.TRAIN
                if is_train else tf.estimator.ModeKeys.EVAL)

        for step, (x, y) in enumerate(data_iter):
            feed_dict = {
                inputs: x,
                targets: y,
                global_step: epoch,
                tx.global_mode(): mode,
            }
            for i, (c, h) in enumerate(initial_state):
                feed_dict[c] = state[i].c
                feed_dict[h] = state[i].h

            rets = sess.run(fetches, feed_dict)
            loss += rets["mle_loss"]
            state = rets["final_state"]
            iters += num_steps

            ppl = np.exp(loss / iters)
            if verbose and is_train and step % (epoch_size // 10) == 10:
                print("%.3f perplexity: %.3f speed: %.0f wps" %
                      ((step + 1) * 1.0 / epoch_size, ppl, iters * batch_size /
                       (time.time() - start_time)))

        ppl = np.exp(loss / iters)
        return ppl

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        for epoch in range(config.num_epochs):
            # Train
            train_data_iter = ptb_iterator(data["train_text_id"],
                                           config.batch_size, num_steps)
            train_ppl = _run_epoch(sess,
                                   train_data_iter,
                                   epoch,
                                   is_train=True,
                                   verbose=True)
            print("Epoch: %d Train Perplexity: %.3f" % (epoch, train_ppl))
            # Valid
            valid_data_iter = ptb_iterator(data["valid_text_id"],
                                           config.batch_size, num_steps)
            valid_ppl = _run_epoch(sess, valid_data_iter, epoch)
            print("Epoch: %d Valid Perplexity: %.3f" % (epoch, valid_ppl))
        # Test
        test_data_iter = ptb_iterator(data["test_text_id"], batch_size,
                                      num_steps)
        test_ppl = _run_epoch(sess, test_data_iter, 0)
        print("Test Perplexity: %.3f" % (test_ppl))
Example #3
0
def _main(_):
    # Data
    tf.logging.set_verbosity(tf.logging.INFO)

    # 1. initialize the horovod
    hvd.init()

    batch_size = config.batch_size
    num_steps = config.num_steps
    data = prepare_data(FLAGS.data_path)
    vocab_size = data["vocab_size"]

    inputs = tf.placeholder(tf.int32, [None, num_steps], name='inputs')
    targets = tf.placeholder(tf.int32, [None, num_steps], name='targets')

    # Model architecture
    initializer = tf.random_uniform_initializer(-config.init_scale,
                                                config.init_scale)
    with tf.variable_scope("model", initializer=initializer):
        embedder = tx.modules.WordEmbedder(vocab_size=vocab_size,
                                           hparams=config.emb)
        emb_inputs = embedder(inputs)
        if config.keep_prob < 1:
            emb_inputs = tf.nn.dropout(
                emb_inputs, tx.utils.switch_dropout(config.keep_prob))

        decoder = tx.modules.BasicRNNDecoder(vocab_size=vocab_size,
                                             hparams={"rnn_cell": config.cell})

        # This _batch_size equals to batch_size // hvd.size() in
        # distributed training.
        # because the mini-batch is distributed to multiple GPUs

        _batch_size = tf.shape(inputs)[0]
        initial_state = decoder.zero_state(_batch_size, tf.float32)
        seq_length = tf.broadcast_to([num_steps], (_batch_size, ))
        outputs, final_state, seq_lengths = decoder(
            decoding_strategy="train_greedy",
            impute_finished=True,
            inputs=emb_inputs,
            sequence_length=seq_length,
            initial_state=initial_state)
    # Losses & train ops
    mle_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
        labels=targets, logits=outputs.logits, sequence_length=seq_lengths)

    # Use global_step to pass epoch, for lr decay
    global_step = tf.placeholder(tf.int32)

    opt = tx.core.get_optimizer(global_step=global_step, hparams=config.opt)

    # 2. wrap the optimizer
    opt = hvd.DistributedOptimizer(opt)

    train_op = tx.core.get_train_op(loss=mle_loss,
                                    optimizer=opt,
                                    global_step=global_step,
                                    learning_rate=None,
                                    increment_global_step=False,
                                    hparams=config.opt)

    def _run_epoch(sess, data_iter, epoch, is_train=False, verbose=False):
        start_time = time.time()
        loss = 0.
        iters = 0

        fetches = {
            "mle_loss": mle_loss,
            "final_state": final_state,
        }
        if is_train:
            fetches["train_op"] = train_op
            epoch_size = (len(data["train_text_id"]) // batch_size - 1)\
                // num_steps

        mode = (tf.estimator.ModeKeys.TRAIN
                if is_train else tf.estimator.ModeKeys.EVAL)

        for step, (x, y) in enumerate(data_iter):
            if step == 0:
                state = sess.run(initial_state, feed_dict={inputs: x})

            feed_dict = {
                inputs: x,
                targets: y,
                global_step: epoch,
                tx.global_mode(): mode,
            }
            for i, (c, h) in enumerate(initial_state):
                feed_dict[c] = state[i].c
                feed_dict[h] = state[i].h

            rets = sess.run(fetches, feed_dict)
            loss += rets["mle_loss"]
            state = rets["final_state"]
            iters += num_steps

            ppl = np.exp(loss / iters)
            if verbose and is_train and hvd.rank() == 0 \
                and (step + 1) % (epoch_size // 10) == 0:
                tf.logging.info(
                    "%.3f perplexity: %.3f speed: %.0f wps" %
                    ((step + 1) * 1.0 / epoch_size, ppl, iters * batch_size /
                     (time.time() - start_time)))
        _elapsed_time = time.time() - start_time
        tf.logging.info("epoch time elapsed: %f" % (_elapsed_time))
        ppl = np.exp(loss / iters)
        return ppl, _elapsed_time

    # 3. set broadcase global variables from rank-0 process
    bcast = hvd.broadcast_global_variables(0)

    # 4. set visible GPU
    session_config = tf.ConfigProto()
    session_config.gpu_options.visible_device_list = str(hvd.local_rank())

    with tf.Session(config=session_config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        # 5. run the broadcast_global_variables node before training
        bcast.run()

        _times = []
        for epoch in range(config.num_epochs):
            # Train
            train_data_iter = ptb_iterator(data["train_text_id"],
                                           config.batch_size,
                                           num_steps,
                                           is_train=True)
            train_ppl, train_time = _run_epoch(sess,
                                               train_data_iter,
                                               epoch,
                                               is_train=True,
                                               verbose=True)
            _times.append(train_time)
            tf.logging.info("Epoch: %d Train Perplexity: %.3f" %
                            (epoch, train_ppl))
            # Valid in the main process
            if hvd.rank() == 0:
                valid_data_iter = ptb_iterator(data["valid_text_id"],
                                               config.batch_size, num_steps)
                valid_ppl, _ = _run_epoch(sess, valid_data_iter, epoch)
                tf.logging.info("Epoch: %d Valid Perplexity: %.3f" %
                                (epoch, valid_ppl))

        tf.logging.info('train times: %s' % (_times))
        tf.logging.info('average train time/epoch %f' %
                        np.mean(np.array(_times)))
        # Test in the main process
        if hvd.rank() == 0:
            test_data_iter = ptb_iterator(data["test_text_id"], batch_size,
                                          num_steps)
            test_ppl, _ = _run_epoch(sess, test_data_iter, 0)
            tf.logging.info("Test Perplexity: %.3f" % (test_ppl))