Ejemplo n.º 1
0
def init_or_restore_variables(config, sess, ensemble_scope=None, train=False):
    # Construct a mapping between saved variable names and names in the current
    # scope. There are two reasons why names might be different:
    #
    #   1. This model is part of an ensemble, in which case a model-specific
    #       name scope will be active.
    #
    #   2. The saved model is from an old version of Nematus (before deep model
    #        support was added) and uses a different variable naming scheme
    #        for the GRUs.
    variables = slim.get_variables_to_restore()
    var_map = {}
    for v in variables:
        name = v.name.split(':')[0]
        if ensemble_scope == None:
            saved_name = name
        elif v.name.startswith(ensemble_scope.name + "/"):
            saved_name = name[len(ensemble_scope.name) + 1:]
            # The ensemble scope is repeated for Adam variables. See
            # https://github.com/tensorflow/tensorflow/issues/8120
            if saved_name.startswith(ensemble_scope.name + "/"):
                saved_name = saved_name[len(ensemble_scope.name) + 1:]
        else:  # v belongs to a different model in the ensemble.
            continue
        if config.model_version == 0.1:
            # Backwards compatibility with the old variable naming scheme.
            saved_name = compat.revert_variable_name(saved_name, 0.1)
        var_map[saved_name] = v
    saver = tf.train.Saver(var_map, max_to_keep=None)

    # compute reload model filename
    reload_filename = None
    if config.reload == 'latest_checkpoint':
        checkpoint_dir = os.path.dirname(config.saveto)
        reload_filename = tf.train.latest_checkpoint(checkpoint_dir)
        if reload_filename != None:
            if (os.path.basename(reload_filename).rsplit('-', 1)[0] !=
                    os.path.basename(config.saveto)):
                logging.error(
                    "Mismatching model filename found in the same directory while reloading from the latest checkpoint"
                )
                sys.exit(1)
            logging.info('Latest checkpoint found in directory ' +
                         os.path.abspath(checkpoint_dir))
    elif config.reload != None:
        reload_filename = config.reload
    if (reload_filename == None) and (config.prior_model != None):
        logging.info('Initializing model parameters from prior')
        reload_filename = config.prior_model

    # initialize or reload training progress
    if train:
        progress = training_progress.TrainingProgress()
        progress.bad_counter = 0
        progress.uidx = 0
        progress.eidx = 0
        progress.estop = False
        progress.history_errs = []
        progress.valid_script_scores = []
        if reload_filename and config.reload_training_progress:
            path = reload_filename + '.progress.json'
            if os.path.exists(path):
                logging.info('Reloading training progress')
                progress.load_from_json(path)
                if (progress.estop == True or progress.eidx > config.max_epochs
                        or progress.uidx >= config.finish_after):
                    logging.warning(
                        'Training is already complete. Disable reloading of training progress (--no_reload_training_progress) or remove or modify progress file (%s) to train anyway.'
                        % reload_path)
                    sys.exit(0)

    # load prior model
    if train and config.prior_model != None:
        load_prior(config, sess, saver)

    # initialize or restore model
    if reload_filename == None:
        logging.info('Initializing model parameters from scratch...')
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
    else:
        logging.info('Loading model parameters from file ' +
                     os.path.abspath(reload_filename))
        saver.restore(sess, os.path.abspath(reload_filename))

    logging.info('Done')

    if train:
        return saver, progress
    else:
        return saver
Ejemplo n.º 2
0
def create_model(config, sess, ensemble_scope=None, train=False):
    logging.info('Building model...')
    # StandardModel is a class in model.py
    model = StandardModel(config)

    # Construct a mapping between saved variable names and names in the current
    # scope. There are two reasons why names might be different:
    #
    #   1. This model is part of an ensemble, in which case a model-specific
    #       name scope will be active.
    #
    #   2. The saved model is from an old version of Nematus (before deep model
    #        support was added) and uses a different variable naming scheme
    #        for the GRUs.

    # slim is just another TF library that can integrate with native TF append
    # is supposed to make defining, training and evaluating complex models easier

    # the variables obtained here are those initialized above for Standard models
    #
    variables = slim.get_variables_to_restore()
    var_map = {}
    for v in variables:
        name = v.name.split(':')[0]
        if ensemble_scope == None:
            saved_name = name
        elif v.name.startswith(ensemble_scope):
            saved_name = name[len(ensemble_scope):]
        else:  # v belongs to a different model in the ensemble.
            continue
        if config.model_version == 0.1:
            # Backwards compatibility with the old variable naming scheme.
            saved_name = compat.revert_variable_name(saved_name, 0.1)
        var_map[saved_name] = v
    saver = tf.train.Saver(var_map, max_to_keep=None)

    # compute reload model filename
    reload_filename = None
    if config.reload == 'latest_checkpoint':
        checkpoint_dir = os.path.dirname(config.saveto)
        reload_filename = tf.train.latest_checkpoint(checkpoint_dir)
        if reload_filename != None:
            if (os.path.basename(reload_filename).rsplit('-', 1)[0] !=
                    os.path.basename(config.saveto)):
                logging.error(
                    "Mismatching model filename found in the same directory while reloading from the latest checkpoint"
                )
                sys.exit(1)
            logging.info('Latest checkpoint found in directory ' +
                         os.path.abspath(checkpoint_dir))
    elif config.reload != None:
        reload_filename = config.reload
    if (reload_filename == None) and (config.prior_model != None):
        logging.info('Initializing model parameters from prior')
        reload_filename = config.prior_model

    # initialize or reload training progress
    if train:
        progress = training_progress.TrainingProgress()
        progress.bad_counter = 0
        progress.uidx = 0
        progress.eidx = 0
        progress.estop = False
        progress.history_errs = []
        if reload_filename and config.reload_training_progress:
            path = reload_filename + '.progress.json'
            if os.path.exists(path):
                logging.info('Reloading training progress')
                progress.load_from_json(path)
                if (progress.estop == True or progress.eidx > config.max_epochs
                        or progress.uidx >= config.finish_after):
                    logging.warning(
                        'Training is already complete. Disable reloading of training progress (--no_reload_training_progress) or remove or modify progress file (%s) to train anyway.'
                        % reload_path)
                    sys.exit(0)

    # load prior model
    if train and config.prior_model != None:
        load_prior(config, sess, saver)

    # initialize or restore model
    if reload_filename == None:
        #so if training for first time, do this
        logging.info('Initializing model parameters from scratch...')
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
    else:
        logging.info('Loading model parameters from file ' +
                     os.path.abspath(reload_filename))
        saver.restore(sess, os.path.abspath(reload_filename))
        if train:
            # The global step is currently recorded in two places:
            #   1. model.t, a tf.Variable read and updated by the optimizer
            #   2. progress.uidx, a Python integer updated by train()
            # We reset model.t to the value recorded in progress to allow the
            # value to be controlled by the user (either explicitly by
            # configuring the value in the progress file or implicitly by using
            # --no_reload_training_progress).
            model.reset_global_step(progress.uidx, sess)

    logging.info('Done')

    if train:
        return model, saver, progress
    else:
        return model, saver