コード例 #1
0
ファイル: transformer.py プロジェクト: bariluz93/nematus
    def __init__(self, config, consts_config_str, embedding_layer, training,
                 name):
        # Set attributes
        self.consts_config_str = consts_config_str
        self.config = config
        self.embedding_layer = embedding_layer
        self.training = training
        self.name = name

        # Track layers
        self.encoder_stack = dict()
        self.is_final_layer = False

        # Create nodes
        self._build_graph()
        _, _, self.num_to_source, self.num_to_target = util.load_dictionaries(
            config)
        a = 1
        self.USE_DEBIASED, _, self.COLLECT_EMBEDDING_TABLE, _ = get_basic_configurations(
            self.consts_config_str)
        if self.USE_DEBIASED:
            debiasManager = DebiasManager.get_manager_instance(
                self.consts_config_str)
            self.embedding_matrix = tf.cast(tf.convert_to_tensor(
                debiasManager.debias_embedding_table()),
                                            dtype=tf.float32)
        else:
            self.embedding_matrix = None
コード例 #2
0
ファイル: server_translator.py プロジェクト: tuyu95/nematus
    def _load_model_options(self):
        """
        Loads config options for each model.
        """

        self._options = []
        for model in self._models:
            config = load_config_from_json_file(model)
            setattr(config, 'reload', model)
            self._options.append(config)

        _, _, _, self._num_to_target = util.load_dictionaries(self._options[0])
コード例 #3
0
    def _load_model_options(self):
        """
        Loads config options for each model.
        """

        self._options = []
        for model in self._models:
            config = util.load_config(model)
            # backward compatibility
            fill_options(config)
            config['reload'] = model
            self._options.append(argparse.Namespace(**config))

        _, _, _, self._num_to_target = util.load_dictionaries(self._options[0])
コード例 #4
0
def translate_file(input_file,
                   output_file,
                   session,
                   models,
                   configs,
                   beam_size=12,
                   nbest=False,
                   minibatch_size=80,
                   maxibatch_size=20,
                   normalization_alpha=1.0):
    """Translates a source file using a translation model (or ensemble).

    Args:
        input_file: file object from which source sentences will be read.
        output_file: file object to which translations will be written.
        session: TensorFlow session.
        models: list of model objects to use for beam search.
        configs: model configs.
        beam_size: beam width.
        nbest: if True, produce n-best output with scores; otherwise 1-best.
        minibatch_size: minibatch size in sentences.
        maxibatch_size: number of minibatches to read and sort, pre-translation.
        normalization_alpha: alpha parameter for length normalization.
    """
    def translate_maxibatch(maxibatch, model_set, num_to_target,
                            num_prev_translated):
        """Translates an individual maxibatch.

        Args:
            maxibatch: a list of sentences.
            model_set: an InferenceModelSet object.
            num_to_target: dictionary mapping target vocabulary IDs to strings.
            num_prev_translated: the number of previously translated sentences.
        """

        # Sort the maxibatch by length and split into minibatches.
        try:
            minibatches, idxs = util.read_all_lines(configs[0], maxibatch,
                                                    minibatch_size)
        except exception.Error as x:
            logging.error(x.msg)
            sys.exit(1)

        # Translate the minibatches and store the resulting beam (i.e.
        # translations and scores) for each sentence.
        beams = []
        for x in minibatches:
            y_dummy = numpy.zeros(shape=(len(x), 1))
            x, x_mask, _, _ = util.prepare_data(x,
                                                y_dummy,
                                                configs[0].factors,
                                                maxlen=None)
            sample = model_set.decode(session=session,
                                      x=x,
                                      x_mask=x_mask,
                                      beam_size=beam_size,
                                      normalization_alpha=normalization_alpha)
            beams.extend(sample)
            num_translated = num_prev_translated + len(beams)
            logging.info('Translated {} sents'.format(num_translated))

        # Put beams into the same order as the input maxibatch.
        tmp = numpy.array(beams, dtype=numpy.object)
        ordered_beams = tmp[idxs.argsort()]

        # Write the translations to the output file.
        for i, beam in enumerate(ordered_beams):
            if nbest:
                num = num_prev_translated + i
                for sent, cost in beam:
                    translation = util.seq2words(sent, num_to_target)
                    line = "{} ||| {} ||| {}\n".format(num, translation,
                                                       str(cost))
                    output_file.write(line)
            else:
                best_hypo, cost = beam[0]
                line = util.seq2words(best_hypo, num_to_target) + '\n'
                output_file.write(line)

    _, _, _, num_to_target = util.load_dictionaries(configs[0])
    model_set = InferenceModelSet(models, configs)

    logging.info("NOTE: Length of translations is capped to {}".format(
        configs[0].translation_maxlen))

    start_time = time.time()

    num_translated = 0
    maxibatch = []
    while True:
        line = input_file.readline()
        if line == "":
            if len(maxibatch) > 0:
                translate_maxibatch(maxibatch, model_set, num_to_target,
                                    num_translated)
                num_translated += len(maxibatch)
            break
        maxibatch.append(line)
        if len(maxibatch) == (maxibatch_size * minibatch_size):
            translate_maxibatch(maxibatch, model_set, num_to_target,
                                num_translated)
            num_translated += len(maxibatch)
            maxibatch = []

    duration = time.time() - start_time
    logging.info('Translated {} sents in {} sec. Speed {} sents/sec'.format(
        num_translated, duration, num_translated / duration))
コード例 #5
0
ファイル: train.py プロジェクト: ykyaol7/nematus
def train(config, sess):
    assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \
    "MAP training requires a prior model file: Use command-line option --prior_model"

    # Construct the graph, with one model replica per GPU

    num_gpus = len(tf_utils.get_available_gpus())
    num_replicas = max(1, num_gpus)

    if config.loss_function == 'MRT':
        assert config.gradient_aggregation_steps == 1
        assert config.max_sentences_per_device == 0, "MRT mode does not support sentence-based split"
        if config.max_tokens_per_device != 0:
            assert (config.samplesN * config.maxlen <= config.max_tokens_per_device), "need to make sure candidates of a sentence could be " \
                                                                                      "feed into the model"
        else:
            assert num_replicas == 1, "MRT mode does not support sentence-based split"
            assert (config.samplesN * config.maxlen <= config.token_batch_size), "need to make sure candidates of a sentence could be " \
                                                                                      "feed into the model"



    logging.info('Building model...')
    replicas = []
    for i in range(num_replicas):
        device_type = "GPU" if num_gpus > 0 else "CPU"
        device_spec = tf.DeviceSpec(device_type=device_type, device_index=i)
        with tf.device(device_spec):
            with tf.variable_scope(tf.get_variable_scope(), reuse=(i>0)):
                if config.model_type == "transformer":
                    model = TransformerModel(config)
                else:
                    model = rnn_model.RNNModel(config)
                replicas.append(model)

    init = tf.zeros_initializer(dtype=tf.int32)
    global_step = tf.get_variable('time', [], initializer=init, trainable=False)

    if config.learning_schedule == "constant":
        schedule = learning_schedule.ConstantSchedule(config.learning_rate)
    elif config.learning_schedule == "transformer":
        schedule = learning_schedule.TransformerSchedule(
            global_step=global_step,
            dim=config.state_size,
            warmup_steps=config.warmup_steps)
    elif config.learning_schedule == "warmup-plateau-decay":
        schedule = learning_schedule.WarmupPlateauDecaySchedule(
            global_step=global_step,
            peak_learning_rate=config.learning_rate,
            warmup_steps=config.warmup_steps,
            plateau_steps=config.plateau_steps)
    else:
        logging.error('Learning schedule type is not valid: {}'.format(
            config.learning_schedule))
        sys.exit(1)

    if config.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=schedule.learning_rate,
                                           beta1=config.adam_beta1,
                                           beta2=config.adam_beta2,
                                           epsilon=config.adam_epsilon)
    else:
        logging.error('No valid optimizer defined: {}'.format(config.optimizer))
        sys.exit(1)

    if config.summary_freq:
        summary_dir = (config.summary_dir if config.summary_dir is not None
                       else os.path.abspath(os.path.dirname(config.saveto)))
        writer = tf.summary.FileWriter(summary_dir, sess.graph)
    else:
        writer = None

    updater = ModelUpdater(config, num_gpus, replicas, optimizer, global_step,
                           writer)

    if config.exponential_smoothing > 0.0:
        smoothing = ExponentialSmoothing(config.exponential_smoothing)

    saver, progress = model_loader.init_or_restore_variables(
        config, sess, train=True)

    global_step.load(progress.uidx, sess)

    if config.sample_freq:
        random_sampler = RandomSampler(
            models=[replicas[0]],
            configs=[config],
            beam_size=1)

    if config.beam_freq or config.valid_script is not None:
        beam_search_sampler = BeamSearchSampler(
            models=[replicas[0]],
            configs=[config],
            beam_size=config.beam_size)

    #save model options
    write_config_to_json_file(config, config.saveto)

    text_iterator, valid_text_iterator = load_data(config)
    _, _, num_to_source, num_to_target = util.load_dictionaries(config)
    total_loss = 0.
    n_sents, n_words = 0, 0
    last_time = time.time()
    logging.info("Initial uidx={}".format(progress.uidx))
    # set epoch = 1 if print per-token-probability
    if config.print_per_token_pro:
        config.max_epochs = progress.eidx+1
    for progress.eidx in range(progress.eidx, config.max_epochs):
        logging.info('Starting epoch {0}'.format(progress.eidx))
        for source_sents, target_sents in text_iterator:
            if len(source_sents[0][0]) != config.factors:
                logging.error('Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'.format(config.factors, len(source_sents[0][0])))
                sys.exit(1)
            x_in, x_mask_in, y_in, y_mask_in = util.prepare_data(
                source_sents, target_sents, config.factors, maxlen=None)
            if x_in is None:
                logging.info('Minibatch with zero sample under length {0}'.format(config.maxlen))
                continue
            write_summary_for_this_batch = config.summary_freq and ((progress.uidx % config.summary_freq == 0) or (config.finish_after and progress.uidx % config.finish_after == 0))
            (factors, seqLen, batch_size) = x_in.shape

            output = updater.update(
                sess, x_in, x_mask_in, y_in, y_mask_in, num_to_target,
                write_summary_for_this_batch)

            if config.print_per_token_pro == False:
                total_loss += output
            else:
                # write per-token probability into the file
                f = open(config.print_per_token_pro, 'a')
                for pro in output:
                    pro = str(pro) + '\n'
                    f.write(pro)
                f.close()

            n_sents += batch_size
            n_words += int(numpy.sum(y_mask_in))
            progress.uidx += 1

            # Update the smoothed version of the model variables.
            # To reduce the performance overhead, we only do this once every
            # N steps (the smoothing factor is adjusted accordingly).
            if config.exponential_smoothing > 0.0 and progress.uidx % smoothing.update_frequency == 0:
                sess.run(fetches=smoothing.update_ops)

            if config.disp_freq and progress.uidx % config.disp_freq == 0:
                duration = time.time() - last_time
                disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]')
                logging.info('{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'.format(disp_time, progress.eidx, progress.uidx, total_loss/n_words, n_words/duration, n_sents/duration))
                last_time = time.time()
                total_loss = 0.
                n_sents = 0
                n_words = 0

            if config.sample_freq and progress.uidx % config.sample_freq == 0:
                x_small = x_in[:, :, :10]
                x_mask_small = x_mask_in[:, :10]
                y_small = y_in[:, :10]
                samples = translate_utils.translate_batch(
                    sess, random_sampler, x_small, x_mask_small,
                    config.translation_maxlen, 0.0)
                assert len(samples) == len(x_small.T) == len(y_small.T), \
                    (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    sample = util.seq2words(ss[0][0], num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    logging.info('SAMPLE: {}'.format(sample))

            if config.beam_freq and progress.uidx % config.beam_freq == 0:
                x_small = x_in[:, :, :10]
                x_mask_small = x_mask_in[:, :10]
                y_small = y_in[:,:10]
                samples = translate_utils.translate_batch(
                    sess, beam_search_sampler, x_small, x_mask_small,
                    config.translation_maxlen, config.normalization_alpha)
                assert len(samples) == len(x_small.T) == len(y_small.T), \
                    (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    for i, (sample_seq, cost) in enumerate(ss):
                        sample = util.seq2words(sample_seq, num_to_target)
                        msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format(
                            i, sample, cost, len(sample), cost/len(sample))
                        logging.info(msg)

            if config.valid_freq and progress.uidx % config.valid_freq == 0:
                if config.exponential_smoothing > 0.0:
                    sess.run(fetches=smoothing.swap_ops)
                    valid_ce = validate(sess, replicas[0], config,
                                        valid_text_iterator)
                    sess.run(fetches=smoothing.swap_ops)
                else:
                    valid_ce = validate(sess, replicas[0], config,
                                        valid_text_iterator)
                if (len(progress.history_errs) == 0 or
                    valid_ce < min(progress.history_errs)):
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter = 0
                    save_non_checkpoint(sess, saver, config.saveto)
                    progress_path = '{0}.progress.json'.format(config.saveto)
                    progress.save_to_json(progress_path)
                else:
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter += 1
                    if progress.bad_counter > config.patience:
                        logging.info('Early Stop!')
                        progress.estop = True
                        break
                if config.valid_script is not None:
                    if config.exponential_smoothing > 0.0:
                        sess.run(fetches=smoothing.swap_ops)
                        score = validate_with_script(sess, beam_search_sampler)
                        sess.run(fetches=smoothing.swap_ops)
                    else:
                        score = validate_with_script(sess, beam_search_sampler)
                    need_to_save = (score is not None and
                        (len(progress.valid_script_scores) == 0 or
                         score > max(progress.valid_script_scores)))
                    if score is None:
                        score = 0.0  # ensure a valid value is written
                    progress.valid_script_scores.append(score)
                    if need_to_save:
                        progress.bad_counter = 0
                        save_path = config.saveto + ".best-valid-script"
                        save_non_checkpoint(sess, saver, save_path)
                        write_config_to_json_file(config, save_path)

                        progress_path = '{}.progress.json'.format(save_path)
                        progress.save_to_json(progress_path)

            if config.save_freq and progress.uidx % config.save_freq == 0:
                saver.save(sess, save_path=config.saveto, global_step=progress.uidx)
                write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx))

                progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx)
                progress.save_to_json(progress_path)

            if config.finish_after and progress.uidx % config.finish_after == 0:
                logging.info("Maximum number of updates reached")
                saver.save(sess, save_path=config.saveto, global_step=progress.uidx)
                write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx))

                progress.estop=True
                progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx)
                progress.save_to_json(progress_path)
                break
        if progress.estop:
            break
コード例 #6
0
def translate_file(input_file,
                   output_file,
                   session,
                   sampler,
                   config,
                   max_translation_len,
                   normalization_alpha,
                   nbest=False,
                   minibatch_size=80,
                   maxibatch_size=20,
                   strategy='biased_beam_search',
                   mask='0',
                   extended_translations=None):
    """Translates a source file using a RandomSampler or BeamSearchSampler.

    Args:
        input_file: file object from which source sentences will be read.
        output_file: file object to which translations will be written.
        session: TensorFlow session.
        sampler: BeamSearchSampler or RandomSampler object.
        config: model config.
        max_translation_len: integer specifying maximum translation length.
        normalization_alpha: float specifying alpha parameter for length
            normalization.
        nbest: if True, produce n-best output with scores; otherwise 1-best.
        minibatch_size: minibatch size in sentences.
        maxibatch_size: number of minibatches to read and sort, pre-translation.
    """
    def translate_maxibatch(maxibatch,
                            num_to_target,
                            num_prev_translated,
                            mask=0):
        """Translates an individual maxibatch.

        Args:
            maxibatch: a list of sentences.
            num_to_target: dictionary mapping target vocabulary IDs to strings.
            num_prev_translated: the number of previously translated sentences.
        """

        # Sort the maxibatch by length and split into minibatches.
        try:
            minibatches, idxs = util.read_all_lines(config, maxibatch,
                                                    minibatch_size)
        except exception.Error as x:
            logging.error(x.msg)
            sys.exit(1)

        # Translate the minibatches and store the resulting beam (i.e.
        # translations and scores) for each sentence.
        beams = []
        for x in minibatches:
            y_dummy = numpy.zeros(shape=(len(x), 1))
            x, x_mask, _, _ = util.prepare_data(x,
                                                y_dummy,
                                                config.factors,
                                                maxlen=None)
            sample = translate_batch(session, sampler, x, x_mask,
                                     max_translation_len, normalization_alpha)
            beams.extend(sample)
            num_translated = num_prev_translated + len(beams)
            logging.info('Translated {} sents'.format(num_translated))

        # Put beams into the same order as the input maxibatch.
        tmp = numpy.array(beams, dtype=numpy.object)
        ordered_beams = tmp[idxs.argsort()]

        # Write the translations to the output file.
        for i, beam in enumerate(ordered_beams):
            if nbest:
                num = num_prev_translated + i
                for sent, cost in beam:
                    translation = util.seq2words(sent, num_to_target)
                    line = "{} ||| {} ||| {}\n".format(num, translation,
                                                       str(cost))
                    output_file.write(line)
            else:
                best_hypo, cost = beam[0]
                # print(best_hypo)
                eos_idx = list(best_hypo).index(0) if 0 in best_hypo else len(
                    best_hypo)
                best_hypo = best_hypo[:eos_idx]
                best_hypo = best_hypo[:len(best_hypo) -
                                      mask] if len(best_hypo) > mask else []
                best_hypo = list(best_hypo) + [0]
                # print(best_hypo)
                line = util.seq2words(best_hypo, num_to_target) + '\n'
                output_file.write(line)

    def translate_maxibatch_slt(maxibatch,
                                num_to_target,
                                target_to_num,
                                num_prev_translated,
                                last_translation,
                                mask='0',
                                trans=[],
                                last_line=''):
        """Translates an individual maxibatch.

        Args:
            maxibatch: a list of sentences.
            num_to_target: dictionary mapping target vocabulary IDs to strings.
            num_prev_translated: the number of previously translated sentences.
        """

        # Sort the maxibatch by length and split into minibatches.
        try:
            minibatches, idxs = util.read_all_lines(config, maxibatch,
                                                    minibatch_size)
        except exception.Error as x:
            logging.error(x.msg)
            sys.exit(1)

        # Translate the minibatches and store the resulting beam (i.e.
        # translations and scores) for each sentence.
        beams = []
        for x in minibatches:
            y_dummy = numpy.zeros(shape=(len(x), 1))
            x, x_mask, _, _ = util.prepare_data(x,
                                                y_dummy,
                                                config.factors,
                                                maxlen=None)
            sample = translate_batch_slt(session,
                                         sampler,
                                         x,
                                         x_mask,
                                         max_translation_len,
                                         normalization_alpha,
                                         last_translation=last_translation)
            # print(type(sample))

            beams.extend(sample)
            num_translated = num_prev_translated + len(beams)
            logging.info('Translated {} sents'.format(num_translated))

        # Put beams into the same order as the input maxibatch.
        tmp = numpy.array(beams, dtype=numpy.object)
        ordered_beams = tmp[idxs.argsort()]
        # Write the translations to the output file.
        for i, beam in enumerate(ordered_beams):
            if nbest:
                num = num_prev_translated + i
                for sent, cost in beam:
                    translation = util.seq2words(sent, num_to_target)
                    line = "{} ||| {} ||| {}\n".format(num, translation,
                                                       str(cost))
                    output_file.write(line)
            else:
                best_hypo, cost = beam[0]
                # print(best_hypo)
                eos_idx = list(best_hypo).index(0) if 0 in best_hypo else len(
                    best_hypo)
                best_hypo = best_hypo[:eos_idx]
                if mask != 'dynamic':
                    mask = int(mask)
                    best_hypo = best_hypo[:len(best_hypo) - mask] if len(
                        best_hypo) > mask else []
                    # print(best_hypo)
                    best_hypo = list(best_hypo) + [0]
                    line = util.seq2words(best_hypo, num_to_target) + '\n'
                else:
                    line = util.seq2words(best_hypo, num_to_target, join=False)
                    items = util.LCP(trans, line)
                    # line = ' '.join(items) + '\n'
                    best_hypo = [target_to_num[item] for item in items] + [0]
                    line = util.seq2words(best_hypo, num_to_target) + '\n'
                    # print(line)
                    # print(best_hypo, last_translation)
                    # if best_hypo[:-1] == last_translation[1:len(best_hypo)]:
                    # line = last_line
                    # best_hypo = last_translation[1:]+[0]
                    # line = util.seq2words(last_translation[1:], num_to_target)+'\n'
                    # print(line)
                    # print(util.seq2words(last_translation[1:], num_to_target))
                output_file.write(line)
                # print(line)
                return list([1]) + list(best_hypo[:-1]), line

    prefix_lines = input_file.readlines()
    input_file.seek(0, 0)
    prefix_mask = []
    last_prefix = ''
    for prefix_line in prefix_lines:
        # print(last_prefix, prefix_line)
        if not last_prefix:
            last_prefix = prefix_line
            continue
        if last_prefix[:-1] in prefix_line:
            prefix_mask.append(0)
        else:
            prefix_mask.append(1)
        last_prefix = prefix_line
    prefix_mask.append(1)

    _, target_to_num, _, num_to_target = util.load_dictionaries(config)

    logging.info("NOTE: Length of translations is capped to {}".format(
        max_translation_len))

    start_time = time.time()

    num_translated = 0
    maxibatch = []

    extension_files = extended_translations.split('+')
    extensions = []
    for extension_file in extension_files:
        extension = open(extension_file, 'r').readlines()
        extensions.append(extension)

    # extensions = open(extended_translations, 'r').readlines()
    # print(extensions[0:5])
    # return

    # a very naive implementation for biased beam search which do not allow minibatch translation
    if strategy == 'biased_beam_search':
        assert minibatch_size == 1
        assert maxibatch_size == 1
        last_translation = [1]
        last_line = ""
        sent_id = 0
        while True:
            line = input_file.readline()
            if line == "":
                if len(maxibatch) > 0:
                    if not last_line or len(last_line) > len(
                            line) or last_line[:] != line[:len(last_line)]:
                        last_translation = [1]
                    translate_maxibatch_slt(maxibatch, num_to_target,
                                            num_translated, last_translation)
                    num_translated += len(maxibatch)
                break
            maxibatch.append(line)
            if len(maxibatch) == (maxibatch_size * minibatch_size):
                if not last_line or len(last_line) > len(
                        line) or last_line[:-1] != line[:len(last_line) - 1]:
                    # print(len(last_line) > len(line))
                    # print(last_line[:] != line[:len(last_line)])
                    # print(last_line)
                    # print(line)
                    last_translation = [1]
                    last_textline = '\n'
                # if prefix_mask[sent_id]:
                #     print(sent_id)
                #     print('line:'+line)
                #     print('prefix:'+prefix_lines[sent_id])
                if not prefix_mask[sent_id]:
                    # print('mask works!')
                    last_translation, last_textline = translate_maxibatch_slt(
                        maxibatch,
                        num_to_target,
                        target_to_num,
                        num_translated,
                        last_translation,
                        mask=mask,
                        trans=[extension[sent_id] for extension in extensions],
                        last_line=last_textline)
                else:
                    last_translation, _ = translate_maxibatch_slt(
                        maxibatch,
                        num_to_target,
                        target_to_num,
                        num_translated,
                        last_translation,
                        mask='0')
                # print(last_translation)
                num_translated += len(maxibatch)
                maxibatch = []
            last_line = line
            sent_id += 1
    else:
        sent_id = 0
        while True:
            line = input_file.readline()
            if line == "":
                if len(maxibatch) > 0:
                    translate_maxibatch(maxibatch, num_to_target,
                                        num_translated)
                    num_translated += len(maxibatch)
                break
            maxibatch.append(line)
            if len(maxibatch) == (maxibatch_size * minibatch_size):
                if not prefix_mask[sent_id]:
                    translate_maxibatch(maxibatch,
                                        num_to_target,
                                        target_to_num,
                                        num_translated,
                                        mask=mask)
                else:
                    translate_maxibatch(maxibatch,
                                        num_to_target,
                                        target_to_num,
                                        num_translated,
                                        mask='0')
                # translate_maxibatch(maxibatch, num_to_target, num_translated)
                num_translated += len(maxibatch)
                maxibatch = []
            sent_id += 1

    duration = time.time() - start_time
    logging.info('Translated {} sents in {} sec. Speed {} sents/sec'.format(
        num_translated, duration, num_translated / duration))

    output_file.flush()
コード例 #7
0
def train(config, sess):
    assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \
    "MAP training requires a prior model file: Use command-line option --prior_model"

    # Construct the graph, with one model replica per GPU

    num_gpus = len(util.get_available_gpus())
    num_replicas = max(1, num_gpus)

    logging.info('Building model...')
    replicas = []
    for i in range(num_replicas):
        device_type = "GPU" if num_gpus > 0 else "CPU"
        device_spec = tf.DeviceSpec(device_type=device_type, device_index=i)
        with tf.device(device_spec):
            with tf.variable_scope(tf.get_variable_scope(), reuse=(i > 0)):
                if config.model_type == "transformer":
                    model = TransformerModel(config)
                else:
                    model = rnn_model.RNNModel(config)
                replicas.append(model)

    init = tf.zeros_initializer(dtype=tf.int32)
    global_step = tf.get_variable('time', [],
                                  initializer=init,
                                  trainable=False)

    if config.learning_schedule == "constant":
        schedule = ConstantSchedule(config.learning_rate)
    elif config.learning_schedule == "transformer":
        schedule = TransformerSchedule(global_step=global_step,
                                       dim=config.state_size,
                                       warmup_steps=config.warmup_steps)
    else:
        logging.error('Learning schedule type is not valid: {}'.format(
            config.learning_schedule))
        sys.exit(1)

    if config.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(
            learning_rate=schedule.learning_rate,
            beta1=config.adam_beta1,
            beta2=config.adam_beta2,
            epsilon=config.adam_epsilon)
    else:
        logging.error('No valid optimizer defined: {}'.format(
            config.optimizer))
        sys.exit(1)

    if config.summary_freq:
        summary_dir = (config.summary_dir if config.summary_dir is not None
                       else os.path.abspath(os.path.dirname(config.saveto)))
        writer = tf.summary.FileWriter(summary_dir, sess.graph)
    else:
        writer = None

    updater = ModelUpdater(config, num_gpus, replicas, optimizer, global_step,
                           writer)

    saver, progress = model_loader.init_or_restore_variables(config,
                                                             sess,
                                                             train=True)

    global_step.load(progress.uidx, sess)

    # Use an InferenceModelSet to abstract over model types for sampling and
    # beam search. Multi-GPU sampling and beam search are not currently
    # supported, so we just use the first replica.
    model_set = inference.InferenceModelSet([replicas[0]], [config])

    #save model options
    write_config_to_json_file(config, config.saveto)

    text_iterator, valid_text_iterator = load_data(config)
    _, _, num_to_source, num_to_target = util.load_dictionaries(config)
    total_loss = 0.
    n_sents, n_words = 0, 0
    last_time = time.time()
    logging.info("Initial uidx={}".format(progress.uidx))
    for progress.eidx in range(progress.eidx, config.max_epochs):
        logging.info('Starting epoch {0}'.format(progress.eidx))
        for source_sents, target_sents in text_iterator:
            if len(source_sents[0][0]) != config.factors:
                logging.error(
                    'Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'
                    .format(config.factors, len(source_sents[0][0])))
                sys.exit(1)
            x_in, x_mask_in, y_in, y_mask_in = util.prepare_data(
                source_sents, target_sents, config.factors, maxlen=None)
            if x_in is None:
                logging.info(
                    'Minibatch with zero sample under length {0}'.format(
                        config.maxlen))
                continue
            write_summary_for_this_batch = config.summary_freq and (
                (progress.uidx % config.summary_freq == 0) or
                (config.finish_after
                 and progress.uidx % config.finish_after == 0))
            (factors, seqLen, batch_size) = x_in.shape

            loss = updater.update(sess, x_in, x_mask_in, y_in, y_mask_in,
                                  write_summary_for_this_batch)
            total_loss += loss
            n_sents += batch_size
            n_words += int(numpy.sum(y_mask_in))
            progress.uidx += 1

            if config.disp_freq and progress.uidx % config.disp_freq == 0:
                duration = time.time() - last_time
                disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]')
                logging.info(
                    '{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'
                    .format(disp_time, progress.eidx, progress.uidx,
                            total_loss / n_words, n_words / duration,
                            n_sents / duration))
                last_time = time.time()
                total_loss = 0.
                n_sents = 0
                n_words = 0

            if config.sample_freq and progress.uidx % config.sample_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :
                                                      10], x_mask_in[:, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = model_set.sample(sess, x_small, x_mask_small)
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    sample = util.seq2words(ss, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    logging.info('SAMPLE: {}'.format(sample))

            if config.beam_freq and progress.uidx % config.beam_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :
                                                      10], x_mask_in[:, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = model_set.beam_search(
                    sess,
                    x_small,
                    x_mask_small,
                    config.beam_size,
                    normalization_alpha=config.normalization_alpha)
                # samples is a list with shape batch x beam x len
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    for i, (sample_seq, cost) in enumerate(ss):
                        sample = util.seq2words(sample_seq, num_to_target)
                        msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format(
                            i, sample, cost, len(sample), cost / len(sample))
                        logging.info(msg)

            if config.valid_freq and progress.uidx % config.valid_freq == 0:
                valid_ce = validate(sess, replicas[0], config,
                                    valid_text_iterator)
                if (len(progress.history_errs) == 0
                        or valid_ce < min(progress.history_errs)):
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter = 0
                    save_non_checkpoint(sess, saver, config.saveto)
                    progress_path = '{0}.progress.json'.format(config.saveto)
                    progress.save_to_json(progress_path)
                else:
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter += 1
                    if progress.bad_counter > config.patience:
                        logging.info('Early Stop!')
                        progress.estop = True
                        break
                if config.valid_script is not None:
                    score = validate_with_script(sess, replicas[0], config)
                    need_to_save = (
                        score is not None
                        and (len(progress.valid_script_scores) == 0
                             or score > max(progress.valid_script_scores)))
                    if score is None:
                        score = 0.0  # ensure a valid value is written
                    progress.valid_script_scores.append(score)
                    if need_to_save:
                        save_path = config.saveto + ".best-valid-script"
                        save_non_checkpoint(sess, saver, save_path)
                        write_config_to_json_file(config, save_path)

                        progress_path = '{}.progress.json'.format(save_path)
                        progress.save_to_json(progress_path)

            if config.save_freq and progress.uidx % config.save_freq == 0:
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                write_config_to_json_file(
                    config, "%s-%s" % (config.saveto, progress.uidx))

                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)

            if config.finish_after and progress.uidx % config.finish_after == 0:
                logging.info("Maximum number of updates reached")
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                write_config_to_json_file(
                    config, "%s-%s" % (config.saveto, progress.uidx))

                progress.estop = True
                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)
                break
        if progress.estop:
            break
コード例 #8
0
ファイル: translate_utils.py プロジェクト: bariluz93/nematus
def translate_file(input_file,
                   output_file,
                   session,
                   sampler,
                   config,
                   max_translation_len,
                   normalization_alpha,
                   consts_config_str,
                   nbest=False,
                   minibatch_size=80,
                   maxibatch_size=20):
    """Translates a source file using a RandomSampler or BeamSearchSampler.

    Args:
        input_file: file object from which source sentences will be read.
        output_file: file object to which translations will be written.
        session: TensorFlow session.
        sampler: BeamSearchSampler or RandomSampler object.
        config: model config.
        max_translation_len: integer specifying maximum translation length.
        normalization_alpha: float specifying alpha parameter for length
            normalization.
        nbest: if True, produce n-best output with scores; otherwise 1-best.
        minibatch_size: minibatch size in sentences.
        maxibatch_size: number of minibatches to read and sort, pre-translation.
    """
    def translate_maxibatch(maxibatch, num_to_target, num_prev_translated):
        """Translates an individual maxibatch.

        Args:
            maxibatch: a list of sentences.
            num_to_target: dictionary mapping target vocabulary IDs to strings.
            num_prev_translated: the number of previously translated sentences.
        """
        # Sort the maxibatch by length and split into minibatches.
        try:
            minibatches, idxs = util.read_all_lines(config, maxibatch,
                                                    minibatch_size)
        except exception.Error as x:
            logging.error(x.msg)
            sys.exit(1)

        # Translate the minibatches and store the resulting beam (i.e.
        # translations and scores) for each sentence.
        beams = []
        for x in minibatches:
            y_dummy = numpy.zeros(shape=(len(x), 1))
            x, x_mask, _, _ = util.prepare_data(x,
                                                y_dummy,
                                                config.factors,
                                                maxlen=None)
            sample = translate_batch(session, sampler, x, x_mask,
                                     max_translation_len, normalization_alpha)
            beams.extend(sample)
            num_translated = num_prev_translated + len(beams)
            logging.info('Translated {} sents'.format(num_translated))

        # Put beams into the same order as the input maxibatch.
        tmp = numpy.array(beams, dtype=numpy.object)
        ordered_beams = tmp[idxs.argsort()]

        # Write the translations to the output file.
        for i, beam in enumerate(ordered_beams):
            if nbest:
                num = num_prev_translated + i
                for sent, cost in beam:
                    translation = util.seq2words(sent, num_to_target)
                    line = "{} ||| {} ||| {}\n".format(num, translation,
                                                       str(cost))
                    output_file.write(line)
            else:
                best_hypo, cost = beam[0]
                line = util.seq2words(best_hypo, num_to_target) + '\n'
                output_file.write(line)

    _, _, COLLECT_EMBEDDING_TABLE, _ = get_basic_configurations(
        consts_config_str)
    _, _, _, num_to_target = util.load_dictionaries(config)

    logging.info("NOTE: Length of translations is capped to {}".format(
        max_translation_len))

    start_time = time.time()

    num_translated = 0
    maxibatch = []
    line_num = 0
    while True:
        if COLLECT_EMBEDDING_TABLE and line_num > 1:
            break

        line = input_file.readline()
        # print(line)
        line_num += 1
        if line == "":
            if len(maxibatch) > 0:
                translate_maxibatch(maxibatch, num_to_target, num_translated)
                num_translated += len(maxibatch)
            break
        maxibatch.append(line)
        if len(maxibatch) == (maxibatch_size * minibatch_size):
            translate_maxibatch(maxibatch, num_to_target, num_translated)
            num_translated += len(maxibatch)
            maxibatch = []

    duration = time.time() - start_time
    logging.info('Translated {} sents in {} sec. Speed {} sents/sec'.format(
        num_translated, duration, num_translated / duration))
コード例 #9
0
ファイル: inference.py プロジェクト: rsennrich/nematus
def translate_file(input_file, output_file, session, models, configs,
                   beam_size=12, nbest=False, minibatch_size=80,
                   maxibatch_size=20, normalization_alpha=1.0):
    """Translates a source file using a translation model (or ensemble).

    Args:
        input_file: file object from which source sentences will be read.
        output_file: file object to which translations will be written.
        session: TensorFlow session.
        models: list of model objects to use for beam search.
        configs: model configs.
        beam_size: beam width.
        nbest: if True, produce n-best output with scores; otherwise 1-best.
        minibatch_size: minibatch size in sentences.
        maxibatch_size: number of minibatches to read and sort, pre-translation.
        normalization_alpha: alpha parameter for length normalization.
    """

    def translate_maxibatch(maxibatch, model_set, num_to_target,
                            num_prev_translated):
        """Translates an individual maxibatch.

        Args:
            maxibatch: a list of sentences.
            model_set: an InferenceModelSet object.
            num_to_target: dictionary mapping target vocabulary IDs to strings.
            num_prev_translated: the number of previously translated sentences.
        """

        # Sort the maxibatch by length and split into minibatches.
        try:
            minibatches, idxs = util.read_all_lines(configs[0], maxibatch,
                                                    minibatch_size)
        except exception.Error as x:
            logging.error(x.msg)
            sys.exit(1)

        # Translate the minibatches and store the resulting beam (i.e.
        # translations and scores) for each sentence.
        beams = []
        for x in minibatches:
            y_dummy = numpy.zeros(shape=(len(x),1))
            x, x_mask, _, _ = util.prepare_data(x, y_dummy, configs[0].factors,
                                                maxlen=None)
            sample = model_set.beam_search(
                session=session,
                x=x,
                x_mask=x_mask,
                beam_size=beam_size,
                normalization_alpha=normalization_alpha)
            beams.extend(sample)
            num_translated = num_prev_translated + len(beams)
            logging.info('Translated {} sents'.format(num_translated))

        # Put beams into the same order as the input maxibatch.
        tmp = numpy.array(beams, dtype=numpy.object)
        ordered_beams = tmp[idxs.argsort()]

        # Write the translations to the output file.
        for i, beam in enumerate(ordered_beams):
            if nbest:
                num = num_prev_translated + i
                for sent, cost in beam:
                    translation = util.seq2words(sent, num_to_target)
                    line = "{} ||| {} ||| {}\n".format(num, translation,
                                                       str(cost))
                    output_file.write(line)
            else:
                best_hypo, cost = beam[0]
                line = util.seq2words(best_hypo, num_to_target) + '\n'
                output_file.write(line)

    _, _, _, num_to_target = util.load_dictionaries(configs[0])
    model_set = InferenceModelSet(models, configs)

    logging.info("NOTE: Length of translations is capped to {}".format(
        configs[0].translation_maxlen))

    start_time = time.time()

    num_translated = 0
    maxibatch = []
    while True:
        line = input_file.readline()
        if line == "":
            if len(maxibatch) > 0:
                translate_maxibatch(maxibatch, model_set, num_to_target,
                                    num_translated)
                num_translated += len(maxibatch)
            break
        maxibatch.append(line)
        if len(maxibatch) == (maxibatch_size * minibatch_size):
            translate_maxibatch(maxibatch, model_set, num_to_target,
                                num_translated)
            num_translated += len(maxibatch)
            maxibatch = []

    duration = time.time() - start_time
    logging.info('Translated {} sents in {} sec. Speed {} sents/sec'.format(
        num_translated, duration, num_translated/duration))
コード例 #10
0
ファイル: train.py プロジェクト: byccln/nematus_glove
def train(config, sess):
    ####################################################
    assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \
    "MAP training requires a prior model file: Use command-line option --prior_model"

    # Construct the graph, with one model replica per GPU

    num_gpus = len(util.get_available_gpus())
    num_replicas = max(1, num_gpus)

    logging.info('Building model...')
    replicas = []
    for i in range(num_replicas):
        device_type = "GPU" if num_gpus > 0 else "CPU"
        device_spec = tf.DeviceSpec(device_type=device_type, device_index=i)
        with tf.device(device_spec):
            with tf.variable_scope(tf.get_variable_scope(), reuse=(i > 0)):
                if config.model_type == "transformer":
                    model = TransformerModel(config)
                else:
                    model = rnn_model.RNNModel(config)
                replicas.append(model)

    init = tf.zeros_initializer(dtype=tf.int32)
    global_step = tf.get_variable('time', [],
                                  initializer=init,
                                  trainable=False)

    if config.learning_schedule == "constant":
        schedule = ConstantSchedule(config.learning_rate)
    elif config.learning_schedule == "transformer":
        schedule = TransformerSchedule(global_step=global_step,
                                       dim=config.state_size,
                                       warmup_steps=config.warmup_steps)
    else:
        logging.error('Learning schedule type is not valid: {}'.format(
            config.learning_schedule))
        sys.exit(1)

    if config.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(
            learning_rate=schedule.learning_rate,
            beta1=config.adam_beta1,
            beta2=config.adam_beta2,
            epsilon=config.adam_epsilon)
    else:
        logging.error('No valid optimizer defined: {}'.format(
            config.optimizer))
        sys.exit(1)

    if config.summary_freq:
        summary_dir = (config.summary_dir if config.summary_dir is not None
                       else os.path.abspath(os.path.dirname(config.saveto)))
        writer = tf.summary.FileWriter(summary_dir, sess.graph)
    else:
        writer = None

    updater = ModelUpdater(config, num_gpus, replicas, optimizer, global_step,
                           writer)

    saver, progress = model_loader.init_or_restore_variables(config,
                                                             sess,
                                                             train=True)

    ############################################################
    #add: pretrain
    if config.pretrain:
        logging.info("Start pre-training")
        #预训练网络参数
        pre_batch_size = 1000
        epochs = 20
        pre_learning_rate = 0.001
        pre_optimizer = tf.train.GradientDescentOptimizer(
            pre_learning_rate).minimize(replicas[0].loss_pre_train)
        #加载预训练数据及相关字典
        gvocab, gvectors = util.pre_load_data(config.pretrain_vocab,
                                              config.pretrain_vectors)
        pre_vocab_list = list(gvocab.keys())
        #过采样
        pre_train_list = []
        with open('/media/ntfs-3/EXP/MULTI/mix/zh-en/data3/glove/vocab.txt',
                  'r',
                  encoding='utf-8') as f:
            for line in f:
                k, v = line.strip().split()
                pre_train_list.extend([k] * int(v))
        utf8_dict = json.load(
            open(config.source_dicts[0], 'r', encoding='utf-8'))
        embedding_list = []
        #开始训练
        for i in range(epochs):
            logging.info("epoch:{}".format(i))
            if i == epochs - 1:
                source_x, source_y, _vocab = util.get_data(pre_vocab_list,
                                                           pre_batch_size,
                                                           gvocab,
                                                           gvectors,
                                                           utf8_dict,
                                                           shuffle=False)
            else:
                source_x, source_y, _vocab = util.get_data(pre_train_list,
                                                           pre_batch_size,
                                                           gvocab,
                                                           gvectors,
                                                           utf8_dict,
                                                           shuffle=True)
            for idx, [s_x, s_y] in enumerate(zip(source_x, source_y)):
                assert len(s_x) == len(s_y), "{}, {}".format(
                    len(s_x), len(s_y))
                sx, sy = util.pre_prepare_data(s_x, s_y)
                feed_dict = {}
                feed_dict[replicas[0].pre_inputs.x] = sx
                feed_dict[replicas[0].pre_inputs.y] = sy
                _, loss, embedding = sess.run([
                    pre_optimizer, replicas[0].loss_pre_train,
                    replicas[0].pre_embedding
                ],
                                              feed_dict=feed_dict)
                if idx % 100 == 0:
                    logging.info("loss:{}".format(loss))
                if i == epochs - 1:
                    embedding_list.append(embedding)
        assert _vocab == pre_vocab_list
        emb = embedding_list[0]
        for e in embedding_list[1:]:
            emb = numpy.concatenate((emb, e))
        numpy.save("pre_emb/pre_emb.npy", emb)
        with open("pre_emb/vocab", "w", encoding="utf-8") as f:
            f.write("\n".join(pre_vocab_list))
        #tsne可视化
        tsne = util.get_tsne(emb, "pre_emb/tsne.npy")
        gtsne = numpy.load(config.pretrain_tsne)
        #util.plot_tsne(_vocab, tsne, gvocab, gtsne, top=20)
        #exit(0)
    ##################################################################################

    global_step.load(progress.uidx, sess)

    # Use an InferenceModelSet to abstract over model types for sampling and
    # beam search. Multi-GPU sampling and beam search are not currently
    # supported, so we just use the first replica.
    model_set = inference.InferenceModelSet([replicas[0]], [config])

    #save model options
    write_config_to_json_file(config, config.saveto)

    text_iterator, valid_text_iterator = load_data(config)
    _, _, num_to_source, num_to_target = util.load_dictionaries(config)
    total_loss = 0.
    n_sents, n_words = 0, 0
    last_time = time.time()
    logging.info("Initial uidx={}".format(progress.uidx))
    for progress.eidx in range(progress.eidx, config.max_epochs):
        logging.info('Starting epoch {0}'.format(progress.eidx))
        for pre_source_sents, source_sents, target_sents in text_iterator:
            #if len(source_sents[0][0]) != config.factors:
            #logging.error('Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'.format(config.factors, len(source_sents[0][0])))
            #sys.exit(1)

            px_in, x_in, x_mask_in, y_in, y_mask_in = util.prepare_data(
                source_sents,
                target_sents,
                config.factors,
                pre_source_sents,
                maxlen=None)

            if x_in is None:
                logging.info(
                    'Minibatch with zero sample under length {0}'.format(
                        config.maxlen))
                continue

            write_summary_for_this_batch = config.summary_freq and (
                (progress.uidx % config.summary_freq == 0) or
                (config.finish_after
                 and progress.uidx % config.finish_after == 0))
            (factors, seqLen, uLen, batch_size) = x_in.shape

            loss = updater.update(sess, px_in, x_in, x_mask_in, y_in,
                                  y_mask_in, write_summary_for_this_batch)

            total_loss += loss
            n_sents += batch_size
            n_words += int(numpy.sum(y_mask_in))
            progress.uidx += 1

            if config.disp_freq and progress.uidx % config.disp_freq == 0:
                duration = time.time() - last_time
                disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]')
                logging.info(
                    '{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'
                    .format(disp_time, progress.eidx, progress.uidx,
                            total_loss / n_words, n_words / duration,
                            n_sents / duration))
                last_time = time.time()
                total_loss = 0.
                n_sents = 0
                n_words = 0

            if config.sample_freq and progress.uidx % config.sample_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :, :
                                                      10], x_mask_in[:, :, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = model_set.sample(sess, x_small, x_mask_small)
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    #source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    sample = util.seq2words(ss, num_to_target)
                    #logging.info('SOURCE: {}'.format(source))
                    #logging.info('SOURCE: {}'.format(xx))
                    logging.info('TARGET: {}'.format(target))
                    logging.info('SAMPLE: {}'.format(sample))

            if config.beam_freq and progress.uidx % config.beam_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :, :
                                                      10], x_mask_in[:, :, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = model_set.beam_search(
                    sess,
                    x_small,
                    x_mask_small,
                    config.beam_size,
                    normalization_alpha=config.normalization_alpha)
                # samples is a list with shape batch x beam x len
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    #source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    #logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    for i, (sample_seq, cost) in enumerate(ss):
                        sample = util.seq2words(sample_seq, num_to_target)
                        msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format(
                            i, sample, cost, len(sample), cost / len(sample))
                        logging.info(msg)

            if config.valid_freq and progress.uidx % config.valid_freq == 0:
                valid_ce = validate(sess, replicas[0], config,
                                    valid_text_iterator)
                if (len(progress.history_errs) == 0
                        or valid_ce < min(progress.history_errs)):
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter = 0
                    save_non_checkpoint(sess, saver, config.saveto)
                    progress_path = '{0}.progress.json'.format(config.saveto)
                    progress.save_to_json(progress_path)
                else:
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter += 1
                    if progress.bad_counter > config.patience:
                        logging.info('Early Stop!')
                        progress.estop = True
                        break
                if config.valid_script is not None:
                    score = validate_with_script(sess, replicas[0], config)
                    need_to_save = (
                        score is not None
                        and (len(progress.valid_script_scores) == 0
                             or score > max(progress.valid_script_scores)))
                    if score is None:
                        score = 0.0  # ensure a valid value is written
                    progress.valid_script_scores.append(score)
                    if need_to_save:
                        progress.bad_counter = 0
                        save_path = config.saveto + ".best-valid-script"
                        save_non_checkpoint(sess, saver, save_path)
                        write_config_to_json_file(config, save_path)

                        progress_path = '{}.progress.json'.format(save_path)
                        progress.save_to_json(progress_path)

            if config.save_freq and progress.uidx % config.save_freq == 0:
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                write_config_to_json_file(
                    config, "%s-%s" % (config.saveto, progress.uidx))

                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)

            if config.finish_after and progress.uidx % config.finish_after == 0:
                logging.info("Maximum number of updates reached")
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                write_config_to_json_file(
                    config, "%s-%s" % (config.saveto, progress.uidx))

                progress.estop = True
                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)
                break
        if progress.estop:
            break
コード例 #11
0
ファイル: train.py プロジェクト: rsennrich/nematus
def train(config, sess):
    assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \
    "MAP training requires a prior model file: Use command-line option --prior_model"

    # Construct the graph, with one model replica per GPU

    num_gpus = len(util.get_available_gpus())
    num_replicas = max(1, num_gpus)

    logging.info('Building model...')
    replicas = []
    for i in range(num_replicas):
        device_type = "GPU" if num_gpus > 0 else "CPU"
        device_spec = tf.DeviceSpec(device_type=device_type, device_index=i)
        with tf.device(device_spec):
            with tf.variable_scope(tf.get_variable_scope(), reuse=(i>0)):
                if config.model_type == "transformer":
                    model = TransformerModel(config)
                else:
                    model = rnn_model.RNNModel(config)
                replicas.append(model)

    init = tf.zeros_initializer(dtype=tf.int32)
    global_step = tf.get_variable('time', [], initializer=init, trainable=False)

    if config.learning_schedule == "constant":
        schedule = ConstantSchedule(config.learning_rate)
    elif config.learning_schedule == "transformer":
        schedule = TransformerSchedule(global_step=global_step,
                                       dim=config.state_size,
                                       warmup_steps=config.warmup_steps)
    else:
        logging.error('Learning schedule type is not valid: {}'.format(
            config.learning_schedule))
        sys.exit(1)

    if config.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=schedule.learning_rate,
                                           beta1=config.adam_beta1,
                                           beta2=config.adam_beta2,
                                           epsilon=config.adam_epsilon)
    else:
        logging.error('No valid optimizer defined: {}'.format(config.optimizer))
        sys.exit(1)

    if config.summary_freq:
        summary_dir = (config.summary_dir if config.summary_dir is not None
                       else os.path.abspath(os.path.dirname(config.saveto)))
        writer = tf.summary.FileWriter(summary_dir, sess.graph)
    else:
        writer = None

    updater = ModelUpdater(config, num_gpus, replicas, optimizer, global_step,
                           writer)

    saver, progress = model_loader.init_or_restore_variables(
        config, sess, train=True)

    global_step.load(progress.uidx, sess)

    # Use an InferenceModelSet to abstract over model types for sampling and
    # beam search. Multi-GPU sampling and beam search are not currently
    # supported, so we just use the first replica.
    model_set = inference.InferenceModelSet([replicas[0]], [config])

    #save model options
    write_config_to_json_file(config, config.saveto)

    text_iterator, valid_text_iterator = load_data(config)
    _, _, num_to_source, num_to_target = util.load_dictionaries(config)
    total_loss = 0.
    n_sents, n_words = 0, 0
    last_time = time.time()
    logging.info("Initial uidx={}".format(progress.uidx))
    for progress.eidx in range(progress.eidx, config.max_epochs):
        logging.info('Starting epoch {0}'.format(progress.eidx))
        for source_sents, target_sents in text_iterator:
            if len(source_sents[0][0]) != config.factors:
                logging.error('Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'.format(config.factors, len(source_sents[0][0])))
                sys.exit(1)
            x_in, x_mask_in, y_in, y_mask_in = util.prepare_data(
                source_sents, target_sents, config.factors, maxlen=None)
            if x_in is None:
                logging.info('Minibatch with zero sample under length {0}'.format(config.maxlen))
                continue
            write_summary_for_this_batch = config.summary_freq and ((progress.uidx % config.summary_freq == 0) or (config.finish_after and progress.uidx % config.finish_after == 0))
            (factors, seqLen, batch_size) = x_in.shape

            loss = updater.update(sess, x_in, x_mask_in, y_in, y_mask_in,
                                  write_summary_for_this_batch)
            total_loss += loss
            n_sents += batch_size
            n_words += int(numpy.sum(y_mask_in))
            progress.uidx += 1

            if config.disp_freq and progress.uidx % config.disp_freq == 0:
                duration = time.time() - last_time
                disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]')
                logging.info('{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'.format(disp_time, progress.eidx, progress.uidx, total_loss/n_words, n_words/duration, n_sents/duration))
                last_time = time.time()
                total_loss = 0.
                n_sents = 0
                n_words = 0

            if config.sample_freq and progress.uidx % config.sample_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :10], x_mask_in[:, :10], y_in[:, :10]
                samples = model_set.sample(sess, x_small, x_mask_small)
                assert len(samples) == len(x_small.T) == len(y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    sample = util.seq2words(ss, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    logging.info('SAMPLE: {}'.format(sample))

            if config.beam_freq and progress.uidx % config.beam_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :10], x_mask_in[:, :10], y_in[:,:10]
                samples = model_set.beam_search(sess, x_small, x_mask_small,
                                               config.beam_size,
                                               normalization_alpha=config.normalization_alpha)
                # samples is a list with shape batch x beam x len
                assert len(samples) == len(x_small.T) == len(y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    for i, (sample_seq, cost) in enumerate(ss):
                        sample = util.seq2words(sample_seq, num_to_target)
                        msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format(
                            i, sample, cost, len(sample), cost/len(sample))
                        logging.info(msg)

            if config.valid_freq and progress.uidx % config.valid_freq == 0:
                valid_ce = validate(sess, replicas[0], config,
                                    valid_text_iterator)
                if (len(progress.history_errs) == 0 or
                    valid_ce < min(progress.history_errs)):
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter = 0
                    save_non_checkpoint(sess, saver, config.saveto)
                    progress_path = '{0}.progress.json'.format(config.saveto)
                    progress.save_to_json(progress_path)
                else:
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter += 1
                    if progress.bad_counter > config.patience:
                        logging.info('Early Stop!')
                        progress.estop = True
                        break
                if config.valid_script is not None:
                    score = validate_with_script(sess, replicas[0], config)
                    need_to_save = (score is not None and
                        (len(progress.valid_script_scores) == 0 or
                         score > max(progress.valid_script_scores)))
                    if score is None:
                        score = 0.0  # ensure a valid value is written
                    progress.valid_script_scores.append(score)
                    if need_to_save:
                        progress.bad_counter = 0
                        save_path = config.saveto + ".best-valid-script"
                        save_non_checkpoint(sess, saver, save_path)
                        write_config_to_json_file(config, save_path)

                        progress_path = '{}.progress.json'.format(save_path)
                        progress.save_to_json(progress_path)

            if config.save_freq and progress.uidx % config.save_freq == 0:
                saver.save(sess, save_path=config.saveto, global_step=progress.uidx)
                write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx))

                progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx)
                progress.save_to_json(progress_path)

            if config.finish_after and progress.uidx % config.finish_after == 0:
                logging.info("Maximum number of updates reached")
                saver.save(sess, save_path=config.saveto, global_step=progress.uidx)
                write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx))

                progress.estop=True
                progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx)
                progress.save_to_json(progress_path)
                break
        if progress.estop:
            break