Ejemplo n.º 1
0
    def _translate(self, process_id, input_item, get_sampler, sess):
        """
        Actual translation (model sampling).
        """

        # unpack input item attributes
        k = input_item.k
        x = input_item.batch
        alpha = input_item.normalization_alpha
        #max_ratio = input_item.max_ratio

        y_dummy = numpy.zeros(shape=(len(x),1))
        x, x_mask, _, _ = util.prepare_data(x, y_dummy,
                                            self._options[0].factors,
                                            maxlen=None)

        sample = translate_utils.translate_batch(
            session=sess,
            sampler=get_sampler(k),
            x=x,
            x_mask=x_mask,
            max_translation_len=self._options[0].translation_maxlen,
            normalization_alpha=alpha)

        return sample
Ejemplo n.º 2
0
def full_sampler(replica, sampler, sess, config, x, x_mask, y, y_mask):
    """generate candidate sentences used for Minimum Risk Training

    Args:
        replica: inference models to do sampling
        x: (factor, len, batch_size)
        x_mask: (len, batch_size)
        y: (len, batch_size)
        y_mask: (len, batch_size)
    Returns:
        x, x_mask, y, y_mask are four lists containing the corresponding content of
        source-candidate sentence pairs, with shape:
        x: (factor, len, batch_size*sampleN)
        x_mask: (len, batch_size*sampleN)
        y: (len, batch_size*sampleN)
        y_mask: (len, batch_size*sampleN)

        y is a list of the corresponding references; index is
        a list of number indicating the starting point of different source sentences.
    """

    sampleN = config.samplesN

    # set maximum number of tokens of sampled candidates
    dynamic_max_len = int(config.max_len_a * x_mask.shape[0] +
                          config.max_len_b)
    max_translation_len = min(config.translation_maxlen, dynamic_max_len)

    if config.sample_way == 'beam_search':

        # split the minibatch into multiple sub-batches, and execute samplings for each sub-batch separately
        if config.max_sentences_of_sampling > 0:
            # number of split equals to batch_size / maximum accepted sentences for sampling (in a device)
            num_split = math.ceil(x_mask.shape[1] /
                                  config.max_sentences_of_sampling)
            # split the numpy array into a list of numpy array
            split_x = np.array_split(x, num_split, 2)
            split_x_mask = np.array_split(x_mask, num_split, 1)
            sample_and_score = []
            # feed sub-batch into model to generate samples
            for i in range(len(split_x)):
                sample_and_score += translate_utils.translate_batch(
                    sess, sampler, split_x[i], split_x_mask[i],
                    max_translation_len, config.normalization_alpha)
        else:
            sample_and_score = translate_utils.translate_batch(
                sess, sampler, x, x_mask, max_translation_len,
                config.normalization_alpha)

        # sample_and_score: outer: batch_size, inner: sampleN elements(each represents a sample)

        # fetch samplings
        samples = []
        for i, ss in enumerate(sample_and_score):
            samples.append([])
            for (sample_seq, cost) in ss:
                samples[i].append(sample_seq.tolist())
        # samples: list with shape (batch_size, sampleN, len), uneven
        # beam search sampling, no need to remove duplicate samples.

        # samples number of each batch (useless in beam sampling mode)
        index = [[0]]
        for i in range((len(samples))):
            index[0].append(index[0][i] + sampleN)

    elif config.sample_way == 'randomly_sample':

        samples = []
        for i in range(x_mask.shape[1]):
            samples.append([])

        if config.max_sentences_of_sampling > 0:
            num_split = math.ceil(x_mask.shape[1] /
                                  config.max_sentences_of_sampling)
            split_x = np.array_split(x, num_split, 2)
            split_x_mask = np.array_split(x_mask, num_split, 1)
            # set normalization_alpha to 0 for randomly sampling (no effect on sampled sentences)
            sample = translate_utils.translate_batch(sess, sampler, split_x[0],
                                                     split_x_mask[0],
                                                     max_translation_len, 0.0)
            for i in range(1, len(split_x)):
                tmp = translate_utils.translate_batch(sess, sampler,
                                                      split_x[i],
                                                      split_x_mask[i],
                                                      max_translation_len, 0.0)
                sample = np.concatenate((sample, tmp))
        else:
            sample = translate_utils.translate_batch(sess, sampler, x, x_mask,
                                                     max_translation_len, 0.0)
        # sample: list: (batch_size, sampleN), each element is a tuple of (numpy array of a sampled sentence, its score)
        for i in range(len(samples)):
            for ss in sample[i]:
                samples[i].append(ss[0].tolist())
            # samples: list with shape (batch_size, sampleN, len), uneven

        # remove duplicate samples
        for i in range(len(samples)):
            samples[i].sort()
            samples[i] = [s for s, _ in itertools.groupby(samples[i])]

        # remove the corresponding x and x_mask
        index = []
        for i in range(len(samples)):
            index.append(len(samples[i]))
        for i in range(x_mask.shape[1]):
            x_new = np.repeat(x, index, axis=2)
            x_mask_new = np.repeat(x_mask, index, axis=1)

        # calculate the the number of remaining candidate samplings for each source sentence,
        # store the information in 'index' for the subsequent normalisation of distribution and calculation of
        # expected risk.
        index = [[0]]
        for i in range((len(samples))):
            index[0].append(index[0][i] + len(samples[i]))

    else:
        assert False

    # add reference in candidate sentences:

    # convert from time domain to batch domain
    y = list(map(list, zip(*y)))
    # y: batch_size X len
    y_mask = list(map(list, zip(*y_mask)))
    if config.mrt_reference:
        for i in range(len(samples)):
            # delete the pad of reference
            lenth = int(sum(y_mask[i]))
            y[i] = y[i][:lenth]
            # reference always at the first
            if y[i] not in samples[i]:
                samples[i].append(y[i])
                samples[i].pop(-2)

    # add padding: (no specific padding token, just assign 0(<EOS>) and masked to avoid generating loss)

    # combine samples from different batches (decrease the outermost dimension)
    ss = []
    for i in samples:
        ss += i
    samples = ss
    # samples: list with shape (batch_size*sampleN, len), uneven
    n_samples = len(samples)
    lengths_y = [len(s) for s in samples]
    maxlen_y = np.max(lengths_y) + 1

    y_new = np.zeros((maxlen_y, n_samples)).astype('int64')
    y_mask_new = np.zeros((maxlen_y, n_samples)).astype('float32')

    for idx, s_y in enumerate(samples):
        y_new[:lengths_y[idx], idx] = s_y
        y_mask_new[:lengths_y[idx] + 1, idx] = 1.

    return x_new.tolist(), x_mask_new.tolist(), y_new.tolist(
    ), y_mask_new.tolist(), y, index
Ejemplo n.º 3
0
def train(config, sess):
    assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \
    "MAP training requires a prior model file: Use command-line option --prior_model"

    # Construct the graph, with one model replica per GPU

    num_gpus = len(tf_utils.get_available_gpus())
    num_replicas = max(1, num_gpus)

    if config.loss_function == 'MRT':
        assert config.gradient_aggregation_steps == 1
        assert config.max_sentences_per_device == 0, "MRT mode does not support sentence-based split"
        if config.max_tokens_per_device != 0:
            assert (config.samplesN * config.maxlen <= config.max_tokens_per_device), "need to make sure candidates of a sentence could be " \
                                                                                      "feed into the model"
        else:
            assert num_replicas == 1, "MRT mode does not support sentence-based split"
            assert (config.samplesN * config.maxlen <= config.token_batch_size), "need to make sure candidates of a sentence could be " \
                                                                                      "feed into the model"



    logging.info('Building model...')
    replicas = []
    for i in range(num_replicas):
        device_type = "GPU" if num_gpus > 0 else "CPU"
        device_spec = tf.DeviceSpec(device_type=device_type, device_index=i)
        with tf.device(device_spec):
            with tf.variable_scope(tf.get_variable_scope(), reuse=(i>0)):
                if config.model_type == "transformer":
                    model = TransformerModel(config)
                else:
                    model = rnn_model.RNNModel(config)
                replicas.append(model)

    init = tf.zeros_initializer(dtype=tf.int32)
    global_step = tf.get_variable('time', [], initializer=init, trainable=False)

    if config.learning_schedule == "constant":
        schedule = learning_schedule.ConstantSchedule(config.learning_rate)
    elif config.learning_schedule == "transformer":
        schedule = learning_schedule.TransformerSchedule(
            global_step=global_step,
            dim=config.state_size,
            warmup_steps=config.warmup_steps)
    elif config.learning_schedule == "warmup-plateau-decay":
        schedule = learning_schedule.WarmupPlateauDecaySchedule(
            global_step=global_step,
            peak_learning_rate=config.learning_rate,
            warmup_steps=config.warmup_steps,
            plateau_steps=config.plateau_steps)
    else:
        logging.error('Learning schedule type is not valid: {}'.format(
            config.learning_schedule))
        sys.exit(1)

    if config.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=schedule.learning_rate,
                                           beta1=config.adam_beta1,
                                           beta2=config.adam_beta2,
                                           epsilon=config.adam_epsilon)
    else:
        logging.error('No valid optimizer defined: {}'.format(config.optimizer))
        sys.exit(1)

    if config.summary_freq:
        summary_dir = (config.summary_dir if config.summary_dir is not None
                       else os.path.abspath(os.path.dirname(config.saveto)))
        writer = tf.summary.FileWriter(summary_dir, sess.graph)
    else:
        writer = None

    updater = ModelUpdater(config, num_gpus, replicas, optimizer, global_step,
                           writer)

    if config.exponential_smoothing > 0.0:
        smoothing = ExponentialSmoothing(config.exponential_smoothing)

    saver, progress = model_loader.init_or_restore_variables(
        config, sess, train=True)

    global_step.load(progress.uidx, sess)

    if config.sample_freq:
        random_sampler = RandomSampler(
            models=[replicas[0]],
            configs=[config],
            beam_size=1)

    if config.beam_freq or config.valid_script is not None:
        beam_search_sampler = BeamSearchSampler(
            models=[replicas[0]],
            configs=[config],
            beam_size=config.beam_size)

    #save model options
    write_config_to_json_file(config, config.saveto)

    text_iterator, valid_text_iterator = load_data(config)
    _, _, num_to_source, num_to_target = util.load_dictionaries(config)
    total_loss = 0.
    n_sents, n_words = 0, 0
    last_time = time.time()
    logging.info("Initial uidx={}".format(progress.uidx))
    # set epoch = 1 if print per-token-probability
    if config.print_per_token_pro:
        config.max_epochs = progress.eidx+1
    for progress.eidx in range(progress.eidx, config.max_epochs):
        logging.info('Starting epoch {0}'.format(progress.eidx))
        for source_sents, target_sents in text_iterator:
            if len(source_sents[0][0]) != config.factors:
                logging.error('Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'.format(config.factors, len(source_sents[0][0])))
                sys.exit(1)
            x_in, x_mask_in, y_in, y_mask_in = util.prepare_data(
                source_sents, target_sents, config.factors, maxlen=None)
            if x_in is None:
                logging.info('Minibatch with zero sample under length {0}'.format(config.maxlen))
                continue
            write_summary_for_this_batch = config.summary_freq and ((progress.uidx % config.summary_freq == 0) or (config.finish_after and progress.uidx % config.finish_after == 0))
            (factors, seqLen, batch_size) = x_in.shape

            output = updater.update(
                sess, x_in, x_mask_in, y_in, y_mask_in, num_to_target,
                write_summary_for_this_batch)

            if config.print_per_token_pro == False:
                total_loss += output
            else:
                # write per-token probability into the file
                f = open(config.print_per_token_pro, 'a')
                for pro in output:
                    pro = str(pro) + '\n'
                    f.write(pro)
                f.close()

            n_sents += batch_size
            n_words += int(numpy.sum(y_mask_in))
            progress.uidx += 1

            # Update the smoothed version of the model variables.
            # To reduce the performance overhead, we only do this once every
            # N steps (the smoothing factor is adjusted accordingly).
            if config.exponential_smoothing > 0.0 and progress.uidx % smoothing.update_frequency == 0:
                sess.run(fetches=smoothing.update_ops)

            if config.disp_freq and progress.uidx % config.disp_freq == 0:
                duration = time.time() - last_time
                disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]')
                logging.info('{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'.format(disp_time, progress.eidx, progress.uidx, total_loss/n_words, n_words/duration, n_sents/duration))
                last_time = time.time()
                total_loss = 0.
                n_sents = 0
                n_words = 0

            if config.sample_freq and progress.uidx % config.sample_freq == 0:
                x_small = x_in[:, :, :10]
                x_mask_small = x_mask_in[:, :10]
                y_small = y_in[:, :10]
                samples = translate_utils.translate_batch(
                    sess, random_sampler, x_small, x_mask_small,
                    config.translation_maxlen, 0.0)
                assert len(samples) == len(x_small.T) == len(y_small.T), \
                    (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    sample = util.seq2words(ss[0][0], num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    logging.info('SAMPLE: {}'.format(sample))

            if config.beam_freq and progress.uidx % config.beam_freq == 0:
                x_small = x_in[:, :, :10]
                x_mask_small = x_mask_in[:, :10]
                y_small = y_in[:,:10]
                samples = translate_utils.translate_batch(
                    sess, beam_search_sampler, x_small, x_mask_small,
                    config.translation_maxlen, config.normalization_alpha)
                assert len(samples) == len(x_small.T) == len(y_small.T), \
                    (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    for i, (sample_seq, cost) in enumerate(ss):
                        sample = util.seq2words(sample_seq, num_to_target)
                        msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format(
                            i, sample, cost, len(sample), cost/len(sample))
                        logging.info(msg)

            if config.valid_freq and progress.uidx % config.valid_freq == 0:
                if config.exponential_smoothing > 0.0:
                    sess.run(fetches=smoothing.swap_ops)
                    valid_ce = validate(sess, replicas[0], config,
                                        valid_text_iterator)
                    sess.run(fetches=smoothing.swap_ops)
                else:
                    valid_ce = validate(sess, replicas[0], config,
                                        valid_text_iterator)
                if (len(progress.history_errs) == 0 or
                    valid_ce < min(progress.history_errs)):
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter = 0
                    save_non_checkpoint(sess, saver, config.saveto)
                    progress_path = '{0}.progress.json'.format(config.saveto)
                    progress.save_to_json(progress_path)
                else:
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter += 1
                    if progress.bad_counter > config.patience:
                        logging.info('Early Stop!')
                        progress.estop = True
                        break
                if config.valid_script is not None:
                    if config.exponential_smoothing > 0.0:
                        sess.run(fetches=smoothing.swap_ops)
                        score = validate_with_script(sess, beam_search_sampler)
                        sess.run(fetches=smoothing.swap_ops)
                    else:
                        score = validate_with_script(sess, beam_search_sampler)
                    need_to_save = (score is not None and
                        (len(progress.valid_script_scores) == 0 or
                         score > max(progress.valid_script_scores)))
                    if score is None:
                        score = 0.0  # ensure a valid value is written
                    progress.valid_script_scores.append(score)
                    if need_to_save:
                        progress.bad_counter = 0
                        save_path = config.saveto + ".best-valid-script"
                        save_non_checkpoint(sess, saver, save_path)
                        write_config_to_json_file(config, save_path)

                        progress_path = '{}.progress.json'.format(save_path)
                        progress.save_to_json(progress_path)

            if config.save_freq and progress.uidx % config.save_freq == 0:
                saver.save(sess, save_path=config.saveto, global_step=progress.uidx)
                write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx))

                progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx)
                progress.save_to_json(progress_path)

            if config.finish_after and progress.uidx % config.finish_after == 0:
                logging.info("Maximum number of updates reached")
                saver.save(sess, save_path=config.saveto, global_step=progress.uidx)
                write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx))

                progress.estop=True
                progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx)
                progress.save_to_json(progress_path)
                break
        if progress.estop:
            break