def __init__(self, hparams, vocab_size): super(SymbolBottomSimple, self).__init__() self._hparams = hparams_lib.copy_hparams(hparams) hidden_dim = self._hparams.hidden_size var_name = "embedding_weights" self._embedding_space = tf.get_variable( var_name, [vocab_size, hidden_dim], initializer=tf.random_normal_initializer(0.0, hidden_dim**-0.5))
def __init__(self, mode, hparams): self._model_hparams = hparams_lib.copy_hparams(hparams) self._input_fn = None self._input_fn_init(mode, hparams) self._hparams = None self._encoders = None self._iterator = self._input_fn.make_one_shot_iterator()
def set_mode(self, mode): """Set hparams with the given mode.""" tf.logging.info("Setting BaseModel mode to '%s'", mode) hparams = hparams_lib.copy_hparams(self._original_hparams) hparams.add_hparam("mode", mode) # When not in training mode, set all forms of dropout to zero. if mode != tf.estimator.ModeKeys.TRAIN: for key in hparams.values(): if key.endswith("dropout") or key == "label_smoothing": tf.logging.info("Setting hparams.%s to 0.0", key) setattr(hparams, key, 0.0) self._hparams = hparams
def __init__(self, hparams, mode=tf.estimator.ModeKeys.TRAIN, problem_hparams=None, **kwargs): super(PretrainModel, self).__init__(**kwargs) # setup hparams self._problem_hparams = problem_hparams self._original_hparams = hparams_lib.copy_hparams(hparams) self.set_mode(mode) self.symbol_bottom_inputs = modalities.SymbolBottomSimple( self._original_hparams, self._problem_hparams.vocab_size["inputs"], ) self.predict_dense = common_layers.DenseReluDense( self._hparams.filter_size, self._hparams.hidden_size, )
def __init__(self, hparams, mode=tf.estimator.ModeKeys.TRAIN, problem_hparams=None, **kwargs): super(BaseModel, self).__init__(**kwargs) # setup hparams self._problem_hparams = problem_hparams self._original_hparams = hparams_lib.copy_hparams(hparams) self.set_mode(mode) self.symbol_bottom_inputs = modalities.SymbolBottomSimple( self._original_hparams, self._problem_hparams.vocab_size["inputs"], ) self.symbol_bottom_targets = modalities.SymbolBottomSimple( self._original_hparams, self._problem_hparams.vocab_size["targets"], )
def train( train_generator, dev_data_generator_fn, hparams, ): hparams = hparams_lib.copy_hparams(hparams) generator_hparams = train_generator.get_generator_hparams() model_cls = registry.model(hparams.model) model = model_cls( hparams, tf.estimator.ModeKeys.TRAIN, generator_hparams, ) optimizer = model.optimizer() # Create and restore checkpoint (if one exists on the path) # checkpoint_prefix = os.path.join(hparams.model_dir, 'ckpt') step_counter = tf.train.get_or_create_global_step() checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer, optimizer_step=step_counter) checkpoint_manager = tf.contrib.checkpoint.CheckpointManager( checkpoint, directory=hparams.model_dir, max_to_keep=hparams.checkpoint_max_to_keep, ) # Restore variables on creation if a checkpoint exists. checkpoint.restore(checkpoint_manager.latest_checkpoint) start_step = step_counter.value() start_timer = time.time() for step_index in range(start_step, hparams.train_steps): features = train_generator.get_next() with tf.GradientTape() as tape: logits, loss = model(features, training=True) tf.contrib.summary.scalar("loss", loss) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables), global_step=step_counter) if not step_index % hparams.monitor_steps: end_timer = time.time() tf.logging.warn("* Loss for step {}: {}".format(step_index, loss)) tf.logging.warn("* Running {} batches takes: {} secs".format( hparams.monitor_steps, end_timer - start_timer)) start_timer = time.time() if not step_index % hparams.eval_steps: dev_data_generator = dev_data_generator_fn() eval_losses = [] while True: try: eval_features = dev_data_generator.get_next() _, eval_loss = model(eval_features, training=False) eval_losses.append(eval_loss) except tf.errors.OutOfRangeError: break eval_mean_loss = tf.reduce_mean(eval_losses) tf.contrib.summary.scalar("eval loss", eval_mean_loss) tf.logging.warn("* Eval for step {}, loss: {}".format( step_index, eval_mean_loss)) checkpoint_manager.save()