Esempio n. 1
0
 def computation(source, source_length):
   placeholders = {
       "source": source,
       "source_length": source_length,
   }
   predictions = beamsearch.create_inference_graph(model_fns=model_fn,
                                                   features=placeholders,
                                                   decode_length=args.decode_length,
                                                   beam_size=args.beam_size,
                                                   top_beams=args.top_beams,
                                                   decode_alpha=args.decode_alpha,
                                                   bosId=nmt_config.bosId,
                                                   eosId=nmt_config.eosId)
   return predictions[0], predictions[1]
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [transformer.Transformer for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.models[i], model_cls_list[i].get_name(),
                      params_list[i]) for i in range(len(args.models))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.models):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(
                        model_cls_list[i].get_name()):  #ignore global_step
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.models)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i)
            })

        predictions = parallel.data_parallelism(
            params.device_list,
            lambda f: beamsearch.create_inference_graph(model_fns, f, params),
            placeholders)

        # Create assign ops
        assign_ops_all = []
        assign_placeholders_all = []
        assign_values_all = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.models)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            assign_placeholders, assign_ops, assign_values = set_variables(
                un_init_var_list, model_var_lists[i], name + "_%d" % i)

            assign_placeholders_all.append(assign_placeholders)
            assign_ops_all.append(assign_ops)
            assign_values_all.append(assign_values)

        #assign_op = tf.group(*assign_ops)
        results = []

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # Restore variables
            for i in range(len(args.models)):
                for p, assign_op, v in zip(assign_placeholders_all[i],
                                           assign_ops_all[i],
                                           assign_values_all[i]):
                    sess.run(assign_op, {p: v})
            sess.run(tf.tables_initializer())

            while True:
                try:
                    feats = sess.run(features)
                    ops, feed_dict = shard_features(feats, placeholders,
                                                    predictions)
                    results.append(sess.run(ops, feed_dict=feed_dict))
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                except tf.errors.OutOfRangeError:
                    break

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            for item in result[0]:
                outputs.append(item.tolist())
            for item in result[1]:
                scores.append(item.tolist())

        outputs = list(itertools.chain(*outputs))
        scores = list(itertools.chain(*scores))

        restored_inputs = []
        restored_outputs = []
        restored_scores = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for outputs, scores in zip(restored_outputs, restored_scores):
                for output, score in zip(outputs, scores):
                    decoded = []
                    for idx in output:
                        if isinstance(idx, six.integer_types):
                            symbol = vocab[idx]
                        else:
                            symbol = idx

                        if symbol == params.eos:
                            break
                        decoded.append(symbol)

                    decoded = str.join(" ", decoded)

                    if not args.log:
                        outfile.write("%s\n" % decoded)
                        break
                    else:
                        pattern = "src[%d]: %s \n trans[%.4f]: %s \n"
                        source = restored_inputs[count]
                        values = (count, source, score, decoded)
                        outfile.write(pattern % values)

                count += 1
Esempio n. 3
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = transformer.Transformer
    args.model = model_cls.get_name()
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    export_params(params.output, "params.json", params)
    export_params(params.output, "%s.json" % args.model,
                  collect_params(params, model_cls.get_parameters()))

    #tf.set_random_seed(params.seed)

    # Build Graph
    with tf.Graph().as_default():
        # Build input queue
        features = dataset.get_training_input(params.input, params)

        # features, init_op = cache.cache_features(features, params.update_cycle)
        # Add pre_trained_embedding:
        if params.use_pretrained_embedding:
            _, src_embs = dataset.get_pre_embeddings(params.embeddings[0])
            _, trg_embs = dataset.get_pre_embeddings(params.embeddings[1])
            features['src_embs'] = src_embs
            features['trg_embs'] = trg_embs
            print('Loaded Embeddings!', src_embs.shape, trg_embs.shape)

        # Build model
        initializer = get_initializer(params)
        model = model_cls(params, args.model)

        # Multi-GPU setting
        sharded_losses = parallel.parallel_model(
            model.get_training_func(initializer), features, params.device_list)
        loss = tf.add_n(sharded_losses) / len(sharded_losses)

        # Create global step
        global_step = tf.train.get_or_create_global_step()
        initial_global_step = global_step.assign(0)

        # Print parameters
        all_weights = {v.name: v for v in tf.trainable_variables()}
        total_size = 0

        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            tf.logging.info("%s\tshape    %s", v.name[:-2].ljust(80),
                            str(v.shape).ljust(20))
            v_size = np.prod(np.array(v.shape.as_list())).tolist()
            total_size += v_size
        tf.logging.info("Total trainable variables size: %d", total_size)

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        if params.learning_rate_minimum:
            lr_min = float(params.learning_rate_minimum)
            learning_rate = tf.maximum(learning_rate, tf.to_float(lr_min))

        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        loss, ops = optimize.create_train_op(loss, opt, global_step, params)

        restore_op = restore_variables(args.output)

        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.get_evaluation_input
        else:
            eval_input_fn = None

        # Add hooks
        save_vars = tf.trainable_variables() + [global_step]
        saver = tf.train.Saver(
            var_list=save_vars if params.only_save_trainable else None,
            max_to_keep=params.keep_checkpoint_max,
            sharded=False)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)

        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            #tf.train.StopAtStepHook(num_steps=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook({
                "step": global_step,
                "loss": loss,
            },
                                       every_n_iter=params.print_steps),
            tf.train.CheckpointSaverHook(
                checkpoint_dir=params.output,
                save_secs=params.save_checkpoint_secs or None,
                save_steps=params.save_checkpoint_steps or None,
                saver=saver)
        ]

        config = session_config(params)

        if eval_input_fn is not None:
            train_hooks.append(
                hooks.EvaluationHook(
                    lambda f: beamsearch.create_inference_graph(
                        [model.get_inference_func()], f, params),
                    lambda: eval_input_fn(eval_inputs, params),
                    lambda x: decode_target_ids(x, params),
                    params.output,
                    config,
                    params.keep_top_checkpoint_max,
                    eval_steps_begin=params.eval_steps_begin,
                    eval_secs=params.eval_secs,
                    eval_steps=params.eval_steps))

        def restore_fn(step_context):
            step_context.session.run(restore_op)

        def step_fn(step_context):
            # Bypass hook calls
            return step_context.run_with_hooks(ops)

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            #sess.run(features['source'].eval())
            #sess.run(features['target'].eval())
            # Restore pre-trained variables
            sess.run_step_fn(restore_fn)
            if params.renew_lr == True:
                sess.run(initial_global_step)

            while not sess.should_stop():
                sess.run_step_fn(step_fn)
Esempio n. 4
0
def main(args):
  tf.logging.set_verbosity(tf.logging.INFO)

  vocabulary = make_vocab(args.vocab_file)

  nmt_config = modeling.NmtConfig.from_json_file(args.nmt_config_file)

  tf.logging.info("Checkpoint Vocab Size: %d", nmt_config.vocab_size)
  tf.logging.info("True Vocab Size: %d", len(vocabulary))

  assert nmt_config.vocab_size == len(vocabulary)

  vocabulary[nmt_config.padId] = nmt_config.pad.encode()
  vocabulary[nmt_config.eosId] = nmt_config.eos.encode()
  vocabulary[nmt_config.unkId] = nmt_config.unk.encode()
  vocabulary[nmt_config.bosId] = nmt_config.bos.encode()

  # Build Graph
  with tf.Graph().as_default():
    # Read input file
    sorted_keys, sorted_inputs = sort_input_file(args.source_input_file)
    while len(sorted_inputs) % args.decode_batch_size != 0:
      sorted_inputs.append(nmt_config.pad)

    tf.logging.info("Total Sentence Size: %d", len(sorted_keys))

    # Build input queue
    with tf.device('/CPU:0'):
      features = get_inference_input(inputs=sorted_inputs,
                                     vocabulary=vocabulary,
                                     max_seq_length=args.max_seq_length,
                                     decode_length=args.decode_length,
                                     decode_batch_size=args.decode_batch_size,
                                     eos=nmt_config.eos.encode(),
                                     unkId=nmt_config.unkId,
                                     use_tpu=args.use_tpu)

    # Create placeholders
    if args.use_tpu:
      placeholders = {
          "source": tf.placeholder(tf.int32, [args.decode_batch_size, args.max_seq_length], "source_0"),
          "source_length": tf.placeholder(tf.int32, [args.decode_batch_size], "source_length_0")
      }
    else:
      placeholders = {
          "source": tf.placeholder(tf.int32, [args.decode_batch_size, None], "source_0"),
          "source_length": tf.placeholder(tf.int32, [args.decode_batch_size], "source_length_0")
      }

    model = modeling.NmtModel(config=nmt_config)

    model_fn = model.get_inference_func()

    if args.use_tpu:
      def computation(source, source_length):
        placeholders = {
            "source": source,
            "source_length": source_length,
        }
        predictions = beamsearch.create_inference_graph(model_fns=model_fn,
                                                        features=placeholders,
                                                        decode_length=args.decode_length,
                                                        beam_size=args.beam_size,
                                                        top_beams=args.top_beams,
                                                        decode_alpha=args.decode_alpha,
                                                        bosId=nmt_config.bosId,
                                                        eosId=nmt_config.eosId)
        return predictions[0], predictions[1]

      ops = tf.compat.v1.tpu.batch_parallel(computation,
                                            [placeholders["source"],
                                             placeholders["source_length"]],
                                            num_shards=8)
    else:
      predictions = beamsearch.create_inference_graph(model_fns=model_fn,
                                                      features=placeholders,
                                                      decode_length=args.decode_length,
                                                      beam_size=args.beam_size,
                                                      top_beams=args.top_beams,
                                                      decode_alpha=args.decode_alpha,
                                                      bosId=nmt_config.bosId,
                                                      eosId=nmt_config.eosId)
      ops = (predictions[0], predictions[1])

    init_vars = tf.train.list_variables(args.init_checkpoint)
    tvars = tf.trainable_variables()

    name_to_variable = {}
    assignment_map = {}

    for var in tvars:
      name = var.name
      m = re.match("^(.*):\\d+$", name)
      if m is not None:
        name = m.group(1)
      name_to_variable[name] = var

    for x in init_vars:
      (name, var) = (x[0], x[1])
      if name in name_to_variable:
        assignment_map[name] = name
    tf.train.init_from_checkpoint(args.init_checkpoint, assignment_map)

    tf.logging.info("**** Trainable Variables ****")
    total_size = 0
    for var in tvars:
      tf.logging.info("  name = %s, shape = %s", var.name, var.shape)
      total_size += reduce(lambda x, y: x * y, var.get_shape().as_list())
    tf.logging.info("  total variable parameters: %d", total_size)

    results = []

    target = ''
    config = None
    if args.use_tpu:
      tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(args.tpu_name)
      target = tpu_cluster_resolver.get_master()
    else:
      config = tf.ConfigProto(allow_soft_placement=True)

    with tf.Session(target=target, config=config) as sess:
      if args.use_tpu:
        sess.run(tf.contrib.tpu.initialize_system())

      sess.run(tf.global_variables_initializer())
      sess.run(tf.tables_initializer())
      while True:
        try:
          feats = sess.run(features)
          feed_dict = {}
          for name in feats:
            feed_dict[placeholders[name]] = feats[name]
          results.append(sess.run(ops, feed_dict=feed_dict))
          tf.logging.log(tf.logging.INFO, "Finished batch %d" % len(results))
        except tf.errors.OutOfRangeError:
          break

      if args.use_tpu:
        sess.run(tf.contrib.tpu.shutdown_system())

    target_dir, _ = os.path.split(args.target_output_file)
    tf.gfile.MakeDirs(target_dir)

    outputs = []

    for result in results:
      for item in result[0]:
        tmp = []
        for subitem in item.tolist():
          tmp.append(subitem)
        outputs.append(tmp)

    origin_outputs = []
    for index in range(len(sorted_keys)):
      origin_outputs.append(outputs[sorted_keys[index]])

    with tf.gfile.Open(args.target_output_file, "w") as outfile:
      for beam_group in origin_outputs:
        for output in beam_group:
          decoded = []
          for idx in output:
            symbol = vocabulary[idx]
            if symbol == nmt_config.eos.encode():
              break
            decoded.append(symbol)

          decoded = str.join(" ", decoded)

          if not args.output_bpe:
            decoded = decoded.replace("@@ ", "")

          outfile.write("%s\n" % decoded)
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    vocabulary = make_vocab(FLAGS.vocab_file)

    nmt_config = modeling.NmtConfig.from_json_file(FLAGS.nmt_config_file,
                                                   vocab_size=len(vocabulary))

    vocabulary[0] = nmt_config.eos.encode()
    vocabulary[1] = nmt_config.unk.encode()
    vocabulary[2] = nmt_config.bos.encode()

    # Build Graph
    with tf.Graph().as_default():
        # Read input file
        sorted_keys, sorted_inputs = sort_input_file(FLAGS.source_input_file)
        # Build input queue
        features = get_inference_input(
            inputs=sorted_inputs,
            vocabulary=vocabulary,
            decode_batch_size=FLAGS.decode_batch_size,
            eos=nmt_config.eos.encode(),
            unkId=nmt_config.unkId)

        # Create placeholders
        placeholders = {
            "source": tf.placeholder(tf.int32, [None, None], "source_0"),
            "source_length": tf.placeholder(tf.int32, [None],
                                            "source_length_0")
        }

        model = modeling.NmtModel(config=nmt_config)

        model_fn = model.get_inference_func()

        predictions = beamsearch.create_inference_graph(
            model_fns=model_fn,
            features=placeholders,
            decode_length=FLAGS.decode_length,
            beam_size=FLAGS.beam_size,
            top_beams=1,
            decode_alpha=FLAGS.decode_alpha,
            bosId=nmt_config.bosId,
            eosId=nmt_config.eosId)

        # Create assign ops
        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if FLAGS.init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, FLAGS.init_checkpoint)
            tf.train.init_from_checkpoint(FLAGS.init_checkpoint,
                                          assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        total_size = 0
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)
            total_size += reduce(lambda x, y: x * y, var.get_shape().as_list())
        tf.logging.info("  total variable parameters: %d", total_size)

        results = []

        # Create session
        tpu_cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone,
            project=FLAGS.gcp_project).get_master()

        with tf.Session(tpu_cluster) as sess:
            sess.run(tf.contrib.tpu.initialize_system())
            # Restore variables
            sess.run(tf.global_variables_initializer())
            sess.run(tf.tables_initializer())

            while True:
                try:
                    feats = sess.run(features)
                    ops = (predictions[0], predictions[1])
                    feed_dict = {}
                    for name in feats:
                        feed_dict[placeholders[name]] = feats[name]
                    results.append(sess.run(ops, feed_dict=feed_dict))
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                except tf.errors.OutOfRangeError:
                    break
            sess.run(tf.contrib.tpu.shutdown_system())

        # Convert to plain text
        outputs = []

        for result in results:
            for item in result[0]:
                for subitem in item.tolist():
                    outputs.append(subitem)

        restored_outputs = []
        for index in range(len(sorted_inputs)):
            restored_outputs.append(outputs[sorted_keys[index]])

        # Write to file
        with tf.gfile.Open(FLAGS.target_output_file, "w") as outfile:
            for output in restored_outputs:
                decoded = []
                for idx in output:
                    if isinstance(idx, six.integer_types):
                        symbol = vocabulary[idx]
                    else:
                        symbol = idx

                    if symbol == nmt_config.eos.encode():
                        break
                    decoded.append(symbol)

                decoded = str.join(" ", decoded)

                decoded = decoded.replace("@@ ", "")

                outfile.write("%s\n" % decoded)