def computation(source, source_length): placeholders = { "source": source, "source_length": source_length, } predictions = beamsearch.create_inference_graph(model_fns=model_fn, features=placeholders, decode_length=args.decode_length, beam_size=args.beam_size, top_beams=args.top_beams, decode_alpha=args.decode_alpha, bosId=nmt_config.bosId, eosId=nmt_config.eosId) return predictions[0], predictions[1]
def main(args): tf.logging.set_verbosity(tf.logging.INFO) # Load configs model_cls_list = [transformer.Transformer for model in args.models] params_list = [default_parameters() for _ in range(len(model_cls_list))] params_list = [ merge_parameters(params, model_cls.get_parameters()) for params, model_cls in zip(params_list, model_cls_list) ] params_list = [ import_params(args.models[i], model_cls_list[i].get_name(), params_list[i]) for i in range(len(args.models)) ] params_list = [ override_parameters(params_list[i], args) for i in range(len(model_cls_list)) ] # Build Graph with tf.Graph().as_default(): model_var_lists = [] # Load checkpoints for i, checkpoint in enumerate(args.models): tf.logging.info("Loading %s" % checkpoint) var_list = tf.train.list_variables(checkpoint) values = {} reader = tf.train.load_checkpoint(checkpoint) for (name, shape) in var_list: if not name.startswith( model_cls_list[i].get_name()): #ignore global_step continue tensor = reader.get_tensor(name) values[name] = tensor model_var_lists.append(values) # Build models model_fns = [] for i in range(len(args.models)): name = model_cls_list[i].get_name() model = model_cls_list[i](params_list[i], name + "_%d" % i) model_fn = model.get_inference_func() model_fns.append(model_fn) params = params_list[0] # Read input file sorted_keys, sorted_inputs = dataset.sort_input_file(args.input) # Build input queue features = dataset.get_inference_input(sorted_inputs, params) # Create placeholders placeholders = [] for i in range(len(params.device_list)): placeholders.append({ "source": tf.placeholder(tf.int32, [None, None], "source_%d" % i), "source_length": tf.placeholder(tf.int32, [None], "source_length_%d" % i) }) predictions = parallel.data_parallelism( params.device_list, lambda f: beamsearch.create_inference_graph(model_fns, f, params), placeholders) # Create assign ops assign_ops_all = [] assign_placeholders_all = [] assign_values_all = [] all_var_list = tf.trainable_variables() for i in range(len(args.models)): un_init_var_list = [] name = model_cls_list[i].get_name() for v in all_var_list: if v.name.startswith(name + "_%d" % i): un_init_var_list.append(v) assign_placeholders, assign_ops, assign_values = set_variables( un_init_var_list, model_var_lists[i], name + "_%d" % i) assign_placeholders_all.append(assign_placeholders) assign_ops_all.append(assign_ops) assign_values_all.append(assign_values) #assign_op = tf.group(*assign_ops) results = [] # Create session with tf.Session(config=session_config(params)) as sess: # Restore variables for i in range(len(args.models)): for p, assign_op, v in zip(assign_placeholders_all[i], assign_ops_all[i], assign_values_all[i]): sess.run(assign_op, {p: v}) sess.run(tf.tables_initializer()) while True: try: feats = sess.run(features) ops, feed_dict = shard_features(feats, placeholders, predictions) results.append(sess.run(ops, feed_dict=feed_dict)) message = "Finished batch %d" % len(results) tf.logging.log(tf.logging.INFO, message) except tf.errors.OutOfRangeError: break # Convert to plain text vocab = params.vocabulary["target"] outputs = [] scores = [] for result in results: for item in result[0]: outputs.append(item.tolist()) for item in result[1]: scores.append(item.tolist()) outputs = list(itertools.chain(*outputs)) scores = list(itertools.chain(*scores)) restored_inputs = [] restored_outputs = [] restored_scores = [] for index in range(len(sorted_inputs)): restored_inputs.append(sorted_inputs[sorted_keys[index]]) restored_outputs.append(outputs[sorted_keys[index]]) restored_scores.append(scores[sorted_keys[index]]) # Write to file with open(args.output, "w") as outfile: count = 0 for outputs, scores in zip(restored_outputs, restored_scores): for output, score in zip(outputs, scores): decoded = [] for idx in output: if isinstance(idx, six.integer_types): symbol = vocab[idx] else: symbol = idx if symbol == params.eos: break decoded.append(symbol) decoded = str.join(" ", decoded) if not args.log: outfile.write("%s\n" % decoded) break else: pattern = "src[%d]: %s \n trans[%.4f]: %s \n" source = restored_inputs[count] values = (count, source, score, decoded) outfile.write(pattern % values) count += 1
def main(args): tf.logging.set_verbosity(tf.logging.INFO) model_cls = transformer.Transformer args.model = model_cls.get_name() params = default_parameters() # Import and override parameters # Priorities (low -> high): # default -> saved -> command params = merge_parameters(params, model_cls.get_parameters()) params = import_params(args.output, args.model, params) override_parameters(params, args) # Export all parameters and model specific parameters export_params(params.output, "params.json", params) export_params(params.output, "%s.json" % args.model, collect_params(params, model_cls.get_parameters())) #tf.set_random_seed(params.seed) # Build Graph with tf.Graph().as_default(): # Build input queue features = dataset.get_training_input(params.input, params) # features, init_op = cache.cache_features(features, params.update_cycle) # Add pre_trained_embedding: if params.use_pretrained_embedding: _, src_embs = dataset.get_pre_embeddings(params.embeddings[0]) _, trg_embs = dataset.get_pre_embeddings(params.embeddings[1]) features['src_embs'] = src_embs features['trg_embs'] = trg_embs print('Loaded Embeddings!', src_embs.shape, trg_embs.shape) # Build model initializer = get_initializer(params) model = model_cls(params, args.model) # Multi-GPU setting sharded_losses = parallel.parallel_model( model.get_training_func(initializer), features, params.device_list) loss = tf.add_n(sharded_losses) / len(sharded_losses) # Create global step global_step = tf.train.get_or_create_global_step() initial_global_step = global_step.assign(0) # Print parameters all_weights = {v.name: v for v in tf.trainable_variables()} total_size = 0 for v_name in sorted(list(all_weights)): v = all_weights[v_name] tf.logging.info("%s\tshape %s", v.name[:-2].ljust(80), str(v.shape).ljust(20)) v_size = np.prod(np.array(v.shape.as_list())).tolist() total_size += v_size tf.logging.info("Total trainable variables size: %d", total_size) learning_rate = get_learning_rate_decay(params.learning_rate, global_step, params) if params.learning_rate_minimum: lr_min = float(params.learning_rate_minimum) learning_rate = tf.maximum(learning_rate, tf.to_float(lr_min)) learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32) tf.summary.scalar("learning_rate", learning_rate) # Create optimizer if params.optimizer == "Adam": opt = tf.train.AdamOptimizer(learning_rate, beta1=params.adam_beta1, beta2=params.adam_beta2, epsilon=params.adam_epsilon) elif params.optimizer == "LazyAdam": opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate, beta1=params.adam_beta1, beta2=params.adam_beta2, epsilon=params.adam_epsilon) else: raise RuntimeError("Optimizer %s not supported" % params.optimizer) loss, ops = optimize.create_train_op(loss, opt, global_step, params) restore_op = restore_variables(args.output) # Validation if params.validation and params.references[0]: files = [params.validation] + list(params.references) eval_inputs = dataset.sort_and_zip_files(files) eval_input_fn = dataset.get_evaluation_input else: eval_input_fn = None # Add hooks save_vars = tf.trainable_variables() + [global_step] saver = tf.train.Saver( var_list=save_vars if params.only_save_trainable else None, max_to_keep=params.keep_checkpoint_max, sharded=False) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) train_hooks = [ tf.train.StopAtStepHook(last_step=params.train_steps), #tf.train.StopAtStepHook(num_steps=params.train_steps), tf.train.NanTensorHook(loss), tf.train.LoggingTensorHook({ "step": global_step, "loss": loss, }, every_n_iter=params.print_steps), tf.train.CheckpointSaverHook( checkpoint_dir=params.output, save_secs=params.save_checkpoint_secs or None, save_steps=params.save_checkpoint_steps or None, saver=saver) ] config = session_config(params) if eval_input_fn is not None: train_hooks.append( hooks.EvaluationHook( lambda f: beamsearch.create_inference_graph( [model.get_inference_func()], f, params), lambda: eval_input_fn(eval_inputs, params), lambda x: decode_target_ids(x, params), params.output, config, params.keep_top_checkpoint_max, eval_steps_begin=params.eval_steps_begin, eval_secs=params.eval_secs, eval_steps=params.eval_steps)) def restore_fn(step_context): step_context.session.run(restore_op) def step_fn(step_context): # Bypass hook calls return step_context.run_with_hooks(ops) # Create session, do not use default CheckpointSaverHook with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output, hooks=train_hooks, save_checkpoint_secs=None, config=config) as sess: #sess.run(features['source'].eval()) #sess.run(features['target'].eval()) # Restore pre-trained variables sess.run_step_fn(restore_fn) if params.renew_lr == True: sess.run(initial_global_step) while not sess.should_stop(): sess.run_step_fn(step_fn)
def main(args): tf.logging.set_verbosity(tf.logging.INFO) vocabulary = make_vocab(args.vocab_file) nmt_config = modeling.NmtConfig.from_json_file(args.nmt_config_file) tf.logging.info("Checkpoint Vocab Size: %d", nmt_config.vocab_size) tf.logging.info("True Vocab Size: %d", len(vocabulary)) assert nmt_config.vocab_size == len(vocabulary) vocabulary[nmt_config.padId] = nmt_config.pad.encode() vocabulary[nmt_config.eosId] = nmt_config.eos.encode() vocabulary[nmt_config.unkId] = nmt_config.unk.encode() vocabulary[nmt_config.bosId] = nmt_config.bos.encode() # Build Graph with tf.Graph().as_default(): # Read input file sorted_keys, sorted_inputs = sort_input_file(args.source_input_file) while len(sorted_inputs) % args.decode_batch_size != 0: sorted_inputs.append(nmt_config.pad) tf.logging.info("Total Sentence Size: %d", len(sorted_keys)) # Build input queue with tf.device('/CPU:0'): features = get_inference_input(inputs=sorted_inputs, vocabulary=vocabulary, max_seq_length=args.max_seq_length, decode_length=args.decode_length, decode_batch_size=args.decode_batch_size, eos=nmt_config.eos.encode(), unkId=nmt_config.unkId, use_tpu=args.use_tpu) # Create placeholders if args.use_tpu: placeholders = { "source": tf.placeholder(tf.int32, [args.decode_batch_size, args.max_seq_length], "source_0"), "source_length": tf.placeholder(tf.int32, [args.decode_batch_size], "source_length_0") } else: placeholders = { "source": tf.placeholder(tf.int32, [args.decode_batch_size, None], "source_0"), "source_length": tf.placeholder(tf.int32, [args.decode_batch_size], "source_length_0") } model = modeling.NmtModel(config=nmt_config) model_fn = model.get_inference_func() if args.use_tpu: def computation(source, source_length): placeholders = { "source": source, "source_length": source_length, } predictions = beamsearch.create_inference_graph(model_fns=model_fn, features=placeholders, decode_length=args.decode_length, beam_size=args.beam_size, top_beams=args.top_beams, decode_alpha=args.decode_alpha, bosId=nmt_config.bosId, eosId=nmt_config.eosId) return predictions[0], predictions[1] ops = tf.compat.v1.tpu.batch_parallel(computation, [placeholders["source"], placeholders["source_length"]], num_shards=8) else: predictions = beamsearch.create_inference_graph(model_fns=model_fn, features=placeholders, decode_length=args.decode_length, beam_size=args.beam_size, top_beams=args.top_beams, decode_alpha=args.decode_alpha, bosId=nmt_config.bosId, eosId=nmt_config.eosId) ops = (predictions[0], predictions[1]) init_vars = tf.train.list_variables(args.init_checkpoint) tvars = tf.trainable_variables() name_to_variable = {} assignment_map = {} for var in tvars: name = var.name m = re.match("^(.*):\\d+$", name) if m is not None: name = m.group(1) name_to_variable[name] = var for x in init_vars: (name, var) = (x[0], x[1]) if name in name_to_variable: assignment_map[name] = name tf.train.init_from_checkpoint(args.init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") total_size = 0 for var in tvars: tf.logging.info(" name = %s, shape = %s", var.name, var.shape) total_size += reduce(lambda x, y: x * y, var.get_shape().as_list()) tf.logging.info(" total variable parameters: %d", total_size) results = [] target = '' config = None if args.use_tpu: tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(args.tpu_name) target = tpu_cluster_resolver.get_master() else: config = tf.ConfigProto(allow_soft_placement=True) with tf.Session(target=target, config=config) as sess: if args.use_tpu: sess.run(tf.contrib.tpu.initialize_system()) sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) while True: try: feats = sess.run(features) feed_dict = {} for name in feats: feed_dict[placeholders[name]] = feats[name] results.append(sess.run(ops, feed_dict=feed_dict)) tf.logging.log(tf.logging.INFO, "Finished batch %d" % len(results)) except tf.errors.OutOfRangeError: break if args.use_tpu: sess.run(tf.contrib.tpu.shutdown_system()) target_dir, _ = os.path.split(args.target_output_file) tf.gfile.MakeDirs(target_dir) outputs = [] for result in results: for item in result[0]: tmp = [] for subitem in item.tolist(): tmp.append(subitem) outputs.append(tmp) origin_outputs = [] for index in range(len(sorted_keys)): origin_outputs.append(outputs[sorted_keys[index]]) with tf.gfile.Open(args.target_output_file, "w") as outfile: for beam_group in origin_outputs: for output in beam_group: decoded = [] for idx in output: symbol = vocabulary[idx] if symbol == nmt_config.eos.encode(): break decoded.append(symbol) decoded = str.join(" ", decoded) if not args.output_bpe: decoded = decoded.replace("@@ ", "") outfile.write("%s\n" % decoded)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) vocabulary = make_vocab(FLAGS.vocab_file) nmt_config = modeling.NmtConfig.from_json_file(FLAGS.nmt_config_file, vocab_size=len(vocabulary)) vocabulary[0] = nmt_config.eos.encode() vocabulary[1] = nmt_config.unk.encode() vocabulary[2] = nmt_config.bos.encode() # Build Graph with tf.Graph().as_default(): # Read input file sorted_keys, sorted_inputs = sort_input_file(FLAGS.source_input_file) # Build input queue features = get_inference_input( inputs=sorted_inputs, vocabulary=vocabulary, decode_batch_size=FLAGS.decode_batch_size, eos=nmt_config.eos.encode(), unkId=nmt_config.unkId) # Create placeholders placeholders = { "source": tf.placeholder(tf.int32, [None, None], "source_0"), "source_length": tf.placeholder(tf.int32, [None], "source_length_0") } model = modeling.NmtModel(config=nmt_config) model_fn = model.get_inference_func() predictions = beamsearch.create_inference_graph( model_fns=model_fn, features=placeholders, decode_length=FLAGS.decode_length, beam_size=FLAGS.beam_size, top_beams=1, decode_alpha=FLAGS.decode_alpha, bosId=nmt_config.bosId, eosId=nmt_config.eosId) # Create assign ops tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if FLAGS.init_checkpoint: (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, FLAGS.init_checkpoint) tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") total_size = 0 for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) total_size += reduce(lambda x, y: x * y, var.get_shape().as_list()) tf.logging.info(" total variable parameters: %d", total_size) results = [] # Create session tpu_cluster = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project).get_master() with tf.Session(tpu_cluster) as sess: sess.run(tf.contrib.tpu.initialize_system()) # Restore variables sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) while True: try: feats = sess.run(features) ops = (predictions[0], predictions[1]) feed_dict = {} for name in feats: feed_dict[placeholders[name]] = feats[name] results.append(sess.run(ops, feed_dict=feed_dict)) message = "Finished batch %d" % len(results) tf.logging.log(tf.logging.INFO, message) except tf.errors.OutOfRangeError: break sess.run(tf.contrib.tpu.shutdown_system()) # Convert to plain text outputs = [] for result in results: for item in result[0]: for subitem in item.tolist(): outputs.append(subitem) restored_outputs = [] for index in range(len(sorted_inputs)): restored_outputs.append(outputs[sorted_keys[index]]) # Write to file with tf.gfile.Open(FLAGS.target_output_file, "w") as outfile: for output in restored_outputs: decoded = [] for idx in output: if isinstance(idx, six.integer_types): symbol = vocabulary[idx] else: symbol = idx if symbol == nmt_config.eos.encode(): break decoded.append(symbol) decoded = str.join(" ", decoded) decoded = decoded.replace("@@ ", "") outfile.write("%s\n" % decoded)