コード例 #1
0
    def __call__(self, session):
        to_eval = [self.prediction_op[0], self.prediction_op[1]]
        evals = get_evals(to_eval, self.model)
        for testcase in all_test_cases:
            state, att_states, att_counts = get_initial_state(self.model)
            for i, token in enumerate(testcase):
                att_mask = attention_masks(self.attns, [0], 1)
                data = (np.array([[map_token(self.map, testcase[i])]]),
                        np.array([[1]]), np.array([att_mask]), np.array([1]))
                token_id, mask = map_token(self.map, testcase[i])
                att_mask = attention_masks(self.attns, np.array([mask]), 1)
                data = (np.array([[token_id]]), np.array([[1]]),
                        np.array([att_mask]), np.array([1]))
                feed_dict = construct_feed_dict(self.model, data, state,
                                                att_states, att_counts)

                results = session.run(evals, feed_dict)
                results, state, att_states, att_counts, _, _ = extract_results(
                    results, evals, 2, self.model)

            probs = results[0]
            predict_ids = results[1]

            for i in range(5):
                print("%s ; %f" % (self.inv_map[predict_ids[0, i]].replace(
                    "\n", "<newline>"), probs[0, i]))

            print("\n\n§§§§§§§§§§§§§\n\n")
コード例 #2
0
    def __call__(self, session):
        test_file_containers = pyreader.get_data(self.data_path,
                                                 self.test_files, 1, self.map)
        data = list(
            zip((list(flatmap(identity_map, c.inputs))
                 for c in test_file_containers),
                (list(flatmap(identity_map, c.masks))
                 for c in test_file_containers)))

        for testcase, var_mask in data:
            print("----------Test case----------\n")
            evals = get_evals([self.prediction_op], self.model)
            state, att_states, att_counts = get_initial_state(self.model)
            predicted_tokens = []

            for i, (token, mask) in enumerate(zip(testcase, var_mask)):
                data = (np.array([[token]]), np.array([[1]]), np.array([0]),
                        np.array([1]))
                feed_dict = construct_feed_dict(self.model, data, state,
                                                att_states, att_counts)
                results = session.run(evals, feed_dict=feed_dict)
                prediction, state, att_states, att_counts, att_vec, lambda_vec = \
                    extract_results(results, evals, 1, self.model)
                predicted_token = self.inverse_map[prediction[0][0]].replace(
                    "\n", "<newline>")
                current_token = self.inverse_map[token].replace(
                    "\n", "<newline>")
                predicted_tokens.append(current_token + " ; " +
                                        predicted_token)

            print("\n".join(predicted_tokens))
コード例 #3
0
ファイル: hooks.py プロジェクト: Hamchin/pycodesuggest
    def __call__(self, sess, epoch, iteration, model, loss, _):
        if iteration == 0 and epoch % self.at_every_epoch == 0:
            print("Generating %d samples from model" % len(self.test_cases))
            evals = get_evals([self.prediction_op], self.model)
            for testcase in self.test_cases:
                output = list(testcase)
                state, att_states, att_counts = get_initial_state(self.model)

                for i in range(len(testcase) + self.sample_length):
                    # TODO: Need to determine whether a generated token is a variable?
                    # Run it through the parser?
                    att_mask = attention_masks(self.attns, [0], 1)
                    data = (np.array([[self.map[output[i]]]]), np.array([[1]]),
                            np.array([att_mask]), np.array([1]))
                    feed_dict = construct_feed_dict(self.model, data, state,
                                                    att_states, att_counts)

                    results = sess.run(evals, feed_dict)
                    results, state, att_states, att_counts, _, _ = extract_results(
                        results, evals, 1, self.model)

                    output_token = self.inverse_map[results[0][0]]

                    if i >= len(testcase) - 1:
                        output.append(output_token)

                print(" ".join(output))
                print("\n\n§§§§§§§§§§§§§§§§\n\n")
コード例 #4
0
ファイル: hooks.py プロジェクト: hedgefair/pycodesuggest
    def __call__(self, sess, epoch, iteration, model, loss, _):
        if iteration == 0 and epoch % self.at_every_epoch == 0:
            print("Generating %d samples from model" % len(self.test_cases))
            evals = get_evals([self.prediction_op], self.model)
            for testcase in self.test_cases:
                output = list(testcase)
                state, att_states, att_counts = get_initial_state(self.model)

                for i in range(len(testcase) + self.sample_length):
                    # TODO: Need to determine whether a generated token is a variable?
                    # Run it through the parser?
                    att_mask = attention_masks(self.attns, [0], 1)
                    data = (np.array([[self.map[output[i]]]]), np.array([[1]]), np.array([att_mask]), np.array([1]))
                    feed_dict = construct_feed_dict(self.model, data, state, att_states, att_counts)

                    results = sess.run(evals, feed_dict)
                    results, state, att_states, att_counts, _, _ = extract_results(results, evals, 1, self.model)

                    output_token = self.inverse_map[results[0][0]]

                    if i >= len(testcase)-1:
                        output.append(output_token)

                print(" ".join(output))
                print("\n\n§§§§§§§§§§§§§§§§\n\n")
コード例 #5
0
ファイル: evaluation.py プロジェクト: hedgefair/pycodesuggest
        def run_network(token_id, state, att_states, att_counts):
            att_mask = attention_masks(self.attns, [0], 1)
            data = (np.array([[token_id]]), np.array([[1]]), np.array([att_mask]), np.array([1]))
            feed_dict = construct_feed_dict(self.model, data, state, att_states, att_counts)

            results = session.run(evals, feed_dict)
            results, state, att_states, att_counts, _, _ = extract_results(results, evals, 2, self.model)
            return results, state, att_states, att_counts
コード例 #6
0
ファイル: evaluation.py プロジェクト: Hamchin/pycodesuggest
        def run_network(token_id, state, att_states, att_ids, att_counts):
            att_mask = attention_masks(self.attns, [0], 1)
            data = (np.array([[token_id]]), np.array([[1]]), np.array([att_mask]), np.array([1]), np.array([1]))
            feed_dict, _ = construct_feed_dict(self.model, data, state, att_states, att_ids, att_counts)

            results = session.run(evals, feed_dict)
            results, state, att_states, _, _, att_counts, _ = extract_results(results, evals, 2, self.model)
            return results, state, att_states, att_counts
コード例 #7
0
ファイル: evaluation.py プロジェクト: hedgefair/pycodesuggest
    def profile(self, session):
        evals = [self.model.cost]
        for batch in self.batcher:
            state, att_states, att_ids, att_counts = get_initial_state(self.model)

            for seq_batch in self.batcher.sequence_iterator(batch):
                feed_dict = construct_feed_dict(self.model, seq_batch, state, att_states, att_ids, att_counts)
                run_metadata = tf.RunMetadata()
                session.run(evals, feed_dict=feed_dict,
                            options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                break

        from tensorflow.python.client import timeline
        trace = timeline.Timeline(step_stats=run_metadata.step_stats)
        trace_file = open('timeline.ctf.json', 'w')
        trace_file.write(trace.generate_chrome_trace_format())
コード例 #8
0
ファイル: evaluation.py プロジェクト: Hamchin/pycodesuggest
    def profile(self, session):
        evals = [self.model.cost]
        for batch in self.batcher:
            state, att_states, att_ids, att_counts = get_initial_state(self.model)

            for seq_batch in self.batcher.sequence_iterator(batch):
                feed_dict = construct_feed_dict(self.model, seq_batch, state, att_states, att_ids, att_counts)
                run_metadata = tf.RunMetadata()
                session.run(evals, feed_dict=feed_dict,
                            options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
                            run_metadata=run_metadata)
                break

        from tensorflow.python.client import timeline
        trace = timeline.Timeline(step_stats=run_metadata.step_stats)
        trace_file = open('timeline.ctf.json', 'w')
        trace_file.write(trace.generate_chrome_trace_format())
コード例 #9
0
ファイル: evaluation.py プロジェクト: hedgefair/pycodesuggest
    def __call__(self, session):
        test_file_containers = pyreader.get_data(self.data_path, self.test_files, 1, self.map)
        data = list(zip((list(flatmap(identity_map, c.inputs)) for c in test_file_containers),
                        (list(flatmap(identity_map, c.masks)) for c in test_file_containers)))

        for testcase, var_mask in data:
            print("----------Test case----------\n")
            evals = get_evals([self.prediction_op], self.model)
            state, att_states, att_counts = get_initial_state(self.model)
            predicted_tokens = []

            for i, (token, mask) in enumerate(zip(testcase, var_mask)):
                data = (np.array([[token]]), np.array([[1]]), np.array([0]), np.array([1]))
                feed_dict = construct_feed_dict(self.model, data, state, att_states, att_counts)
                results = session.run(evals, feed_dict=feed_dict)
                prediction, state, att_states, att_counts, att_vec, lambda_vec = \
                    extract_results(results, evals, 1, self.model)
                predicted_token = self.inverse_map[prediction[0][0]].replace("\n", "<newline>")
                current_token = self.inverse_map[token].replace("\n", "<newline>")
                predicted_tokens.append(current_token + " ; " + predicted_token)

            print("\n".join(predicted_tokens))
コード例 #10
0
ファイル: evaluation.py プロジェクト: hedgefair/pycodesuggest
    def __call__(self, session):
        to_eval = [self.prediction_op[0], self.prediction_op[1]]
        evals = get_evals(to_eval, self.model)
        for testcase in all_test_cases:
            state, att_states, att_counts = get_initial_state(self.model)
            for i, token in enumerate(testcase):
                att_mask = attention_masks(self.attns, [0], 1)
                data = (np.array([[ map_token(self.map, testcase[i]) ]]), np.array([[1]]), np.array([att_mask]), np.array([1]))
                token_id, mask = map_token(self.map, testcase[i])
                att_mask = attention_masks(self.attns, np.array([mask]), 1)
                data = (np.array([[ token_id ]]), np.array([[1]]), np.array([att_mask]), np.array([1]))
                feed_dict = construct_feed_dict(self.model, data, state, att_states, att_counts)

                results = session.run(evals, feed_dict)
                results, state, att_states, att_counts, _, _ = extract_results(results, evals, 2, self.model)

            probs = results[0]
            predict_ids = results[1]

            for i in range(5):
                print("%s ; %f" % (self.inv_map[predict_ids[0, i]].replace("\n", "<newline>"), probs[0, i]))

            print("\n\n§§§§§§§§§§§§§\n\n")
コード例 #11
0
ファイル: hooks.py プロジェクト: Hamchin/pycodesuggest
 def __call__(self, sess, epoch, iteration, logits, loss, processed):
     if iteration == 0 and epoch % self.at_every_epoch == 0:
         total = 0
         correct = 0
         for values in self.batcher:
             total += values.actual_lengths[0]
             data = [
                 data for data in self.batcher.sequence_iterator(values)
             ][0]
             state, att_states, att_ids, att_counts = get_initial_state(
                 self.model)
             feed_dict, identifier_usage = construct_feed_dict(
                 self.model, data, state, att_states, att_ids, att_counts)
             truth = values.targets[0]
             predicted = sess.run(tf.arg_max(tf.nn.softmax(logits), 1),
                                  feed_dict=feed_dict)
             truth = [self.vocab[id] for id in truth if id != 0]
             predicted = [self.vocab[id] for id in predicted if id != 0]
             correct += sum(
                 [1 if t == p else 0 for t, p in zip(truth, predicted)])
         acc = float(correct) / total
         self.update_summary(sess, iteration, ACCURACY_TRACE_TAG, acc)
         print("Epoch " + str(epoch) + "\tAcc " + str(acc) + "\tCorrect " +
               str(correct) + "\tTotal " + str(total))
コード例 #12
0
ファイル: evaluation.py プロジェクト: hedgefair/pycodesuggest
    def __call__(self, session):
        test_file_containers = pyreader.get_data(self.data_path, self.test_files, 1, self.map)
        data = list(zip((list(flatmap(identity_map, c.inputs)) for c in test_file_containers),
                        (list(flatmap(identity_map, c.masks)) for c in test_file_containers)))

        #tokens = ["def", "function234", "(", "arg289", ")", ":", "\n", "§<indent>§", "with", "open", "(", "§OOV§", ",", "'w'", ")", "as", "f|var76", ":", "\n", "§<indent>§", "f|var76", "."]
        #tokens = ["def", "function234", "(", "arg289", ")", ":", "\n", "§<indent>§", "with", "open", "(", "§OOV§", ",", "'r'", ")", "as", "f|var76", ":", "\n", "§<indent>§", "var91", "=", "f|var76", "."]
        #data = [([map_token(self.map, t) for t in tokens],
        #         [np.array([False]) for _ in tokens])]

        for testcase, var_mask in data:
            print("----------Test case----------\n")
            evals = get_evals([self.model.predict], self.model)
            state, att_states, att_counts = get_initial_state(self.model)

            accumulated_tokens = []
            plot_data = np.zeros([len(testcase), self.max_display])
            annotations = np.empty([len(testcase), self.max_display], dtype=object)
            y_labels = []
            predicted_labels = []

            for i, (token, mask) in enumerate(zip(testcase, var_mask)):
                att_mask = attention_masks(self.attns, np.array([mask]), 1)
                data = (np.array([[token]]), np.array([[1]]), np.array([att_mask]), np.array([1]))
                feed_dict = construct_feed_dict(self.model, data, state, att_states, att_counts)
                results = session.run(evals, feed_dict=feed_dict)
                prediction, state, att_states, att_counts, att_vec, lambda_vec = \
                    extract_results(results, evals, 1, self.model)
                predicted = np.argmax(prediction)
                current_token = self.inverse_map[token]
                predicted_token = self.inverse_map[predicted]

                if len(accumulated_tokens) > self.max_attention:
                    accumulated_tokens.pop(0)

                m = att_vec[0].shape[1]
                take = min(m, len(accumulated_tokens))
                alphas = att_vec[0][0, m-take:]
                labels = np.array([clean_token(t) for t in accumulated_tokens])

                '''if take > self.max_display:
                    ind = np.argpartition(alphas, -self.max_display)[-self.max_display:]
                    alphas = alphas[ind]
                    labels = labels[ind]'''

                y_labels.append(current_token.replace("\n", "<newline>"))
                predicted_labels.append(predicted_token.replace("\n", "<newline>"))

                print("%s ; %s" % (current_token.replace("\n", "<newline>"), predicted_token.replace("\n", "<newline>")))

                begin = max(self.max_display-take, 0)
                plot_data[i, begin:] = alphas
                annotations[i, begin:] = labels
                if begin != 0:
                    annotations[i, 0:begin] = ""

                accumulated_tokens.append(current_token)

            for i in range(1, len(y_labels)):
                if y_labels[i] == predicted_labels[i-1]:
                    y_labels[i] = "** " + y_labels[i]

            x_labels = [""] * self.max_display

            sns.set(font_scale=1.2)
            sns.set_style({"savefig.dpi": 100})
            ax = sns.heatmap(plot_data, cmap=plt.cm.Blues, linewidths=.1, annot=annotations, fmt="", vmin=0, vmax=1,
                             cbar=False, xticklabels=x_labels, yticklabels=y_labels, annot_kws={"size": 10})
            plt.yticks(rotation=0)

            fig = ax.get_figure()
            # specify dimensions and save
            fig.set_size_inches(int(self.max_display)*1.3, int(len(plot_data)/3))

            fig.savefig('./out/lagged_attention.png')
            print("Generated file lagged_attention.png")
コード例 #13
0
ファイル: evaluation.py プロジェクト: hedgefair/pycodesuggest
    def __call__(self, session):
        test_file_containers = []
        if self.test_files:
            test_file_containers = pyreader.get_data(self.data_path, self.test_files, 1, self.map)

        data = list(zip(self.test_cases, self.var_masks)) + \
               list(zip((list(flatmap(identity_map, c.inputs)) for c in test_file_containers),
                        (list(flatmap(identity_map, c.masks)) for c in test_file_containers)))

        for testcase, var_mask in data:
            if len(testcase) != len(var_mask):
                raise ValueError("Length of testcase does not match corresponding variable mask: %s" % testcase)

            print("----------Test case----------\n")
            evals = get_evals([self.model.predict], self.model)
            state, att_states, att_ids, att_counts = get_initial_state(self.model)

            prev_mask = False
            attns = []
            plot_data = np.zeros([len(testcase), self.max_attention])
            lambda_data = np.zeros([len(testcase), 2])
            annotations = np.empty([len(testcase), self.max_attention], dtype=object)
            y_labels = []
            predicted_token = ""

            for i, (token, mask) in enumerate(zip(testcase, var_mask)):
                att_mask = attention_masks(self.attns, np.array([mask]), 1)
                data = (np.array([[token]]), np.array([[1]]), np.array([att_mask]), np.array([[1]]), np.array([1]))
                feed_dict = construct_feed_dict(self.model, data, state, att_states, att_ids, att_counts)
                results = session.run(evals, feed_dict=feed_dict)
                prediction, state, att_states, att_ids, alpha_states, att_counts, lambda_vec = \
                    extract_results(results, evals, 1, self.model)
                predicted = np.argmax(prediction)

                if prev_mask:
                    if len(attns) >= self.max_attention:
                        attns = attns[1:]
                    attns.append(prev_token)

                prev_mask = mask
                prev_token = self.inverse_map[token]

                plot_data[i, :] = alpha_states[0][0] * (lambda_vec[0, 1] if lambda_vec[0, 1] < 0.1 else 1)
                lambda_data[i, :] = lambda_vec
                labels = [""] * (self.max_attention-len(attns)) + attns
                annotations[i, :] = labels

                current_token = self.inverse_map[token]
                current_token = "%s%s%s" % ("** " if current_token == predicted_token else "", "(*)" if mask else "", current_token)
                y_labels.append(current_token)
                predicted_token = self.inverse_map[predicted]

            fig, (ax_data, ax_lambda) = plt.subplots(1, 2, gridspec_kw={
                'width_ratios': [self.max_attention, 2]
            })

            blank_x_labels = [""] * self.max_attention
            blank_y_labels = [""] * len(testcase)
            lambda_x_labels = ["LM", "Att"]
            sns.set(font_scale=1.2)
            sns.set_style({"savefig.dpi": 100})
            plt.yticks(rotation=0)
            ax_data = sns.heatmap(plot_data, ax=ax_data, cmap=plt.cm.Blues, linewidths=.1, annot=annotations,
                                  fmt="", vmin=0, vmax=1, cbar=False, xticklabels=blank_x_labels, yticklabels=y_labels,
                                  annot_kws={"size": 9})
            ax_lambda = sns.heatmap(lambda_data, ax=ax_lambda, cmap=plt.cm.Blues, linewidths=.1, annot=False,
                                    fmt="", vmin=0, vmax=1, cbar=False, xticklabels=lambda_x_labels, yticklabels=blank_y_labels,
                                    annot_kws={"size": 9})

            ax_data.set_yticklabels(ax_data.yaxis.get_majorticklabels(), rotation=0)
            ax_lambda.xaxis.tick_top()
            fig.set_size_inches(int(self.max_attention)*1.3, int(len(plot_data)/3))

            fig.savefig('./out/attention2.png')
            print("Generated file attention2.png")
コード例 #14
0
    def __call__(self, session):
        test_file_containers = pyreader.get_data(self.data_path,
                                                 self.test_files, 1, self.map)
        data = list(
            zip((list(flatmap(identity_map, c.inputs))
                 for c in test_file_containers),
                (list(flatmap(identity_map, c.masks))
                 for c in test_file_containers)))

        #tokens = ["def", "function234", "(", "arg289", ")", ":", "\n", "§<indent>§", "with", "open", "(", "§OOV§", ",", "'w'", ")", "as", "f|var76", ":", "\n", "§<indent>§", "f|var76", "."]
        #tokens = ["def", "function234", "(", "arg289", ")", ":", "\n", "§<indent>§", "with", "open", "(", "§OOV§", ",", "'r'", ")", "as", "f|var76", ":", "\n", "§<indent>§", "var91", "=", "f|var76", "."]
        #data = [([map_token(self.map, t) for t in tokens],
        #         [np.array([False]) for _ in tokens])]

        for testcase, var_mask in data:
            print("----------Test case----------\n")
            evals = get_evals([self.model.predict], self.model)
            state, att_states, att_counts = get_initial_state(self.model)

            accumulated_tokens = []
            plot_data = np.zeros([len(testcase), self.max_display])
            annotations = np.empty([len(testcase), self.max_display],
                                   dtype=object)
            y_labels = []
            predicted_labels = []

            for i, (token, mask) in enumerate(zip(testcase, var_mask)):
                att_mask = attention_masks(self.attns, np.array([mask]), 1)
                data = (np.array([[token]]), np.array([[1]]),
                        np.array([att_mask]), np.array([1]))
                feed_dict = construct_feed_dict(self.model, data, state,
                                                att_states, att_counts)
                results = session.run(evals, feed_dict=feed_dict)
                prediction, state, att_states, att_counts, att_vec, lambda_vec = \
                    extract_results(results, evals, 1, self.model)
                predicted = np.argmax(prediction)
                current_token = self.inverse_map[token]
                predicted_token = self.inverse_map[predicted]

                if len(accumulated_tokens) > self.max_attention:
                    accumulated_tokens.pop(0)

                m = att_vec[0].shape[1]
                take = min(m, len(accumulated_tokens))
                alphas = att_vec[0][0, m - take:]
                labels = np.array([clean_token(t) for t in accumulated_tokens])
                '''if take > self.max_display:
                    ind = np.argpartition(alphas, -self.max_display)[-self.max_display:]
                    alphas = alphas[ind]
                    labels = labels[ind]'''

                y_labels.append(current_token.replace("\n", "<newline>"))
                predicted_labels.append(
                    predicted_token.replace("\n", "<newline>"))

                print("%s ; %s" % (current_token.replace("\n", "<newline>"),
                                   predicted_token.replace("\n", "<newline>")))

                begin = max(self.max_display - take, 0)
                plot_data[i, begin:] = alphas
                annotations[i, begin:] = labels
                if begin != 0:
                    annotations[i, 0:begin] = ""

                accumulated_tokens.append(current_token)

            for i in range(1, len(y_labels)):
                if y_labels[i] == predicted_labels[i - 1]:
                    y_labels[i] = "** " + y_labels[i]

            x_labels = [""] * self.max_display

            sns.set(font_scale=1.2)
            sns.set_style({"savefig.dpi": 100})
            ax = sns.heatmap(plot_data,
                             cmap=plt.cm.Blues,
                             linewidths=.1,
                             annot=annotations,
                             fmt="",
                             vmin=0,
                             vmax=1,
                             cbar=False,
                             xticklabels=x_labels,
                             yticklabels=y_labels,
                             annot_kws={"size": 10})
            plt.yticks(rotation=0)

            fig = ax.get_figure()
            # specify dimensions and save
            fig.set_size_inches(
                int(self.max_display) * 1.3, int(len(plot_data) / 3))

            fig.savefig('./out/lagged_attention.png')
            print("Generated file lagged_attention.png")
コード例 #15
0
    def __call__(self, session):
        test_file_containers = []
        if self.test_files:
            test_file_containers = pyreader.get_data(self.data_path,
                                                     self.test_files, 1,
                                                     self.map)

        data = list(zip(self.test_cases, self.var_masks)) + \
               list(zip((list(flatmap(identity_map, c.inputs)) for c in test_file_containers),
                        (list(flatmap(identity_map, c.masks)) for c in test_file_containers)))

        for testcase, var_mask in data:
            if len(testcase) != len(var_mask):
                raise ValueError(
                    "Length of testcase does not match corresponding variable mask: %s"
                    % testcase)

            print("----------Test case----------\n")
            evals = get_evals([self.model.predict], self.model)
            state, att_states, att_ids, att_counts = get_initial_state(
                self.model)

            prev_mask = False
            attns = []
            plot_data = np.zeros([len(testcase), self.max_attention])
            lambda_data = np.zeros([len(testcase), 2])
            annotations = np.empty([len(testcase), self.max_attention],
                                   dtype=object)
            y_labels = []
            predicted_token = ""

            for i, (token, mask) in enumerate(zip(testcase, var_mask)):
                att_mask = attention_masks(self.attns, np.array([mask]), 1)
                data = (np.array([[token]]), np.array([[1]]),
                        np.array([att_mask]), np.array([[1]]), np.array([1]))
                feed_dict = construct_feed_dict(self.model, data, state,
                                                att_states, att_ids,
                                                att_counts)
                results = session.run(evals, feed_dict=feed_dict)
                prediction, state, att_states, att_ids, alpha_states, att_counts, lambda_vec = \
                    extract_results(results, evals, 1, self.model)
                predicted = np.argmax(prediction)

                if prev_mask:
                    if len(attns) >= self.max_attention:
                        attns = attns[1:]
                    attns.append(prev_token)

                prev_mask = mask
                prev_token = self.inverse_map[token]

                plot_data[i, :] = alpha_states[0][0] * (
                    lambda_vec[0, 1] if lambda_vec[0, 1] < 0.1 else 1)
                lambda_data[i, :] = lambda_vec
                labels = [""] * (self.max_attention - len(attns)) + attns
                annotations[i, :] = labels

                current_token = self.inverse_map[token]
                current_token = "%s%s%s" % (
                    "** " if current_token == predicted_token else "",
                    "(*)" if mask else "", current_token)
                y_labels.append(current_token)
                predicted_token = self.inverse_map[predicted]

            fig, (ax_data, ax_lambda) = plt.subplots(
                1, 2, gridspec_kw={'width_ratios': [self.max_attention, 2]})

            blank_x_labels = [""] * self.max_attention
            blank_y_labels = [""] * len(testcase)
            lambda_x_labels = ["LM", "Att"]
            sns.set(font_scale=1.2)
            sns.set_style({"savefig.dpi": 100})
            plt.yticks(rotation=0)
            ax_data = sns.heatmap(plot_data,
                                  ax=ax_data,
                                  cmap=plt.cm.Blues,
                                  linewidths=.1,
                                  annot=annotations,
                                  fmt="",
                                  vmin=0,
                                  vmax=1,
                                  cbar=False,
                                  xticklabels=blank_x_labels,
                                  yticklabels=y_labels,
                                  annot_kws={"size": 9})
            ax_lambda = sns.heatmap(lambda_data,
                                    ax=ax_lambda,
                                    cmap=plt.cm.Blues,
                                    linewidths=.1,
                                    annot=False,
                                    fmt="",
                                    vmin=0,
                                    vmax=1,
                                    cbar=False,
                                    xticklabels=lambda_x_labels,
                                    yticklabels=blank_y_labels,
                                    annot_kws={"size": 9})

            ax_data.set_yticklabels(ax_data.yaxis.get_majorticklabels(),
                                    rotation=0)
            ax_lambda.xaxis.tick_top()
            fig.set_size_inches(
                int(self.max_attention) * 1.3, int(len(plot_data) / 3))

            fig.savefig('./out/attention2.png')
            print("Generated file attention2.png")
コード例 #16
0
ファイル: test.py プロジェクト: Hamchin/pycodesuggest
    with open(os.path.join(config.model_path, "config.pkl"),
              "rb") as config_file:
        model_config_dict = pickle.load(config_file)
        model_config_dict["batch_size"] = config.batch_size
        if "attention" not in model_config_dict:
            model_config_dict["attention"] = config.attention
        model_config = FlagWrapper(model_config_dict)
    with tf.Graph().as_default(), tf.Session() as session:
        generator_config = copy_flags(model_config)
        generator_config.seq_length = 1
        generator_config.batch_size = 1
        with tf.variable_scope("model", reuse=None):
            model = create_model(model_config, False)
        with tf.variable_scope("model", reuse=True):
            generator_model = create_model(generator_config, False)
        init = tf.initialize_all_variables()
        session.run(init)
        load_model(session, config.model_path)
        state, att_states, att_ids, att_counts = get_initial_state(model)
        feed_dict, identifier_usage = construct_feed_dict(
            model, data, state, att_states, att_ids, att_counts)
        predicted = session.run(tf.arg_max(tf.nn.softmax(model.logits), 1),
                                feed_dict=feed_dict)
        replace = lambda word: "\\n" if word == "\n" else word
        truth = [vocab[id] for id in inputs[0] if id != 0]
        predicted = [vocab[id] for id in predicted if id != 0]
        [
            print("%s\t%s" % (replace(t), replace(p)))
            for t, p in zip(truth, [''] + predicted)
        ]