def load_test(self):
        tqa_matrix = np.load("bert_tqa.npy")
        sample_matrix = np.load("bert_sample.npy")
        label_matrix = np.load("bert_label.npy")

        tqas_train, tqas_test, \
        samples_train, samples_test, \
        labels_train, labels_test = train_test_split(tqa_matrix, sample_matrix, label_matrix,
                                                               test_size = 0.20, random_state = 422)


        # labels_train = np.where(labels_train == 0, -1, 1)
        # labels_test = np.where(labels_test == 0, -1, 1)


        model = BertLSTMModel
        loss_function = torch.nn.BCEWithLogitsLoss(reduction='mean')
        self.train_handler = DataHandler(predictors={'p1' : tqas_train, 'p2' : samples_train}, response=labels_train, policy=DataPolicy.ALL_DATA)
        self.trainer = BaseTrainer(data_handler=self.train_handler, model=model, loss_function=loss_function, lr=0.001)

        for i in range(200):
            self.trainer.model.is_train = True
            self.trainer.train(weight_decay=0.0000, n=100)
            self.trainer.model.is_train = False



            self.train_handler2 = DataHandler(predictors={'p1' : tqas_test, 'p2' : samples_test}, response=labels_test, policy=DataPolicy.ALL_DATA)


            results = self.trainer.model(self.train_handler2).detach().numpy()
            results = expit(results)
            results = np.where(results >= 0.5, 1, 0)
            results = np.squeeze(results)

            acc = (np.squeeze(results) == np.squeeze(labels_test)).sum() / labels_test.shape[0]
            print("Acc Test: {}".format(acc))

            self.train_handler2 = DataHandler(predictors={'p1' : tqas_train, 'p2' : samples_train}, response=labels_train, policy=DataPolicy.ALL_DATA)


            tp = (((results == 1) * (labels_test == 1)).sum())
            tn = ((results == 0) * (labels_test == 0)).sum()
            fn = ((results == 0) * (labels_test == 1)).sum()
            fp = ((results == 1) * (labels_test == 0)).sum()
            print("tp: {}, tn: {}, fn: {}, fp: {}, recall: {}, precision: {}".format(tp, tn, fn, fp, tp / (tp + fn), tp / (tp + fp)))


            results = self.trainer.model(self.train_handler2).detach().numpy()
            results = expit(results)
            results = np.where(results >= 0.5, 1, 0)

            acc = (np.squeeze(results) == np.squeeze(labels_train)).sum() / labels_train.shape[0]
            print("Acc Train: {}".format(acc))
Ejemplo n.º 2
0
    def train(self):
        print("Training")
        torch.manual_seed(10)
        self.model = ComplicatedTextModel
        ys = Tensor(self.labels_train)
        print(ys.shape)

        # loss_function = torch.nn.BCELoss(reduction='mean')
        loss_function = torch.nn.BCELoss(reduction='mean')
        predictors = {"p1" : self.p1s_train, "p2" : self.samples_train, "c" : self.contexts_train}
        self.train_handler = DataHandler(predictors=predictors, response=self.labels_train, policy=DataPolicy.ALL_DATA)
        self.trainer = BaseTrainer(data_handler=self.train_handler, model=ComplicatedTextModel, loss_function=loss_function, lr=0.001)
        # self.trainer = SimpleTrainer(self.samples_train, ys.view(-1, 1), self.model, 0.001)
        # self.trainer.train_text(self.tqas_train, self.contexts_train)
        self.trainer.train(1200, weight_decay=0.000)
Ejemplo n.º 3
0
    def prepare_test_train(self):
        np.random.seed(10)
        dq = np.load("save/divided_by_query.npy")
        contexts, tqas, samples, labels = dq

        n_queries = contexts.shape[0]
        q_indices = list(range(n_queries))

        shuffle(q_indices)
        train_i = q_indices[20:]
        test_i = q_indices[:20]

        self.test_contexts, self.test_tqas, self.test_samples, self.test_labels = \
            contexts[test_i], tqas[test_i], samples[test_i], labels[test_i]
        self.train_contexts, self.train_tqas, self.train_samples, self.train_labels = \
            contexts[train_i], tqas[train_i], samples[train_i], labels[train_i]


        loss_function = torch.nn.BCELoss(reduction='mean')

        data_handlers = []


        for i in range(self.train_contexts.shape[0]):
            predictors = {"xs" : self.train_samples[i] }
            response = self.train_labels[i]
            data_handler = DataHandler(predictors=predictors, response=response, policy=DataPolicy.ALL_DATA)
            data_handlers.append(data_handler)


            # trainer = BaseTrainer(data_handler=data_handler, model=SimpleTextLogisticModel, loss_function=loss_function)
        lt = LayeredLearner(data_handlers=data_handlers, model=SimpleTextLogisticModel, loss_function=loss_function, trainer_class=BaseTrainer, weight_decay=0.05)


        weights = []
        for layer in lt.base_layers:
            w = layer.trainer.model.linear.weight.detach().numpy()
            # bias = layer.trainer.model.linear.bias.detach().numpy()
            # weights.append(np.concatenate([w, np.expand_dims(bias, 1)], 1))
            weights.append(w)

        # weights = [i.trainer.model.linear.weight.detach().numpy() for i in lt.base_layers]
        weights = np.concatenate(weights, 0)
        # ss = StandardScaler(with_mean=False).fit(weights)
        # weights = ss.transform(weights)

        contexts = [np.expand_dims(i.T[0], 1) for i in self.train_contexts]
        tqas = [np.expand_dims(i.T[0], 1) for i in self.train_tqas]
        squished_samples = [i.mean(1) for i in self.train_samples]

        # print(contexts[0].shape)
        mse = torch.nn.MSELoss(reduction='mean')


        # def my_loss(y_pred, y_actual):
        #     # print(y_pred.shape)
        #     # print(y_actual.shape)
        #     return (((y_pred - y_actual) ** 2).sum(1) ** 0.5).mean()

        predictors = {"p1" : np.hstack(tqas), "c" : np.hstack(contexts), "squished_samples" : np.vstack(squished_samples)}
        data_handler = DataHandler(predictors=predictors, response=weights, model_parameters={"n_weights" : weights.shape[1]}, response_shape=(weights.shape))

        trainer = BaseTrainer(data_handler=data_handler, model=WeightEstimatorModel, loss_function=mse, lr=0.01)
        trainer.train(n=6000, weight_decay=0.0)

        total_acc = 0.0
        total_counter = 0
        trainer.model.is_training = False


        # predictors = {"p1" : tqas, "c" : contexts[0], "xs" : samples}
        # data_handler = DataHandler(predictors=predictors, response=weights, model_parameters={"n_weights" : weights.shape[1]})

        # m = WeightEstimatorModel(data_handler)
        # best_weights = trainer.model(data_handler).detach().numpy().T

        # print(((best_weights - weights.T) ** 2).sum(0) ** 0.5)

        for i in range(self.test_contexts.shape[0]):
        # for i in range(self.train_contexts.shape[0]):
            context = np.expand_dims(self.test_contexts[i].T[0], 1)
            tqa = np.expand_dims(self.test_tqas[i].T[0], 1)
            samples = self.test_samples[i]
            labels = self.test_labels[i]

            # context = np.expand_dims(self.train_contexts[i].T[0], 1)
            # tqa = np.expand_dims(self.train_tqas[i].T[0], 1)
            # samples = self.train_samples[i]
            # labels = self.train_labels[i]

            squished_samples = samples.mean(1)
            predictors = {"p1" : tqa, "c" : context, "xs" : samples, "squished_samples" : squished_samples}
            data_handler = DataHandler(predictors=predictors, response=weights, model_parameters={"n_weights" : weights.shape[1]})
            # m = WeightEstimatorModel(data_handler)
            m = trainer.model
            best_weights = m(data_handler).detach().numpy().T

            # bias = best_weights[-1]
            # best_weights = best_weights[0:best_weights.shape[0] - 1]

            lm = SimpleTextLogisticModel(data_handler)
            # lm.linear.bias = Parameter(Tensor(bias.T))

            lm.linear.weight = Parameter(Tensor(best_weights.T))

            results = lm(data_handler)
            results = np.where(results > 0.5, 1.0, 0.0)
            acc = (labels == np.squeeze(results)).sum() / (labels.shape[0])
            total_acc += acc
            total_counter += 1

        print("Final ACC: {}".format(total_acc / total_counter))
Ejemplo n.º 4
0
    def prepare_test_train2(self):
        np.random.seed(10)
        dq = np.load("save/divided_by_query.npy")
        contexts, p1s, samples, labels = dq

        n_queries = contexts.shape[0]
        q_indices = list(range(n_queries))

        shuffle(q_indices)
        train_i = q_indices[10:]
        test_i = q_indices[:10]

        self.test_contexts, self.test_tqas, self.test_samples, self.test_labels = \
            contexts[test_i], p1s[test_i], samples[test_i], labels[test_i]
        self.train_contexts, self.train_tqas, self.train_samples, self.train_labels = \
            contexts[train_i], p1s[train_i], samples[train_i], labels[train_i]


        samples = []
        p1s = []
        labels = []
        contexts = []
        for i in range(self.train_samples.shape[0]):
            samples.append(self.train_samples[i].T)
            p1s.append(self.train_tqas[i].T)
            labels.append(self.train_labels[i])
            contexts.append(self.train_contexts[i].T)

        samples = np.concatenate(samples)
        labels = np.concatenate(labels)
        p1s = np.concatenate(p1s)
        contexts = np.concatenate(contexts)

        print("Training")
        torch.manual_seed(10)
        self.model = SimpleTextModel

        loss_function = torch.nn.BCELoss(reduction='mean')
        predictors = {"p1" : p1s, "p2" : samples, "c" : contexts}
        self.train_handler = DataHandler(predictors=predictors, response=labels, policy=DataPolicy.ALL_DATA)
        self.trainer = BaseTrainer(data_handler=self.train_handler, model=SimpleTextModel, loss_function=loss_function)
        # self.trainer = SimpleTrainer(self.samples_train, ys.view(-1, 1), self.model, 0.001)
        # self.trainer.train_text(self.tqas_train, self.contexts_train)
        self.trainer.train()


        samples = []
        p1s = []
        labels = []
        contexts = []
        for i in range(self.test_samples.shape[0]):
            samples.append(self.test_samples[i].T)
            p1s.append(self.test_tqas[i].T)
            labels.append(self.test_labels[i])
            contexts.append(self.test_contexts[i].T)

        samples = np.concatenate(samples)
        labels = np.concatenate(labels)
        p1s = np.concatenate(p1s)
        contexts = np.concatenate(contexts)

        predictors = {"p1" : p1s, "p2" : samples, "c" : contexts}
        handler = DataHandler(predictors=predictors, response=labels, policy=DataPolicy.ALL_DATA)

        results = self.trainer.model(handler)
        results = np.where(results > 0.5, 1.0, 0.0)
        acc = (np.squeeze(results) == labels).sum() / labels.shape[0]
        print(acc)
Ejemplo n.º 5
0
class ElmoManager(object):
    def __init__(self, loc, load_bert=True):
        if load_bert:
            self.max_words = 200
            self.analyzer = ElmoTqaEnwikiAnalyzer(loc, self.max_words)
            self.analyzer.elmotize()
            self.model = ElmoEmbedder()
            self.embeddings = {}
            self.max_sentences = 4
            self.embedding_size = 1024

    def get_embedding(self, p1):
        # return [self.model.embed_sentence(i) for i in p1["tokens"]]
        sentences = [
            i[0] for i in self.model.embed_sentences(p1["tokens"]
                                                     [0:self.max_sentences])
        ]
        for idx in range(len(sentences)):
            sentence = sentences[idx]
            # if sentence.shape[0] < self.max_words:
            #     word_diff = self.max_words - sentence.shape[0]
            #     zshape = (word_diff, sentence.shape[1])
            #     sentence = np.concatenate([sentence, np.zeros(zshape)], 0)
            sentences[idx] = sentence.mean(0)

        sentences = np.asarray(sentences)

        if sentences.shape[0] < self.max_sentences:
            sentence_diff = self.max_sentences - sentences.shape[0]
            # zshape = (sentence_diff, self.max_words, self.embedding_size)
            zshape = (sentence_diff, self.embedding_size)
            sentences = np.concatenate([sentences, np.zeros(zshape)], 0)

        # return np.asarray(sentences)
        return sentences

    def run_test(self):
        t = len(self.analyzer.data)
        tqa_matrix = []
        sample_matrix = []
        label_matrix = []
        for idx, example in enumerate(self.analyzer.data):
            print("{} for {}".format(idx, t))
            # if idx > 1:
            #     break
            try:
                qid = example["qid"]
                qd = self.analyzer.bert_data[qid]
                tqa = qd["tqa"][0]
                tqa_embedding = self.get_embedding(tqa)

                for p in qd["enwiki"]:
                    embedding = self.get_embedding(p)
                    tqa_matrix.append(tqa_embedding)
                    label_matrix.append(1)
                    sample_matrix.append(embedding)

                for p in qd["negatives"]:
                    embedding = self.get_embedding(p)
                    tqa_matrix.append(tqa_embedding)
                    label_matrix.append(0)
                    sample_matrix.append(embedding)
            except RuntimeError:
                print("Error")
            except KeyError:
                print("Key Error")

        tqa_matrix = np.asarray(tqa_matrix)
        sample_matrix = np.asarray(sample_matrix)
        label_matrix = np.asarray(label_matrix)

        np.save("elmo_tqa.npy", tqa_matrix)
        np.save("elmo_sample.npy", sample_matrix)
        np.save("elmo_label.npy", label_matrix)

    def load_test(self):
        tqa_matrix = np.load("elmo_tqa.npy")
        sample_matrix = np.load("elmo_sample.npy")
        label_matrix = np.load("elmo_label.npy")

        tqas_train, tqas_test, \
        samples_train, samples_test, \
        labels_train, labels_test = train_test_split(tqa_matrix, sample_matrix, label_matrix,
                                                               test_size = 0.05, random_state = 422)

        # labels_train = np.where(labels_train == 0, -1, 1)
        # labels_test = np.where(labels_test == 0, -1, 1)

        model = BertLSTMModel
        loss_function = torch.nn.BCEWithLogitsLoss(reduction='mean')
        self.train_handler = DataHandler(predictors={
            'p1': tqas_train,
            'p2': samples_train
        },
                                         response=labels_train,
                                         policy=DataPolicy.ALL_DATA)
        self.trainer = BaseTrainer(data_handler=self.train_handler,
                                   model=model,
                                   loss_function=loss_function,
                                   lr=0.001)

        for i in range(10):
            self.trainer.model.is_train = True
            self.trainer.train(weight_decay=0.0000, n=5)
            self.trainer.model.is_train = False

            self.train_handler2 = DataHandler(predictors={
                'p1': tqas_test,
                'p2': samples_test
            },
                                              response=labels_test,
                                              policy=DataPolicy.ALL_DATA)

            results = self.trainer.model(self.train_handler2).detach().numpy()
            results = expit(results)
            results = np.where(results >= 0.5, 1, 0)
            results = np.squeeze(results)

            acc = (np.squeeze(results)
                   == np.squeeze(labels_test)).sum() / labels_test.shape[0]
            print("Acc Test: {}".format(acc))

            self.train_handler2 = DataHandler(predictors={
                'p1': tqas_train,
                'p2': samples_train
            },
                                              response=labels_train,
                                              policy=DataPolicy.ALL_DATA)

            tp = (((results == 1) * (labels_test == 1)).sum())
            tn = ((results == 0) * (labels_test == 0)).sum()
            fn = ((results == 0) * (labels_test == 1)).sum()
            fp = ((results == 1) * (labels_test == 0)).sum()
            print("tp: {}, tn: {}, fn: {}, fp: {}, recall: {}, precision: {}".
                  format(tp, tn, fn, fp, tp / (tp + fn), tp / (tp + fp)))

            results = self.trainer.model(self.train_handler2).detach().numpy()
            results = expit(results)
            results = np.where(results >= 0.5, 1, 0)

            acc = (np.squeeze(results)
                   == np.squeeze(labels_train)).sum() / labels_train.shape[0]
            print("Acc Train: {}".format(acc))

            # results = np.where(results > 0.5, 1, 1)
            #
            # acc = (np.squeeze(results) == np.squeeze(labels_test)).sum() / labels_test.shape[0]
            # print(acc)
            #
            # results = np.where(results > 0.5, 0, 0)
            #
            # acc = (np.squeeze(results) == np.squeeze(labels_test)).sum() / labels_test.shape[0]
            # print(acc)

        torch.save(self.trainer.model, "elmo_model")
class BertManager(object):
    def __init__(self, loc, load_bert=True):
        if load_bert:
            self.analyzer = BertTqaEnwikiAnalyzer(loc)
            self.analyzer.bertanize()
            self.model = BertModel.from_pretrained('bert-base-uncased')
            self.model.eval()
            self.embeddings = {}







    def get_bert_logits(self, p1, p2):
        segments = np.concatenate([np.zeros(len(p1), dtype=np.long), np.ones(len(p2), dtype=np.long)])
        segments = torch.tensor([segments])
        token_tensor = torch.tensor([p1 + p2])
        with torch.no_grad():
            pred = self.model(token_tensor, segments)
        return pred

    def get_embedding(self, p1):
        # print(p1)
        # print(self.analyzer.tokenizer.convert_ids_to_tokens(p1))
        token_tensors = [i for i in p1["tokens"]]
        token_tensors = torch.tensor(token_tensors)



        with torch.no_grad():
            # return self.model(torch.tensor([p1]))[1].shape
            t =  self.model(token_tensors)[1].detach().numpy()
            sdiff = 4 - len(p1["tokens"])
            if sdiff > 0:
                t = np.concatenate([t, np.zeros((sdiff, t.shape[1]))])
            return t



    def get_label(self, p1, p2):
        try:
            logits = self.get_bert_logits(p1, p2)
            return 1 if logits[0][0] > logits[0][1] else 0
        except:
            print("Error")
            return -1


    def get_accuracy(self):
        total_acc = 0.0
        counter = 0
        t = len(self.analyzer.data)
        for idx, example in enumerate(self.analyzer.data):
            print("{} for {}".format(idx, t))
            qid = example["qid"]
            qd = self.analyzer.bert_data[qid]
            tqa = qd["tqa"][0]

            for p in qd["enwiki"]:
                prediction = self.get_label(tqa, p)
                if prediction == -1:
                    continue
                elif prediction == 1:
                    total_acc += 1
                counter += 1.0


            for p in qd["negatives"]:
                prediction = self.get_label(tqa, p)
                if prediction == -1:
                    continue
                elif prediction == 0:
                    total_acc += 1
                counter += 1.0

            if counter != 0:
                print("Final ACC: {}".format(total_acc / counter))


        print("Final ACC: {}".format(total_acc / counter))


    def run_test(self):
        t = len(self.analyzer.data)
        tqa_matrix = []
        sample_matrix = []
        label_matrix = []
        for idx, example in enumerate(self.analyzer.data):
            print("{} for {}".format(idx, t))
            try:
                qid = example["qid"]
                qd = self.analyzer.bert_data[qid]
                tqa = qd["tqa"][0]
                tqa_embedding = self.get_embedding(tqa)

                for p in qd["enwiki"]:
                    embedding = self.get_embedding(p)
                    tqa_matrix.append(tqa_embedding)

                    # rel = p["rel"]

                    label_matrix.append(1)

                    sample_matrix.append(embedding)

                for p in qd["negatives"]:
                    embedding = self.get_embedding(p)
                    tqa_matrix.append(tqa_embedding)
                    label_matrix.append(0)
                    sample_matrix.append(embedding)
            except RuntimeError:
                print("Error")
            except KeyError:
                print("Key Error")



        tqa_matrix = np.asarray(tqa_matrix)
        sample_matrix = np.asarray(sample_matrix)
        label_matrix = np.asarray(label_matrix)

        np.save("bert_tqa.npy", tqa_matrix)
        np.save("bert_sample.npy", sample_matrix)
        np.save("bert_label.npy", label_matrix)

    def sigmoid(self, x):
        return 1 / (1 + math.exp(-x))

    def load_test(self):
        tqa_matrix = np.load("bert_tqa.npy")
        sample_matrix = np.load("bert_sample.npy")
        label_matrix = np.load("bert_label.npy")

        tqas_train, tqas_test, \
        samples_train, samples_test, \
        labels_train, labels_test = train_test_split(tqa_matrix, sample_matrix, label_matrix,
                                                               test_size = 0.20, random_state = 422)


        # labels_train = np.where(labels_train == 0, -1, 1)
        # labels_test = np.where(labels_test == 0, -1, 1)


        model = BertLSTMModel
        loss_function = torch.nn.BCEWithLogitsLoss(reduction='mean')
        self.train_handler = DataHandler(predictors={'p1' : tqas_train, 'p2' : samples_train}, response=labels_train, policy=DataPolicy.ALL_DATA)
        self.trainer = BaseTrainer(data_handler=self.train_handler, model=model, loss_function=loss_function, lr=0.001)

        for i in range(200):
            self.trainer.model.is_train = True
            self.trainer.train(weight_decay=0.0000, n=100)
            self.trainer.model.is_train = False



            self.train_handler2 = DataHandler(predictors={'p1' : tqas_test, 'p2' : samples_test}, response=labels_test, policy=DataPolicy.ALL_DATA)


            results = self.trainer.model(self.train_handler2).detach().numpy()
            results = expit(results)
            results = np.where(results >= 0.5, 1, 0)
            results = np.squeeze(results)

            acc = (np.squeeze(results) == np.squeeze(labels_test)).sum() / labels_test.shape[0]
            print("Acc Test: {}".format(acc))

            self.train_handler2 = DataHandler(predictors={'p1' : tqas_train, 'p2' : samples_train}, response=labels_train, policy=DataPolicy.ALL_DATA)


            tp = (((results == 1) * (labels_test == 1)).sum())
            tn = ((results == 0) * (labels_test == 0)).sum()
            fn = ((results == 0) * (labels_test == 1)).sum()
            fp = ((results == 1) * (labels_test == 0)).sum()
            print("tp: {}, tn: {}, fn: {}, fp: {}, recall: {}, precision: {}".format(tp, tn, fn, fp, tp / (tp + fn), tp / (tp + fp)))


            results = self.trainer.model(self.train_handler2).detach().numpy()
            results = expit(results)
            results = np.where(results >= 0.5, 1, 0)

            acc = (np.squeeze(results) == np.squeeze(labels_train)).sum() / labels_train.shape[0]
            print("Acc Train: {}".format(acc))
Ejemplo n.º 7
0
class ElmoRedux(object):
    def __init__(self, loc, array_loc: str, qrel_loc: str):
        self.query_parser = QrelReader(qrel_loc)
        print("Retrieving vectors")
        input = {PipeEnum.IN_LOC.value: array_loc}
        self.vector_reader = NumpyArrayPipelineReader(**input)
        self.vector_reader.run()

        self.labels = []
        self.positives = []
        self.negatives = []

    def run_test(self):
        t = len(self.analyzer.data)
        tqa_matrix = []
        sample_matrix = []
        label_matrix = []
        for idx, example in enumerate(self.analyzer.data):
            print("{} for {}".format(idx, t))
            # if idx > 1:
            #     break
            try:
                qid = example["qid"]
                qd = self.analyzer.bert_data[qid]
                tqa = qd["tqa"][0]
                tqa_embedding = self.get_embedding(tqa)

                for p in qd["enwiki"]:
                    embedding = self.get_embedding(p)
                    tqa_matrix.append(tqa_embedding)
                    label_matrix.append(1)
                    sample_matrix.append(embedding)

                for p in qd["negatives"]:
                    embedding = self.get_embedding(p)
                    tqa_matrix.append(tqa_embedding)
                    label_matrix.append(0)
                    sample_matrix.append(embedding)
            except RuntimeError:
                print("Error")
            except KeyError:
                print("Key Error")

        tqa_matrix = np.asarray(tqa_matrix)
        sample_matrix = np.asarray(sample_matrix)
        label_matrix = np.asarray(label_matrix)

        np.save("elmo_tqa.npy", tqa_matrix)
        np.save("elmo_sample.npy", sample_matrix)
        np.save("elmo_label.npy", label_matrix)

    def load_test(self):
        tqa_matrix = np.load("elmo_tqa.npy")
        sample_matrix = np.load("elmo_sample.npy")
        label_matrix = np.load("elmo_label.npy")

        tqas_train, tqas_test, \
        samples_train, samples_test, \
        labels_train, labels_test = train_test_split(tqa_matrix, sample_matrix, label_matrix,
                                                               test_size = 0.05, random_state = 422)

        # labels_train = np.where(labels_train == 0, -1, 1)
        # labels_test = np.where(labels_test == 0, -1, 1)

        model = BertLSTMModel
        loss_function = torch.nn.BCEWithLogitsLoss(reduction='mean')
        self.train_handler = DataHandler(predictors={
            'p1': tqas_train,
            'p2': samples_train
        },
                                         response=labels_train,
                                         policy=DataPolicy.ALL_DATA)
        self.trainer = BaseTrainer(data_handler=self.train_handler,
                                   model=model,
                                   loss_function=loss_function,
                                   lr=0.001)

        for i in range(10):
            self.trainer.model.is_train = True
            self.trainer.train(weight_decay=0.0000, n=5)
            self.trainer.model.is_train = False

            self.train_handler2 = DataHandler(predictors={
                'p1': tqas_test,
                'p2': samples_test
            },
                                              response=labels_test,
                                              policy=DataPolicy.ALL_DATA)

            results = self.trainer.model(self.train_handler2).detach().numpy()
            results = expit(results)
            results = np.where(results >= 0.5, 1, 0)
            results = np.squeeze(results)

            acc = (np.squeeze(results)
                   == np.squeeze(labels_test)).sum() / labels_test.shape[0]
            print("Acc Test: {}".format(acc))

            self.train_handler2 = DataHandler(predictors={
                'p1': tqas_train,
                'p2': samples_train
            },
                                              response=labels_train,
                                              policy=DataPolicy.ALL_DATA)

            tp = (((results == 1) * (labels_test == 1)).sum())
            tn = ((results == 0) * (labels_test == 0)).sum()
            fn = ((results == 0) * (labels_test == 1)).sum()
            fp = ((results == 1) * (labels_test == 0)).sum()
            print("tp: {}, tn: {}, fn: {}, fp: {}, recall: {}, precision: {}".
                  format(tp, tn, fn, fp, tp / (tp + fn), tp / (tp + fp)))

            results = self.trainer.model(self.train_handler2).detach().numpy()
            results = expit(results)
            results = np.where(results >= 0.5, 1, 0)

            acc = (np.squeeze(results)
                   == np.squeeze(labels_train)).sum() / labels_train.shape[0]
            print("Acc Train: {}".format(acc))

            # results = np.where(results > 0.5, 1, 1)
            #
            # acc = (np.squeeze(results) == np.squeeze(labels_test)).sum() / labels_test.shape[0]
            # print(acc)
            #
            # results = np.where(results > 0.5, 0, 0)
            #
            # acc = (np.squeeze(results) == np.squeeze(labels_test)).sum() / labels_test.shape[0]
            # print(acc)

        torch.save(self.trainer.model, "elmo_model")