示例#1
0
def main(settings):
    """
    Translates a source language file (or STDIN) into a target language file
    (or STDOUT).
    """
    # Create the TensorFlow session.
    tf_config = tf.ConfigProto()
    tf_config.allow_soft_placement = True
    session = tf.Session(config=tf_config)

    # Load config file for each model.
    configs = []
    for model in settings.models:
        config = load_config_from_json_file(model)
        setattr(config, 'reload', model)
        configs.append(config)

    # Create the model graphs.
    logging.debug("Loading models\n")
    models = []
    for i, config in enumerate(configs):
        with tf.variable_scope("model%d" % i) as scope:
            if config.model_type == "transformer":
                model = TransformerModel(config)
            else:
                model = rnn_model.RNNModel(config)
            model.sampling_utils = SamplingUtils(settings)
            models.append(model)

    # Add smoothing variables (if the models were trained with smoothing).
    #FIXME Assumes either all models were trained with smoothing or none were.
    if configs[0].exponential_smoothing > 0.0:
        smoothing = ExponentialSmoothing(configs[0].exponential_smoothing)

    # Restore the model variables.
    for i, config in enumerate(configs):
        with tf.variable_scope("model%d" % i) as scope:
            _ = model_loader.init_or_restore_variables(config, session,
                                                       ensemble_scope=scope)

    # Swap-in the smoothed versions of the variables.
    if configs[0].exponential_smoothing > 0.0:
        session.run(fetches=smoothing.swap_ops)

    # TODO Ensembling is currently only supported for RNNs, so if
    # TODO len(models) > 1 then check models are all rnn

    # Translate the source file.
    inference.translate_file(input_file=settings.input,
                             output_file=settings.output,
                             session=session,
                             models=models,
                             configs=configs,
                             beam_size=settings.beam_size,
                             nbest=settings.n_best,
                             minibatch_size=settings.minibatch_size,
                             maxibatch_size=settings.maxibatch_size,
                             normalization_alpha=settings.normalization_alpha)
示例#2
0
def main(settings):
    """
    Translates a source language file (or STDIN) into a target language file
    (or STDOUT).
    """
    # Create the TensorFlow session.
    g = tf.Graph()
    with g.as_default():
        tf_config = tf.compat.v1.ConfigProto()
        tf_config.allow_soft_placement = True
        session = tf.compat.v1.Session(config=tf_config)

        # Load config file for each model.
        configs = []
        for model in settings.models:
            config = load_config_from_json_file(model)
            setattr(config, 'reload', model)
            setattr(config, 'translation_maxlen', settings.translation_maxlen)
            configs.append(config)

        # Create the model graphs.
        logging.debug("Loading models\n")
        models = []
        for i, config in enumerate(configs):
            with tf.compat.v1.variable_scope("model%d" % i) as scope:
                if config.model_type == "transformer":
                    model = TransformerModel(
                        config, consts_config_str=settings.config_str)
                else:
                    model = rnn_model.RNNModel(config)
                model.sampling_utils = SamplingUtils(settings)
                models.append(model)
        # Add smoothing variables (if the models were trained with smoothing).
        # FIXME Assumes either all models were trained with smoothing or none were.
        if configs[0].exponential_smoothing > 0.0:
            smoothing = ExponentialSmoothing(configs[0].exponential_smoothing)

        # Restore the model variables.
        for i, config in enumerate(configs):
            with tf.compat.v1.variable_scope("model%d" % i) as scope:
                _ = model_loader.init_or_restore_variables(
                    config, session, ensemble_scope=scope)

        # Swap-in the smoothed versions of the variables.
        if configs[0].exponential_smoothing > 0.0:
            session.run(fetches=smoothing.swap_ops)

        max_translation_len = settings.translation_maxlen

        # Create a BeamSearchSampler / RandomSampler.
        if settings.translation_strategy == 'beam_search':
            sampler = BeamSearchSampler(models, configs, settings.beam_size)
        else:
            assert settings.translation_strategy == 'sampling'
            sampler = RandomSampler(models, configs, settings.beam_size)

        # Warn about the change from neg log probs to log probs for the RNN.
        if settings.n_best:
            model_types = [config.model_type for config in configs]
            if 'rnn' in model_types:
                logging.warn(
                    'n-best scores for RNN models have changed from '
                    'positive to negative (as of commit 95793196...). '
                    'If you are using the scores for reranking etc, then '
                    'you may need to update your scripts.')

        # Translate the source file.
        translate_utils.translate_file(
            input_file=settings.input,
            output_file=settings.output,
            session=session,
            sampler=sampler,
            config=configs[0],
            max_translation_len=max_translation_len,
            normalization_alpha=settings.normalization_alpha,
            consts_config_str=settings.config_str,
            nbest=settings.n_best,
            minibatch_size=settings.minibatch_size,
            maxibatch_size=settings.maxibatch_size)
示例#3
0
文件: train.py 项目: ykyaol7/nematus
def train(config, sess):
    assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \
    "MAP training requires a prior model file: Use command-line option --prior_model"

    # Construct the graph, with one model replica per GPU

    num_gpus = len(tf_utils.get_available_gpus())
    num_replicas = max(1, num_gpus)

    if config.loss_function == 'MRT':
        assert config.gradient_aggregation_steps == 1
        assert config.max_sentences_per_device == 0, "MRT mode does not support sentence-based split"
        if config.max_tokens_per_device != 0:
            assert (config.samplesN * config.maxlen <= config.max_tokens_per_device), "need to make sure candidates of a sentence could be " \
                                                                                      "feed into the model"
        else:
            assert num_replicas == 1, "MRT mode does not support sentence-based split"
            assert (config.samplesN * config.maxlen <= config.token_batch_size), "need to make sure candidates of a sentence could be " \
                                                                                      "feed into the model"



    logging.info('Building model...')
    replicas = []
    for i in range(num_replicas):
        device_type = "GPU" if num_gpus > 0 else "CPU"
        device_spec = tf.DeviceSpec(device_type=device_type, device_index=i)
        with tf.device(device_spec):
            with tf.variable_scope(tf.get_variable_scope(), reuse=(i>0)):
                if config.model_type == "transformer":
                    model = TransformerModel(config)
                else:
                    model = rnn_model.RNNModel(config)
                replicas.append(model)

    init = tf.zeros_initializer(dtype=tf.int32)
    global_step = tf.get_variable('time', [], initializer=init, trainable=False)

    if config.learning_schedule == "constant":
        schedule = learning_schedule.ConstantSchedule(config.learning_rate)
    elif config.learning_schedule == "transformer":
        schedule = learning_schedule.TransformerSchedule(
            global_step=global_step,
            dim=config.state_size,
            warmup_steps=config.warmup_steps)
    elif config.learning_schedule == "warmup-plateau-decay":
        schedule = learning_schedule.WarmupPlateauDecaySchedule(
            global_step=global_step,
            peak_learning_rate=config.learning_rate,
            warmup_steps=config.warmup_steps,
            plateau_steps=config.plateau_steps)
    else:
        logging.error('Learning schedule type is not valid: {}'.format(
            config.learning_schedule))
        sys.exit(1)

    if config.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=schedule.learning_rate,
                                           beta1=config.adam_beta1,
                                           beta2=config.adam_beta2,
                                           epsilon=config.adam_epsilon)
    else:
        logging.error('No valid optimizer defined: {}'.format(config.optimizer))
        sys.exit(1)

    if config.summary_freq:
        summary_dir = (config.summary_dir if config.summary_dir is not None
                       else os.path.abspath(os.path.dirname(config.saveto)))
        writer = tf.summary.FileWriter(summary_dir, sess.graph)
    else:
        writer = None

    updater = ModelUpdater(config, num_gpus, replicas, optimizer, global_step,
                           writer)

    if config.exponential_smoothing > 0.0:
        smoothing = ExponentialSmoothing(config.exponential_smoothing)

    saver, progress = model_loader.init_or_restore_variables(
        config, sess, train=True)

    global_step.load(progress.uidx, sess)

    if config.sample_freq:
        random_sampler = RandomSampler(
            models=[replicas[0]],
            configs=[config],
            beam_size=1)

    if config.beam_freq or config.valid_script is not None:
        beam_search_sampler = BeamSearchSampler(
            models=[replicas[0]],
            configs=[config],
            beam_size=config.beam_size)

    #save model options
    write_config_to_json_file(config, config.saveto)

    text_iterator, valid_text_iterator = load_data(config)
    _, _, num_to_source, num_to_target = util.load_dictionaries(config)
    total_loss = 0.
    n_sents, n_words = 0, 0
    last_time = time.time()
    logging.info("Initial uidx={}".format(progress.uidx))
    # set epoch = 1 if print per-token-probability
    if config.print_per_token_pro:
        config.max_epochs = progress.eidx+1
    for progress.eidx in range(progress.eidx, config.max_epochs):
        logging.info('Starting epoch {0}'.format(progress.eidx))
        for source_sents, target_sents in text_iterator:
            if len(source_sents[0][0]) != config.factors:
                logging.error('Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'.format(config.factors, len(source_sents[0][0])))
                sys.exit(1)
            x_in, x_mask_in, y_in, y_mask_in = util.prepare_data(
                source_sents, target_sents, config.factors, maxlen=None)
            if x_in is None:
                logging.info('Minibatch with zero sample under length {0}'.format(config.maxlen))
                continue
            write_summary_for_this_batch = config.summary_freq and ((progress.uidx % config.summary_freq == 0) or (config.finish_after and progress.uidx % config.finish_after == 0))
            (factors, seqLen, batch_size) = x_in.shape

            output = updater.update(
                sess, x_in, x_mask_in, y_in, y_mask_in, num_to_target,
                write_summary_for_this_batch)

            if config.print_per_token_pro == False:
                total_loss += output
            else:
                # write per-token probability into the file
                f = open(config.print_per_token_pro, 'a')
                for pro in output:
                    pro = str(pro) + '\n'
                    f.write(pro)
                f.close()

            n_sents += batch_size
            n_words += int(numpy.sum(y_mask_in))
            progress.uidx += 1

            # Update the smoothed version of the model variables.
            # To reduce the performance overhead, we only do this once every
            # N steps (the smoothing factor is adjusted accordingly).
            if config.exponential_smoothing > 0.0 and progress.uidx % smoothing.update_frequency == 0:
                sess.run(fetches=smoothing.update_ops)

            if config.disp_freq and progress.uidx % config.disp_freq == 0:
                duration = time.time() - last_time
                disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]')
                logging.info('{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'.format(disp_time, progress.eidx, progress.uidx, total_loss/n_words, n_words/duration, n_sents/duration))
                last_time = time.time()
                total_loss = 0.
                n_sents = 0
                n_words = 0

            if config.sample_freq and progress.uidx % config.sample_freq == 0:
                x_small = x_in[:, :, :10]
                x_mask_small = x_mask_in[:, :10]
                y_small = y_in[:, :10]
                samples = translate_utils.translate_batch(
                    sess, random_sampler, x_small, x_mask_small,
                    config.translation_maxlen, 0.0)
                assert len(samples) == len(x_small.T) == len(y_small.T), \
                    (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    sample = util.seq2words(ss[0][0], num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    logging.info('SAMPLE: {}'.format(sample))

            if config.beam_freq and progress.uidx % config.beam_freq == 0:
                x_small = x_in[:, :, :10]
                x_mask_small = x_mask_in[:, :10]
                y_small = y_in[:,:10]
                samples = translate_utils.translate_batch(
                    sess, beam_search_sampler, x_small, x_mask_small,
                    config.translation_maxlen, config.normalization_alpha)
                assert len(samples) == len(x_small.T) == len(y_small.T), \
                    (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    for i, (sample_seq, cost) in enumerate(ss):
                        sample = util.seq2words(sample_seq, num_to_target)
                        msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format(
                            i, sample, cost, len(sample), cost/len(sample))
                        logging.info(msg)

            if config.valid_freq and progress.uidx % config.valid_freq == 0:
                if config.exponential_smoothing > 0.0:
                    sess.run(fetches=smoothing.swap_ops)
                    valid_ce = validate(sess, replicas[0], config,
                                        valid_text_iterator)
                    sess.run(fetches=smoothing.swap_ops)
                else:
                    valid_ce = validate(sess, replicas[0], config,
                                        valid_text_iterator)
                if (len(progress.history_errs) == 0 or
                    valid_ce < min(progress.history_errs)):
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter = 0
                    save_non_checkpoint(sess, saver, config.saveto)
                    progress_path = '{0}.progress.json'.format(config.saveto)
                    progress.save_to_json(progress_path)
                else:
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter += 1
                    if progress.bad_counter > config.patience:
                        logging.info('Early Stop!')
                        progress.estop = True
                        break
                if config.valid_script is not None:
                    if config.exponential_smoothing > 0.0:
                        sess.run(fetches=smoothing.swap_ops)
                        score = validate_with_script(sess, beam_search_sampler)
                        sess.run(fetches=smoothing.swap_ops)
                    else:
                        score = validate_with_script(sess, beam_search_sampler)
                    need_to_save = (score is not None and
                        (len(progress.valid_script_scores) == 0 or
                         score > max(progress.valid_script_scores)))
                    if score is None:
                        score = 0.0  # ensure a valid value is written
                    progress.valid_script_scores.append(score)
                    if need_to_save:
                        progress.bad_counter = 0
                        save_path = config.saveto + ".best-valid-script"
                        save_non_checkpoint(sess, saver, save_path)
                        write_config_to_json_file(config, save_path)

                        progress_path = '{}.progress.json'.format(save_path)
                        progress.save_to_json(progress_path)

            if config.save_freq and progress.uidx % config.save_freq == 0:
                saver.save(sess, save_path=config.saveto, global_step=progress.uidx)
                write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx))

                progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx)
                progress.save_to_json(progress_path)

            if config.finish_after and progress.uidx % config.finish_after == 0:
                logging.info("Maximum number of updates reached")
                saver.save(sess, save_path=config.saveto, global_step=progress.uidx)
                write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx))

                progress.estop=True
                progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx)
                progress.save_to_json(progress_path)
                break
        if progress.estop:
            break
def calc_scores(source_file, target_file, scorer_settings, configs):
    """Calculates sentence pair scores using each of the specified models.

    By default (when scorer_settings.normalization_alpha is 0.0), the score
    is the sentence-level cross entropy, otherwise it's a normalized version.

    Args:
        source_file: file object for file containing source sentences.
        target_file: file object for file containing target sentences.
        scorer_settings: a ScorerSettings object.
        configs: a list of Namespace objects specifying the model configs.

    Returns:
        A list of lists of floats. The outer list contains one list for each
        model (in the same order given by configs). The inner list contains
        one score for each sentence pair.
    """
    scores = []
    for config in configs:
        g = tf.Graph()
        with g.as_default():
            tf_config = tf.ConfigProto()
            tf_config.allow_soft_placement = True
            with tf.Session(config=tf_config) as sess:

                logging.info('Building model...')

                # Create the model graph.
                if config.model_type == 'transformer':
                    model = transformer.Transformer(config)
                else:
                    model = rnn_model.RNNModel(config)

                # Add smoothing variables (if the model was trained with
                # smoothing).
                if config.exponential_smoothing > 0.0:
                    smoothing = ExponentialSmoothing(
                        config.exponential_smoothing)

                # Restore the model variables.
                saver = model_loader.init_or_restore_variables(config, sess)

                # Swap-in the smoothed versions of the variables (if present).
                if config.exponential_smoothing > 0.0:
                    sess.run(fetches=smoothing.swap_ops)

                text_iterator = TextIterator(
                    source=source_file.name,
                    target=target_file.name,
                    source_dicts=config.source_dicts,
                    target_dict=config.target_dict,
                    model_type=config.model_type,
                    batch_size=scorer_settings.minibatch_size,
                    maxlen=float('inf'),
                    source_vocab_sizes=config.source_vocab_sizes,
                    target_vocab_size=config.target_vocab_size,
                    use_factor=(config.factors > 1),
                    sort_by_length=False)

                ce_vals, _ = train.calc_cross_entropy_per_sentence(
                    sess,
                    model,
                    config,
                    text_iterator,
                    normalization_alpha=scorer_settings.normalization_alpha)

                scores.append(ce_vals)
    return scores
示例#5
0
def init_or_restore_variables(config, sess, ensemble_scope=None, train=False):
    # Add variables and ops for exponential smoothing, if enabled (unless
    # training, as they will already have been added).
    if not train and config.exponential_smoothing > 0.0:
        smoothing = ExponentialSmoothing(config.exponential_smoothing)

    # Construct a mapping between saved variable names and names in the current
    # scope. There are two reasons why names might be different:
    #
    #   1. This model is part of an ensemble, in which case a model-specific
    #       name scope will be active.
    #
    #   2. The saved model is from an old version of Nematus (before deep model
    #        support was added) and uses a different variable naming scheme
    #        for the GRUs.

    accum_regex = re.compile('^accum\d+$')

    def is_excluded_variable(name):
        # Exclude gradient accumulation variables.
        if accum_regex.match(name):
            return True
        if name == 'accumulated_loss':
            return True
        return False

    variables = slim.get_variables_to_restore()
    var_map = {}
    for v in variables:
        name = v.name.split(':')[0]
        if ensemble_scope == None:
            saved_name = name
        elif v.name.startswith(ensemble_scope.name + "/"):
            saved_name = name[len(ensemble_scope.name)+1:]
            # The ensemble scope is repeated for Adam variables. See
            # https://github.com/tensorflow/tensorflow/issues/8120
            if saved_name.startswith(ensemble_scope.name + "/"):
                saved_name = saved_name[len(ensemble_scope.name)+1:]
        else: # v belongs to a different model in the ensemble.
            continue
        if is_excluded_variable(saved_name):
            continue
        if config.model_version == 0.1:
            # Backwards compatibility with the old variable naming scheme.
            saved_name = _revert_variable_name(saved_name, 0.1)
        var_map[saved_name] = v
    saver = tf.train.Saver(var_map, max_to_keep=None)

    # compute reload model filename
    reload_filename = None
    if config.reload == 'latest_checkpoint':
        checkpoint_dir = os.path.dirname(config.saveto)
        reload_filename = tf.train.latest_checkpoint(checkpoint_dir)
        if reload_filename != None:
            if (os.path.basename(reload_filename).rsplit('-', 1)[0] !=
                os.path.basename(config.saveto)):
                logging.error("Mismatching model filename found in the same directory while reloading from the latest checkpoint")
                sys.exit(1)
            logging.info('Latest checkpoint found in directory ' + os.path.abspath(checkpoint_dir))
    elif config.reload != None:
        reload_filename = config.reload
    if (reload_filename == None) and (config.prior_model != None):
        logging.info('Initializing model parameters from prior')
        reload_filename = config.prior_model

    # initialize or reload training progress
    if train:
        progress = training_progress.TrainingProgress()
        progress.bad_counter = 0
        progress.uidx = 0
        progress.eidx = 0
        progress.estop = False
        progress.history_errs = []
        progress.valid_script_scores = []
        if reload_filename and config.reload_training_progress:
            path = reload_filename + '.progress.json'
            if os.path.exists(path):
                logging.info('Reloading training progress')
                progress.load_from_json(path)
                if (progress.estop == True or
                    progress.eidx > config.max_epochs or
                    progress.uidx >= config.finish_after):
                    logging.warning('Training is already complete. Disable reloading of training progress (--no_reload_training_progress) or remove or modify progress file (%s) to train anyway.' % path)
                    sys.exit(0)

    # load prior model
    if train and config.prior_model != None:
        load_prior(config, sess, saver)

    init_op = tf.global_variables_initializer()

    # initialize or restore model
    if reload_filename == None:
        logging.info('Initializing model parameters from scratch...')
        sess.run(init_op)
    else:
        logging.info('Loading model parameters from file ' + os.path.abspath(reload_filename))
        # Initialize all variables even though most will be overwritten by
        # the subsequent saver.restore() call. This is to allow for variables
        # that are not saved to the checkpoint. Currently that is just the
        # gradient accumulation variables, which are unusual in that they
        # persist across multiple sessions during training (and therefore need
        # to be variables) but are regularly reset to zero.
        sess.run(init_op)
        saver.restore(sess, os.path.abspath(reload_filename))
    logging.info('Done')

    # For everything apart from training, use the smoothed version of the
    # parameters (if available).
    if not train and config.exponential_smoothing > 0.0:
        logging.info('Using smoothed model parameters')
        sess.run(fetches=smoothing.swap_ops)

    if train:
        return saver, progress
    else:
        return saver