def train(self, data, words, params):

        start_time = time.time()
        evaluate_all(self, words)

        counter = 0

        try:
            for eidx in xrange(params.epochs):

                kf = self.get_minibatches_idx(len(data),
                                              params.batchsize,
                                              shuffle=True)
                uidx = 0

                for _, train_index in kf:

                    uidx += 1
                    batch = [data[t] for t in train_index]

                    for i in batch:
                        i[0].populate_embeddings(words, True)
                        i[1].populate_embeddings(words, True)

                    (g1x, g1mask, g2x, g2mask, p1x, p1mask, p2x,
                     p2mask) = self.getpairs(batch, params)

                    cost = self.train_function(g1x, g2x, p1x, p2x, g1mask,
                                               g2mask, p1mask, p2mask)

                    if np.isnan(cost) or np.isinf(cost):
                        print 'NaN detected'

                    if utils.check_if_quarter(uidx, len(kf)):
                        if params.save:
                            counter += 1
                            self.save_params(
                                params.outfile + str(counter) + '.pickle',
                                words)
                        if params.evaluate:
                            evaluate_all(self, words)

                    for i in batch:
                        i[0].representation = None
                        i[1].representation = None
                        i[0].unpopulate_embeddings()
                        i[1].unpopulate_embeddings()

                if params.save:
                    counter += 1
                    self.save_params(params.outfile + str(counter) + '.pickle',
                                     words)

                if params.evaluate:
                    evaluate_all(self, words)

                print 'Epoch ', (eidx + 1), 'Cost ', cost

        except KeyboardInterrupt:
            print "Training interupted"

        end_time = time.time()
        print "total time:", (end_time - start_time)
    def train(self, data, words, params):

        start_time = time.time()
        evaluate_all(self, words, params)

        old_v = 0
        try:

            for eidx in range(params.epochs):

                kf = self.get_minibatches_idx(len(data), params.batchsize, shuffle=True)
                lkf = len(kf)
                uidx = 0
                sentence_samples = []

                while(len(kf) > 0):

                    megabatch = []
                    idxs = []
                    idx = 0
                    for i in range(params.mb_batchsize):
                        if len(kf) > 0:
                            arr = [data[t] for t in kf[0][1]]
                            curr_idxs = [i + idx for i in range(len(kf[0][1]))]
                            kf.pop(0)
                            megabatch.extend(arr)
                            idxs.append(curr_idxs)
                            idx += len(curr_idxs)
                    uidx += len(idxs)

                    for i in megabatch:
                        if params.wordtype == "words":
                            if params.scramble > 0:
                                n = np.random.binomial(1, params.scramble, 1)[0]
                                if n > 0:
                                    i[0].populate_embeddings_scramble(words)
                                    i[1].populate_embeddings_scramble(words)
                                else:
                                    i[0].populate_embeddings(words, True)
                                    i[1].populate_embeddings(words, True)
                            else:
                                i[0].populate_embeddings(words, True)
                                i[1].populate_embeddings(words, True)
                        else:
                            i[0].populate_embeddings_ngrams(words, 3, True)
                            i[1].populate_embeddings_ngrams(words, 3, True)

                    (g1x, g1mask, g2x, g2mask, p1x, p1mask, p2x, p2mask),(g1_s,g2_s,p1_s,p2_s) = self.get_pairs(megabatch, params)
                    cost = 0
                    for i in idxs:
                       # cc1,cc2 = self.cost_each_data(g1x[i], g2x[i], p1x[i], p2x[i], g1mask[i], g2mask[i], p1mask[i], p2mask[i])
                        cost += self.train_function(g1x[i], g2x[i], p1x[i], p2x[i], g1mask[i], g2mask[i], p1mask[i], p2mask[i])
                       # for j in range(len(i)):
                       #     try:
                       #         sentence_samples.append({'orig':g1_s[i[j]],'para':g2_s[i[j]], 'neg_orign':p1_s[i[j]], 'neg_para':p2_s[i[j]], 'orig_cost':str(cc1[j]),'para_cost':str(cc2[j])})
                       #     except IndexError:
                       #         print(j,i[j])
                    cost = cost / len(idxs)

                    if np.isnan(cost) or np.isinf(cost):
                        print('NaN detected')

                    if utils.check_if_quarter(uidx-len(idxs), uidx, lkf):
                        if params.evaluate:
                            v = evaluate_all(self, words, params)
                        if params.save:
                            if v > old_v:
                                old_v = v
                                self.save_params(params.outfile + '.pickle', words)

                    for i in megabatch:
                        i[0].representation = None
                        i[1].representation = None
                        i[0].unpopulate_embeddings()
                        i[1].unpopulate_embeddings()

                if params.evaluate:
                    v = evaluate_all(self, words, params)

                if params.save:
                    if v > old_v:
                        old_v = v
                        self.save_params(params.outfile + '.pickle', words)

                print('Epoch ', (eidx + 1), 'Cost ', cost)
                #with open("../data/sampled_samples/%s_epoch_%d.json" % (params.model, eidx),'w') as f:
                #    json_result = json.dumps(sentence_samples, indent=2)
                #    f.write(json_result)

        except KeyboardInterrupt:
            print("Training interupted")

        end_time = time.time()
        print("total time:", (end_time - start_time))
예제 #3
0
    def train(self, data, words, params):

        start_time = time.time()
        evaluate_all(self, words, params)

        old_v = 0
        try:

            for eidx in xrange(params.epochs):

                kf = self.get_minibatches_idx(len(data),
                                              params.batchsize,
                                              shuffle=True)
                lkf = len(kf)
                uidx = 0

                while (len(kf) > 0):

                    megabatch = []
                    idxs = []
                    idx = 0
                    for i in range(params.mb_batchsize):
                        if len(kf) > 0:
                            arr = [data[t] for t in kf[0][1]]
                            curr_idxs = [i + idx for i in range(len(kf[0][1]))]
                            kf.pop(0)
                            megabatch.extend(arr)
                            idxs.append(curr_idxs)
                            idx += len(curr_idxs)
                    uidx += len(idxs)

                    for i in megabatch:
                        if params.wordtype == "words":
                            if params.scramble > 0:
                                n = np.random.binomial(1, params.scramble,
                                                       1)[0]
                                if n > 0:
                                    i[0].populate_embeddings_scramble(words)
                                    i[1].populate_embeddings_scramble(words)
                                else:
                                    i[0].populate_embeddings(words, True)
                                    i[1].populate_embeddings(words, True)
                            else:
                                i[0].populate_embeddings(words, True)
                                i[1].populate_embeddings(words, True)
                        else:
                            i[0].populate_embeddings_ngrams(words, 3, True)
                            i[1].populate_embeddings_ngrams(words, 3, True)

                    (g1x, g1mask, g2x, g2mask, p1x, p1mask, p2x,
                     p2mask) = self.get_pairs(megabatch, params)

                    cost = 0
                    for i in idxs:
                        cost += self.train_function(g1x[i], g2x[i], p1x[i],
                                                    p2x[i], g1mask[i],
                                                    g2mask[i], p1mask[i],
                                                    p2mask[i])

                    cost = cost / len(idxs)

                    if np.isnan(cost) or np.isinf(cost):
                        print 'NaN detected'

                    if utils.check_if_quarter(uidx - len(idxs), uidx, lkf):
                        if params.evaluate:
                            v = evaluate_all(self, words, params)
                        if params.save:
                            if v > old_v:
                                old_v = v
                                self.save_params(params.outfile + '.pickle',
                                                 words)

                    for i in megabatch:
                        i[0].representation = None
                        i[1].representation = None
                        i[0].unpopulate_embeddings()
                        i[1].unpopulate_embeddings()

                if params.evaluate:
                    v = evaluate_all(self, words, params)

                if params.save:
                    if v > old_v:
                        old_v = v
                        self.save_params(params.outfile + '.pickle', words)

                print 'Epoch ', (eidx + 1), 'Cost ', cost

        except KeyboardInterrupt:
            print "Training interupted"

        end_time = time.time()
        print "total time:", (end_time - start_time)
    def train(self, data, ngram_words, word_words, params):

        start_time = time.time()
        evaluate_all(self, ngram_words, word_words, params)

        old_v = 0
        try:

            for eidx in xrange(params.epochs):

                kf = self.get_minibatches_idx(len(data),
                                              params.batchsize,
                                              shuffle=True)
                lkf = len(kf)
                uidx = 0

                while (len(kf) > 0):

                    megabatch = []
                    idxs = []
                    idx = 0
                    for i in range(params.mb_batchsize):
                        if len(kf) > 0:
                            arr = [data[t] for t in kf[0][1]]
                            curr_idxs = [i + idx for i in range(len(kf[0][1]))]
                            kf.pop(0)
                            megabatch.extend(arr)
                            idxs.append(curr_idxs)
                            idx += len(curr_idxs)
                    uidx += len(idxs)

                    megabatch2 = []
                    for n, i in enumerate(megabatch):
                        example = (i[0], copy.deepcopy(i[0]), i[1],
                                   copy.deepcopy(i[1]))
                        if params.combination_type == "ngram-word-lstm":
                            example = (i[0], copy.deepcopy(i[0]),
                                       copy.deepcopy(i[0]), i[1],
                                       copy.deepcopy(i[1]),
                                       copy.deepcopy(i[1]))
                        if params.combination_type == "ngram-word" or params.combination_type == "ngram-lstm":
                            example[0].populate_embeddings_ngrams(
                                ngram_words, 3, True)
                            example[1].populate_embeddings(word_words, True)
                            example[2].populate_embeddings_ngrams(
                                ngram_words, 3, True)
                            example[3].populate_embeddings(word_words, True)
                        elif params.combination_type == "word-lstm":
                            example[0].populate_embeddings(word_words, True)
                            example[1].populate_embeddings(word_words, True)
                            example[2].populate_embeddings(word_words, True)
                            example[3].populate_embeddings(word_words, True)
                        elif params.combination_type == "ngram-word-lstm":
                            example[0].populate_embeddings_ngrams(
                                ngram_words, 3, True)
                            example[1].populate_embeddings(word_words, True)
                            example[2].populate_embeddings(word_words, True)
                            example[3].populate_embeddings_ngrams(
                                ngram_words, 3, True)
                            example[4].populate_embeddings(word_words, True)
                            example[5].populate_embeddings(word_words, True)
                        megabatch2.append(example)
                    megabatch = megabatch2

                    if params.combination_type != "ngram-word-lstm":
                        (g1nx, g1nmask, g1wx, g1wmask, g2nx, g2nmask, g2wx, g2wmask,
                        p1nx, p1nmask, p1wx, p1wmask, p2nx, p2nmask, p2wx, p2wmask) \
                            = self.get_pairs(megabatch, params)
                    else:
                        (g1nx, g1nmask, g1wx, g1wmask, g1lx, g1lmask, g2nx, g2nmask, g2wx, g2wmask, g2lx, g2lmask,
                        p1nx, p1nmask, p1wx, p1wmask, p1lx, p1lmask, p2nx, p2nmask, p2wx, p2wmask, p2lx, p2lmask) \
                            = self.get_pairs(megabatch, params)

                    cost = 0
                    for i in idxs:
                        if params.combination_type != "ngram-word-lstm":
                            cost += self.train_function(
                                g1nx[i], g2nx[i], p1nx[i], p2nx[i], g1nmask[i],
                                g2nmask[i], p1nmask[i], p2nmask[i], g1wx[i],
                                g2wx[i], p1wx[i], p2wx[i], g1wmask[i],
                                g2wmask[i], p1wmask[i], p2wmask[i])
                        else:
                            cost += self.train_function(
                                g1nx[i], g2nx[i], p1nx[i], p2nx[i], g1nmask[i],
                                g2nmask[i], p1nmask[i], p2nmask[i], g1wx[i],
                                g2wx[i], p1wx[i], p2wx[i], g1wmask[i],
                                g2wmask[i], p1wmask[i], p2wmask[i], g1lx[i],
                                g2lx[i], p1lx[i], p2lx[i], g1lmask[i],
                                g2lmask[i], p1lmask[i], p2lmask[i])

                    cost = cost / len(idxs)

                    if np.isnan(cost) or np.isinf(cost):
                        print('NaN detected')

                    if utils.check_if_quarter(uidx - len(idxs), uidx, lkf):
                        if params.evaluate:
                            v = evaluate_all(self, ngram_words, word_words,
                                             params)
                        if params.save:
                            if v > old_v:
                                old_v = v
                                self.save_params(params.outfile + '.pickle',
                                                 (ngram_words, word_words))

                    for i in megabatch:
                        i[0].representation = None
                        i[1].representation = None
                        i[2].representation = None
                        i[3].representation = None
                        if params.combination_type == "ngram-word-lstm":
                            i[4].representation = None
                            i[5].representation = None
                        i[0].unpopulate_embeddings()
                        i[1].unpopulate_embeddings()
                        i[2].unpopulate_embeddings()
                        i[3].unpopulate_embeddings()
                        if params.combination_type == "ngram-word-lstm":
                            i[4].representation = None
                            i[5].representation = None

                if params.evaluate:
                    v = evaluate_all(self, ngram_words, word_words, params)

                if params.save:
                    if v > old_v:
                        old_v = v
                        self.save_params(params.outfile + '.pickle',
                                         (ngram_words, word_words))

                print('Epoch ', (eidx + 1), 'Cost ', cost)

        except KeyboardInterrupt:
            print("Training interupted")

        end_time = time.time()
        print("total time:", (end_time - start_time))