示例#1
0
    def __init__(self, data_dir='data/mdbt'):
        Tracker.__init__(self)
        # data profile
        self.data_dir = data_dir
        self.validation_url = os.path.join(self.data_dir, 'data/validate.json')
        self.word_vectors_url = os.path.join(
            self.data_dir, 'word-vectors/paragram_300_sl999.txt')
        self.training_url = os.path.join(self.data_dir, 'data/train.json')
        self.ontology_url = os.path.join(self.data_dir, 'data/ontology.json')
        self.testing_url = os.path.join(self.data_dir, 'data/test.json')
        self.model_url = os.path.join(self.data_dir, 'models/model-1')
        self.graph_url = os.path.join(self.data_dir, 'graphs/graph-1')
        self.results_url = os.path.join(self.data_dir, 'results/log-1.txt')
        self.kb_url = os.path.join(self.data_dir, 'data/')  # not used
        self.train_model_url = os.path.join(self.data_dir,
                                            'train_models/model-1')
        self.train_graph_url = os.path.join(self.data_dir,
                                            'train_graph/graph-1')

        print('Configuring MDBT model...')
        self.word_vectors = load_word_vectors(self.word_vectors_url)

        # Load the ontology and extract the feature vectors
        self.ontology, self.ontology_vectors, self.slots = load_ontology(
            self.ontology_url, self.word_vectors)

        # Load and process the training data
        self.dialogues, self.actual_dialogues = load_woz_data(
            self.testing_url, self.word_vectors, self.ontology)
        self.no_dialogues = len(self.dialogues)

        self.model_variables = model_definition(self.ontology_vectors,
                                                len(self.ontology),
                                                self.slots,
                                                num_hidden=None,
                                                bidir=True,
                                                net_type=None,
                                                test=True,
                                                dev='cpu')
        self.state = init_state()
        _config = tf.ConfigProto()
        _config.gpu_options.allow_growth = True
        _config.allow_soft_placement = True
        self.sess = tf.Session(config=_config)
        self.param_restored = False
        self.det_dic = {}
        for domain, dic in REF_USR_DA.items():
            for key, value in dic.items():
                assert '-' not in key
                self.det_dic[key.lower()] = key + '-' + domain
                self.det_dic[value.lower()] = key + '-' + domain
        self.value_dict = json.load(
            open(os.path.join(self.data_dir, '../multiwoz/value_dict.json')))
示例#2
0
    def train(self):
        """
            Train the model.
            Model saved to
        """
        num_hid, bidir, net_type, n2p, batch_size, model_url, graph_url, dev = \
                None, True, None, None, None, None, None, None
        global train_batch_size, MODEL_URL, GRAPH_URL, device, TRAIN_MODEL_URL, TRAIN_GRAPH_URL

        if batch_size:
            train_batch_size = batch_size
            print("Setting up the batch size to {}.........................".
                  format(batch_size))
        if model_url:
            TRAIN_MODEL_URL = model_url
            print("Setting up the model url to {}.........................".
                  format(TRAIN_MODEL_URL))
        if graph_url:
            TRAIN_GRAPH_URL = graph_url
            print("Setting up the graph url to {}.........................".
                  format(TRAIN_GRAPH_URL))

        if dev:
            device = dev
            print(
                "Setting up the device to {}.........................".format(
                    device))

        # 1 Load and process the input data including the ontology
        # Load the word embeddings
        word_vectors = load_word_vectors(self.word_vectors_url)

        # Load the ontology and extract the feature vectors
        ontology, ontology_vectors, slots = load_ontology(
            self.ontology_url, word_vectors)

        # Load and process the training data
        dialogues, _ = load_woz_data(self.training_url, word_vectors, ontology)
        no_dialogues = len(dialogues)

        # Load and process the validation data
        val_dialogues, _ = load_woz_data(self.validation_url, word_vectors,
                                         ontology)

        # Generate the validation batch data
        val_data = generate_batch(val_dialogues, 0, len(val_dialogues),
                                  len(ontology))
        val_iterations = int(len(val_dialogues) / train_batch_size)

        # 2 Initialise and set up the model graph
        # Initialise the model
        graph = tf.Graph()
        with graph.as_default():
            model_variables = model_definition(ontology_vectors,
                                               len(ontology),
                                               slots,
                                               num_hidden=num_hid,
                                               bidir=bidir,
                                               net_type=net_type,
                                               dev=device)
            (user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels,
             domain_labels, domain_accuracy, slot_accuracy, value_accuracy,
             value_f1, train_step, keep_prob, _, _, _) = model_variables
            [precision, recall, value_f1] = value_f1
            saver = tf.train.Saver()
            if device == 'gpu':
                config = tf.ConfigProto(allow_soft_placement=True)
                config.gpu_options.allow_growth = True
            else:
                config = tf.ConfigProto(device_count={'GPU': 0})

            sess = tf.Session(config=config)
            if os.path.exists(TRAIN_MODEL_URL + ".index"):
                saver.restore(sess, TRAIN_MODEL_URL)
                print("Loading from an existing model {} ....................".
                      format(TRAIN_MODEL_URL))
            else:
                if not os.path.exists(TRAIN_MODEL_URL):
                    os.makedirs('/'.join(TRAIN_MODEL_URL.split('/')[:-1]))
                    os.makedirs('/'.join(TRAIN_GRAPH_URL.split('/')[:-1]))
                init = tf.global_variables_initializer()
                sess.run(init)
                print(
                    "Create new model parameters....................................."
                )
            merged = tf.summary.merge_all()
            val_accuracy = tf.summary.scalar('validation_accuracy',
                                             value_accuracy)
            val_f1 = tf.summary.scalar('validation_f1_score', value_f1)
            train_writer = tf.summary.FileWriter(TRAIN_GRAPH_URL, graph)
            train_writer.flush()

        # 3 Perform an epoch of training
        last_update = -1
        best_f_score = -1
        for epoch in range(no_epochs):

            batch_size = train_batch_size
            sys.stdout.flush()
            iterations = math.ceil(no_dialogues / train_batch_size)
            start_time = time.time()
            val_i = 0
            shuffle(dialogues)
            for batch_id in range(iterations):

                if batch_id == iterations - 1 and no_dialogues % iterations != 0:
                    batch_size = no_dialogues % train_batch_size

                batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \
                batch_no_turns = generate_batch(dialogues, batch_id, batch_size, len(ontology))

                [_, summary, da, sa, va, vf, pr, re] = sess.run(
                    [
                        train_step, merged, domain_accuracy, slot_accuracy,
                        value_accuracy, value_f1, precision, recall
                    ],
                    feed_dict={
                        user: batch_user,
                        sys_res: batch_sys,
                        labels: batch_labels,
                        domain_labels: batch_domain_labels,
                        user_uttr_len: batch_user_uttr_len,
                        sys_uttr_len: batch_sys_uttr_len,
                        no_turns: batch_no_turns,
                        keep_prob: 0.5
                    })

                print(
                    "The accuracies for domain is {:.2f}, slot {:.2f}, value {:.2f}, f1_score {:.2f} precision {:.2f}"
                    " recall {:.2f} for batch {}".format(
                        da, sa, va, vf, pr, re, batch_id + iterations * epoch))

                train_writer.add_summary(
                    summary, start_batch + batch_id + iterations * epoch)

                # ================================ VALIDATION ==============================================

                if batch_id % batches_per_eval == 0 or batch_id == 0:
                    if batch_id == 0:
                        print("Batch", "0", "to", batch_id, "took",
                              round(time.time() - start_time, 2), "seconds.")

                    else:
                        print("Batch",
                              batch_id + iterations * epoch - batches_per_eval,
                              "to", batch_id + iterations * epoch, "took",
                              round(time.time() - start_time, 3), "seconds.")
                        start_time = time.time()

                    _, _, v_acc, f1_score, sm1, sm2 = evaluate_model(
                        sess, model_variables, val_data,
                        [val_accuracy, val_f1], batch_id, val_i)
                    val_i += 1
                    val_i %= val_iterations
                    train_writer.add_summary(
                        sm1, start_batch + batch_id + iterations * epoch)
                    train_writer.add_summary(
                        sm2, start_batch + batch_id + iterations * epoch)
                    stime = time.time()
                    current_metric = f1_score
                    print(" Validation metric:", round(current_metric,
                                                       5), " eval took",
                          round(time.time() - stime, 2), "last update at:",
                          last_update, "/", iterations)

                    # and if we got a new high score for validation f-score, we need to save the parameters:
                    if current_metric > best_f_score:
                        last_update = batch_id + iterations * epoch + 1
                        print(
                            "\n ====================== New best validation metric:",
                            round(current_metric,
                                  4), " - saving these parameters. Batch is:",
                            last_update, "/", iterations,
                            "---------------- ===========  \n")

                        best_f_score = current_metric

                        saver.save(sess, TRAIN_MODEL_URL)

            print("The best parameters achieved a validation metric of",
                  round(best_f_score, 4))