Ejemplo n.º 1
0
    def train(corpus, bigrams_dims, unigrams_dims, lstm_units, hidden_units,
              epochs, batch_size, train_data_file, dev_data_file,
              model_save_file, droprate, unk_params, alpha, beta):

        start_time = time.time()

        fm = corpus
        bigrams_size = corpus.total_bigrams()
        unigrams_size = corpus.total_unigrams()

        network = Network(
            bigrams_size=bigrams_size,
            unigrams_size=unigrams_size,
            bigrams_dims=bigrams_dims,
            unigrams_dims=unigrams_dims,
            lstm_units=lstm_units,
            hidden_units=hidden_units,
            label_size=fm.total_labels(),
            span_nums=fm.total_span_nums(),
            droprate=droprate,
        )

        network.init_params()

        print('Hidden units : {} ,per LSTM units : {}'.format(
            hidden_units,
            lstm_units,
        ))

        print('Embeddings: bigrams = {}, unigrams = {}'.format(
            (bigrams_size, bigrams_dims), (unigrams_size, unigrams_dims)))

        print('Dropout rate : {}'.format(droprate))
        print('Parameters initialized in [-0.01,0.01]')
        print('Random UNKing parameter z = {}'.format(unk_params))

        training_data = corpus.gold_data_from_file(train_data_file)
        num_batched = -(-len(training_data) // batch_size)
        print('Loaded {} training sentences ({} batches of size {})!'.format(
            len(training_data),
            num_batched,
            batch_size,
        ))

        parse_every = -(-num_batched // 4)

        dev_sentences = SegSentence.load_sentence_file(dev_data_file)
        print('Loaded {} validation sentences!'.format(len(dev_sentences)))

        best_acc = FScore()
        for epoch in xrange(1, epochs + 1):
            print('............ epoch {} ............'.format(epoch))

            total_cost = 0.0
            total_states = 0
            training_acc = FScore()

            np.random.shuffle(training_data)

            for b in xrange(num_batched):
                batch = training_data[(b * batch_size):(b + 1) * batch_size]

                explore = [
                    Segmenter.exploration(example,
                                          fm,
                                          network,
                                          alpha=alpha,
                                          beta=beta) for example in batch
                ]
                for (_, acc) in explore:
                    training_acc += acc

                batch = [example for (example, _) in explore]

                dynet.renew_cg()
                network.prep_params()

                errors = []
                for example in batch:
                    ## random UNKing ##
                    for (i, uni) in enumerate(example['unigrams']):
                        if uni <= 2:
                            continue

                        u_freq = fm.unigrams_freq_list[uni]
                        drop_prob = unk_params / (unk_params + u_freq)
                        r = np.random.random()
                        if r < drop_prob:
                            example['unigrams'][i] = 0

                    for (i, bi) in enumerate(example['fwd_bigrams']):
                        if bi <= 2:
                            continue

                        b_freq = fm.bigrams_freq_list[bi]
                        drop_prob = unk_params / (unk_params + b_freq)
                        r = np.random.random()
                        if r < drop_prob:
                            example['fwd_bigrams'][i] = 0

                    fwd, back = network.evaluate_recurrent(
                        example['fwd_bigrams'],
                        example['unigrams'],
                    )

                    for (left,
                         right), correct in example['label_data'].items():
                        # correct = example['label_data'][(left,right)]
                        scores = network.evaluate_labels(
                            fwd, back, left, right)

                        probs = dynet.softmax(scores)
                        loss = -dynet.log(dynet.pick(probs, correct))
                        errors.append(loss)
                    total_states += len(example['label_data'])

                batch_error = dynet.esum(errors)
                total_cost += batch_error.scalar_value()
                batch_error.backward()
                network.trainer.update()

                mean_cost = total_cost / total_states

                print(
                    '\rBatch {}  Mean Cost {:.4f}  [Train: {}]'.format(
                        b,
                        mean_cost,
                        training_acc,
                    ),
                    end='',
                )
                sys.stdout.flush()

                if ((b + 1) % parse_every) == 0 or b == (num_batched - 1):
                    dev_acc = Segmenter.evaluate_corpus(
                        dev_sentences,
                        fm,
                        network,
                    )
                    print(' [Val: {}]'.format(dev_acc))

                    if dev_acc.fscore() > best_acc.fscore():
                        best_acc = dev_acc
                        network.save(model_save_file)
                        print('    [saved model : {}]'.format(model_save_file))

            current_time = time.time()
            runmins = (current_time - start_time) / 60
            print(' Elapsed time: {:.2f}m'.format(runmins))

        return network