def predict(self): """Predicts result from the model.""" params, flags_obj, is_train = self.params, self.flags_obj, False with tf.name_scope("model"): model = transformer.create_model(params, is_train) self._load_weights_if_possible(model, flags_obj.init_weight_path) model.summary() subtokenizer = tokenizer.Subtokenizer(flags_obj.vocab_file) ds = data_pipeline.eval_input_fn(params) ds = ds.map(lambda x, y: x).take(_SINGLE_SAMPLE) ret = model.predict(ds) val_outputs, _ = ret length = len(val_outputs) for i in range(length): translate.translate_from_input(val_outputs[i], subtokenizer)
def train(self): """Trains the model.""" params, flags_obj, is_train = self.params, self.flags_obj, True model = transformer.create_model(params, is_train) opt = self._create_optimizer() model.compile(opt, target_tensors=[]) model.summary() self._load_weights_if_possible(model, flags_obj.init_weight_path) cur_log_dir = _get_log_dir_or_default(flags_obj) _ensure_dir(cur_log_dir) map_data_fn = data_pipeline.map_data_for_transformer_fn train_ds = data_pipeline.train_input_fn(params) train_ds = train_ds.map( map_data_fn, num_parallel_calls=params["num_parallel_calls"]) valid_ds = data_pipeline.eval_input_fn(params) valid_ds = valid_ds.map( map_data_fn, num_parallel_calls=params["num_parallel_calls"]) init_epoch = flags_obj.init_epoch or 0 init_steps = init_epoch * flags_obj.steps_per_epoch callbacks = self._create_callbacks(cur_log_dir, init_steps, params) history = model.fit(train_ds, initial_epoch=init_epoch, epochs=flags_obj.train_epochs, steps_per_epoch=flags_obj.steps_per_epoch, validation_data=valid_ds, validation_steps=flags_obj.validation_steps, callbacks=callbacks) tf.compat.v1.logging.info("\nTrain history: {}".format( history.history)) save_weight_path = os.path.join(cur_log_dir, "saves-model-weights.hdf5") save_model_path = os.path.join(cur_log_dir, "saves-model.hdf5") model.save_weights(save_weight_path) model.save(save_model_path)
def predict(self): """Predicts result from the model.""" self.params['train'] = False params = self.params flags_obj = self.flags_obj with tf.name_scope("model"): model = transformer.create_model(params, is_train=False) self._load_weights_if_possible( model, tf.train.latest_checkpoint(self.flags_obj.model_dir)) model.summary() subtokenizer = tokenizer.Subtokenizer(flags_obj.vocab_file) print(params) ds = data_pipeline.eval_input_fn(params) ds = ds.map(lambda x, y: x).take(_SINGLE_SAMPLE) import time start = time.time() ret = model.predict(ds) val_outputs, _ = ret length = len(val_outputs) for i in range(length): translate.translate_from_input(val_outputs[i], subtokenizer) print('\n\n\n', time.time() - start)