コード例 #1
0
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

        # open old config and check if models are compatible
        with file_io.FileIO(os.path.join(args.init_from, 'config.pkl'),
                            'r') as f:
            saved_model_args = cPickle.load(f)
        need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with file_io.FileIO(os.path.join(args.init_from, 'chars_vocab.pkl'),
                            'r') as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars == data_loader.chars, "Data and loaded model disagree on character set!"
        assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    with file_io.FileIO(os.path.join(args.save_dir, 'config.pkl'), 'w') as f:
        cPickle.dump(args, f)
    with file_io.FileIO(os.path.join(args.save_dir, 'chars_vocab.pkl'),
                        'w') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)

    with tf.Session() as sess:

        merged_summaries = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(
            os.path.join(args.save_dir, 'log'), sess.graph)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(args.num_epochs):
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            data_loader.reset_batch_pointer()
            state = sess.run(model.state_in)
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {
                    model.data_in: x,
                    model.targets: y,
                    model.state_in: state
                }
                #                for i, (c, h) in enumerate(model.initial_state):
                #                    feed[c] = state[i].c
                #                    feed[h] = state[i].h
                summary, train_loss, state, _ = sess.run([
                    merged_summaries, model.cost, model.state_out,
                    model.train_op
                ], feed)
                summary_writer.add_summary(summary,
                                           e * data_loader.num_batches + b)
                end = time.time()
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, end - start))
                if (e * data_loader.num_batches + b) % args.save_every == 0\
                    or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))
コード例 #2
0
def train(args):

    start_time = datetime.now()

    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length,
                             args.test_flag)
    args.vocab_size = data_loader.vocab_size
    args.save_dir += '_bit_{}'.format(args.w_bit)
    result_file_path = 'result/bit_{}_{}.txt'.format(args.w_bit,
                                                     args.test_flag)

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            args.init_from), " %s must be a a path" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "config.pkl")
        ), "config.pkl file does not exist in path %s" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "chars_vocab.pkl")
        ), "chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = cPickle.load(f)
        need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl'), 'rb') as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars == data_loader.chars, "Data and loaded model disagree on character set!"
        assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)
    block_size = args.block_size

    with tf.Session() as sess:
        # instrument for tensorboard

        # tf.contrib.quantize.create_training_graph(quant_delay=2000000)
        # tf.contrib.quantize.create_eval_graph()

        summaries = tf.summary.merge_all()
        writer = tf.summary.FileWriter(
            os.path.join(args.log_dir, time.strftime("%Y-%m-%d-%H-%M-%S")))
        writer.add_graph(sess.graph)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(args.num_epochs):
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            loss_list = []
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y}
                for i, (c, h) in enumerate(model.initial_state):
                    feed[c] = state[i].c
                    feed[h] = state[i].h

                # instrument for tensorboard
                summ, train_loss, state, _ = sess.run(
                    [summaries, model.cost, model.final_state, model.train_op],
                    feed)
                writer.add_summary(summ, e * data_loader.num_batches + b)

                end = time.time()
                print(
                    "{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches, e,
                            train_loss, end - start))
                loss_list.append(train_loss)
                if (e * data_loader.num_batches + b) % args.save_every == 0\
                        or (e == args.num_epochs-1 and
                            b == data_loader.num_batches-1):
                    # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))
                # optim_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
                # print([v for v in optim_vars]) #=> prints lists of vars created
                size = np.sum([
                    np.prod(v.get_shape().as_list())
                    for v in tf.trainable_variables()
                ])

            print("mean_loss for this epoch:{:.3f}".format(
                sum(loss_list) / float(len(loss_list))))
            with open(result_file_path, 'a') as f:
                print("mean_loss for this epoch:{:.3f}".format(
                    sum(loss_list) / float(len(loss_list))),
                      file=f)

    print("Run time: {}".format(datetime.now() - start_time))
    with open(result_file_path, 'a') as f:
        print("Run time: {}".format(datetime.now() - start_time), file=f)
コード例 #3
0
def main(_):
  pp.pprint(FLAGS.__flags)

  if not os.path.exists(FLAGS.checkpoint_dir):
    print(" [*] Creating checkpoint directory...")
    os.makedirs(FLAGS.checkpoint_dir)

  data_loader = TextLoader(os.path.join(FLAGS.data_dir, FLAGS.dataset_name),
                           FLAGS.batch_size, FLAGS.seq_length)
  vocab_size = data_loader.vocab_size
  valid_size = 50
  valid_window = 100

  with tf.variable_scope('model'):
    train_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size,
                          FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                          FLAGS.seq_length, FLAGS.keep_prob,
                          FLAGS.grad_clip, FLAGS.nce_samples)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, train_model.global_step,
                                               data_loader.num_batches, FLAGS.grad_clip,
                                               staircase=True)
  with tf.variable_scope('model', reuse=True):
    simple_model = CharRNN(vocab_size, 1, FLAGS.rnn_size,
                           FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                           1, FLAGS.keep_prob,
                           FLAGS.grad_clip)

  with tf.variable_scope('model', reuse=True):
    valid_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size,
                          FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                          FLAGS.seq_length, FLAGS.keep_prob,
                          FLAGS.grad_clip)

  with tf.Session() as sess:
    tf.global_variables_initializer().run()

    best_val_pp = float('inf')
    best_val_epoch = 0
    valid_loss = 0
    valid_perplexity = 0
    start = time.time()

    if FLAGS.export:
      print("Eval...")
      final_embeddings = train_model.embedding.eval(sess)
      emb_file = os.path.join(FLAGS.data_dir, FLAGS.dataset_name, 'emb.npy')
      print("Embedding shape: {}".format(final_embeddings.shape))
      np.save(emb_file, final_embeddings)

    else: # Train
      current_step = 0
      similarity, valid_examples, _ = compute_similarity(train_model, valid_size, valid_window, 6)

      # save hyper-parameters
      cPickle.dump(FLAGS.__flags, open(FLAGS.log_dir + "/hyperparams.pkl", 'wb'))

      # run it!
      for e in range(FLAGS.num_epochs):
        data_loader.reset_batch_pointer()

        # decay learning rate
        sess.run(tf.assign(train_model.lr, learning_rate))

        # iterate by batch
        for b in range(data_loader.num_batches):
          x, y = data_loader.next_batch()
          res, time_batch = run_epochs(sess, x, y, train_model)
          train_loss = res["loss"][0]
          train_perplexity = np.exp(train_loss)
          iterate = e * data_loader.num_batches + b

          if current_step != 0 and current_step % FLAGS.valid_every == 0:
            valid_loss = 0

            for vb in range(data_loader.num_valid_batches):
              res, valid_time_batch = run_epochs(sess, data_loader.x_valid[vb], data_loader.y_valid[vb], valid_model, False)
              valid_loss += res["loss"][0]

            valid_loss = valid_loss / data_loader.num_valid_batches
            valid_perplexity = np.exp(valid_loss)

            print("### valid_perplexity = {:.2f}, time/batch = {:.2f}".format(valid_perplexity, valid_time_batch))

            log_str = ""

            # Generate sample
            smp1 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"我喜歡做")
            smp2 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"他吃飯時會用")
            smp3 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"人類總要重複同樣的")
            smp4 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"天色暗了,好像快要")

            log_str = log_str + smp1 + "\n"
            log_str = log_str + smp2 + "\n"
            log_str = log_str + smp3 + "\n"
            log_str = log_str + smp4 + "\n"

            # Write a similarity log
            # Note that this is expensive (~20% slowdown if computed every 500 steps)
            sim = similarity.eval()
            for i in range(valid_size):
              valid_word = data_loader.chars[valid_examples[i]]
              top_k = 8 # number of nearest neighbors
              nearest = (-sim[i, :]).argsort()[1:top_k+1]
              log_str = log_str + "Nearest to %s:" % valid_word
              for k in range(top_k):
                close_word = data_loader.chars[nearest[k]]
                log_str = "%s %s," % (log_str, close_word)
              log_str = log_str + "\n"
            print(log_str)
            # Write to log
            text_file = codecs.open(FLAGS.log_dir + "/similarity.txt", "w", "utf-8")
            text_file.write(log_str)
            text_file.close()

          # print log
          print("{}/{} (epoch {}) loss = {:.2f}({:.2f}) perplexity(train/valid) = {:.2f}({:.2f}) time/batch = {:.2f} chars/sec = {:.2f}k"\
              .format(e * data_loader.num_batches + b,
                      FLAGS.num_epochs * data_loader.num_batches,
                      e, train_loss, valid_loss, train_perplexity, valid_perplexity,
                      time_batch, (FLAGS.batch_size * FLAGS.seq_length) / time_batch / 1000))

          current_step = tf.train.global_step(sess, train_model.global_step)

        if valid_perplexity < best_val_pp:
          best_val_pp = valid_perplexity
          best_val_epoch = iterate

          # save best model
          train_model.save(sess, FLAGS.checkpoint_dir, FLAGS.dataset_name)
          print("model saved to {}".format(FLAGS.checkpoint_dir))

        # early_stopping
        if iterate - best_val_epoch > FLAGS.early_stopping:
          print('Total time: {}'.format(time.time() - start))
          break
コード例 #4
0
def train(args):
    # Create the data_loader object, which loads up all of our batches, vocab dictionary, etc.
    # from utils.py (and creates them if they don't already exist).
    # These files go in the data directory.
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    load_model = False
    if not os.path.exists(args.save_dir):
        print("Creating directory %s" % args.save_dir)
        os.mkdir(args.save_dir)
    elif os.path.exists(os.path.join(args.save_dir, 'config.pkl')):
        # Trained model already exists
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            with open(os.path.join(args.save_dir, 'config.pkl')) as f:
                saved_args = pickle.load(f)
                args.rnn_size = saved_args.rnn_size
                args.num_layers = saved_args.num_layers
                args.model = saved_args.model
                print("Found a previous checkpoint. Overwriting model description arguments to:")
                print(" model: {}, rnn_size: {}, num_layers: {}".format(
                    saved_args.model, saved_args.rnn_size, saved_args.num_layers))
                load_model = True

    # Save all arguments to config.pkl in the save directory -- NOT the data directory.
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)
    # Save a tuple of the characters list and the vocab dictionary to chars_vocab.pkl in
    # the save directory -- NOT the data directory.
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        pickle.dump((data_loader.chars, data_loader.vocab), f)

    # Create the model!
    print("Building the model")
    model = Model(args)

    config = tf.ConfigProto(log_device_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(model.save_variables_list())
        if load_model:
            print("Loading saved parameters")
            saver.restore(sess, ckpt.model_checkpoint_path)
        global_epoch_fraction = sess.run(model.global_epoch_fraction)
        global_seconds_elapsed = sess.run(model.global_seconds_elapsed)
        if load_model:
            print("Resuming from global epoch fraction {:.3f}, total trained time: {}, learning rate: {}".format(
                global_epoch_fraction, global_seconds_elapsed, sess.run(model.lr)))
        data_loader.cue_batch_pointer_to_epoch_fraction(global_epoch_fraction)
        initial_batch_step = int((global_epoch_fraction - int(global_epoch_fraction)) * data_loader.total_batch_count)
        epoch_range = (int(global_epoch_fraction), args.num_epochs + int(global_epoch_fraction))
        writer = tf.summary.FileWriter(args.save_dir, graph=tf.get_default_graph())
        outputs = [model.cost, model.final_state, model.train_op, model.summary_op]
        is_lstm = args.model == 'lstm'
        global_step = epoch_range[0] * data_loader.total_batch_count + initial_batch_step
        try:
            for e in range(*epoch_range):
                # e iterates through the training epochs.
                # Reset the model state, so it does not carry over from the end of the previous epoch.
                state = sess.run(model.initial_state)
                batch_range = (initial_batch_step, data_loader.total_batch_count)
                initial_batch_step = 0
                for b in range(*batch_range):
                    global_step += 1
                    if global_step % args.decay_steps == 0:
                        # Set the model.lr element of the model to track
                        # the appropriately decayed learning rate.
                        current_learning_rate = sess.run(model.lr)
                        current_learning_rate *= args.decay_rate
                        sess.run(tf.assign(model.lr, current_learning_rate))
                        print("Decayed learning rate to {}".format(current_learning_rate))
                    start = time.time()
                    # Pull the next batch inputs (x) and targets (y) from the data loader.
                    x, y = data_loader.next_batch()

                    # feed is a dictionary of variable references and respective values for initialization.
                    # Initialize the model's input data and target data from the batch,
                    # and initialize the model state to the final state from the previous batch, so that
                    # model state is accumulated and carried over between batches.
                    feed = {model.input_data: x, model.targets: y}
                    if is_lstm:
                        for i, (c, h) in enumerate(model.initial_state):
                            feed[c] = state[i].c
                            feed[h] = state[i].h
                    else:
                        for i, c in enumerate(model.initial_state):
                            feed[c] = state[i]
                    # Run the session! Specifically, tell TensorFlow to compute the graph to calculate
                    # the values of cost, final state, and the training op.
                    # Cost is used to monitor progress.
                    # Final state is used to carry over the state into the next batch.
                    # Training op is not used, but we want it to be calculated, since that calculation
                    # is what updates parameter states (i.e. that is where the training happens).
                    train_loss, state, _, summary = sess.run(outputs, feed)
                    elapsed = time.time() - start
                    global_seconds_elapsed += elapsed
                    writer.add_summary(summary, e * batch_range[1] + b + 1)
                    print("{}/{} (epoch {}/{}), loss = {:.3f}, time/batch = {:.3f}s".format(
                        b, batch_range[1], e, epoch_range[1], train_loss, elapsed))
                    # Every save_every batches, save the model to disk.
                    # By default, only the five most recent checkpoint files are kept.
                    if (e * batch_range[1] + b + 1) % args.save_every == 0 \
                            or (e == epoch_range[1] - 1 and b == batch_range[1] - 1):
                        save_model(sess, saver, model, args.save_dir, global_step,
                                data_loader.total_batch_count, global_seconds_elapsed)
        except KeyboardInterrupt:
            # Introduce a line break after ^C is displayed so save message
            # is on its own line.
            print()
        finally:
            writer.flush()
            global_step = e * data_loader.total_batch_count + b
            save_model(sess, saver, model, args.save_dir, global_step, data_loader.total_batch_count,
                       global_seconds_elapsed)
コード例 #5
0
def main():
    args = parse_args()
    loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = loader.vocab_size
    print("vocab_size = {}".format(args.vocab_size))

    if args.init_from is not None:
        if os.path.isdir(args.init_from):  # init from directory
            assert os.path.exists(args.init_from), \
            "{} is not a directory".format(args.init_from)
            parent_dir = args.init_from
        else:  # init from file
            assert os.path.exists("{}.index".format(args.init_from)), \
            "{} is not a checkpoint".format(args.init_from)
            parent_dir = os.path.dirname(args.init_from)

        config_file = os.path.join(parent_dir, 'config.pkl')
        vocab_file = os.path.join(parent_dir, 'vocab.pkl')

        assert os.path.isfile(config_file), \
        "config.pkl does not exist in directory {}".format(parent_dir)
        assert os.path.isfile(vocab_file), \
        "vocab.pkl does not exist in directory {}".format(parent_dir)

        if os.path.isdir(args.init_from):
            checkpoint = tf.train.latest_checkpoint(parent_dir)
            assert checkpoint, \
            "no checkpoint in directory {}".format(init_from)
        else:
            checkpoint = args.init_from

        with open(os.path.join(parent_dir, 'config.pkl'), 'rb') as f:
            saved_args = pickle.load(f)
        with open(os.path.join(parent_dir, 'vocab.pkl'), 'rb') as f:
            saved_vocab = pickle.load(f)
        assert saved_vocab == loader.vocab, \
        "vocab in data directory differs from save"

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    new_config_file = os.path.join(args.save_dir, 'config.pkl')
    new_vocab_file = os.path.join(args.save_dir, 'vocab.pkl')

    if not os.path.exists(new_config_file):
        with open(new_config_file, 'wb') as f:
            pickle.dump(args, f)
    if not os.path.exists(new_vocab_file):
        with open(new_vocab_file, 'wb') as f:
            pickle.dump(loader.vocab, f)

    model = Model(args)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())

        if args.init_from is not None:
            try:
                saver.restore(sess, checkpoint)
            except ValueError:
                print("{} is not a valid checkpoint".format(checkpoint))
            print("initializing from {}".format(checkpoint))

        for e in range(args.num_epochs):
            loader.reset_batch_pointer()
            for b in range(loader.num_batches):
                start = time.time()
                x, y, length = loader.next_batch()
                feed = {
                    model.input_data: x,
                    model.targets: y,
                    model.sequence_lengths: length
                }
                train_loss, _ = sess.run([model.cost, model.optimizer], feed)
                end = time.time()
                global_step = e * loader.num_batches + b
                if global_step % args.display_every == 0 and global_step != 0:
                    print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(b, loader.num_batches, e, train_loss, end - start))
                if global_step % args.save_every == 0 and global_step != 0:
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=global_step)
                    print("model saved to {}".format(checkpoint_path))
コード例 #6
0
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            args.init_from), " %s must be a a path" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "config.pkl")
        ), "config.pkl file does not exist in path %s" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "chars_vocab.pkl")
        ), "chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl')) as f:
            saved_model_args = cPickle.load(f)
        need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl')) as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars == data_loader.chars, "Data and loaded model disagree on character set!"
        assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)
    totaltime = 0
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(args.num_epochs):
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            data_loader.reset_batch_pointer()
            state = model.initial_state.eval()
            for b in range(data_loader.num_batches):
                # x, y = data_loader.next_batch()
                # feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                start = time.time()
                state, _ = sess.run([model.final_state, model.train_op])
                end = time.time()
                totaltime += (end - start)
                print('Total time is ', totaltime)
        print('Total time for whole dataset is', totaltime)
コード例 #7
0
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            args.init_from), " %s must be a a path" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "config.pkl")
        ), "config.pkl file does not exist in path %s" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "chars_vocab.pkl")
        ), "chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = cPickle.load(f)
        need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl'), 'rb') as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars == data_loader.chars, "Data and loaded model disagree on character set!"
        assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)

    with tf.Session() as sess:
        # instrument for tensorboard
        summaries = tf.summary.merge_all()
        writer = tf.summary.FileWriter(
            os.path.join(args.log_dir, time.strftime("%Y-%m-%d-%H-%M-%S")))
        writer.add_graph(sess.graph)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(args.num_epochs):
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y}
                for i, (c, h, z) in enumerate(model.initial_state):
                    feed[c] = state[i].c
                    feed[h] = state[i].h
                    feed[z] = state[i].z
                train_loss, state, _ = sess.run(
                    [model.cost, model.final_state, model.train_op], feed)

                # instrument for tensorboard
                summ, train_loss, state, _ = sess.run(
                    [summaries, model.cost, model.final_state, model.train_op],
                    feed)
                writer.add_summary(summ, e * data_loader.num_batches + b)

                end = time.time()
                print(
                    "{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches, e,
                            train_loss, end - start))
                if (e * data_loader.num_batches + b) % args.save_every == 0\
                        or (e == args.num_epochs-1 and
                            b == data_loader.num_batches-1):
                    # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))
コード例 #8
0
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length, partition='train')
    eval_data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length, partition='eval')
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = cPickle.load(f)
        need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl'), 'rb') as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars==data_loader.chars, "Data and loaded model disagree on character set!"
        assert saved_vocab==data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    # model = Model(args, opt_method="Adam")
    model = Model(args, opt_method="SGD")
    loss_list = []
    eval_loss_list = []
    gpu_mem_portion=0.005
    n_core = 16
    with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=n_core,
                          inter_op_parallelism_threads=n_core,
                          gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_mem_portion))) as sess, tf.device("cpu:0") as devices:
        # instrument for tensorboard
        summaries = tf.summary.merge(model.train_summary)
        writer = tf.summary.FileWriter(
                os.path.join(args.log_dir, time.strftime("%Y-%m-%d-%H-%M-%S")))
        writer.add_graph(sess.graph)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)

        

        # do evaluation
        e = -1
        eval_loss = 0.0
        start = time.time()
        eval_data_loader.reset_batch_pointer()
        state = sess.run(model.initial_state)
        for b in range(eval_data_loader.num_batches):
            x, y = eval_data_loader.next_batch()
            feed = {model.input_data: x, model.targets: y}
            for i, (c, h) in enumerate(model.initial_state):
                feed[c] = state[i].c
                feed[h] = state[i].h
            eval_loss_batch, state = sess.run([model.eval_cost, model.final_state], feed)
            eval_loss += eval_loss_batch
        eval_loss /= eval_data_loader.num_batches
        # instrument for tensorboard
        summ = tf.Summary(value=[tf.Summary.Value(tag="eval_loss", simple_value=eval_loss), ])
        writer.add_summary(summ, e * data_loader.num_batches)

        eval_loss_list.append( [(e + 1) * data_loader.num_batches, eval_loss] )
        end = time.time()
        print("{}/{} (epoch {}), eval_loss = {:.3f}, time/batch = {:.3f}"
              .format( (e + 1) * data_loader.num_batches,
                      args.num_epochs * data_loader.num_batches,
                      0, eval_loss, end - start))

        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr,
                               args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y}
                for i, (c, h) in enumerate(model.initial_state):
                    feed[c] = state[i].c
                    feed[h] = state[i].h
                # train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)

                # instrument for tensorboard
                summ, train_loss, state, _ = sess.run([summaries, model.cost, model.final_state, model.train_op], feed)
                writer.add_summary(summ, e * data_loader.num_batches + b)

                loss_list.append(train_loss)

                end = time.time()
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                      .format(e * data_loader.num_batches + b,
                              args.num_epochs * data_loader.num_batches,
                              e, train_loss, end - start))
                if (e * data_loader.num_batches + b) % args.save_every == 0\
                        or (e == args.num_epochs-1 and
                            b == data_loader.num_batches-1):
                    # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path,
                               global_step=e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))

            # do evaluation
            eval_loss = 0.0
            start = time.time()
            print("start evaluation")
            eval_data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            for b in range(eval_data_loader.num_batches):
                x, y = eval_data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y}
                for i, (c, h) in enumerate(model.initial_state):
                    feed[c] = state[i].c
                    feed[h] = state[i].h
                eval_loss_batch, state = sess.run([model.eval_cost, model.final_state], feed)
                eval_loss += eval_loss_batch
            eval_loss /= eval_data_loader.num_batches
            # instrument for tensorboard
            summ = tf.Summary(value=[tf.Summary.Value(tag="eval_loss", simple_value=eval_loss), ])
            writer.add_summary(summ, e * data_loader.num_batches)

            eval_loss_list.append( [(e + 1) * data_loader.num_batches, eval_loss] )
            end = time.time()
            print("{}/{} (epoch {}), eval_loss = {:.3f}, time/batch = {:.3f}"
                  .format( (e + 1) * data_loader.num_batches,
                          args.num_epochs * data_loader.num_batches,
                          e, eval_loss, end - start))
            

    	    with open(args.log_dir + "/loss.txt", "w") as f:
        	np.savetxt(f, np.array(loss_list) )
            with open(args.log_dir + "/eval_loss.txt", "w") as f:
                np.savetxt(f, np.array(eval_loss_list) )
コード例 #9
0
def train(args):
    print(args)
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length, args.training_data_ratio)
    args.vocab_size = data_loader.vocab_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)

    #sess = tf.InteractiveSession()
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()
        summary_writer = tf.train.SummaryWriter('/tmp', sess.graph)

        step = 0
        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            #print("model learning rate is {}".format(model.lr.eval()))
            data_loader.reset_batch_pointer('train')

            state = model.initial_state.eval()
            for b in xrange(data_loader.ntrain):
                start = time.time()
                x, y = data_loader.next_batch('train')

                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                step = e * data_loader.ntrain + b
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(step,
                            args.num_epochs * data_loader.ntrain,
                            e, train_loss, end - start))

                if step % args.write_summary_every == 0:
                    # training loss
                    summary_str = sess.run(summary_op, feed_dict=feed)
                    summary_writer.add_summary(summary_str, step)

                if step % args.save_every == 0 or (step + 1) == (args.num_epochs * data_loader.ntrain):
                    # eval validation loss
                    data_loader.reset_batch_pointer('validation')
                    validation_state = model.initial_state.eval()
                    val_losses = 0
                    for n in xrange(data_loader.nvalidation):
                        x, y = data_loader.next_batch('validation')
                        val_feed = {model.input_data: x, model.targets: y, model.initial_state: validation_state}
                        validation_loss, validation_state = sess.run([model.cost, model.final_state], val_feed)
                        val_losses += validation_loss

                    validation_loss = val_losses / data_loader.nvalidation
                    print("validation loss is {}".format(validation_loss))

                    # write top 5 validation loss to a json file
                    args_dict = vars(args)
                    args_dict['step'] = step
                    val_loss_file = args.save_dir + '/val_loss.json'
                    loss_json = ''
                    save_new_checkpoint = False
                    time_int = int(time.time())
                    args_dict['checkpoint_path'] = os.path.join(args.save_dir, 'model.ckpt-'+str(time_int))
                    if os.path.exists(val_loss_file):
                        with open(val_loss_file, "r") as text_file:
                            text = text_file.read()
                            if text == '':
                                loss_json = {validation_loss: args_dict}
                                save_new_checkpoint = True
                            else:
                                loss_json = json.loads(text)
                                losses = loss_json.keys()
                                if len(losses) > 3:
                                    losses.sort(key=lambda x: float(x), reverse=True)
                                    loss = losses[0]
                                    if validation_loss < float(loss):
                                        to_be_remove_ckpt_file_path =  loss_json[loss]['checkpoint_path']
                                        to_be_remove_ckpt_meta_file_path = to_be_remove_ckpt_file_path + '.meta'
                                        print("removed checkpoint {}".format(to_be_remove_ckpt_file_path))
                                        if os.path.exists(to_be_remove_ckpt_file_path):
                                            os.remove(to_be_remove_ckpt_file_path)
                                        if os.path.exists(to_be_remove_ckpt_meta_file_path):
                                            os.remove(to_be_remove_ckpt_meta_file_path)
                                        del(loss_json[loss])
                                        loss_json[validation_loss] = args_dict
                                        save_new_checkpoint = True
                                else:
                                    loss_json[validation_loss] = args_dict
                                    save_new_checkpoint = True
                    else:
                       loss_json = {validation_loss: args_dict}
                       save_new_checkpoint = True

                    if save_new_checkpoint:
                        checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step = time_int)
                        print("model saved to {}".format(checkpoint_path + '-' + str(time_int)))

                        with open(val_loss_file, "w") as text_file:
                            json.dump(loss_json, text_file)
コード例 #10
0
 def setUp(self):
     self.data_loader = TextLoader("tests/test_data",
                                   batch_size=2,
                                   seq_length=5)
コード例 #11
0
def main():
    tl = TextLoader()
    args = Param()
    lstm_model = Model(args)
    train(tl, lstm_model, args)
コード例 #12
0
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length,
                             args.input_encoding)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            args.init_from), " %s must be a path" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "config.pkl")
        ), "config.pkl file does not exist in path %s" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "words_vocab.pkl")
        ), "words_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = cPickle.load(f)
        need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'words_vocab.pkl'), 'rb') as f:
            saved_words, saved_vocab = cPickle.load(f)
        assert saved_words == data_loader.words, "Data and loaded model disagree on word set!"
        assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.words, data_loader.vocab), f)

    model = Model(args)

    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(args.log_dir)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=True)
                    ) as sess:  # fareed gpu_options=gpu_options)) as sess:
        train_writer.add_graph(sess.graph)
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(tf.global_variables())

        #fareed
        dot_rep = graph_to_dot(sess.graph)
        #s = Source(dot_rep, filename="test.gv", format="PNG")
        with open('./profs/rnn.dot', 'w') as fwr:
            fwr.write(str(dot_rep))

        options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()

        operations_tensors = {}
        operations_names = tf.get_default_graph().get_operations()
        count1 = 0
        count2 = 0

        for operation in operations_names:
            operation_name = operation.name
            operations_info = tf.get_default_graph().get_operation_by_name(
                operation_name).values()
            if len(operations_info) > 0:
                if not (operations_info[0].shape.ndims is None):
                    operation_shape = operations_info[0].shape.as_list()
                    operation_dtype_size = operations_info[0].dtype.size
                    if not (operation_dtype_size is None):
                        operation_no_of_elements = 1
                        for dim in operation_shape:
                            if not (dim is None):
                                operation_no_of_elements = operation_no_of_elements * dim
                        total_size = operation_no_of_elements * operation_dtype_size
                        operations_tensors[operation_name] = total_size
                    else:
                        count1 = count1 + 1
                else:
                    count1 = count1 + 1
                    operations_tensors[operation_name] = -1

                #   print('no shape_1: ' + operation_name)
                #  print('no shape_2: ' + str(operations_info))
                #  operation_namee = operation_name + ':0'
                # tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print('no shape_3:' + str(tf.shape(tensor)))
                # print('no shape:' + str(tensor.get_shape()))

            else:
                # print('no info :' + operation_name)
                # operation_namee = operation.name + ':0'
                count2 = count2 + 1
                operations_tensors[operation_name] = -1

                # try:
                #   tensor = tf.get_default_graph().get_tensor_by_name(operation_namee)
                # print(tensor)
                # print(tf.shape(tensor))
                # except:
                # print('no tensor: ' + operation_namee)
        print(count1)
        print(count2)

        with open('./profs/tensors_sz_32.txt', 'w') as f:
            for tensor, size in operations_tensors.items():
                f.write('"' + tensor + '"::' + str(size) + '\n')
        #end fareed

        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(model.epoch_pointer.eval(), args.num_epochs):
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            speed = 0
            if args.init_from is None:
                assign_op = model.epoch_pointer.assign(e)
                sess.run(assign_op)
            if args.init_from is not None:
                data_loader.pointer = model.batch_pointer.eval()
                args.init_from = None
            for b in range(data_loader.pointer, data_loader.num_batches):
                x, y = data_loader.next_batch()
                feed = {
                    model.input_data: x,
                    model.targets: y,
                    model.initial_state: state
                }
                start = time.time()

                if b % 10 == 7:
                    summary, train_loss, state, _, _ = sess.run(
                        [
                            merged, model.cost, model.final_state,
                            model.train_op, model.inc_batch_pointer_op
                        ],
                        feed,
                        run_metadata=run_metadata,
                        options=options)
                    profile(run_metadata, b)
                    if b == 7:
                        options_mem = tf.profiler.ProfileOptionBuilder.time_and_memory(
                        )
                        options_mem["min_bytes"] = 0
                        options_mem["min_micros"] = 0
                        options_mem["output"] = 'file:outfile=./profs/mem.txt'
                        options_mem["select"] = ("bytes", "peak_bytes",
                                                 "output_bytes",
                                                 "residual_bytes")
                        mem = tf.profiler.profile(tf.get_default_graph(),
                                                  run_meta=run_metadata,
                                                  cmd="scope",
                                                  options=options_mem)

                else:
                    summary, train_loss, state, _, _ = sess.run([
                        merged, model.cost, model.final_state, model.train_op,
                        model.inc_batch_pointer_op
                    ], feed)
                    speed = time.time() - start
                    train_writer.add_summary(summary,
                                             e * data_loader.num_batches + b)

                if (e * data_loader.num_batches + b) % int(
                        args.batch_size / 10) == 0 and b % 10 != 7:
                    print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                        .format(e * data_loader.num_batches + b,
                                args.num_epochs * data_loader.num_batches,
                                e, train_loss, speed))
                """ if (e * data_loader.num_batches + b) % args.save_every == 0 \
                        or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path)) """
        train_writer.close()
コード例 #13
0
ファイル: train.py プロジェクト: Tahlor/word-rnn-tensorflow
def train(args):
    tf.reset_default_graph()
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length,
                             args.input_encoding)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        try:
            # check if all necessary files exist
            assert os.path.isdir(
                args.init_from), " %s must be a path" % args.init_from
            assert os.path.isfile(
                os.path.join(args.init_from, "config.pkl")
            ), "config.pkl file does not exist in path %s" % args.init_from
            assert os.path.isfile(
                os.path.join(args.init_from, "words_vocab.pkl")
            ), "words_vocab.pkl.pkl file does not exist in path %s" % args.init_from
            ckpt = tf.train.get_checkpoint_state(args.init_from)
            assert ckpt, "No checkpoint found"
            assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

            # open old config and check if models are compatible
            with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
                saved_model_args = cPickle.load(f)
            need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
            for checkme in need_be_same:
                assert vars(saved_model_args)[checkme] == vars(
                    args
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

            # open saved vocab/dict and check if vocabs/dicts are compatible
            with open(os.path.join(args.init_from, 'words_vocab.pkl'),
                      'rb') as f:
                saved_words, saved_vocab = cPickle.load(f)
            assert saved_words == data_loader.words, "Data and loaded model disagree on word set!"
            assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"
        except:
            print("Could not init from old file")

    ## Dump new stuff
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.words, data_loader.vocab), f)

    model = Model(args)

    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(args.log_dir)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem)

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        model_dict = {
            "model": model,
            "words": data_loader.words,
            "vocab": data_loader.vocab,
            "sess": sess
        }
        train_writer.add_graph(sess.graph)

        # Write graph quick
        writer = tf.summary.FileWriter(os.path.join(args.save_dir, "graph"),
                                       sess.graph)
        writer.close()

        tf.global_variables_initializer().run()
        saver = tf.train.Saver(tf.global_variables())

        # restore model
        if args.init_from is not None:
            try:
                saver.restore(sess, ckpt.model_checkpoint_path)
            except:
                print("Could not restore")

        # Epoch loop
        for e in range(model.epoch_pointer.eval(), args.num_epochs):
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            speed = 0
            if args.init_from is None:
                assign_op = model.epoch_pointer.assign(e)
                sess.run(assign_op)
            if args.init_from is not None:
                try:
                    data_loader.pointer = model.batch_pointer.eval()
                    args.init_from = None
                except:
                    pass

            # Batch step loop
            for b in range(data_loader.pointer, data_loader.num_batches):
                start = time.time()
                x, y, last_words, syllables, topic_words = data_loader.next_batch(
                )

                # Concatenate Inputs
                #x = tf.concat([x[:,:,None],last_words[:,:,None]],2)
                if args.end_word_training:
                    feed = {
                        model.input_data: x,
                        model.targets: last_words,
                        model.bonus_features: last_words,
                        model.initial_state: state,
                        model.syllables: syllables,
                        model.topic_words: topic_words,
                        model.batch_time: speed
                    }
                elif args.syllable_training:
                    feed = {
                        model.input_data: x,
                        model.targets: last_words,
                        model.bonus_features: last_words,
                        model.initial_state: state,
                        model.syllables: syllables,
                        model.topic_words: topic_words,
                        model.batch_time: speed
                    }
                else:
                    feed = {
                        model.input_data: x,
                        model.targets: y,
                        model.bonus_features: last_words,
                        model.initial_state: state,
                        model.syllables: syllables,
                        model.topic_words: topic_words,
                        model.batch_time: speed
                    }
                summary, train_loss, state, _, _ = sess.run([
                    merged, model.cost, model.final_state, model.train_op,
                    model.inc_batch_pointer_op
                ], feed)
                train_writer.add_summary(summary,
                                         e * data_loader.num_batches + b)
                speed = time.time() - start
                if (e * data_loader.num_batches + b) % args.batch_size == 0:
                    print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                        .format(e * data_loader.num_batches + b,
                                args.num_epochs * data_loader.num_batches,
                                e, train_loss, speed))
                #if (e * data_loader.num_batches + b) % args.save_every == 0 \
                #if b % 1000 in [1, 100] \
                if (e * data_loader.num_batches + b) % args.save_every == 0 \
                    or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))

                    #sample.main(save_dir = args.save_dir, output_path = "sample.txt", internal_call = True, model = model_dict)
                    python_path = "python"
                    #python_path = r"/usr/bin/python2.6/python"
                    if args.sample:
                        subprocess.call(
                            "python sample.py -e turtle -o sample.txt -s {}".
                            format(args.save_dir).split(),
                            shell=False)

        train_writer.close()
コード例 #14
0
ファイル: train.py プロジェクト: karantan/word-rnn-tensorflow
    def run(self):
        data_loader = TextLoader(
            self.config.data_dir,
            self.config.batch_size,
            self.config.seq_length,
            self.config.input_encoding,
        )
        self.config.vocab_size = data_loader.vocab_size

        # check compatibility if training is continued from previously
        # saved model
        if self.config.init_from is not None:
            # check if all necessary files exist
            assert os.path.isdir(
                self.config.init_from), ('{} must be a path'.format(
                    self.config.init_from))
            assert os.path.isfile(
                os.path.join(self.config.init_from, 'config.pkl')), (
                    'config.pkl file does not exist in path {}'.format(
                        self.config.init_from))
            assert os.path.isfile(
                os.path.join(self.config.init_from, 'words_vocab.pkl')
            ), 'words_vocab.pkl.pkl file does not exist in path {}'.format(
                self.config.init_from)
            ckpt = tf.train.get_checkpoint_state(self.config.init_from)
            assert ckpt, 'No checkpoint found'
            assert ckpt.model_checkpoint_path, (
                'No model path found in checkpoint')

            # open old config and check if models are compatible
            with open(os.path.join(self.config.init_from, 'config.pkl'),
                      'rb') as f:
                saved_model_args = cPickle.load(f)
            need_be_same = ['model', 'rnn_size', 'num_layers', 'seq_length']
            for checkme in need_be_same:
                assert vars(saved_model_args)[checkme] == vars(
                    self)[checkme], (
                        'Command line argument and saved model disagree '
                        'on "{}".'.format(checkme))

            # open saved vocab/dict and check if vocabs/dicts are compatible
            with open(os.path.join(self.config.init_from, 'words_vocab.pkl'),
                      'rb') as f:
                saved_words, saved_vocab = cPickle.load(f)
            assert saved_words == data_loader.words, (
                'Data and loaded model disagree on word set!')
            assert saved_vocab == data_loader.vocab, (
                'Data and loaded model disagree on dictionary mappings!')

        with open(os.path.join(self.config.save_dir, 'config.pkl'), 'wb') as f:
            cPickle.dump(self.config, f)
        with open(os.path.join(self.config.save_dir, 'words_vocab.pkl'),
                  'wb') as f:
            cPickle.dump((data_loader.words, data_loader.vocab), f)

        model = Model(self.config)

        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(self.config.log_dir)
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=self.config.gpu_mem)

        with tf.Session(config=tf.ConfigProto(
                gpu_options=gpu_options)) as sess:
            train_writer.add_graph(sess.graph)
            tf.global_variables_initializer().run()
            saver = tf.train.Saver(tf.global_variables())
            # restore model
            if self.config.init_from is not None:
                saver.restore(sess, ckpt.model_checkpoint_path)
            for e in range(model.epoch_pointer.eval(), self.config.num_epochs):
                sess.run(
                    tf.assign(
                        model.lr,
                        self.config.learning_rate *
                        (self.config.decay_rate**e),
                    ))
                data_loader.reset_batch_pointer()
                state = sess.run(model.initial_state)
                speed = 0
                if self.config.init_from is None:
                    assign_op = model.epoch_pointer.assign(e)
                    sess.run(assign_op)
                if self.config.init_from is not None:
                    data_loader.pointer = model.batch_pointer.eval()
                    self.config.init_from = None
                for b in range(data_loader.pointer, data_loader.num_batches):
                    start = time.time()
                    x, y = data_loader.next_batch()
                    feed = {
                        model.input_data: x,
                        model.targets: y,
                        model.initial_state: state,
                        model.batch_time: speed,
                    }
                    summary, train_loss, state, _, _ = sess.run([
                        merged,
                        model.cost,
                        model.final_state,
                        model.train_op,
                        model.inc_batch_pointer_op,
                    ], feed)
                    train_writer.add_summary(summary,
                                             e * data_loader.num_batches + b)
                    speed = time.time() - start
                    if ((e * data_loader.num_batches + b) %
                            self.config.batch_size == 0):
                        print(
                            '{}/{} (epoch {}), train_loss = {:.3f}, '
                            'time/batch = {:.3f}'.format(
                                e * data_loader.num_batches + b,
                                self.config.num_epochs *
                                data_loader.num_batches,
                                e,
                                train_loss,
                                speed,
                            ), )
                    # save for the last result
                    if ((e * data_loader.num_batches + b) %
                            self.config.save_every == 0
                            or (e == self.config.num_epochs - 1
                                and b == data_loader.num_batches - 1)):
                        checkpoint_path = os.path.join(
                            self.config.save_dir,
                            'model-{:.3f}.ckpt'.format(train_loss),
                        )
                        saver.save(
                            sess,
                            checkpoint_path,
                            global_step=e * data_loader.num_batches + b,
                        )
                        print('model saved to {}'.format(checkpoint_path))
            train_writer.close()
コード例 #15
0
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    load_model = False
    if not os.path.exists(args.save_dir):
        print("Creating directory %s" % args.save_dir)
        os.mkdir(args.save_dir)
    elif (os.path.exists(os.path.join(args.save_dir, 'config.pkl'))):
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
                saved_args = pickle.load(f)
                args.block_size = saved_args.block_size
                args.num_blocks = saved_args.num_blocks
                args.num_layers = saved_args.num_layers
                args.model = saved_args.model
                print(
                    "Found a previous checkpoint. Overwritin model description arguments to:"
                )
                print(
                    " model: {}, block_size: {}, num_blocks: {}, num_layers: {}"
                    .format(saved_args.model, saved_args.block_size,
                            saved_args.num_blocks, saved_args.num_layers))
                load_model = True

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        pickle.dump((data_loader.chars, data_loader.vocab), f)

    print("Building the model")
    model = Model(args)
    print("Total trainable parameters: {:,d}".format(
        model.trainable_parameter_count()))

    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

    config = tf.ConfigProto(log_device_placement=False)
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(model.save_variables_list(), max_to_keep=3)
        if (load_model):
            print("Loading saved parameters")
            saver.restore(sess, ckpt.model_checkpoint_path)
        global_epoch_fraction = sess.run(model.global_epoch_fraction)
        global_seconds_elapsed = sess.run(model.global_seconds_elapsed)
        if load_model:
            print(
                "Resuming from global epoch fraction {:.3f}, total trained time: {}, learning rate: {}"
                .format(
                    global_epoch_fraction,
                    datatime.timedelta(
                        seconds=float(gloabal_seconds_elapsed)).sess.run(
                            model.lr)))
        if args.set_learning_rate > 0:
            sess.run(tf.assign(model.lr, args.set_learning_rate))
            print("Reset learning rate to {}".format(args.set_learning_rate))

        data_loader.cue_batch_pointer_to_epoch_fraction(global_epoch_fraction)
        initial_batch_step = int(
            (global_epoch_fraction - int(global_epoch_fraction)) *
            data_loader.total_batch_count)
        epoch_range = (int(global_epoch_fraction),
                       args.num_epochs + int(global_epoch_fraction))
        writer = tf.summary.FileWriter(args.save_dir,
                                       graph=tf.get_default_graph())
        outputs = [
            model.cost, model.final_state, model.train_op, model.summary_op
        ]
        global_step = epoch_range[
            0] * data_loader.total_batch_count + initial_batch_step
        avg_loss = 0
        avg_steps = 0

        try:
            for e in range(*epoch_range):
                state = sess.run(model.zero_state)
                batch_range = (initial_batch_step,
                               data_loader.total_batch_count)
                initial_batch_step = 0
                for b in range(*batch_range):
                    global_step += 1
                    if global_step % args.decay_steps == 0:
                        current_learning_rate = sess.run(model.lr)
                        current_learning_rate *= args.decay_rate
                        sess.run(tf.assign(model.lr, current_learning_rate))
                        print("Decayed learning rate to {}".format(
                            current_learning_rate))
                    start = time.time()

                    x, y = data_loader.next_batch()

                    feed = {model.input_data: x, model.targets: y}
                    model.add_state_to_feed_dict(feed, state)

                    train_loss, state, _, summary = sess.run(outputs, feed)
                    elapsed = time.time() - start
                    global_seconds_elapsed += elapsed
                    writer.add_summary(summary, e * batch_range[1] + b + 1)

                    if avg_steps < 100:
                        avg_steps += 1
                    avg_loss = 1 / avg_steps * train_loss + (
                        1 - 1 / avg_steps) * avg_loss
                    print(
                        "{:, d} / {:, d} (epoch: {:.3f} / {}), loss {:.3f} (avg {:.3f}), {:.3f}s"
                        .format(b, batch_range[1], e + b / batch_range[1],
                                epoch_range[1], train_loss, avg_loss, elapsed))

                    if (e * batch_range[1] + b + 1) % args.save_every == 0 or (
                            e == epoch_range[1] - 1
                            and b == batch_range[1] - 1):
                        save_model(sess, saver, model, args.save_dir,
                                   global_step, data_loader.total_batch_count,
                                   global_seconds_elapsed)
        except KeyboardInterrupt:
            print()
        finally:
            writer.flush()
            global_step = e * data_loader.total_batch_count + b
            save_model(sess, saver, model, args.save_dir, global_step,
                       data_loader.total_batch_count, global_seconds_elapsed)
コード例 #16
0
ファイル: train.py プロジェクト: jrz94/RNN_Branch_prediction
def train(args):
    start = time.time()
    save_dir = args.save_dir
    try:
        os.stat(save_dir)
    except:
        os.mkdir(save_dir)

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    #load data
    data_loader = TextLoader(args)
    train_data = data_loader.train_data
    dev_data = data_loader.dev_data

    out_file = os.path.join(args.save_dir, args.output)
    fout = codecs.open(out_file, "w", encoding="UTF-8")

    args.word_vocab_size = data_loader.word_vocab_size
    args.out_vocab_size = data_loader.word_vocab_size
    print ("Word vocab size: " + str(data_loader.word_vocab_size) + "\n")
    fout.write("Word vocab size: " + str(data_loader.word_vocab_size) + "\n")

    # Model
    lm_model = WordLM

    print ("Begin training...")
    # If using gpu:
    # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
    # gpu_config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
    # add parameters to the tf session -> tf.Session(config=gpu_config)
    with tf.Graph().as_default(), tf.Session() as sess:
        initializer = tf.random_uniform_initializer(-args.init_scale, args.init_scale)

        # Build models
        with tf.variable_scope("model", reuse=None, initializer=initializer):
            mtrain = lm_model(args, is_training=True)
        with tf.variable_scope("model", reuse=True, initializer=initializer):
            mdev = lm_model(args, is_training=False)

        # save only the last model
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
        tf.global_variables_initializer().run()
        dev_pp = 10000000.0

        # process each epoch
        e = 0
        decay_counter = 1
        learning_rate = args.learning_rate
        while e < args.num_epochs:
            if e > 4:
                lr_decay = args.decay_rate ** decay_counter
                learning_rate = args.learning_rate * lr_decay
                decay_counter += 1
            print("Epoch: %d" % (e + 1))
            mtrain.assign_lr(sess, learning_rate)
            print("Learning rate: %.3f" % sess.run(mtrain.lr))

            train_perplexity = run_epoch(sess, mtrain, train_data, data_loader, mtrain.train_op, verbose=True)
            print("Train Perplexity: %.3f" % train_perplexity)

            dev_perplexity = run_epoch(sess, mdev, dev_data, data_loader, tf.no_op())
            print("Valid Perplexity: %.3f" % dev_perplexity)

            # write results to file
            fout.write("Epoch: %d\n" % (e + 1))
            fout.write("Learning rate: %.3f\n" % sess.run(mtrain.lr))
            fout.write("Train Perplexity: %.3f\n" % train_perplexity)
            fout.write("Valid Perplexity: %.3f\n" % dev_perplexity)
            fout.flush()

            if dev_pp > dev_perplexity:
                print ("Achieve highest perplexity on dev set, save model.")
                checkpoint_path = os.path.join(save_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=e)
                print ("model saved to {}".format(checkpoint_path))
                dev_pp = dev_perplexity
            e += 1

        print("Training time: %.0f" % (time.time() - start))
        fout.write("Training time: %.0f\n" % (time.time() - start))
コード例 #17
0
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)

    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        train_loss_iterations = {
            'iteration': [],
            'epoch': [],
            'train_loss': [],
            'val_loss': []
        }

        for e in range(args.num_epochs):
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            data_loader.reset_batch_pointer()
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y}
                train_loss, state, _ = sess.run(
                    [model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                batch_idx = e * data_loader.num_batches + b
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(batch_idx,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, end - start))
                train_loss_iterations['iteration'].append(batch_idx)
                train_loss_iterations['epoch'].append(e)
                train_loss_iterations['train_loss'].append(train_loss)

                if batch_idx % args.save_every == 0:

                    # evaluate
                    avg_val_loss = 0
                    for x_val, y_val in data_loader.val_batches:
                        feed_val = {
                            model.input_data: x_val,
                            model.targets: y_val
                        }
                        val_loss, state_val, _ = sess.run(
                            [model.cost, model.final_state, model.train_op],
                            feed_val)
                        avg_val_loss += val_loss / len(data_loader.val_batches)
                    print('val_loss: {:.3f}'.format(avg_val_loss))
                    train_loss_iterations['val_loss'].append(avg_val_loss)

                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))
                else:
                    train_loss_iterations['val_loss'].append(None)

            pd.DataFrame(data=train_loss_iterations,
                         columns=list(train_loss_iterations.keys())).to_csv(
                             os.path.join(args.save_dir, 'log.csv'))
コード例 #18
0
def train(args):
    # Data Preparation
    # ====================================

    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size
    print("Number of sentences: {}".format(data_loader.num_data))
    print("Vocabulary size: {}".format(args.vocab_size))

    # Check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            args.init_from), " %s must be a path" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "config.pkl")
        ), "config.pkl file does not exist in path %s" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "words_vocab.pkl")
        ), "words_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = pickle.load(f)
        need_be_same = ["rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'words_vocab.pkl'), 'rb') as f:
            saved_words, saved_vocab = pickle.load(f)
        assert saved_words == data_loader.words, "Data and loaded model disagree on word set!"
        assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'wb') as f:
        pickle.dump((data_loader.words, data_loader.vocab), f)
    """
    embedding_matrix = get_vocab_embedding(args.save_dir, data_loader.words, args.embedding_file)
    print("Embedding matrix shape:",embedding_matrix.shape)
    """

    # Training
    # ====================================
    with tf.Graph().as_default():
        with tf.Session() as sess:
            model = BasicLSTM(args)

            # Define training procedure
            global_step = tf.Variable(0, name='global_step', trainable=False)
            optimizer = tf.train.AdamOptimizer(args.learning_rate)
            tvars = tf.trainable_variables()
            grads, _ = tf.clip_by_global_norm(tf.gradients(model.cost, tvars),
                                              args.grad_clip)
            train_op = optimizer.apply_gradients(zip(grads, tvars),
                                                 global_step=global_step)

            # Keep track of gradient values and sparsity
            grad_summaries = []
            for g, v in zip(grads, tvars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram(
                        "{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar(
                        "{}/grad/sparsity".format(v.name),
                        tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)

            # Summary for loss
            loss_summary = tf.summary.scalar("loss", model.cost)

            # Train summaries
            merged = tf.summary.merge_all()
            if not os.path.exists(args.log_dir):
                os.makedirs(args.log_dir)
            train_writer = tf.summary.FileWriter(args.log_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables())

            # Initialize all variables
            sess.run(tf.global_variables_initializer())

            # Restore model
            if args.init_from is not None:
                saver.restore(sess, ckpt.model_checkpoint_path)

            # Start training
            print("Start training")
            for epoch in range(args.num_epochs):
                data_loader.reset_batch_pointer()
                state = sess.run(model.initial_state)
                for i in range(data_loader.num_batches):
                    start = time.time()
                    x_batch, y_batch = data_loader.next_batch()
                    feed_dict = {
                        model.x: x_batch,
                        model.y: y_batch,
                        model.keep_prob: args.keep_prob
                    }
                    _, step, summary, loss, equal = sess.run([
                        train_op, global_step, merged, model.cost, model.equal
                    ], feed_dict)

                    print(
                        "training step {}, epoch {}, batch {}/{}, loss: {:.4f}, accuracy: {:.4f}, time/batch: {:.3f}"
                        .format(step, epoch, i, data_loader.num_batches, loss,
                                np.mean(equal),
                                time.time() - start))
                    train_writer.add_summary(summary, step)

                    current_step = tf.train.global_step(sess, global_step)
                    if current_step % args.save_every == 0 or (
                            epoch == args.num_epochs - 1
                            and i == data_loader.num_batches -
                            1):  #save for the last result
                        checkpoint_path = os.path.join(args.save_dir,
                                                       'model.ckpt')
                        path = saver.save(sess,
                                          checkpoint_path,
                                          global_step=current_step)
                        print("Saved model checkpoint to {}".format(path))

            train_writer.close()
コード例 #19
0
def train(args):
    start = time.time()
    save_dir = args.save_dir
    try:
        os.stat(save_dir)
    except:
        os.mkdir(save_dir)

    args.eos = ''
    args.sos = ''
    if args.EOS == "true":
        args.eos = '</s>'
        args.out_vocab_size += 1
    if args.SOS == "true":
        args.sos = '<s>'
        args.out_vocab_size += 1

    data_loader = TextLoader(args)
    train_data = data_loader.train_data
    dev_data = data_loader.dev_data

    fout = open(os.path.join(args.save_dir, args.output), "a")

    args.word_vocab_size = data_loader.word_vocab_size

    if args.unit != "word":
        args.subword_vocab_size = data_loader.subword_vocab_size
    fout.write(str(args) + "\n")

    # Statistics of words
    fout.write("Word vocab size: " + str(data_loader.word_vocab_size) + "\n")

    # Statistics of sub units
    if args.unit != "word":
        fout.write("Subword vocab size: " +
                   str(data_loader.subword_vocab_size) + "\n")
        if args.composition == "bi-lstm":
            if args.unit == "char":
                fout.write("Maximum word length: " +
                           str(data_loader.max_word_len) + "\n")
                args.bilstm_num_steps = data_loader.max_word_len
            elif args.unit == "char-ngram":
                fout.write("Maximum ngram per word: " +
                           str(data_loader.max_ngram_per_word) + "\n")
                args.bilstm_num_steps = data_loader.max_ngram_per_word
            elif args.unit == "morpheme" or args.unit == "oracle":
                fout.write("Maximum morpheme per word: " +
                           str(data_loader.max_morph_per_word) + "\n")
                args.bilstm_num_steps = data_loader.max_morph_per_word
            else:
                sys.exit("Wrong unit.")
        elif args.composition == "addition":
            if args.unit not in ["char-ngram", "morpheme", "oracle"]:
                sys.exit("Wrong composition.")
        else:
            sys.exit("Wrong unit/composition.")
    else:
        if args.composition != "none":
            sys.exit("Wrong composition.")

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    if args.unit == "word":
        lm_model = WordModel
    elif args.composition == "addition":
        lm_model = AdditiveModel
    elif args.composition == "bi-lstm":
        lm_model = BiLSTMModel
    else:
        sys.exit("Unknown unit or composition.")

    print(args)
    print("Begin training...")
    with tf.Graph().as_default(), tf.Session() as sess:
        if args.seed != 0:
            tf.set_random_seed(args.seed)
            np.random.seed(seed=args.seed)

        initializer = tf.random_uniform_initializer(-args.init_scale,
                                                    args.init_scale)

        # Build models
        with tf.variable_scope("model", reuse=None, initializer=initializer):
            mtrain = lm_model(args, is_training=True)
        with tf.variable_scope("model", reuse=True, initializer=initializer):
            mdev = lm_model(args, is_training=False)

        # save only the last model
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=1)
        tf.initialize_all_variables().run()
        dev_pp = 10000000.0

        # print(sess.run(mtrain.embedding))

        # save only the last model
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        if args.cont == 'true':  # continue training from a saved model
            ckpt = tf.train.get_checkpoint_state(args.save_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            ckpt_name = ckpt.model_checkpoint_path.split('-')
            e = int(ckpt_name[2]) + 1
        else:
            # process each epoch
            e = 1

        learning_rate = args.learning_rate
        patience = args.patience

        while e <= args.num_epochs:

            print("Epoch: %d" % e)
            mtrain.assign_lr(sess, learning_rate)
            print("Learning rate: %.3f" % sess.run(mtrain.lr))

            train_perplexity = run_epoch(sess,
                                         mtrain,
                                         train_data,
                                         data_loader,
                                         mtrain.train_op,
                                         verbose=True)
            dev_perplexity = run_epoch(sess, mdev, dev_data, data_loader,
                                       tf.no_op())

            print("Train Perplexity: %.3f" % train_perplexity)
            print("Valid Perplexity: %.3f" % dev_perplexity)

            # write results to file
            fout.write("Epoch: %d\n" % e)
            fout.write("Learning rate: %.3f\n" % sess.run(mtrain.lr))
            fout.write("Train Perplexity: %.3f\n" % train_perplexity)
            fout.write("Valid Perplexity: %.3f\n" % dev_perplexity)
            fout.flush()

            decrease_lr = False
            diff = dev_pp - dev_perplexity
            if diff >= 0.1:
                print("Achieve highest perplexity on dev set, save model.")
                checkpoint_path = os.path.join(save_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=e)
                print("model saved to {}".format(checkpoint_path))
                dev_pp = dev_perplexity
            else:
                decrease_lr = True

            if e > 4:
                if args.patience != 0:
                    if decrease_lr:
                        patience -= 1
                        if patience == 0:
                            learning_rate *= args.decay_rate
                            patience = args.patience
                    # decrease learning rate
                    else:
                        learning_rate *= args.decay_rate
                # decrease learning rate
                else:
                    learning_rate *= args.decay_rate

            if learning_rate < 0.0001:
                print('Learning rate too small, stop training.')
                break

            e += 1

        print("Training time: %.0f" % (time.time() - start))
        fout.write("Training time: %.0f\n" % (time.time() - start))
コード例 #20
0
ファイル: train.py プロジェクト: zbn123/tensorflow-1
def train(args):
    # 加载数据,解释详见util.py文件
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    # 如果要从原来的模型基础上继续训练的话,执行这一程序块
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            args.init_from), " %s must be a a path" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "config.pkl")
        ), "config.pkl file does not exist in path %s" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "chars_vocab.pkl")
        ), "chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.latest_checkpoint(args.init_from)
        assert ckpt, "No checkpoint found"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = cPickle.load(f)
        need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl'), 'rb') as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars == data_loader.chars, "Data and loaded model disagree on character set!"
        assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)

    # 这里的.pkl文件我们在model.py里面保存模型的时候还会看到~
    # 这里就是把以前保存的模型又调了出来
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    # 我们的模型文件,类似caffe里面的train_net.prototxt文件
    # 前面的参数传递进去,具体解释详见model.py
    model = Model(args)

    with tf.Session() as sess:
        # instrument for tensorboard
        summaries = tf.summary.merge_all()
        writer = tf.summary.FileWriter(
            os.path.join(args.log_dir, time.strftime("%Y-%m-%d-%H-%M-%S")))
        writer.add_graph(sess.graph)

        # 所有变量初始化并运行
        sess.run(tf.global_variables_initializer())
        # 创建一个saver,便于后面的模型保存和重载
        saver = tf.train.Saver(tf.global_variables())
        # restore model
        # 加载模型
        if args.init_from is not None:
            saver.restore(sess, ckpt)
        # e代表每个epoch
        for e in range(args.num_epochs):
            # tf.assign(A, new_number): 这个函数的功能主要是把A的值变为new_number, 也就是重新赋值
            # 学习率的dacay,lr = learning_rate * decay_rate^e
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            # 状态都初始化
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            # 提出不同的batch然后feed进model
            # b代表每个batch
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y}
                for i, (c, h) in enumerate(model.initial_state):
                    feed[c] = state[i].c
                    feed[h] = state[i].h

                # instrument for tensorboard
                # 运行模型,跑出来一个结果
                summ, train_loss, state, _ = sess.run(
                    [summaries, model.cost, model.final_state, model.train_op],
                    feed)
                writer.add_summary(summ, e * data_loader.num_batches + b)

                end = time.time()
                print(
                    "{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches, e,
                            train_loss, end - start))
                # 当到达保存步数或训练到最后一步时保存模型
                if (e * data_loader.num_batches + b) % args.save_every == 0\
                        or (e == args.num_epochs-1 and
                            b == data_loader.num_batches-1):
                    # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))
コード例 #21
0
    def __init__(self,sess,args,flag):
        self.args = args
	    self.flag = flag
	    self.sess = sess
	    self.data_loader = TextLoader(args.data_dir,args.batch_size,args.seq_length)
	    self.g = tf.Graph()
コード例 #22
0
def train(args):
    if args.continue_training in ['True', 'true']:
        args.continue_training = True
    else:
        args.continue_training = False

    data_loader = TextLoader(True, args.utils_dir, args.data_path,
                             args.batch_size, args.seq_length, None, None)
    args.vocab_size = data_loader.vocab_size
    args.label_size = data_loader.label_size

    if args.continue_training:
        assert os.path.isfile(
            os.path.join(args.save_dir, 'config.pkl')
        ), 'config.pkl file does not exist in path %s' % args.save_dir
        assert os.path.isfile(
            os.path.join(args.utils_dir, 'chars_vocab.pkl')
        ), 'chars_vocab.pkl file does not exist in path %s' % args.utils_dir
        assert os.path.isfile(
            os.path.join(args.utils_dir, 'labels.pkl')
        ), 'labels.pkl file does not exist in path %s' % args.utils_dir
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        assert ckpt, 'No checkpoint found'
        assert ckpt.model_checkpoint_path, 'No model path found in checkpoint'

        with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
            saved_model_args = pickle.load(f)
        need_be_same = ['model', 'rnn_size', 'num_layers', 'seq_length']
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], 'command line argument and saved model disagree on %s' % checkme

        with open(os.path.join(args.utils_dir, 'chars_vocab.pkl'), 'rb') as f:
            saved_chars, saved_vocab = pickle.load(f)
        with open(os.path.join(args.utils_dir, 'labels.pkl'), 'rb') as f:
            saved_labels = pickle.load(f)
        assert saved_chars == data_loader.chars, 'data and loaded model disagree on character set'
        assert saved_vocab == data_loader.vocab, 'data and loaded model disagree on dictionary mappings'
        assert saved_labels == data_loader.labels, 'data and loaded model disagree on label dictionary mappings'

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)
    with open(os.path.join(args.utils_dir, 'chars_vocab.pkl'), 'wb') as f:
        pickle.dump((data_loader.chars, data_loader.vocab), f)
    with open(os.path.join(args.utils_dir, 'labels.pkl'), 'wb') as f:
        pickle.dump(data_loader.labels, f)

    model = Model(args)

    with tf.Session() as sess:
        init = tf.initialize_all_variables()
        sess.run(init)
        saver = tf.train.Saver(tf.all_variables())

        if args.continue_training:
            saver.restore(sess, ckpt.model_checkpoint_path)

        for e in range(args.num_epochs):
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            data_loader.reset_batch_pointer()

            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y}
                train_loss, state, _, accuracy = sess.run([
                    model.cost, model.final_state, model.optimizer,
                    model.accuracy
                ],
                                                          feed_dict=feed)
                end = time.time()
                print '{}/{} (epoch {}), train_loss = {:.3f}, accuracy = {:.3f}, time/batch = {:.3f}'\
                    .format(e * data_loader.num_batches + b + 1,
                            args.num_epochs * data_loader.num_batches,
                            e + 1,
                            train_loss,
                            accuracy,
                            end - start)
                if (e*data_loader.num_batches+b+1) % args.save_every == 0 \
                    or (e==args.num_epochs-1 and b==data_loader.num_batches-1):
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * data_loader.num_batches + b + 1)
                    print 'model saved to {}'.format(checkpoint_path)
コード例 #23
0
def main():
    args = parse_args()
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = pickle.load(f)
    loader = TextLoader(saved_args.data_dir, saved_args.batch_size, \
    saved_args.seq_length, isTraining=False)

    saved_args.batch_size = 1  # Set batch size to 1 when sampling
    model = Model(saved_args, training=False)

    lut = {}

    vocab = loader.vocab
    charset = vocab.keys()
    charset_ordered = sorted(vocab.keys(), key=(lambda key: vocab[key]))

    results = []
    results_len = 0

    # Load first character probabilities
    first_char_probs = loader.first_char_probs
    for c in charset:
        if first_char_probs.has_key(c):
            if vocab[c] == 0:
                continue
            else:
                lut[c] = first_char_probs[c]

    total_start = time.time()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        start = time.time()
        while results_len < args.sample_size:
            # Pick first letter according to probability
            first_char = np.random.choice(first_char_probs.keys(),
                                          p=first_char_probs.values())
            # print first_char + '\t' + str(first_char_probs[first_char])
            current_prefix = first_char
            result_prob = first_char_probs[first_char]
            err = False
            while not current_prefix.endswith("\n"):
                # Pick next letter according to probability
                length = len(current_prefix)
                # Get next possible characters' probabilities by NN
                line = np.array(map(vocab.get, current_prefix))
                line = np.pad(line, (0, saved_args.seq_length - len(line)),
                              'constant')
                feed = {
                    model.input_data: [line],
                    model.sequence_lengths: [length]
                }
                probs = sess.run([model.probs], feed)
                probs = np.reshape(probs, (-1, saved_args.vocab_size))
                next_char_prob = probs[length - 1]
                # next_char_prob[i] is probability of char in vocab with value i
                next_char = np.random.choice(charset_ordered, p=next_char_prob)
                # print next_char + '\t' + str(next_char_prob[vocab[next_char]])
                current_prefix += next_char
                result_prob *= next_char_prob[vocab[next_char]]
                if len(current_prefix
                       ) > saved_args.seq_length:  # this shouldn't happen
                    print "Something that shouldn't happen happened"
                    err = True
                    break
            if err:
                continue
            # print str(result_prob) + '\t' + current_prefix
            results.append(str(result_prob) + '\n')
            results_len += 1
            if results_len % args.display_every == 0:
                end = time.time()
                print("Progress: {}/{}; time taken = {}".format(
                    results_len, args.sample_size, end - start))
                start = time.time()

    with open(args.output_file, 'w') as f:
        f.writelines(results)
    total_end = time.time()
    print("Generated {} samples; total time taken = {}".format(
        len(results), total_end - total_start))
コード例 #24
0
def cross_validation(args):
    data_loader = TextLoader(True, args.utils_dir, args.data_path,
                             args.batch_size, args.seq_length, None, None)
    args.vocab_size = data_loader.vocab_size
    args.label_size = data_loader.label_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)
    with open(os.path.join(args.utils_dir, 'chars_vocab.pkl'), 'wb') as f:
        pickle.dump((data_loader.chars, data_loader.vocab), f)
    with open(os.path.join(args.utils_dir, 'labels.pkl'), 'wb') as f:
        pickle.dump(data_loader.labels, f)

    data = data_loader.tensor.copy()
    np.random.shuffle(data)
    data_list = np.array_split(data, 10, axis=0)

    model = Model(args)
    accuracy_list = []

    with tf.Session() as sess:
        for n in range(10):
            init = tf.initialize_all_variables()
            sess.run(init)
            saver = tf.train.Saver(tf.all_variables())

            test_data = data_list[n].copy()
            train_data = np.concatenate(map(lambda i: data_list[i],
                                            [j for j in range(10) if j != n]),
                                        axis=0)
            data_loader.tensor = train_data

            for e in range(args.num_epochs):
                sess.run(
                    tf.assign(model.lr,
                              args.learning_rate * (args.decay_rate**e)))
                data_loader.reset_batch_pointer()

                for b in range(data_loader.num_batches):
                    start = time.time()
                    x, y = data_loader.next_batch()
                    feed = {model.input_data: x, model.targets: y}
                    train_loss, state, _, accuracy = sess.run([
                        model.cost, model.final_state, model.optimizer,
                        model.accuracy
                    ],
                                                              feed_dict=feed)
                    end = time.time()
                    print '{}/{} (epoch {}), train_loss = {:.3f}, accuracy = {:.3f}, time/batch = {:.3f}'\
                        .format(e * data_loader.num_batches + b + 1,
                                args.num_epochs * data_loader.num_batches,
                                e + 1,
                                train_loss,
                                accuracy,
                                end - start)
                    if (e*data_loader.num_batches+b+1) % args.save_every == 0 \
                        or (e==args.num_epochs-1 and b==data_loader.num_batches-1):
                        checkpoint_path = os.path.join(args.save_dir,
                                                       'model.ckpt')
                        saver.save(sess,
                                   checkpoint_path,
                                   global_step=e * data_loader.num_batches +
                                   b + 1)
                        print 'model saved to {}'.format(checkpoint_path)

            n_chunks = len(test_data) / args.batch_size
            if len(test_data) % args.batch_size:
                n_chunks += 1
            test_data_list = np.array_split(test_data, n_chunks, axis=0)

            correct_total = 0.0
            num_total = 0.0
            for m in range(n_chunks):
                start = time.time()
                x = test_data_list[m][:, :-1]
                y = test_data_list[m][:, -1]
                results = model.predict_class(sess, x)
                correct_num = np.sum(results == y)
                end = time.time()

                correct_total += correct_num
                num_total += len(x)

            accuracy_total = correct_total / num_total
            accuracy_list.append(accuracy_total)
            print 'total_num = {}, total_accuracy = {:.6f}'.format(
                int(num_total), accuracy_total)

    accuracy_average = np.average(accuracy_list)
    print 'The average accuracy of cross_validation is {}'.format(
        accuracy_average)
コード例 #25
0
def train(args):
    # 加载词库,取词库的大小
    data_loader = TextLoader(args.batch_size)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            args.init_from), " %s must be a a path" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "config.pkl")
        ), "config.pkl file does not exist in path %s" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "chars_vocab.pkl")
        ), "chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"
        assert os.path.isfile(
            os.path.join(args.init_from, "iterations")
        ), "iterations file does not exist in path %s " % args.init_from

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = cPickle.load(f)

        need_be_same = ["model", "rnn_size", "num_layers"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl'), 'rb') as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars == data_loader.chars, "Data and loaded model disagree on character set!"
        assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    # 保存本次的运行配置
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    # 保存本次词库和词的编号字典
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    # 创建模型
    model = Model(args)

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(tf.global_variables())
        iterations = 0
        # restore model and number of iterations
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
            with open(os.path.join(args.save_dir, 'iterations'), 'rb') as f:
                iterations = cPickle.load(f)

        losses = []
        for e in range(args.num_epochs):
            # 指数衰减学习率
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            # 每轮大循环重置批指针索引
            data_loader.reset_batch_pointer()

            for b in range(data_loader.num_batches):
                iterations += 1

                start = time.time()
                # 取一个批次的输入数据和目标数据
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y}
                train_loss, _, _ = sess.run(
                    [model.cost, model.final_state, model.train_op], feed)
                end = time.time()

                sys.stdout.write('\r')
                info = "{}/{} (epoch {}), train_loss = {:.3f}, iterations = {} time/batch = {:.3f}" \
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, iterations, end - start)
                sys.stdout.write(info)
                sys.stdout.flush()

                losses.append(train_loss)

                if (e * data_loader.num_batches + b + 1) % args.save_every == 0\
                    or (e == args.num_epochs - 1 and b == data_loader.num_batches - 1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=iterations)
                    with open(os.path.join(args.save_dir, "iterations"),
                              'wb') as f:
                        cPickle.dump(iterations, f)
                    with open(
                            os.path.join(args.save_dir,
                                         "losses-" + str(iterations)),
                            'wb') as f:
                        cPickle.dump(losses, f)
                    losses = []
                    sys.stdout.write('\n')
                    print("model saved to {}".format(checkpoint_path))
            sys.stdout.write('\n')
コード例 #26
0
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            args.init_from), " %s must be a a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),\
            "config.pkl file does not exist in path %s"%args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),\
            "chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = cPickle.load(f)
        need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme],\
                "Command line argument and saved model disagree on '%s' "%checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl'), 'rb') as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars == data_loader.chars, "Data and loaded model disagree on character set!"
        assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)
    model = Model(args)

    with tf.Session() as sess:
        # instrument for tensorboard
        summaries = tf.summary.merge_all()
        writer = tf.summary.FileWriter(
            os.path.join(args.log_dir, time.strftime("%Y-%m-%d-%H-%M-%S")))
        writer.add_graph(sess.graph)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)

        prev_loss = np.Infinity
        av_train_loss = 0
        min_test_loss = np.Infinity

        with open(args.loss_log, mode='a') as loss_log:
            print("\nData: {}, number of layers: {}, hidden state size: {}".
                  format(args.data_dir, args.num_layers, args.rnn_size),
                  file=loss_log)
        for e in range(args.num_epochs):
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            state = sess.run(model.initial_state)
            b = 1
            start = time.time()

            data_loader.next_epoch()
            for x, y, word_lengths in data_loader.iterbatches('train'):
                feed = {
                    model.input_data: x,
                    model.targets: y,
                    model.word_len: word_lengths
                }
                for i, (c, h) in enumerate(model.initial_state):
                    feed[c] = state[i].c
                    feed[h] = state[i].h

                # instrument for tensorboard
                summ, train_loss, state, _ = sess.run(
                    [summaries, model.cost, model.final_state, model.train_op],
                    feed)
                writer.add_summary(summ, e * data_loader.num_batches + b)
                writer.flush()

                av_train_loss += train_loss

                end = time.time()
                print(
                    "{}/{} (file {}/{}, epoch {}/{}), train_loss = {:.3f}, time/batch = {:.3f}"
                    .format(
                        b - data_loader.num_batches * data_loader.file_number,
                        data_loader.num_batches, data_loader.file_number,
                        data_loader.num_train_files - 1, e,
                        args.num_epochs - 1, train_loss, end - start))
                b += 1

                start = time.time()

            av_train_loss /= b

            # test
            av_test_loss = 0
            t = 1
            start = time.time()
            for x_test, y_test, word_lengths in data_loader.iterbatches(
                    'test'):
                feed = {
                    model.input_data: x_test,
                    model.targets: y_test,
                    model.word_len: word_lengths
                }
                test_loss = sess.run([model.cost], feed)
                test_loss = test_loss[0]
                av_test_loss += test_loss
                end = time.time()
                print(
                    "{}/{} (file {}/{}, epoch {}/{}), test_loss = {:.3f}, time/batch = {:.3f}"
                    .format(
                        t - data_loader.num_batches *
                        (data_loader.file_number -
                         data_loader.num_train_files), data_loader.num_batches,
                        data_loader.file_number - data_loader.num_train_files -
                        1, data_loader.num_test_files - 1, e,
                        args.num_epochs - 1, test_loss, end - start))
                t += 1
                start = time.time()

            av_test_loss /= t

            with open(args.loss_log, mode='a') as loss_log:
                print(
                    "Epoch {}, learning_rate = {}, train_loss = {}, test_loss = {}"
                    .format(e, args.learning_rate, av_train_loss,
                            av_test_loss),
                    file=loss_log)

            if av_test_loss >= prev_loss:
                args.learning_rate /= 2
                print(args.learning_rate)
            if av_test_loss < min_test_loss:
                checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                saver.save(sess,
                           checkpoint_path,
                           global_step=e * data_loader.num_batches *
                           data_loader.num_train_files)
                print("model saved to {}".format(checkpoint_path))
                min_test_loss = av_test_loss
            prev_loss = av_test_loss

            print("epoch {} ended".format(e))
            if args.learning_rate < args.min_lr:
                print("Minimum reached", args.learning_rate, "<", args.min_lr)
                break
コード例 #27
0
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length,
                             args.input_encoding)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            args.init_from), " %s must be a path" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "config.pkl")
        ), "config.pkl file does not exist in path %s" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, "words_vocab.pkl")
        ), "words_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = cPickle.load(f)
        need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'words_vocab.pkl'), 'rb') as f:
            saved_words, saved_vocab = cPickle.load(f)
        assert saved_words == data_loader.words, "Data and loaded model disagree on word set!"
        assert saved_vocab == data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.words, data_loader.vocab), f)

    model = Model(args)

    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(args.log_dir)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem)

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        train_writer.add_graph(sess.graph)
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(tf.global_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(model.epoch_pointer.eval(), args.num_epochs):
            sess.run(
                tf.assign(model.lr, args.learning_rate * (args.decay_rate**e)))
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            speed = 0
            if args.init_from is None:
                assign_op = model.epoch_pointer.assign(e)
                sess.run(assign_op)
            if args.init_from is not None:
                data_loader.pointer = model.batch_pointer.eval()
                args.init_from = None
            for b in range(data_loader.pointer, data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {
                    model.input_data: x,
                    model.targets: y,
                    model.initial_state: state,
                    model.batch_time: speed
                }
                summary, train_loss, state, _, _ = sess.run([
                    merged, model.cost, model.final_state, model.train_op,
                    model.inc_batch_pointer_op
                ], feed)
                train_writer.add_summary(summary,
                                         e * data_loader.num_batches + b)
                speed = time.time() - start
                if (e * data_loader.num_batches + b) % args.batch_size == 0:
                    print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                        .format(e * data_loader.num_batches + b,
                                args.num_epochs * data_loader.num_batches,
                                e, train_loss, speed))
                if (e * data_loader.num_batches + b) % args.save_every == 0 \
                        or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))
        train_writer.close()
コード例 #28
0
ファイル: train.py プロジェクト: kanghj/DSM
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    # if args.init_from is not None:
    #     # check if all necessary files exist
    #     assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
    #     assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
    #     assert os.path.isfile(os.path.join(args.init_from,"words_vocab.pkl")),"words_vocab.pkl.pkl file does not exist in path %s" % args.init_from
    #     ckpt = tf.train.get_checkpoint_state(args.init_from)
    #     assert ckpt,"No checkpoint found"
    #     assert ckpt.model_checkpoint_path,"No model path found in checkpoint"
    #
    #     # open old config and check if models are compatible
    #     with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
    #         saved_model_args = cPickle.load(f)
    #     need_be_same=["model","rnn_size","num_layers","seq_length"]
    #     for checkme in need_be_same:
    #         assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme
    #
    #     # open saved vocab/dict and check if vocabs/dicts are compatible
    #     with open(os.path.join(args.init_from, 'words_vocab.pkl'), 'rb') as f:
    #         saved_words, saved_vocab = cPickle.load(f)
    #     assert saved_words==data_loader.words, "Data and loaded model disagreee on word set!"
    #     assert saved_vocab==data_loader.vocab, "Data and loaded model disagreee on dictionary mappings!"
    #
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.words, data_loader.vocab), f)
    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    model = Model(args)

    with  tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, end - start))
                if (e * data_loader.num_batches + b) % args.save_every == 0 \
                        or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))
コード例 #29
0
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size
    
    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist 
        assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt,"No checkpoint found"
        assert ckpt.model_checkpoint_path,"No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl')) as f:
            saved_model_args = cPickle.load(f)
        need_be_same=["model","rnn_size","num_layers","seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme
        
        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl')) as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars==data_loader.chars, "Data and loaded model disagreee on character set!"
        assert saved_vocab==data_loader.vocab, "Data and loaded model disagreee on dictionary mappings!"
        
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)
        
    model = Model(args)

    stats_data = { 'mini_batch': [], 'epochs': [], 'train_losses': [], 'time_per_batch': [] }

    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = model.initial_state.eval()
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()

                stats_data[ 'epochs' ].append( e * data_loader.num_batches + b )
                stats_data[ 'epochs' ].append( e )
                stats_data[ 'train_losses' ].append( train_loss )
                stats_data[ 'time_per_batch' ].append( end - start )

                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, end - start) )

                if (e * data_loader.num_batches + b) % args.save_every == 0\
                    or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))


    if args.stats_file_name != '':
        print( 'saving stats: {0}'.format( args.stats_file_name ) )
        cPickle.dump( stats_data, open( args.stats_file_name, 'wb' ) )
コード例 #30
0
def train(args):
    start = time.time()
    save_dir = args.save_dir
    try:
        os.stat(save_dir)
    except:
        os.mkdir(save_dir)

    args.eos = ''
    args.sos = ''
    if args.EOS == "true":
        args.eos = '</s>'
        args.out_vocab_size += 1
    if args.SOS == "true":
        args.sos = '<s>'
        args.out_vocab_size += 1

    local_test = False
    if local_test:
        # Gozde
        # char, char-ngram, morpheme, word, or oracle
        args.unit = "oracle-db"
        args.composition = "add-bi-lstm"
        args.train_file = "data/train.morph"
        args.dev_file = "data/dev.morph"
        args.batch_size = 12
        # End of test

    data_loader = TextLoader(args)
    train_data = data_loader.train_data
    dev_data = data_loader.dev_data

    fout = open(os.path.join(args.save_dir, args.output), "a")

    args.word_vocab_size = data_loader.word_vocab_size

    if args.unit != "word":
        args.subword_vocab_size = data_loader.subword_vocab_size
    fout.write(str(args) + "\n")

    # Statistics of words
    fout.write("Word vocab size: " + str(data_loader.word_vocab_size) + "\n")

    # Statistics of sub units
    if args.unit != "word":
        fout.write("Subword vocab size: " + str(data_loader.subword_vocab_size) + "\n")
        if args.composition == "bi-lstm":
            if args.unit == "char":
                fout.write("Maximum word length: " + str(data_loader.max_word_len) + "\n")
                args.bilstm_num_steps = data_loader.max_word_len
            elif args.unit == "char-ngram":
                fout.write("Maximum ngram per word: " + str(data_loader.max_ngram_per_word) + "\n")
                args.bilstm_num_steps = data_loader.max_ngram_per_word
            elif args.unit == "morpheme" or args.unit == "oracle":
                fout.write("Maximum morpheme per word: " + str(data_loader.max_morph_per_word) + "\n")
                args.bilstm_num_steps = data_loader.max_morph_per_word
            else:
                sys.exit("Wrong unit.")
        elif args.composition == "add-bi-lstm":
            fout.write("Maximum db per word: " + str(data_loader.max_db_per_word) + "\n")
            fout.write("Maximum morph per db: " + str(data_loader.max_morph_per_db) + "\n")
            args.bilstm_num_steps = data_loader.max_db_per_word
        elif args.composition == "addition":
            if args.unit not in ["char-ngram", "morpheme", "oracle"]:
                sys.exit("Wrong composition.")
        else:
            sys.exit("Wrong unit/composition.")
    else:
        if args.composition != "none":
            sys.exit("Wrong composition.")

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    print(args)
    if args.unit == "word":
        lm_model = WordModel
    elif args.composition == "addition":
        lm_model = AdditiveModel
    elif args.composition == "bi-lstm":
        lm_model = BiLSTMModel
    elif args.composition == "add-bi-lstm":
        lm_model = AddBiLSTMModel
    else:
        sys.exit("Unknown unit or composition.")

    print("Begin training...")

    mtrain = lm_model(args)
    if args.use_cuda:
        mtrain = mtrain.cuda()

    nParams = sum([p.nelement() for p in mtrain.parameters()])
    print('* number of parameters: %d' % nParams)

    optim = Optim(
        args.optimization, args.learning_rate, args.grad_clip,
        lr_decay=args.decay_rate,
        patience=args.patience
    )
    # update all parameters
    optim.set_parameters(mtrain.parameters())

    dev_pp = 10000000.0

    if args.cont == 'true':  # continue training from a saved model
        # get model parameters
        model_path, e = get_last_model_path(args.save_dir)
        saved_model = torch.load(model_path)
        mtrain.load_state_dict(saved_model['state_dict'])
        # get optimizer states
        # not saving learning rate (probably too small so it won't continue training)
        optim.last_ppl = saved_model['last_ppl']
    else:
        # process each epoch
        e = 1

    while e <= args.num_epochs:
        print("Epoch: %d" % e)
        print("Learning rate: %.3f" % optim.lr)

        #  (1) train for one epoch on the training set
        train_perplexity = run_epoch(mtrain, train_data, data_loader,optim, eval=False)
        print("Train Perplexity: %.3f" % train_perplexity)

        #  (2) evaluate on the validation set
        dev_perplexity = run_epoch(mtrain, dev_data, data_loader, optim, eval=True)
        print("Valid Perplexity: %.3f" % dev_perplexity)

        #  (3) update the learning rate
        optim.updateLearningRate(dev_perplexity, e)

        # (4) save results and report
        diff = dev_pp - dev_perplexity
        if diff >= 0.1:
            print("Achieve highest perplexity on dev set, save model.")
            checkpoint = {
                'state_dict': mtrain.state_dict(),
                'last_ppl':optim.last_ppl
            }
            torch.save(checkpoint,
                       '%s/%s-%d.pt' % (save_dir, "model", e))
            dev_pp = dev_perplexity

        # write results to file
        fout.write("Epoch: %d\n" % e)
        fout.write("Learning rate: %.3f\n" % optim.lr)
        fout.write("Train Perplexity: %.3f\n" % train_perplexity)
        fout.write("Valid Perplexity: %.3f\n" % dev_perplexity)
        fout.flush()

        if optim.lr < 0.0001:
            print('Learning rate too small, stop training.')
            break

        e += 1

    print("Training time: %.0f" % (time.time() - start))
    fout.write("Training time: %.0f\n" % (time.time() - start))