コード例 #1
0
    def train_dtRNN(self):
        print("Loading Pre-processed SICK dataset")

        sick_path = path_utils.get_sick_path()
        self.load_dataset(sick_path)

        print("Load Complete")

        BASE = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

        for epoch_val in range(self.epochs):
            sent_tree_set1, sent_tree_set2, relatedness_scores, sick_text = shuffle(
                self.sent_tree_set1, self.sent_tree_set2,
                self.relatedness_scores, self.sick_text)
            self.training(sent_tree_set1[:self.n], sent_tree_set2[:self.n],
                          relatedness_scores[:self.n], epoch_val)

            z = str(datetime.datetime.now())
            file_name = utils.get_file_name("txt",
                                            sentEmbdType=self.SentEmbd_type,
                                            epochNum=epoch_val + 1,
                                            ntp=self.n,
                                            hiddenDims=self.hid_dim,
                                            timestamp=z)
            logs_path = path_utils.get_logs_path('SentEmbd/' + file_name)

            print("Testing")

            acc = self.testing(sent_tree_set1[self.n:],
                               sent_tree_set2[self.n:],
                               relatedness_scores[self.n:], logs_path)
            acc = "{0:.3}".format(acc)
            acc += "%"

            print("Accuracy after epoch %d is %s" % (epoch_val + 1, acc))

            file_name = file_name = utils.get_file_name(
                "pkl",
                sentEmbdType=self.SentEmbd_type,
                epochNum=epoch_val + 1,
                ntp=self.n,
                dep_len=56,
                word_vector_size=200,
                dim=self.hid_dim,
                accuracy=acc,
                timestamp=z,
            )
            # self.SentEmbd_type+str(epoch_val+1)+"_"+str(self.n)+"_"+str(self.hid_dim)+"_"+acc+"_"+z[0]+"_"+z[1].split('.')[0]+".pkl"
            save_path = path_utils.get_save_states_path('SentEmbd/' +
                                                        file_name)

            self.sent_embd.save_params(save_path, self.epochs)
        return
コード例 #2
0
ファイル: test.py プロジェクト: virajshastri97/EruditeX
def test_get_state_file_name():
    from Helpers import utils

    filename = utils.get_file_name(extension='pkl',
                                   first_name='mehmood shakeel deshmukh',
                                   username='******',
                                   age=22)

    required = 'age:22__first_name:mehmood_shakeel_deshmukh__username:meshde.pkl'
    assert (filename == required)
    return
コード例 #3
0
def train_extraction_module(inp_dim, hid_dim, epochs,
    initialization='glorot_normal', optimization='adam', threshold=0.5,
    compressed_dataset=False, train_size=0.75, debug=False):

    model = AnsSelect(
        inp_dim = inp_dim,
        hid_dim = hid_dim,
        initialization = initialization,
        optimization = optimization
    )

    dataset = AnswerExtract.get_babi_dataset(compressed_dataset)

    if debug:
        dataset = dataset[:100]

    for epoch in tqdm(range(epochs), total=epochs, unit='epoch', desc='Epochs'):
        training_data, testing_data = train_test_split(
            dataset,
            train_size = train_size,
            shuffle = True
        )


        if compressed_dataset:
            train_compressed_dataset(model, training_data)
            y_true, y_pred = test_compressed_dataset(
                model,
                testing_data,
                threshold,
            )
        else:
            train_normal_dataset(model, training_data)
            y_true, y_pred = test_normal_dataset(
                model,
                testing_data,
                threshold,
            )

        score = f1_score(y_true, y_pred)
        print("Epoch: {} F1-score: {}".format(epoch, score))

        filename = utils.get_file_name(
            epoch_count = epoch,
            f1_score = score,
            initialization = initialization,
            optimization = optimization,
            inp_dim = inp_dim,
            hid_dim = hid_dim,
            threshold = threshold,
            extension = 'pkl'
        )
        model.save_params(filename, epochs=epoch)
    return
コード例 #4
0
ファイル: test.py プロジェクト: virajshastri97/EruditeX
def test_configurations():
    from Helpers.deployment_utils import create_config
    from Helpers.deployment_utils import get_config
    from Helpers.utils import get_file_name

    filename = 'age:22__name:mehmood__time:10:12:30__username:meshde.pkl'
    create_config(filename, 'test.cfg')

    config = get_config('test.cfg')
    assert ('state' in config)
    assert (config['state'] == filename)

    del config['state']
    output_filename = get_file_name(extension='pkl', **config)

    assert (filename == output_filename)
    return