Exemple #1
0
def Entrenamiento(construye_grafo, train_dir, num_training_steps=None):
    summary_frequency = 10
    save_checkpoint_secs = 60
    checkpoints_to_keep = 10
    keep_checkpoint_every_n_hours = 1
    master = ''
    task = 0
    num_ps_tasks = 0

    with tf.Graph().as_default():
        with tf.device(tf.compat.v1.train.replica_device_setter(num_ps_tasks)):
            tf.compat.v1.logging.set_verbosity('ERROR')
            construye_grafo()

            global_step = tf.compat.v1.train.get_or_create_global_step()
            loss = tf.compat.v1.get_collection('loss')[0]
            perplexity = tf.compat.v1.get_collection('metrics/perplexity')[0]
            accuracy = tf.compat.v1.get_collection('metrics/accuracy')[0]
            train_op = tf.compat.v1.get_collection('train_op')[0]

            logging_dict = {
                'Global Step': global_step,
                'Loss': loss,
                'Perplexity': perplexity,
                'Accuracy': accuracy
            }

            hooks = [
                tf.estimator.NanTensorHook(loss),
                tf.estimator.LoggingTensorHook(logging_dict,
                                               every_n_iter=summary_frequency),
                tf.estimator.StepCounterHook(output_dir=train_dir,
                                             every_n_steps=summary_frequency)
            ]

            if num_training_steps:
                hooks.append(tf.estimator.StopAtStepHook(num_training_steps))

            scaffold = tf.compat.v1.train.Scaffold(
                saver=tf.compat.v1.train.Saver(
                    max_to_keep=checkpoints_to_keep,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours
                ))
            tf.compat.v1.logging.set_verbosity('INFO')
            tf.compat.v1.logging.info('Comenzando ciclo de entrenamiento...')
            contrib_training.train(train_op=train_op,
                                   logdir=train_dir,
                                   scaffold=scaffold,
                                   hooks=hooks,
                                   save_checkpoint_secs=save_checkpoint_secs,
                                   save_summaries_steps=summary_frequency,
                                   master=master,
                                   is_chief=task == 0)
            tf.compat.v1.logging.info('Entrenamiento completado.')
def run_training(build_graph_fn, train_dir, num_training_steps=None,
                 summary_frequency=10, save_checkpoint_secs=60,
                 checkpoints_to_keep=10, keep_checkpoint_every_n_hours=1,
                 master='', task=0, num_ps_tasks=0):
    """Runs the training loop.

  Args:
    build_graph_fn: A function that builds the graph ops.
    train_dir: The path to the directory where checkpoints and summary events
        will be written to.
    num_training_steps: The number of steps to train for before exiting.
    summary_frequency: The number of steps between each summary. A summary is
        when graph values from the last step are logged to the console and
        written to disk.
    save_checkpoint_secs: The frequency at which to save checkpoints, in
        seconds.
    checkpoints_to_keep: The number of most recent checkpoints to keep in
       `train_dir`. Keeps all if set to 0.
    keep_checkpoint_every_n_hours: Keep a checkpoint every N hours, even if it
        results in more checkpoints than checkpoints_to_keep.
    master: URL of the Tensorflow master.
    task: Task number for this worker.
    num_ps_tasks: Number of parameter server tasks.
  """
    with tf.Graph().as_default():
        with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
            build_graph_fn()

            global_step = tf.train.get_or_create_global_step()
            loss = tf.get_collection('loss')[0]
            perplexity = tf.get_collection('metrics/perplexity')[0]
            accuracy = tf.get_collection('metrics/accuracy')[0]
            train_op = tf.get_collection('train_op')[0]

            logging_dict = {
                'Global Step': global_step,
                'Loss': loss,
                'Perplexity': perplexity,
                'Accuracy': accuracy
            }
            hooks = [
                tf.train.NanTensorHook(loss),
                tf.train.LoggingTensorHook(
                    logging_dict, every_n_iter=summary_frequency),
                tf.train.StepCounterHook(
                    output_dir=train_dir, every_n_steps=summary_frequency)
            ]
            if num_training_steps:
                hooks.append(tf.train.StopAtStepHook(num_training_steps))

            scaffold = tf.train.Scaffold(
                saver=tf.train.Saver(
                    max_to_keep=checkpoints_to_keep,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours))

            tf.logging.info('Starting training loop...')
            contrib_training.train(
                train_op=train_op,
                logdir=train_dir,
                scaffold=scaffold,
                hooks=hooks,
                save_checkpoint_secs=save_checkpoint_secs,
                save_summaries_steps=summary_frequency,
                master=master,
                is_chief=task == 0)
            tf.logging.info('Training complete.')
Exemple #3
0
def train(train_dir,
          config,
          dataset_fn,
          checkpoints_to_keep=5,
          keep_checkpoint_every_n_hours=1,
          num_steps=None,
          master='',
          num_sync_workers=0,
          num_ps_tasks=0,
          task=0):
    """Train loop."""
    tf.gfile.MakeDirs(train_dir)
    is_chief = (task == 0)
    if is_chief:
        _trial_summary(config.hparams, config.train_examples_path
                       or config.tfds_name, train_dir)
    with tf.Graph().as_default():
        with tf.device(
                tf.train.replica_device_setter(num_ps_tasks,
                                               merge_devices=True)):

            model = config.model
            model.build(config.hparams,
                        config.data_converter.output_depth,
                        is_training=True)

            optimizer = model.train(**_get_input_tensors(dataset_fn(), config))

            hooks = []
            if num_sync_workers:
                optimizer = tf.train.SyncReplicasOptimizer(
                    optimizer, num_sync_workers)
                hooks.append(optimizer.make_session_run_hook(is_chief))

            grads, var_list = zip(*optimizer.compute_gradients(model.loss))
            global_norm = tf.global_norm(grads)
            tf.summary.scalar('global_norm', global_norm)

            if config.hparams.clip_mode == 'value':
                g = config.hparams.grad_clip
                clipped_grads = [
                    tf.clip_by_value(grad, -g, g) for grad in grads
                ]
            elif config.hparams.clip_mode == 'global_norm':
                clipped_grads = tf.cond(
                    global_norm < config.hparams.grad_norm_clip_to_zero,
                    lambda: tf.clip_by_global_norm(  # pylint:disable=g-long-lambda
                        grads,
                        config.hparams.grad_clip,
                        use_norm=global_norm)[0],
                    lambda: [tf.zeros(tf.shape(g)) for g in grads])
            else:
                raise ValueError('Unknown clip_mode: {}'.format(
                    config.hparams.clip_mode))
            train_op = optimizer.apply_gradients(zip(clipped_grads, var_list),
                                                 global_step=model.global_step,
                                                 name='train_step')

            logging_dict = {
                'global_step': model.global_step,
                'loss': model.loss
            }

            hooks.append(
                tf.train.LoggingTensorHook(logging_dict, every_n_iter=100))
            if num_steps:
                hooks.append(tf.train.StopAtStepHook(last_step=num_steps))

            scaffold = tf.train.Scaffold(saver=tf.train.Saver(
                max_to_keep=checkpoints_to_keep,
                keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours))
            contrib_training.train(train_op=train_op,
                                   logdir=train_dir,
                                   scaffold=scaffold,
                                   hooks=hooks,
                                   save_checkpoint_secs=60,
                                   master=master,
                                   is_chief=is_chief)