コード例 #1
0
ファイル: score.py プロジェクト: nd1511/nematus
def score_model(source_file, target_file, scorer_settings, options):
    scores = []
    for option in options:
        g = tf.Graph()
        with g.as_default():
            tf_config = tf.ConfigProto()
            tf_config.allow_soft_placement = True
            with tf.Session(config=tf_config) as sess:
                logging.info('Building model...')
                model = StandardModel(option)
                saver = nmt.init_or_restore_variables(option, sess)

                text_iterator = TextIterator(
                    source=source_file.name,
                    target=target_file.name,
                    source_dicts=option.source_dicts,
                    target_dict=option.target_dict,
                    batch_size=scorer_settings.b,
                    maxlen=float('inf'),
                    source_vocab_sizes=option.source_vocab_sizes,
                    target_vocab_size=option.target_vocab_size,
                    use_factor=(option.factors > 1),
                    sort_by_length=False)

                losses = nmt.calc_loss_per_sentence(
                    option,
                    sess,
                    text_iterator,
                    model,
                    normalization_alpha=scorer_settings.normalization_alpha)

                scores.append(losses)
    return scores
コード例 #2
0
def theano_to_tensorflow_model(in_path, out_path):
    saved_model = np.load(in_path)
    config = theano_to_tensorflow_config(in_path)
    th2tf = construct_parameter_map(config)

    with tf.Session() as sess:
        logging.info('Building model...')
        model = StandardModel(config)
        saver = nmt.init_or_restore_variables(config, sess)
        seen = set()
        assign_ops = []
        for key in saved_model.keys():
            # ignore adam parameters
            if key.startswith('adam'):
                continue
            tf_name = th2tf[key]
            if tf_name is not None:
                assert tf_name not in seen
                seen.add(tf_name)
                tf_var = tf.get_default_graph().get_tensor_by_name(tf_name)
                if (sess.run(tf.shape(tf_var)) !=
                        saved_model[key].shape).any():
                    print "mismatch for", tf_name, key, saved_model[
                        key].shape, sess.run(tf.shape(tf_var))
                assign_ops.append(tf.assign(tf_var, saved_model[key]))
            else:
                print "Not saving", key, "because no TF equivalent"
        sess.run(assign_ops)
        saver.save(sess, save_path=out_path)

        print "The following TF variables were not assigned (excluding Adam vars):"
        print "You should see only 'beta1_power', 'beta2_power' and 'time' variable listed"
        for tf_var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
            if tf_var.name not in seen and 'Adam' not in tf_var.name:
                print tf_var.name
コード例 #3
0
ファイル: translate.py プロジェクト: hfxunlp/nematus
    def _load_models(self, process_id, sess):
        """
        Loads models and returns them
        """
        logging.debug("Process '%s' - Loading models\n" % (process_id))

        import tensorflow as tf
        models = []
        for i, options in enumerate(self._options):
            with tf.variable_scope("model%d" % i) as scope:
                model = StandardModel(options)
                saver = init_or_restore_variables(options, sess,
                                                  ensemble_scope=scope)
                models.append(model)

        logging.info("NOTE: Length of translations is capped to {}".format(self._options[0].translation_maxlen))
        return models
コード例 #4
0
ファイル: nmt.py プロジェクト: aquaktus/nematus
def validate_helper(config, sess):
    logging.info('Building model...')
    model = StandardModel(options)
    saver = init_or_restore_variables(config, sess)
    valid_text_iterator = TextIterator(
        source=config.valid_source_dataset,
        target=config.valid_target_dataset,
        source_dicts=config.source_dicts,
        target_dict=config.target_dict,
        batch_size=config.valid_batch_size,
        maxlen=config.maxlen,
        source_vocab_sizes=config.source_vocab_sizes,
        target_vocab_size=config.target_vocab_size,
        shuffle_each_epoch=False,
        sort_by_length=False,  #TODO
        use_factor=(config.factors > 1),
        maxibatch_size=config.maxibatch_size)
    costs = validate(config, sess, valid_text_iterator, model)
    lines = open(config.valid_target_dataset).readlines()
    for cost, line in zip(costs, lines):
        logging.info("{0} {1}".format(cost, line.strip()))
コード例 #5
0
ファイル: nmt.py プロジェクト: aquaktus/nematus
def train(config, sess):
    assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \
    "MAP training requires a prior model file: Use command-line option --prior_model"

    logging.info('Building model...')
    model = StandardModel(config)

    if config.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=config.learning_rate)
    else:
        logging.error('No valid optimizer defined: {}'.format(
            config.optimizer))
        sys.exit(1)

    init = tf.zeros_initializer(dtype=tf.int32)
    global_step = tf.get_variable('time', [],
                                  initializer=init,
                                  trainable=False)

    if config.summaryFreq:
        summary_dir = (config.summary_dir if config.summary_dir is not None
                       else os.path.abspath(os.path.dirname(config.saveto)))
        writer = tf.summary.FileWriter(summary_dir, sess.graph)
    else:
        writer = None

    updater = ModelUpdater(config, model, optimizer, global_step, writer)

    saver, progress = init_or_restore_variables(config, sess, train=True)

    global_step.load(progress.uidx, sess)

    #save model options
    config_as_dict = OrderedDict(sorted(vars(config).items()))
    json.dump(config_as_dict, open('%s.json' % config.saveto, 'wb'), indent=2)

    text_iterator, valid_text_iterator = load_data(config)
    _, _, num_to_source, num_to_target = load_dictionaries(config)
    total_loss = 0.
    n_sents, n_words = 0, 0
    last_time = time.time()
    logging.info("Initial uidx={}".format(progress.uidx))
    for progress.eidx in xrange(progress.eidx, config.max_epochs):
        logging.info('Starting epoch {0}'.format(progress.eidx))
        for source_sents, target_sents in text_iterator:
            print("")
            print("")
            print("")
            print("########## Source Sents ############")
            print(source_sents)
            print("")
            print("")
            print("")
            print("########## Target Sents ############")
            print(target_sents)
            if len(source_sents[0][0]) != config.factors:
                logging.error(
                    'Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'
                    .format(config.factors, len(source_sents[0][0])))
                sys.exit(1)
            x_in, x_mask_in, y_in, y_mask_in = util.prepare_data(
                source_sents, target_sents, config.factors, maxlen=None)
            if x_in is None:
                logging.info(
                    'Minibatch with zero sample under length {0}'.format(
                        config.maxlen))
                continue
            write_summary_for_this_batch = config.summaryFreq and (
                (progress.uidx % config.summaryFreq == 0) or
                (config.finish_after
                 and progress.uidx % config.finish_after == 0))
            (factors, seqLen, batch_size) = x_in.shape

            loss = updater.update(sess, x_in, x_mask_in, y_in, y_mask_in,
                                  write_summary_for_this_batch)
            total_loss += loss
            n_sents += batch_size
            n_words += int(numpy.sum(y_mask_in))
            progress.uidx += 1

            if config.dispFreq and progress.uidx % config.dispFreq == 0:
                duration = time.time() - last_time
                disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]')
                logging.info(
                    '{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'
                    .format(disp_time, progress.eidx, progress.uidx,
                            total_loss / n_words, n_words / duration,
                            n_sents / duration))
                last_time = time.time()
                total_loss = 0.
                n_sents = 0
                n_words = 0

            if config.sampleFreq and progress.uidx % config.sampleFreq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :
                                                      10], x_mask_in[:, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = model.sample(sess, x_small, x_mask_small)
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    sample = util.seq2words(ss, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    logging.info('SAMPLE: {}'.format(sample))

            if config.beamFreq and progress.uidx % config.beamFreq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :
                                                      10], x_mask_in[:, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = model.beam_search(sess, x_small, x_mask_small,
                                            config.beam_size)
                # samples is a list with shape batch x beam x len
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    for i, (sample_seq, cost) in enumerate(ss):
                        sample = util.seq2words(sample_seq, num_to_target)
                        msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format(
                            i, sample, cost, len(sample), cost / len(sample))
                        logging.info(msg)

            if config.validFreq and progress.uidx % config.validFreq == 0:
                costs = validate(config, sess, valid_text_iterator, model)
                # validation loss is mean of normalized sentence log probs
                valid_loss = sum(costs) / len(costs)
                if (len(progress.history_errs) == 0
                        or valid_loss < min(progress.history_errs)):
                    progress.history_errs.append(valid_loss)
                    progress.bad_counter = 0
                    saver.save(sess, save_path=config.saveto)
                    progress_path = '{0}.progress.json'.format(config.saveto)
                    progress.save_to_json(progress_path)
                else:
                    progress.history_errs.append(valid_loss)
                    progress.bad_counter += 1
                    if progress.bad_counter > config.patience:
                        logging.info('Early Stop!')
                        progress.estop = True
                        break
                if config.valid_script is not None:
                    score = validate_with_script(sess, model, config,
                                                 valid_text_iterator)
                    need_to_save = (
                        score is not None
                        and (len(progress.valid_script_scores) == 0
                             or score > max(progress.valid_script_scores)))
                    if score is None:
                        score = 0.0  # ensure a valid value is written
                    progress.valid_script_scores.append(score)
                    if need_to_save:
                        save_path = config.saveto + ".best-valid-script"
                        saver.save(sess, save_path=save_path)
                        progress_path = '{}.progress.json'.format(save_path)
                        progress.save_to_json(progress_path)

            if config.saveFreq and progress.uidx % config.saveFreq == 0:
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)

            if config.finish_after and progress.uidx % config.finish_after == 0:
                logging.info("Maximum number of updates reached")
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                progress.estop = True
                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)
                break
        if progress.estop:
            break
コード例 #6
0
ファイル: nmt.py プロジェクト: aquaktus/nematus
    config.source_dicts = config.dictionaries[:-1]
    config.source_vocab_sizes = vocab_sizes[:-1]
    config.target_dict = config.dictionaries[-1]
    config.target_vocab_size = vocab_sizes[-1]

    # set the model version
    config.model_version = 0.2
    config.theano_compat = False

    return config


if __name__ == "__main__":

    # set up logging
    level = logging.INFO
    logging.basicConfig(level=level, format='%(levelname)s: %(message)s')

    config = parse_args()
    logging.info(config)
    with tf.Session() as sess:
        if config.translate_valid:
            logging.info('Building model...')
            model = StandardModel(config)
            saver = init_or_restore_variables(config, sess)
            translate(sess, model, config)
        elif config.run_validation:
            validate_helper(config, sess)
        else:
            train(config, sess)