Ejemplo n.º 1
0
    def __init__(self,
                 transcripts_paths,
                 survey_paths,
                 stats_path,
                 price_tracker_model,
                 liwc_path,
                 max_examples=None):
        transcripts = self._read_transcripts(transcripts_paths, max_examples)
        self.dataset = utils.filter_rejected_chats(transcripts)

        dialogue_scores = self._read_surveys(survey_paths)

        self.price_tracker = PriceTracker(price_tracker_model)

        self.liwc = LIWC.from_pkl(liwc_path)

        # group chats depending on whether the seller or the buyer wins
        #self.buyer_wins, self.seller_wins = self.group_outcomes_and_roles()

        self.stats_path = stats_path
        if not os.path.exists(self.stats_path):
            os.makedirs(self.stats_path)

        self.examples = [
            Dialogue.from_dict(raw, dialogue_scores.get(raw['uuid'], {}),
                               self.price_tracker) for raw in self.dataset
        ]
Ejemplo n.º 2
0
def get_system(name, args, schema=None, timed=False, model_path=None):
    from core.price_tracker import PriceTracker
    lexicon = PriceTracker(args.price_tracker_model)

    if name == 'rulebased':
        from rulebased_system import RulebasedSystem
        from model.generator import Templates, Generator
        from model.manager import Manager
        templates = Templates.from_pickle(args.templates)
        generator = Generator(templates)
        manager = Manager.from_pickle(args.policy)
        return RulebasedSystem(lexicon, generator, manager, timed)
    elif name == 'hybrid':
        from hybrid_system import HybridSystem
        templates = Templates.from_pickle(args.templates)
        manager = PytorchNeuralSystem(args, schema, lexicon, model_path, timed)
        generator = Generator(templates)
        return HybridSystem(lexicon, generator, manager, timed)
    elif name == 'cmd':
        from cmd_system import CmdSystem
        return CmdSystem()
    elif name == 'pt-neural':
        from neural_system import PytorchNeuralSystem
        assert model_path
        return PytorchNeuralSystem(args, schema, lexicon, model_path, timed)
    else:
        raise ValueError('Unknown system %s' % name)
Ejemplo n.º 3
0
def get_system(name, args, schema=None, timed=False, model_path=None):
    lexicon = PriceTracker(args.price_tracker_model)
    if name == 'rulebased':
        templates = Templates.from_pickle(args.templates)
        generator = Generator(templates)
        manager = Manager.from_pickle(args.policy)
        return RulebasedSystem(lexicon, generator, manager, timed)
    #elif name == 'config-rulebased':
    #    configs = read_json(args.rulebased_configs)
    #    return ConfigurableRulebasedSystem(configs, lexicon, timed_session=timed, policy=args.config_search_policy, max_chats_per_config=args.chats_per_config, db=args.trials_db, templates=templates)
    elif name == 'cmd':
        return CmdSystem()
    elif name.startswith('ranker'):
        # TODO: hack
        #retriever1 = Retriever(args.index+'-1', context_size=args.retriever_context_len, num_candidates=args.num_candidates)
        #retriever2 = Retriever(args.index+'-2', context_size=args.retriever_context_len, num_candidates=args.num_candidates)
        retriever = Retriever(args.index, context_size=args.retriever_context_len, num_candidates=args.num_candidates)
        if name == 'ranker-ir':
            return IRRankerSystem(schema, lexicon, retriever)
        elif name == 'ranker-ir1':
            return IRRankerSystem(schema, lexicon, retriever1)
        elif name == 'ranker-ir2':
            return IRRankerSystem(schema, lexicon, retriever2)
        elif name == 'ranker-neural':
            return NeuralRankerSystem(schema, lexicon, retriever, model_path, args.mappings)
        else:
            raise ValueError
    elif name in ('neural-gen', 'neural-sel'):
        assert model_path
        return NeuralSystem(schema, lexicon, model_path, args.mappings, args.decoding, index=args.index, num_candidates=args.num_candidates, retriever_context_len=args.retriever_context_len, timed_session=timed)
    else:
        raise ValueError('Unknown system %s' % name)
Ejemplo n.º 4
0
def get_data_generator(args, model_args, schema, test=False):
    from cocoa.core.scenario_db import ScenarioDB
    from cocoa.core.dataset import read_dataset
    from cocoa.core.util import read_json

    from core.scenario import Scenario
    from core.price_tracker import PriceTracker
    from .preprocess import DataGenerator, Preprocessor
    import os.path

    # TODO: move this to dataset
    dataset = read_dataset(args, Scenario)

    mappings_path = model_args.mappings

    lexicon = PriceTracker(model_args.price_tracker_model)

    preprocessor = Preprocessor(schema,
                                lexicon,
                                model_args.entity_encoding_form,
                                model_args.entity_decoding_form,
                                model_args.entity_target_form,
                                model=model_args.model)

    if test:
        model_args.dropout = 0
        train, dev, test = None, None, dataset.test_examples
    else:
        train, dev, test = dataset.train_examples, dataset.test_examples, None
    data_generator = DataGenerator(train,
                                   dev,
                                   test,
                                   preprocessor,
                                   schema,
                                   mappings_path,
                                   cache=args.cache,
                                   ignore_cache=args.ignore_cache,
                                   num_context=model_args.num_context,
                                   batch_size=args.batch_size,
                                   model=model_args.model,
                                   dia_num=args.dia_num,
                                   state_length=args.state_length)

    return data_generator
Ejemplo n.º 5
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--transcripts',
                        nargs='*',
                        help='JSON transcripts to extract templates')
    parser.add_argument('--max-examples', default=-1, type=int)
    parser.add_argument('--templates', help='Path to load templates')
    parser.add_argument('--policy', help='Path to load model')
    parser.add_argument('--schema-path', help='Path to schema')
    parser.add_argument(
        '--agent', help='Only consider examples with the given type of agent')
    add_price_tracker_arguments(parser)
    args = parser.parse_args()

    lexicon = PriceTracker(args.price_tracker_model)
    #templates = Templates.from_pickle(args.templates)
    templates = Templates()
    manager = Manager.from_pickle(args.policy)
    analyzer = Analyzer(lexicon)

    # TODO: skip examples
    examples = read_examples(args.transcripts, args.max_examples, Scenario)
    agent = args.agent
    if agent is not None:
        examples = [e for e in examples if agent in e.agents.values()]
    analyzer.example_stats(examples, agent=agent)
    #import sys; sys.exit()

    parsed_dialogues = []
    for example in examples:
Ejemplo n.º 6
0
 def to_float_price(entity):
     return float('{:.2f}'.format(PriceTracker.get_price(entity)))
Ejemplo n.º 7
0
    parser.add_argument('--transcripts',
                        nargs='*',
                        help='JSON transcripts to extract templates')
    parser.add_argument('--price-tracker-model')
    parser.add_argument('--max-examples', default=-1, type=int)
    parser.add_argument('--output', help='Path to save templates')
    parser.add_argument('--output-transcripts',
                        help='Path to JSON examples with templates')
    parser.add_argument('--templates', help='Path to load templates')
    parser.add_argument('--debug', default=False, action='store_true')
    args = parser.parse_args()

    if args.templates:
        templates = Templates.from_pickle(args.templates)
    else:
        price_tracker = PriceTracker(args.price_tracker_model)
        template_extractor = TemplateExtractor(price_tracker)
        template_extractor.extract_templates(args.transcripts,
                                             args.max_examples)
        write_pickle(template_extractor.templates, args.output)
        templates = Templates(template_extractor.templates)

    t = templates.templates
    response_tags = set(t.response_tag.values)
    tag_counts = []
    for tag in response_tags:
        tag_counts.append(
            (tag, t[t.response_tag == tag].shape[0] / float(t.shape[0])))
    tag_counts = sorted(tag_counts, key=lambda x: x[1], reverse=True)
    for x in tag_counts:
        print x
Ejemplo n.º 8
0
def get_data_generator(args, model_args, mappings, schema):
    from cocoa.core.scenario_db import ScenarioDB
    from cocoa.core.dataset import read_dataset, EvalExample
    from cocoa.core.util import read_json

    from core.scenario import Scenario
    from core.price_tracker import PriceTracker
    from core.slot_detector import SlotDetector
    from retriever import Retriever
    from preprocess import DataGenerator, LMDataGenerator, EvalDataGenerator, Preprocessor
    import os.path

    # TODO: move this to dataset
    if args.eval:
        dataset = []
        for path in args.eval_examples_paths:
            dataset.extend(
                [EvalExample.from_dict(schema, e) for e in read_json(path)])
    else:
        dataset = read_dataset(args, Scenario)
    lexicon = PriceTracker(model_args.price_tracker_model)
    slot_detector = SlotDetector(slot_scores_path=model_args.slot_scores)

    # Model config tells data generator which batcher to use
    model_config = {}
    if args.retrieve or model_args.model in ('ir', 'selector'):
        model_config['retrieve'] = True
    if args.predict_price:
        model_config['price'] = True

    # For retrieval-based models only: whether to add ground truth response in the candidates
    if model_args.model in ('selector', 'ir'):
        if 'loss' in args.eval_modes and 'generation' in args.eval_modes:
            print '"loss" requires ground truth reponse to be added to the candidate set. Please evaluate "loss" and "generation" separately.'
            raise ValueError
        if (not args.test) or args.eval_modes == ['loss']:
            add_ground_truth = True
        else:
            add_ground_truth = False
        print 'Ground truth response {} be added to the candidate set.'.format(
            'will' if add_ground_truth else 'will not')
    else:
        add_ground_truth = False

    # TODO: hacky
    if args.model == 'lm':
        DataGenerator = LMDataGenerator

    if args.retrieve or args.model in ('selector', 'ir'):
        retriever = Retriever(args.index,
                              context_size=args.retriever_context_len,
                              num_candidates=args.num_candidates)
    else:
        retriever = None

    preprocessor = Preprocessor(schema,
                                lexicon,
                                model_args.entity_encoding_form,
                                model_args.entity_decoding_form,
                                model_args.entity_target_form,
                                slot_filling=model_args.slot_filling,
                                slot_detector=slot_detector)

    trie_path = os.path.join(model_args.mappings, 'trie.pkl')

    if args.eval:
        data_generator = EvalDataGenerator(dataset, preprocessor, mappings,
                                           model_args.num_context)
    else:
        if args.test:
            model_args.dropout = 0
            train, dev, test = None, None, dataset.test_examples
        else:
            train, dev, test = dataset.train_examples, dataset.test_examples, None
        data_generator = DataGenerator(train,
                                       dev,
                                       test,
                                       preprocessor,
                                       schema,
                                       mappings,
                                       retriever=retriever,
                                       cache=args.cache,
                                       ignore_cache=args.ignore_cache,
                                       candidates_path=args.candidates_path,
                                       num_context=model_args.num_context,
                                       trie_path=trie_path,
                                       batch_size=args.batch_size,
                                       model_config=model_config,
                                       add_ground_truth=add_ground_truth)

    return data_generator
Ejemplo n.º 9
0
class StrategyAnalyzer(object):
    sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

    def __init__(self,
                 transcripts_paths,
                 survey_paths,
                 stats_path,
                 price_tracker_model,
                 liwc_path,
                 max_examples=None):
        transcripts = self._read_transcripts(transcripts_paths, max_examples)
        self.dataset = utils.filter_rejected_chats(transcripts)

        dialogue_scores = self._read_surveys(survey_paths)

        self.price_tracker = PriceTracker(price_tracker_model)

        self.liwc = LIWC.from_pkl(liwc_path)

        # group chats depending on whether the seller or the buyer wins
        #self.buyer_wins, self.seller_wins = self.group_outcomes_and_roles()

        self.stats_path = stats_path
        if not os.path.exists(self.stats_path):
            os.makedirs(self.stats_path)

        self.examples = [
            Dialogue.from_dict(raw, dialogue_scores.get(raw['uuid'], {}),
                               self.price_tracker) for raw in self.dataset
        ]

    def _read_transcripts(self, transcripts_paths, max_examples):
        transcripts = []
        for transcripts_path in transcripts_paths:
            transcripts.extend(read_json(transcripts_path))
        if max_examples is not None:
            transcripts = transcripts[:max_examples]
        return transcripts

    def _read_surveys(self, survey_paths):
        dialogue_scores = {}
        for path in survey_paths:
            dialogue_scores.update(read_json(path)[1])
        return dialogue_scores

    def label_dialogues(self, labels=('speech_act', 'stage', 'liwc')):
        for dialogue in self.examples:
            if 'speech_act' in labels:
                dialogue.extract_keywords()
                dialogue.label_speech_acts()
            if 'stage' in labels:
                dialogue.label_stage()
            if 'liwc' in labels:
                dialogue.label_liwc(self.liwc)

    def summarize_tags(self):
        tags = defaultdict(lambda: defaultdict(int))
        for dialogue in self.examples:
            for turn in dialogue.turns:
                agent_name = dialogue.agents[turn.agent]
                for tag in turn.tags:
                    tags[agent_name][tag] += 1
        for system, labels in tags.iteritems():
            print system.upper()
            for k, v in labels.iteritems():
                print k, v

    def create_dataframe(self):
        data = []
        for dialogue in ifilter(lambda x: x.has_deal(), self.examples):
            for turn in dialogue.turns:
                for u in turn.iter_utterances():
                    row = {
                        'post_id': dialogue.post_id,
                        'chat_id': dialogue.chat_id,
                        'scenario_id': dialogue.scenario_id,
                        'buyer_target': dialogue.buyer_target,
                        'listing_price': dialogue.listing_price,
                        'margin_seller': dialogue.margins['seller'],
                        'margin_buyer': dialogue.margins['buyer'],
                        'stage': u.stage,
                        'role': turn.role,
                        'num_tokens': u.num_tokens(),
                    }
                    for a in u.speech_acts:
                        row['act_{}'.format(a[0].name)] = 1
                    for cat, word_count in u.categories.iteritems():
                        row['cat_{}'.format(cat)] = sum(word_count.values())
                    for q in dialogue.eval_questions:
                        for r in ('buyer', 'seller'):
                            key = 'eval_{question}_{role}'.format(question=q,
                                                                  role=r)
                            try:
                                row[key] = dialogue.eval_scores[r][q]
                            except KeyError:
                                row[key] = -1
                    data.append(row)
        df = pd.DataFrame(data).fillna(0)
        return df

    def summarize_liwc(self, k=10):
        categories = defaultdict(lambda: defaultdict(int))
        for dialogue in ifilter(lambda x: x.has_deal(), self.examples):
            for u in dialogue.iter_utterances():
                for cat, word_count in u.categories.iteritems():
                    for w, count in word_count.iteritems():
                        categories[cat][w] += count

        cat_freq = {
            c: sum(word_counts.values())
            for c, word_counts in categories.iteritems()
        }
        cat_freq = sorted(cat_freq.items(), key=lambda x: x[1], reverse=True)

        def topk(word_counts, k=10):
            wc = sorted(word_counts.items(), key=lambda x: x[1],
                        reverse=True)[:k]
            return [x[0] for x in wc]

        for cat, count in cat_freq:
            print cat, count, topk(categories[cat], k)

    def html_visualize(self,
                       output,
                       img_path,
                       css_file=None,
                       mpld3_plugin=None):
        examples = [ex for ex in self.examples if self.has_deal(ex)]
        examples.sort(
            key=lambda d: (d.scenario_id, d.outcome['offer']['price']))

        include_scripts = []
        include_scripts.append(
            '<script type="text/javascript" src="http://d3js.org/d3.v3.min.js"></script>'
        )
        include_scripts.append(
            '<script type="text/javascript" src="https://mpld3.github.io/js/mpld3.v0.2.js"></script>'
        )
        if mpld3_plugin:
            with open(mpld3_plugin, 'r') as fin:
                mpld3_script = fin.read()
                include_scripts.append(
                    '<script type="text/javascript">{}</script>'.format(
                        mpld3_script))

        css_style = """
            table {
                table-layout: fixed;
                width: 600px;
                border-collapse: collapse;
                }

            tr:nth-child(n) {
                border: solid thin;
                }

            .fig {
                height: 500px;
            }
            """
        if css_file:
            with open(css_file, 'r') as fin:
                css_style = '{}\n{}'.format(css_style, fin.read())
        style = [
            '<style type="text/css">', css_style, Dialogue.css, '</style>'
        ]

        header = ['<head>'] + style + include_scripts + ['</head>']

        plot_divs = []
        plot_scripts = []
        plot_scripts.append('<script type="text/javascript">')
        for d in examples:
            var_name = 'json_{}'.format(d.chat_id)
            json_str = json.dumps(d.fig_dict())
            div_name = 'fig_{}'.format(d.chat_id)
            plot_divs.append('<div class="fig" id="{div_name}"></div>'.format(
                div_name=div_name))
            plot_scripts.append('var {var_name} = {json_str};'.format(
                var_name=var_name, json_str=json_str))
            plot_scripts.append(
                '!function(mpld3) {{ mpld3.draw_figure("{div_name}", {var_name}); }}(mpld3);'
                .format(div_name=div_name, var_name=var_name))
        plot_scripts.append('</script>')

        body = ['<body>']
        for d, plot_div in izip(examples, plot_divs):
            body.extend(
                NegotiationHTMLVisualizer.render_scenario(None,
                                                          img_path=img_path,
                                                          kbs=d.kbs,
                                                          uuid=d.scenario_id))
            body.append('<p>Final deal: {}</p>'.format(
                d.outcome['offer']['price']))
            body.append(plot_div)
        body.extend(plot_scripts)
        body.append('</body>')

        html_lines = ['<html>'] + header + body + ['</html>']

        outfile = open(output, 'w')
        for line in html_lines:
            outfile.write(line.encode('utf8') + "\n")
        print 'Write to', output
        outfile.close()

    @classmethod
    def get_price_trend(cls, price_tracker, chat, agent=None):
        def _normalize_price(seen_price):
            return (float(seller_target) - float(seen_price)) / (
                float(seller_target) - float(buyer_target))

        scenario = NegotiationScenario.from_dict(None, chat['scenario'])
        # print chat['scenario']
        kbs = scenario.kbs
        roles = {
            kbs[0].facts['personal']['Role']: 0,
            kbs[1].facts['personal']['Role']: 1
        }

        buyer_target = kbs[roles[utils.BUYER]].facts['personal']['Target']
        seller_target = kbs[roles[utils.SELLER]].facts['personal']['Target']

        prices = []
        for e in chat['events']:
            if e['action'] == 'message':
                if agent is not None and e['agent'] != agent:
                    continue
                raw_tokens = tokenize(e['data'])
                # link entity
                linked_tokens = price_tracker.link_entity(raw_tokens,
                                                          kb=kbs[e['agent']])
                for token in linked_tokens:
                    if isinstance(token, Entity):
                        try:
                            replaced = PriceScaler.unscale_price(
                                kbs[e['agent']], token)
                        except OverflowError:
                            print "Raw tokens: ", raw_tokens
                            print "Overflow error: {:s}".format(token)
                            print kbs[e['agent']].facts
                            print "-------"
                            continue
                        norm_price = _normalize_price(replaced.canonical.value)
                        if 0. <= norm_price <= 2.:
                            # if the number is greater than the list price or significantly lower than the buyer's
                            # target it's probably not a price
                            prices.append(norm_price)
                # do some stuff here
            elif e['action'] == 'offer':
                norm_price = _normalize_price(e['data']['price'])
                if 0. <= norm_price <= 2.:
                    prices.append(norm_price)
                # prices.append(e['data']['price'])

        # print "Chat: {:s}".format(chat['uuid'])
        # print "Trend:", prices

        return prices

    @classmethod
    def split_turn(cls, turn):
        # a single turn can be comprised of multiple sentences
        return cls.sent_tokenizer.tokenize(turn)

    def get_speech_acts(self, ex):
        stats = {0: [], 1: []}
        kbs = ex.kbs
        for e in ex.events:
            if e.action != 'message':
                continue

            sentences = self.split_turn(e.data.lower())

            for s in sentences:
                tokens = tokenize(s)
                linked_tokens = self.price_tracker.link_entity(tokens,
                                                               kb=kbs[e.agent])
                act = SpeechActAnalyzer.get_speech_act(s, linked_tokens)
                stats[e.agent].append(act)

        return stats

    @classmethod
    def valid_price(cls, price):
        return price <= MAX_PRICE and price >= MIN_PRICE

    @classmethod
    def valid_margin(cls, margin):
        return margin <= MAX_MARGIN and margin >= MIN_MARGIN

    def get_first_price(self, ex):
        agents = {1: None, 0: None}
        for e in ex.events:
            if e.action == 'message':
                for sent_tokens in e.tokens:
                    for token in sent_tokens:
                        if agents[1] and agents[0]:
                            return agents
                        # Return at the first mention
                        if is_entity(token):
                            price = token.canonical.value
                            agents[e.agent] = (e.role, price)
                            return agents
        return agents

    @classmethod
    def get_margin(cls, ex, price, agent, role, remove_outlier=True):
        agent_target = ex.scenario.kbs[agent].facts["personal"]["Target"]
        partner_target = ex.scenario.kbs[1 - agent].facts["personal"]["Target"]
        midpoint = (agent_target + partner_target) / 2.
        norm_factor = np.abs(midpoint - agent_target)
        if role == utils.SELLER:
            margin = (price - midpoint) / norm_factor
        else:
            margin = (midpoint - price) / norm_factor
        if remove_outlier and not cls.valid_margin(margin):
            return None
        return margin

    @classmethod
    def print_ex(cls, ex):
        print '===================='
        for e in ex.events:
            print e.role.upper(), e.data
        print '===================='

    def get_basic_stats(self, ex):
        stats = {0: None, 1: None}
        for agent in (0, 1):
            num_turns = ex.num_turns()
            num_tokens = ex.num_tokens()
            stats[agent] = {
                'role': ex.kbs[agent].facts['personal']['Role'],
                'num_turns': num_turns,
                'num_tokens_per_turn': num_tokens / num_turns * 1.,
            }
        return stats

    def is_good_negotiator(self, final_margin):
        if final_margin > 0.8 and final_margin <= 1:
            return 1
        elif final_margin < -0.8 and final_margin <= -1:
            return -1
        else:
            return 0

    @classmethod
    def has_deal(cls, ex):
        if ex.outcome is None or ex.outcome['reward'] == 0 or ex.outcome.get(
                'offer', None) is None or ex.outcome['offer']['price'] is None:
            return False
        return True

    def plot_speech_acts(self, output='figures/speech_acts'):
        data = defaultdict(list)
        for ex in ifilter(self.has_deal, self.examples):
            stats = self.get_speech_acts(ex)
            final_price = ex.outcome['offer']['price']
            for agent, acts in stats.iteritems():
                role = ex.agent_to_role[agent]
                final_margin = self.get_margin(ex, final_price, agent, role)
                label = self.is_good_negotiator(final_margin)
                for act in acts:
                    data['role'].append(role)
                    data['label'].append(label)
                    data['final_margin'].append(final_margin)
                    data['act'].append(act)

        for role in ('seller', 'buyer'):
            print role.upper()
            print '=' * 40
            good_seller_act = [
                a for r, l, m, a in izip(data['role'], data['label'],
                                         data['final_margin'], data['act'])
                if r == role and l == 1
            ]
            bad_seller_act = [
                a for r, l, m, a in izip(data['role'], data['label'],
                                         data['final_margin'], data['act'])
                if r == role and l == -1
            ]
            sum_act = lambda a, l: np.mean([1 if a in x else 0 for x in l])
            print len(good_seller_act), len(bad_seller_act)
            print '{:<20} {:<10} {:<10}'.format('ACT', 'GOOD', 'BAD')
            print '-' * 40
            for act in SpeechActs.ACTS:
                print '{:<20} {:<10.4f} {:<10.4f}'.format(
                    act, sum_act(act, good_seller_act),
                    sum_act(act, bad_seller_act))

        return

    def plot_basic_stats(self, output='figures/basic_stats'):
        data = {
            'role': [],
            'final_margin': [],
            'num_turns': [],
            'num_tokens_per_turn': [],
            'label': []
        }
        for ex in ifilter(self.has_deal, self.examples):
            stats = self.get_basic_stats(ex)
            final_price = ex.outcome['offer']['price']
            for agent, stats in stats.iteritems():
                role = stats['role']
                final_margin = self.get_margin(ex, final_price, agent, role)
                label = self.is_good_negotiator(final_margin)
                for k, v in stats.iteritems():
                    data[k].append(v)
                data['label'].append(label)
                data['final_margin'].append(final_margin)
        fig = plt.figure()
        df = pd.DataFrame(data)
        #g = sns.lmplot(x='num_tokens_per_turn', y='final_margin', col='role', row='label', data=dataframe, scatter_kws={'alpha':0.5})
        #g.savefig(output)
        for role in ('buyer', 'seller'):
            d1 = df.num_tokens_per_turn[(df['label'] == 1)
                                        & (df['role'] == role)]
            d2 = df.num_tokens_per_turn[(df.label == -1) & (df.role == role)]
            sns.distplot(d1, label='good')
            sns.distplot(d2, label='bad')
            plt.legend()
            plt.savefig('%s_%s.png' % (output, role))
            plt.clf()

    def plot_opening_vs_result(self, output='figures/opening_vs_result.png'):
        data = {'role': [], 'init_margin': [], 'final_margin': []}
        for ex in ifilter(self.has_deal, self.examples):
            final_price = ex.outcome['offer']['price']
            init_prices = self.get_first_price(ex)
            for agent, p in init_prices.iteritems():
                if p is None:
                    continue
                role, price = p
                init_margin = self.get_margin(ex, price, agent, role)
                final_margin = self.get_margin(ex, final_price, agent, role)
                if init_margin is None or final_margin is None:
                    continue
                # NOTE: sometimes one is saying a price is not okay, i.e. negative mention
                # TODO: detect negative vs positive mention
                if init_margin == -1 and init_margin < final_margin:
                    continue
                #if init_margin < final_margin:
                #    print role, (price, init_margin), (final_price, final_margin)
                #    self.print_ex(ex)
                #    import sys; sys.exit()
                for k, v in izip(('role', 'init_margin', 'final_margin'),
                                 (role, init_margin, final_margin)):
                    data[k].append(v)
        dataframe = pd.DataFrame(data)
        fig = plt.figure()
        g = sns.lmplot(x='init_margin',
                       y='final_margin',
                       col='role',
                       data=dataframe,
                       scatter_kws={'alpha': 0.5})
        g.savefig(output)

    def group_outcomes_and_roles(self):
        buyer_wins = []
        seller_wins = []
        ties = 0
        total_chats = 0
        for ex in self.dataset:
            roles = {
                0: ex["scenario"]["kbs"][0]["personal"]["Role"],
                1: ex["scenario"]["kbs"][1]["personal"]["Role"]
            }
            winner = utils.get_winner(ex)
            if winner is None:
                continue
            total_chats += 1
            if winner == -1:
                buyer_wins.append(ex)
                seller_wins.append(ex)
                ties += 1
            elif roles[winner] == utils.BUYER:
                buyer_wins.append(ex)
            elif roles[winner] == utils.SELLER:
                seller_wins.append(ex)

        print "# of ties: {:d}".format(ties)
        print "Total chats with outcomes: {:d}".format(total_chats)
        return buyer_wins, seller_wins

    def plot_length_vs_margin(self, out_name='turns_vs_margin.png'):
        labels = ['buyer wins', 'seller wins']
        plt.figure(figsize=(10, 6))

        for (chats, lbl) in zip([self.buyer_wins, self.seller_wins], labels):
            margins = defaultdict(list)
            for ex in chats:
                turns = utils.get_turns_per_agent(ex)
                total_turns = turns[0] + turns[1]
                margin = utils.get_margin(ex)
                if margin > MAX_MARGIN or margin < 0.:
                    continue

                margins[total_turns].append(margin)

            sorted_keys = list(sorted(margins.keys()))

            turns = []
            means = []
            errors = []
            for k in sorted_keys:
                if len(margins[k]) >= THRESHOLD:
                    turns.append(k)
                    means.append(np.mean(margins[k]))
                    errors.append(stats.sem(margins[k]))

            plt.errorbar(turns, means, yerr=errors, label=lbl, fmt='--o')

        plt.legend()
        plt.xlabel('# of turns in dialogue')
        plt.ylabel('Margin of victory')

        save_path = os.path.join(self.stats_path, out_name)
        plt.savefig(save_path)

    def plot_margin_histograms(self):
        for (lbl, group) in zip(['buyer_wins', 'seller_wins'],
                                [self.buyer_wins, self.seller_wins]):
            margins = []
            for ex in group:
                winner = utils.get_winner(ex)
                if winner is None:
                    continue
                margin = utils.get_margin(ex)
                if 0 <= margin <= MAX_MARGIN:
                    margins.append(margin)

            b = np.linspace(0, MAX_MARGIN, num=int(MAX_MARGIN / 0.2) + 2)
            print b
            hist, bins = np.histogram(margins, bins=b)

            width = np.diff(bins)
            center = (bins[:-1] + bins[1:]) / 2

            fig, ax = plt.subplots(figsize=(8, 3))
            ax.bar(center, hist, align='center', width=width)
            ax.set_xticks(bins)

            save_path = os.path.join(
                self.stats_path, '{:s}_wins_margins_histogram.png'.format(lbl))
            plt.savefig(save_path)

    def plot_length_histograms(self):
        lengths = []
        for ex in self.dataset:
            winner = utils.get_winner(ex)
            if winner is None:
                continue
            turns = utils.get_turns_per_agent(ex)
            total_turns = turns[0] + turns[1]
            lengths.append(total_turns)

        hist, bins = np.histogram(lengths)

        width = np.diff(bins)
        center = (bins[:-1] + bins[1:]) / 2

        fig, ax = plt.subplots(figsize=(8, 3))
        ax.bar(center, hist, align='center', width=width)
        ax.set_xticks(bins)

        save_path = os.path.join(self.stats_path, 'turns_histogram.png')
        plt.savefig(save_path)

    def plot_price_trends(self, top_n=10):
        labels = ['buyer_wins', 'seller_wins']
        for (group, lbl) in zip([self.buyer_wins, self.seller_wins], labels):
            plt.figure(figsize=(10, 6))
            trends = []
            for chat in group:
                winner = utils.get_winner(chat)
                margin = utils.get_margin(chat)
                if margin > 1.0 or margin < 0.:
                    continue
                if winner is None:
                    continue

                # print "Winner: Agent {:d}\tWin margin: {:.2f}".format(winner, margin)
                if winner == -1 or winner == 0:
                    trend = self.get_price_trend(self.price_tracker,
                                                 chat,
                                                 agent=0)
                    if len(trend) > 1:
                        trends.append((margin, chat, trend))
                if winner == -1 or winner == 1:
                    trend = self.get_price_trend(self.price_tracker,
                                                 chat,
                                                 agent=1)
                    if len(trend) > 1:
                        trends.append((margin, chat, trend))

                # print ""

            sorted_trends = sorted(trends, key=lambda x: x[0], reverse=True)
            for (idx, (margin, chat,
                       trend)) in enumerate(sorted_trends[:top_n]):
                print '{:s}: Chat {:s}\tMargin: {:.2f}'.format(
                    lbl, chat['uuid'], margin)
                print 'Trend: ', trend
                print chat['scenario']['kbs']
                print ""
                plt.plot(trend, label='Margin={:.2f}'.format(margin))
            plt.legend()
            plt.xlabel('N-th price mentioned in chat')
            plt.ylabel('Value of mentioned price')
            out_path = os.path.join(self.stats_path,
                                    '{:s}_trend.png'.format(lbl))
            plt.savefig(out_path)

    def _get_price_mentions(self, chat, agent=None):
        scenario = NegotiationScenario.from_dict(None, chat['scenario'])
        # print chat['scenario']
        kbs = scenario.kbs

        prices = 0
        for e in chat['events']:
            if agent is not None and e['agent'] != agent:
                continue
            if e['action'] == 'message':
                raw_tokens = tokenize(e['data'])
                # link entity
                linked_tokens = self.price_tracker.link_entity(
                    raw_tokens, kb=kbs[e['agent']])
                for token in linked_tokens:
                    if isinstance(token,
                                  Entity) and token.canonical.type == 'price':
                        prices += 1

        return prices

    def plot_speech_acts_old(self):
        labels = ['buyer_wins', 'seller_wins']
        for (group, lbl) in zip([self.buyer_wins, self.seller_wins], labels):
            plt.figure(figsize=(10, 6))
            speech_act_counts = dict(
                (act, defaultdict(list)) for act in SpeechActs.ACTS)
            for chat in group:
                winner = utils.get_winner(chat)
                margin = utils.get_margin(chat)
                if margin > MAX_MARGIN or margin < 0.:
                    continue
                if winner is None:
                    continue

                margin = round_partial(
                    margin
                )  # round the margin to the nearest 0.1 to reduce noise

                if winner == -1 or winner == 0:
                    speech_acts = self.get_speech_acts(chat, agent=0)
                    # print "Chat {:s}\tWinner: {:d}".format(chat['uuid'], winner)
                    # print speech_acts
                    for act in SpeechActs.ACTS:
                        frac = float(speech_acts.count(act)) / float(
                            len(speech_acts))
                        speech_act_counts[act][margin].append(frac)
                if winner == -1 or winner == 1:
                    speech_acts = self.get_speech_acts(chat, agent=1)
                    # print "Chat {:s}\tWinner: {:d}".format(chat['uuid'], winner)
                    # print speech_acts
                    for act in SpeechActs.ACTS:
                        frac = float(speech_acts.count(act)) / float(
                            len(speech_acts))
                        speech_act_counts[act][margin].append(frac)

            for act in SpeechActs.ACTS:
                counts = speech_act_counts[act]
                margins = []
                fracs = []
                errors = []
                bin_totals = 0.
                for m in sorted(counts.keys()):
                    if len(counts[m]) > THRESHOLD:
                        bin_totals += len(counts[m])
                        margins.append(m)
                        fracs.append(np.mean(counts[m]))
                        errors.append(stats.sem(counts[m]))
                print bin_totals / float(len(margins))

                plt.errorbar(margins, fracs, yerr=errors, label=act, fmt='--o')

            plt.xlabel('Margin of victory')
            plt.ylabel('Fraction of speech act occurences')
            plt.title('Speech act frequency vs. margin of victory')
            plt.legend()
            save_path = os.path.join(self.stats_path,
                                     '{:s}_speech_acts.png'.format(lbl))
            plt.savefig(save_path)

    def plot_speech_acts_by_role(self):
        labels = utils.ROLES
        for lbl in labels:
            plt.figure(figsize=(10, 6))
            speech_act_counts = dict(
                (act, defaultdict(list)) for act in SpeechActs.ACTS)
            for chat in self.dataset:
                if utils.get_winner(chat) is None:
                    # skip chats with no outcomes
                    continue
                speech_acts = self.get_speech_acts(chat, role=lbl)
                agent = 1 if chat['scenario']['kbs'][1]['personal'][
                    'Role'] == lbl else 0
                margin = utils.get_margin(chat, agent=agent)
                if margin > MAX_MARGIN:
                    continue
                margin = round_partial(margin)
                for act in SpeechActs.ACTS:
                    frac = float(speech_acts.count(act)) / float(
                        len(speech_acts))
                    speech_act_counts[act][margin].append(frac)

            for act in SpeechActs.ACTS:
                counts = speech_act_counts[act]
                margins = []
                fracs = []
                errors = []
                for m in sorted(counts.keys()):
                    if len(counts[m]) > THRESHOLD:
                        margins.append(m)
                        fracs.append(np.mean(counts[m]))
                        errors.append(stats.sem(counts[m]))

                plt.errorbar(margins, fracs, yerr=errors, label=act, fmt='--o')

            plt.xlabel('Margin of victory')
            plt.ylabel('Fraction of speech act occurences')
            plt.title('Speech act frequency vs. margin of victory')
            plt.legend()
            save_path = os.path.join(self.stats_path,
                                     '{:s}_speech_acts.png'.format(lbl))
            plt.savefig(save_path)