Example #1
0
def fit(config):
    word_to_context, vocabulary, \
            train_generator, valid_generator, test_generator = \
                prepare_data(config)

    # Violate encapsulation a bit.
    print("need to add a class for the unknown word and to add examples of unknown words to the training data",
            file=sys.stderr)
    label_encoder = train_generator.label_encoder
    target_map = dict(zip(
        label_encoder.classes_, range(len(label_encoder.classes_))))

    n_classes = len(label_encoder.classes_)
    print('n_classes %d' % n_classes)

    # We don't know the number of context embeddings in advance, so we
    # set them at runtime based on the size of the vocabulary.
    config.n_context_embeddings = len(vocabulary)
    print('n_context_embeddings %d' % config.n_context_embeddings)

    graph = build_model(config, n_classes)

    config.logger('model has %d parameters' % graph.count_params())

    config.logger('building callbacks')

    callbacks = build_callbacks(config,
            valid_generator,
            n_samples=config.n_val_callback_samples,
            dictionary=build_retriever(),
            target_map=target_map)

    pbar = build_pbar(train_generator.word_to_context)
    y = []
    for i,(word,contexts) in enumerate(train_generator.word_to_context.items()):
        pbar.update(i+1)
        y_word = label_encoder.transform(word)
        y.extend((y_word for context in contexts))
    pbar.finish()
    class_weight = balanced_class_weights(
            y, n_classes,
            class_weight_exponent=config.class_weight_exponent)

    verbose = 2 if 'background' in config.mode else 1

    #print(next(train_generator.generate(train=True)))

    graph.fit_generator(train_generator.generate(train=True),
            samples_per_epoch=config.samples_per_epoch,
            nb_worker=config.n_worker,
            nb_epoch=config.n_epoch,
            validation_data=valid_generator.generate(),
            nb_val_samples=config.n_val_samples,
            callbacks=callbacks,
            class_weight=class_weight,
            verbose=verbose)
Example #2
0
 def class_weights(self, class_weight_exponent, target="multiclass_correction_target"):
     return balanced_class_weights(self.hdf5_file[target], 2, class_weight_exponent)
Example #3
0
 def class_weights(self, class_weight_exponent):
     return balanced_class_weights(
             self.target['binary_target'],
             2,
             class_weight_exponent)
Example #4
0
 def class_weights(self,
                   class_weight_exponent,
                   target='multiclass_correction_target'):
     return balanced_class_weights(self.hdf5_file[target], 2,
                                   class_weight_exponent)
Example #5
0
 def class_weights(self, class_weight_exponent):
     return balanced_class_weights(self.target['binary_target'], 2,
                                   class_weight_exponent)
    def generate_next(self, i, train=False):
        non_word = self.non_words[i]
        correction = self.corrections[i]
        candidates = self.retriever[non_word]

        if self.use_correct_word_as_non_word_example:
            # Add the correct word to the list of candidates, even if it's there
            # already, so we can teach the model that the true correction for a
            # known word is the word itself.
            candidates.append(correction)

        if len(candidates) <= 1:
            return None

        candidates = candidates[:self.max_candidates]

        targets = np.zeros((len(candidates),2), dtype=np.int64)
        for j, candidate in enumerate(candidates):
            if candidate == correction:
                targets[j, 1] = 1
            else:
                targets[j, 0] = 1

        class_weights = balanced_class_weights(targets[:, 1], n_classes=2,
                class_weight_exponent=self.sample_weight_exponent)
        sample_weights = np.zeros(len(candidates))
        for k, candidate in enumerate(candidates):
            sample_weights[k] = class_weights[targets[k, 1]]

        # Build a character matrix of the non-word inputs, first marking
        # the non-words with start and end of token characters.
        non_word_inputs = ['^'+non_word+'$' for candidate in candidates]
        if self.use_correct_word_as_non_word_example:
            non_word_inputs.append('^'+correction+'$')
        non_word_matrix, non_word_kept = spelling.preprocess.build_char_matrix(
                [non_word] * len(candidates), width=self.model_input_width)

        # Build a character matrix of the real word inputs, first marking
        # the real words with start and end of token characters.
        candidate_word_inputs = ['^'+candidate+'$' for candidate in candidates]
        candidate_word_matrix, candidate_word_kept = spelling.preprocess.build_char_matrix(
                candidates, width=self.model_input_width)

        mask = non_word_kept & candidate_word_kept
        idx = np.where(non_word_kept & candidate_word_kept)[0]

        non_word_matrix = non_word_matrix[mask]
        candidate_word_matrix = candidate_word_matrix[mask]
        targets = targets[mask]
        sample_weights = sample_weights[mask]

        _non_word = np.array([non_word] * len(idx))
        _correct_word = np.array([correction] * len(idx))
        _candidate_word = np.array(candidates)[mask]

        lengths = [len(x) for x in [non_word_matrix, candidate_word_matrix, targets, sample_weights, _non_word, _correct_word, _candidate_word]]
        if not all(l == lengths[0] for l in lengths):
            print('non_word %s non_word_matrix %d candidate_word_matrix %d targets %d sample_weights %d _non_word %d _correct_word %d _candidate_word %d' % 
                    (non_word,
                    len(non_word_matrix),
                    len(candidate_word_matrix),
                    len(targets),
                    len(sample_weights),
                    len(_non_word),
                    len(_correct_word),
                    len(_candidate_word)))

        assert \
                len(non_word_matrix) == \
                len(candidate_word_matrix) == \
                len(targets) == \
                len(sample_weights) == \
                len(_non_word) == \
                len(_correct_word) == \
                len(_candidate_word)

        def log_normalize(x):
            return 1 + np.log(1 + x)


        data_dict = {
                'non_word': np.array([non_word] * len(idx)),
                'correct_word': np.array([correction] * len(idx)),
                'candidate_word': np.array(candidates)[mask],
                'non_word_input': non_word_matrix[mask],
                'candidate_word_input': candidate_word_matrix[mask],
                'binary_correction_target': targets[mask],
                'candidate_rank_first': log_normalize(np.arange(len(candidates))[mask]),
                'candidate_rank_last': log_normalize(np.arange(len(candidates))[mask])
                }

        sample_weight_dict = {
                'binary_correction_target': sample_weights[mask],
                'candidate_rank_first': 10*sample_weights[mask],
                'candidate_rank_last': 2*sample_weights[mask]
                }

        non_words = None
        candidate_words = None

        for name in self.distance_targets:
            if non_words is None:
                non_words = data_dict['non_word'].tolist()
                candidate_words = data_dict['candidate_word'].tolist()

            target_values = []
            for i,non_word in enumerate(non_words):
                candidate_word = candidate_words[i]
                d = spelling.features.distance(
                        non_word, candidate_word, name)
                target_values.append(d)
            data_dict[name+'_first'] = 10*log_normalize(np.array(target_values))
            data_dict[name+'_last'] = 2*log_normalize(np.array(target_values))

            # We want the model to learn to predict the edit distance
            # equally well for all examples.
            #sample_weight_dict[name] = np.ones(len(idx))
            sample_weight_dict[name] = sample_weights[mask]

        return (data_dict, sample_weight_dict)
    def build_next(self, samples):
        correct_words = []
        non_words = []
        non_word_char_inputs = []
        real_words = []
        real_word_char_inputs = []
        modified_contexts = []
        contexts = []
        context_inputs = []
        context_inputs_01 = []
        context_inputs_02 = []
        context_inputs_03 = []
        context_inputs_04 = []
        context_inputs_05 = []
        targets = []
        sample_weights = []

        for i,(word,context) in enumerate(samples):
            # TODO: the generator sometimes creates real words.
            non_word = self.non_word_generator.transform([word])[0]
            # TODO: does it matter whether the candidate list contains
            # the non-word?
            candidates = self.retriever[non_word]
            # This ensures that there's always an example in a mini-batch
            # with target 1.  
            if word not in candidates:
                candidates.append(word)

            # Add the candidates to the batch.
            correct_words.extend([word] * len(candidates))
            non_words.extend([non_word] * len(candidates))
            real_words.extend(candidates)

            if self.use_real_word_examples:
                # Add to the batch one more example, consisting of the
                # real word itself as an example non-word, and the real
                # word as the candidate and true correction.
                correct_words.append(word)
                non_words.append(word)
                real_words.append(word)
                candidates.append(word)

            candidate_targets = []
            for candidate in candidates:
                candidate_targets.append(1 if candidate == word else 0)
                candidate_context = list(context)
                candidate_context[int(len(candidate_context)/2)] = candidate
                modified_contexts.append(candidate_context)

            targets.extend(candidate_targets)

            candidate_targets = np_utils.to_categorical(candidate_targets, 2)
            class_weights = balanced_class_weights(
                candidate_targets[:, 1].astype(int),
                n_classes=2,
                class_weight_exponent=self.sample_weight_exponent)
            for k, candidate in enumerate(candidates):
                sample_weights.append(
                        class_weights[candidate_targets[k, 1]])

            contexts.append(modified_contexts)

        context_inputs = self.context_input_transformer.transform(
            modified_contexts)

        for i,ctx_input in enumerate(context_inputs):
            context_inputs_01.append([ctx_input[0]])
            context_inputs_02.append([ctx_input[1]])
            context_inputs_03.append([ctx_input[2]])
            context_inputs_04.append([ctx_input[3]])
            context_inputs_05.append([ctx_input[4]])
       
        #print("non_words", non_words)
        #print("non_words[0]", non_words[0])
        #print("non_words[-1]", non_words[-1])
        #print("real_words", real_words)
        #print("correct_words", correct_words)

        non_word_char_inputs = self.char_input_transformer.transform(non_words)
        real_word_char_inputs = self.char_input_transformer.transform(real_words)

        # This transformer expects each example to be a list.  (Just this transformer?)
        real_word_inputs = self.real_word_input_transformer.transform(
            [[real_word] for real_word in real_words])

        targets = np_utils.to_categorical(targets, 2)

        data_dict = {
                'correct_word': np.array(correct_words),
                'non_word': np.array(non_words),
                'candidate_word': np.array(real_words),
                self.real_word_char_input_name: np.array(real_word_char_inputs),
                self.non_word_char_input_name: np.array(non_word_char_inputs),
                self.real_word_input_name: np.array(real_word_inputs),
                self.context_input_name: context_inputs,
                '%s_%02d' % (self.context_input_name,1): np.array(context_inputs_01),
                '%s_%02d' % (self.context_input_name,2): np.array(context_inputs_02),
                '%s_%02d' % (self.context_input_name,3): np.array(context_inputs_03),
                '%s_%02d' % (self.context_input_name,4): np.array(context_inputs_04),
                '%s_%02d' % (self.context_input_name,5): np.array(context_inputs_05),
                self.target_name: targets
                }

        sample_weight_dict = {
                self.target_name: np.array(sample_weights)
                }

        return data_dict, sample_weight_dict
Example #8
0
 def class_weights(self, class_weight_exponent, target):
     return balanced_class_weights(
             self.hdf5_file[target],
             2,
             class_weight_exponent)