Example #1
0
            embeddings = np.array(pickle.load(f))
        FLAGS.vocab_size = embeddings.shape[0]
        if FLAGS.init_dict:
            pretrained_word_embeddings = embeddings
        else:
            pretrained_word_embeddings = None

        with sess.as_default():
            train_record_file = './%s/train.tfrecords' % (FLAGS.data_path)
            valid_record_file = './%s/valid.tfrecords' % (FLAGS.data_path)

            parser = get_record_parser(FLAGS)
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print("Create training dataset begain... | %s " % time_str)
            train_dataset = get_batch_dataset(train_record_file, parser,
                                              FLAGS.batch_size,
                                              FLAGS.num_threads,
                                              FLAGS.capacity, False)
            valid_dataset = get_batch_dataset(valid_record_file, parser,
                                              FLAGS.batch_size,
                                              FLAGS.num_threads,
                                              FLAGS.capacity, True)
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print("Create training dataset end... | %s " % time_str)

            handle = tf.placeholder(tf.string, shape=[])
            iterator = tf.data.Iterator.from_string_handle(
                handle, train_dataset.output_types,
                train_dataset.output_shapes)
            train_iterator = train_dataset.make_one_shot_iterator()
            valid_iterator = valid_dataset.make_initializable_iterator()
            train_handle = sess.run(train_iterator.string_handle())
Example #2
0
        os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    # Checkpoint directory.
    out_dir = os.path.abspath(os.path.join(os.path.curdir, FLAGS.log_root))
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints/"))

    with tf.device("/%s" % FLAGS.gpu):
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=session_conf)

        with sess.as_default():
            parser = get_record_parser(FLAGS)
            test_dataset = get_batch_dataset(FLAGS.test_record_file, parser,
                                             FLAGS.batch_size,
                                             FLAGS.num_threads, FLAGS.capacity,
                                             True)
            test_iterator = test_dataset.make_initializable_iterator()
            sess.run(test_iterator.initializer)

            test_handle = sess.run(test_iterator.string_handle())
            handle = tf.placeholder(tf.string, shape=[])
            iterator = tf.data.Iterator.from_string_handle(
                handle, test_dataset.output_types, test_dataset.output_shapes)

            model = model(iterator, FLAGS, FLAGS.embed_dim, FLAGS.vocab_size,
                          FLAGS.char_embed_dim, FLAGS.char_vocab_size,
                          FLAGS.rnn_dim, FLAGS.max_turn,
                          FLAGS.max_utterance_len, FLAGS.max_word_len)

            global_step = tf.Variable(0, name="global_step", trainable=False)