Пример #1
0
  def predict(self):
    """Predicts result from the model."""
    params, flags_obj, is_train = self.params, self.flags_obj, False

    with tf.name_scope("model"):
      model = transformer.create_model(params, is_train)
      self._load_weights_if_possible(model, flags_obj.init_weight_path)
      model.summary()
    subtokenizer = tokenizer.Subtokenizer(flags_obj.vocab_file)

    ds = data_pipeline.eval_input_fn(params)
    ds = ds.map(lambda x, y: x).take(_SINGLE_SAMPLE)
    ret = model.predict(ds)
    val_outputs, _ = ret
    length = len(val_outputs)
    for i in range(length):
      translate.translate_from_input(val_outputs[i], subtokenizer)
Пример #2
0
    def train(self):
        """Trains the model."""
        params, flags_obj, is_train = self.params, self.flags_obj, True
        model = transformer.create_model(params, is_train)
        opt = self._create_optimizer()

        model.compile(opt, target_tensors=[])
        model.summary()
        self._load_weights_if_possible(model, flags_obj.init_weight_path)

        cur_log_dir = _get_log_dir_or_default(flags_obj)
        _ensure_dir(cur_log_dir)

        map_data_fn = data_pipeline.map_data_for_transformer_fn
        train_ds = data_pipeline.train_input_fn(params)
        train_ds = train_ds.map(
            map_data_fn, num_parallel_calls=params["num_parallel_calls"])
        valid_ds = data_pipeline.eval_input_fn(params)
        valid_ds = valid_ds.map(
            map_data_fn, num_parallel_calls=params["num_parallel_calls"])

        init_epoch = flags_obj.init_epoch or 0
        init_steps = init_epoch * flags_obj.steps_per_epoch
        callbacks = self._create_callbacks(cur_log_dir, init_steps, params)

        history = model.fit(train_ds,
                            initial_epoch=init_epoch,
                            epochs=flags_obj.train_epochs,
                            steps_per_epoch=flags_obj.steps_per_epoch,
                            validation_data=valid_ds,
                            validation_steps=flags_obj.validation_steps,
                            callbacks=callbacks)
        tf.compat.v1.logging.info("\nTrain history: {}".format(
            history.history))

        save_weight_path = os.path.join(cur_log_dir,
                                        "saves-model-weights.hdf5")
        save_model_path = os.path.join(cur_log_dir, "saves-model.hdf5")
        model.save_weights(save_weight_path)
        model.save(save_model_path)
Пример #3
0
    def predict(self):
        """Predicts result from the model."""
        self.params['train'] = False

        params = self.params
        flags_obj = self.flags_obj

        with tf.name_scope("model"):
            model = transformer.create_model(params, is_train=False)
            self._load_weights_if_possible(
                model, tf.train.latest_checkpoint(self.flags_obj.model_dir))
            model.summary()
        subtokenizer = tokenizer.Subtokenizer(flags_obj.vocab_file)
        print(params)
        ds = data_pipeline.eval_input_fn(params)
        ds = ds.map(lambda x, y: x).take(_SINGLE_SAMPLE)
        import time
        start = time.time()
        ret = model.predict(ds)
        val_outputs, _ = ret
        length = len(val_outputs)
        for i in range(length):
            translate.translate_from_input(val_outputs[i], subtokenizer)
        print('\n\n\n', time.time() - start)