def create_train_model(model_creator, hparams, scope=None, single_cell_fn=None): """Create train graph, model, and iterator.""" src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(): tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file) src_dataset = tf.contrib.data.TextLineDataset(src_file) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, tgt_vocab_table, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder) # Note: One can set model_device_fn to `tf.train.replica_device_setter(ps_tasks)` for distributed training. with tf.device(model_helper.get_device_str(hparams.base_gpu)): # model_creator: 模型 model = model_creator(hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, target_vocab_table=tgt_vocab_table, scope=scope, single_cell_fn=single_cell_fn) return TrainModel(graph=graph, model=model, iterator=iterator, skip_count_placeholder=skip_count_placeholder)
def create_infer_model(model_creator, hparams, scope=None, single_cell_fn=None): """Create inference model.""" graph = tf.Graph() tgt_vocab_file = hparams.tgt_vocab_file with graph.as_default(): tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file) # 转换成反向表 reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) src_dataset = tf.contrib.data.Dataset.from_tensor_slices( src_placeholder) iterator = iterator_utils.get_infer_iterator( src_dataset, source_reverse=hparams.source_reverse, src_max_len=hparams.src_max_len_infer) model = model_creator( hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.INFER, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, scope=scope, single_cell_fn=single_cell_fn) return InferModel(graph=graph, model=model, src_placeholder=src_placeholder, iterator=iterator)
def create_eval_model(model_creator, hparams, scope=None, single_cell_fn=None): """Create train graph, model, src/tgt file holders, and iterator.""" tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(): tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, tgt_vocab_table, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len) model = model_creator(hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, target_vocab_table=tgt_vocab_table, scope=scope, single_cell_fn=single_cell_fn) return EvalModel(graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, iterator=iterator)
# add the mask to the scaled tensor if mask is not None: scaled_attention_logits += (mask * 1e-9) # softmax is normalized on the last axis (seq_len_k) so that the scores add up to 1. attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., q_len, kv_len) output = tf.matmul(attention_weights, v) # [.., q_len, d_model] ? [.., k_len, d_model] return output, attention_weights, scaled_attention_logits # print(os.getcwd()) base_path = "/home/panxie/Documents/sign-language/nslt/Data" src_file = base_path + "/phoenix2014T.dev.sign" tgt_file = base_path + "/phoenix2014T.dev.de" tgt_vocab_table = create_tgt_vocab_table(base_path + "/phoenix2014T.vocab.de") dataset = dataset.get_train_dataset(src_file, tgt_file, tgt_vocab_table) cnt = 0 for data in dataset.take(1): cnt += 1 src_inputs, tgt_in, tgt_out, src_path, src_len, tgt_len = data bs, t, h, w, c = src_inputs.shape print(src_inputs.shape, src_path) src_inputs = tf.reshape(src_inputs, (bs * t, h, w, c)) cnn_output = resnet_model(src_inputs, training=False) cnn_output = tf.reshape(cnn_output, (bs, t, -1)) attention_out, atten_weights, atten_logits = scaled_dot_product_attention( cnn_output, cnn_output, cnn_output, mask=None) for i in range(100): # print(atten_logits[0, i, :]) print(tf.nn.top_k(atten_logits[0, i, :], k=10).indices)
config = FLAGS FLAGS.output_dir = "./output_dir/checkpoints_alexnet_ctc" FLAGS.best_output = "./output_dir/checkpoints_alexnet_ctc/best_bleu" for arg in vars(FLAGS): logger.info("{}, {}".format(arg, getattr(FLAGS, arg))) tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab(config.tgt_vocab_file, "./", sos="<s>", eos="</s>", unk=vocab_utils.UNK) tgt_vocab_table = vocab_utils.create_tgt_vocab_table(config.tgt_vocab_file) word2idx, idx2word = vocab_utils.create_tgt_dict(tgt_vocab_file) # model = Model(rnn_units=config.rnn_units, tgt_vocab_size=tgt_vocab_size, tgt_emb_size=config.tgt_emb_size) model = CTCModel(input_shape=config.input_shape, tgt_vocab_size=tgt_vocab_size, dropout=config.dropout, rnn_units=FLAGS.rnn_units) lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( config.learning_rate, decay_steps=config.decay_steps, decay_rate=0.96, staircase=True) optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) config = FLAGS for arg in vars(config): logger.info("{}, {}".format(arg, getattr(config, arg))) tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab(config.tgt_vocab_file, out_dir="./", sos="<s>", eos="</s>", unk=vocab_utils.UNK) tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file) word2idx, idx2word = vocab_utils.create_tgt_dict(tgt_vocab_file) print(word2idx["<blank>"]) print(len(word2idx)) # model = Model(rnn_units=config.rnn_units, tgt_vocab_size=tgt_vocab_size, tgt_emb_size=config.tgt_emb_size) if config.cnn_architecture == "resnet": cnn_model_path = config.resnet_weight_path else: cnn_model_path = config.alexnet_weight_path model = SFNet(input_shape=config.input_shape, tgt_vocab_size=tgt_vocab_size, rnn_units=config.rnn_units, cnn_arch=config.cnn_architecture, cnn_model_path=cnn_model_path, dropout=config.dropout)