예제 #1
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       single_cell_fn=None):
    """Create train graph, model, and iterator."""
    src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
    tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)

    tgt_vocab_file = hparams.tgt_vocab_file

    graph = tf.Graph()

    with graph.as_default():

        tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file)

        src_dataset = tf.contrib.data.TextLineDataset(src_file)
        tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            tgt_vocab_table,
            sos=hparams.sos,
            eos=hparams.eos,
            source_reverse=hparams.source_reverse,
            random_seed=hparams.random_seed,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len,
            skip_count=skip_count_placeholder)

        # Note: One can set model_device_fn to `tf.train.replica_device_setter(ps_tasks)` for distributed training.
        with tf.device(model_helper.get_device_str(hparams.base_gpu)):
            # model_creator: 模型
            model = model_creator(hparams,
                                  iterator=iterator,
                                  mode=tf.contrib.learn.ModeKeys.TRAIN,
                                  target_vocab_table=tgt_vocab_table,
                                  scope=scope,
                                  single_cell_fn=single_cell_fn)

    return TrainModel(graph=graph,
                      model=model,
                      iterator=iterator,
                      skip_count_placeholder=skip_count_placeholder)
예제 #2
0
def create_infer_model(model_creator,
                       hparams,
                       scope=None,
                       single_cell_fn=None):
    """Create inference model."""
    graph = tf.Graph()

    tgt_vocab_file = hparams.tgt_vocab_file

    with graph.as_default():

        tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file)
        # 转换成反向表
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)

        src_dataset = tf.contrib.data.Dataset.from_tensor_slices(
            src_placeholder)
        iterator = iterator_utils.get_infer_iterator(
            src_dataset,
            source_reverse=hparams.source_reverse,
            src_max_len=hparams.src_max_len_infer)

        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.INFER,
            target_vocab_table=tgt_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            scope=scope,
            single_cell_fn=single_cell_fn)

    return InferModel(graph=graph,
                      model=model,
                      src_placeholder=src_placeholder,
                      iterator=iterator)
예제 #3
0
def create_eval_model(model_creator, hparams, scope=None, single_cell_fn=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    tgt_vocab_file = hparams.tgt_vocab_file
    graph = tf.Graph()

    with graph.as_default():

        tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file)

        src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder)
        tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder)
        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            tgt_vocab_table,
            sos=hparams.sos,
            eos=hparams.eos,
            source_reverse=hparams.source_reverse,
            random_seed=hparams.random_seed,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len)

        model = model_creator(hparams,
                              iterator=iterator,
                              mode=tf.contrib.learn.ModeKeys.EVAL,
                              target_vocab_table=tgt_vocab_table,
                              scope=scope,
                              single_cell_fn=single_cell_fn)

    return EvalModel(graph=graph,
                     model=model,
                     src_file_placeholder=src_file_placeholder,
                     tgt_file_placeholder=tgt_file_placeholder,
                     iterator=iterator)
예제 #4
0
    # add the mask to the scaled tensor
    if mask is not None:
        scaled_attention_logits += (mask * 1e-9)
    # softmax is normalized on the last axis (seq_len_k) so that the scores add up to 1.
    attention_weights = tf.nn.softmax(scaled_attention_logits,
                                      axis=-1)  # (..., q_len, kv_len)
    output = tf.matmul(attention_weights,
                       v)  # [.., q_len, d_model] ? [.., k_len, d_model]
    return output, attention_weights, scaled_attention_logits


# print(os.getcwd())
base_path = "/home/panxie/Documents/sign-language/nslt/Data"
src_file = base_path + "/phoenix2014T.dev.sign"
tgt_file = base_path + "/phoenix2014T.dev.de"
tgt_vocab_table = create_tgt_vocab_table(base_path + "/phoenix2014T.vocab.de")
dataset = dataset.get_train_dataset(src_file, tgt_file, tgt_vocab_table)
cnt = 0
for data in dataset.take(1):
    cnt += 1
    src_inputs, tgt_in, tgt_out, src_path, src_len, tgt_len = data
    bs, t, h, w, c = src_inputs.shape
    print(src_inputs.shape, src_path)
    src_inputs = tf.reshape(src_inputs, (bs * t, h, w, c))
    cnn_output = resnet_model(src_inputs, training=False)
    cnn_output = tf.reshape(cnn_output, (bs, t, -1))
    attention_out, atten_weights, atten_logits = scaled_dot_product_attention(
        cnn_output, cnn_output, cnn_output, mask=None)
    for i in range(100):
        # print(atten_logits[0, i, :])
        print(tf.nn.top_k(atten_logits[0, i, :], k=10).indices)
예제 #5
0
config = FLAGS
FLAGS.output_dir = "./output_dir/checkpoints_alexnet_ctc"
FLAGS.best_output = "./output_dir/checkpoints_alexnet_ctc/best_bleu"

for arg in vars(FLAGS):
    logger.info("{}, {}".format(arg, getattr(FLAGS, arg)))


tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab(config.tgt_vocab_file,
                                                         "./",
                                                         sos="<s>",
                                                         eos="</s>",
                                                         unk=vocab_utils.UNK)

tgt_vocab_table = vocab_utils.create_tgt_vocab_table(config.tgt_vocab_file)
word2idx, idx2word = vocab_utils.create_tgt_dict(tgt_vocab_file)

# model = Model(rnn_units=config.rnn_units, tgt_vocab_size=tgt_vocab_size, tgt_emb_size=config.tgt_emb_size)
model = CTCModel(input_shape=config.input_shape, tgt_vocab_size=tgt_vocab_size, dropout=config.dropout,
                 rnn_units=FLAGS.rnn_units)


lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    config.learning_rate,
    decay_steps=config.decay_steps,
    decay_rate=0.96,
    staircase=True)

optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
예제 #6
0
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

config = FLAGS
for arg in vars(config):
    logger.info("{}, {}".format(arg, getattr(config, arg)))

tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab(config.tgt_vocab_file,
                                                         out_dir="./",
                                                         sos="<s>",
                                                         eos="</s>",
                                                         unk=vocab_utils.UNK)

tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file)
word2idx, idx2word = vocab_utils.create_tgt_dict(tgt_vocab_file)
print(word2idx["<blank>"])
print(len(word2idx))

# model = Model(rnn_units=config.rnn_units, tgt_vocab_size=tgt_vocab_size, tgt_emb_size=config.tgt_emb_size)
if config.cnn_architecture == "resnet":
    cnn_model_path = config.resnet_weight_path
else:
    cnn_model_path = config.alexnet_weight_path
model = SFNet(input_shape=config.input_shape,
              tgt_vocab_size=tgt_vocab_size,
              rnn_units=config.rnn_units,
              cnn_arch=config.cnn_architecture,
              cnn_model_path=cnn_model_path,
              dropout=config.dropout)