Exemplo n.º 1
0
def create_test_batch(roles, picks, bans, vocab):
    batch = recDotDefaultDict()
    example = create_example(roles, picks, bans, vocab)
    batch = batching_dicts(
        batch, example)  # list of dictionaries to dictionary of lists.
    batch = padding(batch)
    return batch
Exemplo n.º 2
0
  def article2entries(self, article):
    if not (article.category and article.contexts):
      return []

    entry = recDotDefaultDict()
    entry.title.raw = article.title
    desc = article.desc.split()
    entry.desc.raw = desc
    entry.desc.word = self.vocab.word.sent2ids(desc)

    entry.category.raw = article.category
    entry.category.label = self.vocab.category.token2id(article.category)
    if entry.category.label == self.vocab.category.token2id(_UNK):
      return []
    entry.contexts.raw = []
    entry.contexts.word = []
    entry.contexts.char = []
    entry.contexts.link = []
    for context, link in article.contexts[:self.max_contexts]:
      context = context.split()
      entry.contexts.raw.append(context)

      if self.mask_link:
        context = mask_span(context, link)
      entry.contexts.word.append(self.vocab.word.sent2ids(context))
      entry.contexts.char.append(self.vocab.char.sent2ids(context))
      entry.contexts.link.append(link)
    return [entry]
Exemplo n.º 3
0
def read_log(fpath, max_num_card, num_next_candidates_samples, is_sente):
  if not os.path.exists(fpath):
    return 
  logs = [l.strip() for l in open(fpath)]
  if len(logs) != 90: # 30 turn * 3
    return None
  n_examples = len(logs) // 3

  data = []
  for i in range(n_examples):
    t = 3 * i
    state = [int(x) for x in logs[t].split()]
    candidates = [int(x) for x in logs[t+1].split()]
    action = int(logs[t+2].split()[0])
    new_state = [int(x) for x in logs[t].split()]
    new_state[action] += 1

    # Temporal values.
    reward = 0
    is_end_state = False
    d = recDotDefaultDict()

    # Convert state format from [num_card0, num_card1, ....] to [card0_0, card0_1, ... card0_max_num_card, card1_0, ...]. Each element must be 0 or 1.
    d.state = state_to_onehot(state, max_num_card)
    d.next_state = state_to_onehot(new_state, max_num_card)
    d.next_candidates = [random.sample(range(160), 3) for _ in range(num_next_candidates_samples)] # Since the next candidates are unknown by all rights, the next expected Q-values are aproximated by sampling.
    d.candidates = candidates
    d.action = action
    d.reward = reward
    d.is_end_state = is_end_state
    d.is_sente = [1, 0] if is_sente else [0, 1]
    d.current_num_cards = sum(state) 
    data.append(d)
  data[-1].is_end_state = True
  return data
Exemplo n.º 4
0
    def setup_placeholder(self):
        ph = recDotDefaultDict()
        ph.text.word = tf.placeholder(
            tf.int32, name='text.word', shape=[None, None, None]
        ) if self.encoder.wbase else None  # [batch_size, max_num_sent, max_num_word]
        ph.text.char = tf.placeholder(
            tf.int32, name='text.char', shape=[None, None, None, None]
        ) if self.encoder.cbase else None  # [batch_size, max_num_sent, max_num_word, max_num_char]

        ph.query = tf.placeholder(tf.int32, name='query',
                                  shape=[None, 2])  # [batch_size, 2]
        for k in ['subjective', 'objective']:
            ph.target[k] = tf.placeholder(
                tf.int32,
                name='target.%s' % k,
                shape=[None, None, self.max_mention_width
                       ])  # [max_sequence_len, max_mention_width]

        ph.mentions = tf.placeholder(
            tf.int32, name='mentions',
            shape=[None, None, 2])  # [batch_size, max_num_mentions, 2]
        ph.num_mentions = tf.placeholder(tf.int32,
                                         name='num_mentions',
                                         shape=[None])
        ph.loss_weights_by_label = tf.placeholder(
            tf.float32,
            name='loss_weights_by_label',
            shape=[None, self.vocab.rel.size])
        return ph
Exemplo n.º 5
0
 def tensorize(self, data):
     batch = recDotDefaultDict()
     for d in data:
         batch = batching_dicts(
             batch, d)  # list of dictionaries to dictionary of lists.
     batch = self.padding(batch)
     return batch
Exemplo n.º 6
0
    def __init__(self, sess, config, encoder, activation=tf.nn.relu):
        super(CategoryClassification, self).__init__(sess, config)
        self.sess = sess
        self.encoder = encoder
        self.activation = activation
        self.is_training = encoder.is_training
        self.keep_prob = 1.0 - tf.to_float(
            self.is_training) * config.dropout_rate
        self.vocab = encoder.vocab

        with tf.name_scope('Placeholder'):
            self.ph = recDotDefaultDict()
            # [batch_size, max_num_context, max_num_words]
            self.ph.text.word = tf.placeholder(
                tf.int32, name='contexts.word',
                shape=[None, None, None]) if self.encoder.wbase else None
            self.ph.text.char = tf.placeholder(
                tf.int32, name='contexts.char',
                shape=[None, None, None, None]) if self.encoder.cbase else None

            self.ph.link = tf.placeholder(tf.int32,
                                          name='link',
                                          shape=[None, None, 2])
            self.ph.target = tf.placeholder(tf.int32,
                                            name='link',
                                            shape=[None])

            self.sentence_length = tf.count_nonzero(self.ph.text.word, axis=-1)
            self.num_contexts = tf.cast(
                tf.count_nonzero(self.sentence_length, axis=-1), tf.float32)

        with tf.name_scope('Encoder'):
            word_repls = encoder.word_encoder.word_encode(self.ph.text.word)
            char_repls = encoder.word_encoder.char_encode(self.ph.text.char)

            text_emb, text_outputs, state = encoder.encode(
                [word_repls, char_repls], self.sentence_length)
            mention_starts, mention_ends = tf.unstack(self.ph.link, axis=-1)

            mention_repls, head_scores = encoder.get_batched_mention_emb(
                text_emb, text_outputs, mention_starts,
                mention_ends)  # [batch_size, max_n_contexts, mention_size]
            self.adv_outputs = tf.reshape(
                text_outputs, [
                    shape(text_outputs, 0) * shape(text_outputs, 1),
                    shape(text_outputs, 2),
                    shape(text_outputs, 3)
                ]
            )  # [batch_size * max_n_contexts, max_sentence_length, output_size]

        with tf.variable_scope('Inference'):
            self.outputs = self.inference(mention_repls)
            self.predictions = tf.argmax(self.outputs, axis=-1)

        with tf.name_scope("Loss"):
            self.losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=self.outputs, labels=self.ph.target)

            self.loss = tf.reduce_mean(self.losses)
Exemplo n.º 7
0
    def formatize_and_print(self_class, flat_batches, predictions, vocab=None):
        '''
    Args:
    - predictions: A list of a tuple (relations, mention_starts, mention_ends), which contains the predicted relations (both of subj, obj) and mention spans. Each element of the list corresponds to each example.

    '''
        n_data = 0
        n_success = 0
        triples = recDotDict({'gold': [], 'prediction': []})
        mentions = recDotDict({'gold': [], 'prediction': []})

        for i, (b, p) in enumerate(zip(flat_batches, predictions)):
            query = b.query
            gold_triples = b.triples
            predicted_triples = recDotDefaultDict()
            predicted_triples.subjective = []
            predicted_triples.objective = []

            gold_mentions = [
                recDotDict({
                    'raw': r,
                    'flat_position': p
                }) for r, p in zip(b.mentions.raw, b.mentions.flat_position)
            ]

            predicted_mentions = []
            for (subj_rel_id, obj_rel_id), (mention_start,
                                            mention_end) in zip(*p):
                if mention_end <= len(b.text.flat) and (
                        mention_start, mention_end) != (PAD_ID, PAD_ID):
                    mention = recDotDict()
                    mention.raw = ' '.join(
                        b.text.flat[mention_start:mention_end + 1])
                    mention.flat_position = (mention_start, mention_end)
                    predicted_mentions.append(mention)
                else:
                    continue
                if subj_rel_id != vocab.rel.UNK_ID:
                    rel = dotDict({
                        'raw': vocab.rel.id2token(subj_rel_id),
                        'name': vocab.rel.id2name(subj_rel_id),
                    })
                    predicted_triples.subjective.append([query, rel, mention])
                if obj_rel_id != vocab.rel.UNK_ID:
                    rel = dotDict({
                        'raw': vocab.rel.id2token(obj_rel_id),
                        'name': vocab.rel.id2name(obj_rel_id),
                    })
                    predicted_triples.objective.append([mention, rel, query])
            triples.gold.append(gold_triples)
            triples.prediction.append(predicted_triples)
            mentions.gold.append(gold_mentions)
            mentions.prediction.append(predicted_mentions)
            _id = BOLD + '<%04d>' % (i) + RESET
            print(_id)
            self_class.print_example(
                b, vocab, prediction=[predicted_triples, predicted_mentions])
            print('')
        return triples, mentions
Exemplo n.º 8
0
  def add_replay(self, fpath):
    log = self.read_log(fpath)
    if not log:
      return 
    fkey, p1log, p2log = log
    p1log_tensors = recDotDefaultDict()
    for d in p1log:
      batching_dicts(p1log_tensors, d)

    p2log_tensors = recDotDefaultDict()
    for d in p2log:
      batching_dicts(p2log_tensors, d)
    # Propagate rewards from the last state for N-step TD.
    # T = len(p1log)
    # for t in range(T):
    #   p1log[t].reward = (self.td_gamma ** (T - t - 1)) * p1log[-1].reward
    #   p2log[t].reward = (self.td_gamma ** (T - t - 1)) * p2log[-1].reward
    data = [p1log_tensors, p2log_tensors]
    return data
Exemplo n.º 9
0
 def setup_placeholder(self, config):
     # Placeholders
     with tf.name_scope('Placeholder'):
         ph = recDotDefaultDict()
         # encoder's placeholder
         ph.picks = tf.placeholder(tf.int32, name='pick', shape=[None, 10])
         ph.bans = tf.placeholder(tf.int32, name='ban', shape=[None, 10])
         ph.roles = tf.placeholder(tf.int32, name='role', shape=[None, 10])
         ph.win = tf.placeholder(tf.int32, name='win', shape=[None])
     return ph
Exemplo n.º 10
0
def create_example_from_jline(jline, vocab):
    # assert type(context) == list # The context must be separated into tokens.
    #example = recDotDefaultDict()
    # example.context.raw = context
    # example.context.word = vocab.encoder.word.tokens2ids(context)
    # if vocab.encoder.char:
    #   example.context.char = vocab.encoder.char.tokens2ids(context)
    # if response:
    #   example.response.raw = response
    #   example.response.word = vocab.decoder.word.tokens2ids(response)
    example = recDotDefaultDict()

    example.gameVersion = '.'.join(jline.gameVersion.split('.')[:2])
    #example.gameMode = jline.gameMode
    example.queueID = int(jline.queueId)

    # 100 for blue side. 200 for red side.
    blue = 0 if jline.teams[0].teamId == 100 else 1
    red = 1 - blue

    example.win = 0 if jline.teams[blue].win == 'Win' else 1
    #example.bans.order = [x.pickTurn for x in jline.teams[blue].bans] +[x.pickTurn for x in jline.teams[red].bans]

    example.bans.ids = [
        vocab.champion.key2id(int(x.championId))
        for x in jline.teams[blue].bans
    ] + [
        vocab.champion.key2id(int(x.championId)) for x in jline.teams[red].bans
    ]
    example.bans.raw = [
        vocab.champion.key2token(int(x.championId))
        for x in jline.teams[blue].bans
    ] + [
        vocab.champion.key2token(int(x.championId))
        for x in jline.teams[red].bans
    ]

    # participantsは必ずblueからになってる?
    example.teams = [
        0 if x.teamId == TEAMID.blue else 1 for x in jline.participants
    ]

    example.picks.ids = [
        vocab.champion.key2id(int(x.championId)) for x in jline.participants
    ]
    example.picks.raw = [
        vocab.champion.key2token(int(x.championId)) for x in jline.participants
    ]
    example.roles.raw = [
        reformat_role(x.timeline.role, x.timeline.lane)
        for x in jline.participants
    ]
    example.roles.ids = [vocab.role.token2id(x) for x in example.roles.raw]
    return example
Exemplo n.º 11
0
def create_example(roles, picks, bans, vocab):
    example = recDotDefaultDict()
    example.roles.raw = roles
    example.roles.ids = [vocab.role.token2id(x) for x in roles]

    example.picks.raw = picks
    example.picks.ids = [vocab.champion.token2id(x) for x in picks]
    example.bans.raw = bans
    example.bans.ids = [vocab.champion.token2id(x) for x in bans]
    example.win = 0
    return example
Exemplo n.º 12
0
        def qid2entity(qid, article):
            assert qid in article.link
            s_id, (begin, end) = article.link[qid]

            # The offset is the number of words in previous sentences.
            offset = sum([len(sent) for sent in article.text[:s_id]])
            entity = recDotDefaultDict()
            # Replace entity's name with the actual representation in the article.
            entity.raw = ' '.join(article.text[s_id][begin:end + 1])
            entity.position = article.link[qid]
            entity.flat_position = (begin + offset, end + offset)
            return entity
Exemplo n.º 13
0
 def get_batches(self, batch_size, current_epoch, 
                 random_sample=True, is_training=True):
   dataset_path = os.path.join(self.dataset_path, '%03d' % current_epoch)
   data = self.read_data(dataset_path)
   for i in range(self.iterations_per_epoch):
     batch = recDotDefaultDict()
     batch.is_training = is_training 
     if random_sample:
       replays = random.sample(data, batch_size)
       for x in replays:
         batching_dicts(batch, x)
       yield batch
Exemplo n.º 14
0
        def triple2entry(triple, article, label):
            entry = recDotDefaultDict()
            entry.qid = article.qid

            subj_qid, rel_pid, obj_qid = triple
            rel = self.properties[rel_pid].name.split()
            entry.rel.raw = rel  # 1D tensor of str.
            entry.rel.word = self.vocab.word.sent2ids(rel)  # 1D tensor of int.
            entry.rel.char = self.vocab.char.sent2ids(rel)  # 2D tensor of int.
            entry.subj = qid2position(subj_qid, article)  # (begin, end)
            entry.obj = qid2position(obj_qid, article)  # (begin, end)
            entry.label = label  # 1 or 0.

            entry.text.raw = article.text
            raw_text = article.text
            if self.mask_link:
                raw_text = mask_span(raw_text, entry.subj.position)
                raw_text = mask_span(raw_text, entry.obj.position)
            entry.text.word = self.vocab.word.sent2ids(raw_text)
            entry.text.char = self.vocab.char.sent2ids(raw_text)

            return entry
Exemplo n.º 15
0
    def debug(self, model=None):
        if not model:
            model = self.create_model(self.config)
            self.output_variables_as_text(model)
            exit(1)
        batch = common.recDotDefaultDict()
        state = [common.flatten([[1, 0, 0, 0] for _ in range(160)])]
        batch.state = state
        batch.is_sente = [[1, 0] for _ in state]
        batch.current_num_cards = [[1, 1] for s in state]
        batch.is_training = False
        res = model.step(batch, 0)
        print(res)

        batches = self.dataset.get_batches(self.config.batch_size,
                                           0,
                                           is_training=True)
        for b in batches:
            b = common.flatten_recdict(b)
            for k in b:
                print(k, b[k])
            exit(1)
Exemplo n.º 16
0
    def setup_placeholders(self):
        # Placeholders
        with tf.name_scope('Placeholder'):
            ph = recDotDefaultDict()
            # encoder's placeholder
            ph.text.word = tf.placeholder(
                tf.int32, name='text.word',
                shape=[None, None,
                       None])  # [batch_size, n_max_contexts, n_max_word]
            ph.text.char = tf.placeholder(
                tf.int32, name='text.char', shape=[None, None, None, None]
            ) if self.encoder.cbase else None  # [batch_size, n_max_contexts, n_max_word, n_max_char]

            ph.link = tf.placeholder(
                tf.int32, name='link.position',
                shape=[None, None, 2])  # [batch_size, n_max_contexts, 2]

            # decoder's placeholder
            ph.target = tf.placeholder(tf.int32,
                                       name='descriptions',
                                       shape=[None, None])
        return ph
Exemplo n.º 17
0
    def article2entries(self, article):
        entry = recDotDefaultDict()
        entry.title.raw = article.title

        entry.desc.raw = self.vocab.decoder.word.tokenizer(article.desc)
        entry.desc.word = self.vocab.decoder.word.sent2ids(article.desc)

        entry.contexts.raw = []
        entry.contexts.word = []
        entry.contexts.char = []
        entry.contexts.link = []
        for context, link in article.contexts[:self.max_contexts]:
            context = context.split()
            entry.contexts.raw.append(context)

            if self.mask_link:
                context = mask_span(context, link)
            entry.contexts.word.append(
                self.vocab.encoder.word.sent2ids(context))
            entry.contexts.char.append(
                self.vocab.encoder.char.sent2ids(context))
            entry.contexts.link.append(link)
        return [entry]
Exemplo n.º 18
0
    def __init__(self, sess, config, manager, encoder, activation=tf.nn.relu):
        super().__init__(sess, config)
        self.sess = sess
        self.config = config
        self.encoder = encoder
        self.vocab = manager.vocab
        self.other_tasks = manager.tasks

        self.is_training = encoder.is_training
        self.keep_prob = 1.0 - tf.to_float(
            self.is_training) * config.dropout_rate
        self.feature_size = config.f_embedding_size
        self.max_training_sentences = config.max_training_sentences
        self.mention_ratio = config.mention_ratio
        self.max_antecedents = config.max_antecedents
        self.use_metadata = config.use_metadata
        self.ffnn_depth = config.ffnn_depth
        self.ffnn_size = config.ffnn_size

        self.max_mention_width = config.max_mention_width
        self.use_width_feature = config.use_width_feature
        self.use_distance_feature = config.use_distance_feature

        # Placeholders
        with tf.name_scope('Placeholder'):
            self.ph = recDotDefaultDict()
            # Input document
            self.ph.text.word = tf.placeholder(
                tf.int32, name='text.word',
                shape=[None, None]) if self.encoder.wbase else None
            self.ph.text.char = tf.placeholder(
                tf.int32, name='text.char',
                shape=[None, None, None]) if self.encoder.cbase else None

            # TODO: truncate_exampleしたものをw_sentencesにfeedした時、何故かtruncate前の単語数を数えてしまう。とりあえずsentence_lengthを直接feed.

            #self.sentence_length = tf.count_nonzero(self.ph.text.word, axis=1, dtype=tf.int32)
            self.ph.sentence_length = tf.placeholder(tf.int32, shape=[None])

            # Clusters
            self.ph.gold_starts = tf.placeholder(tf.int32,
                                                 name='gold_starts',
                                                 shape=[None])
            self.ph.gold_ends = tf.placeholder(tf.int32,
                                               name='gold_ends',
                                               shape=[None])
            self.ph.cluster_ids = tf.placeholder(tf.int32,
                                                 name='cluster_ids',
                                                 shape=[None])
            # Metadata
            self.ph.speaker_ids = tf.placeholder(tf.int32,
                                                 shape=[None],
                                                 name="speaker_ids")
            self.ph.genre = tf.placeholder(tf.int32, shape=[], name="genre")

        # Embeddings
        with tf.variable_scope('Embeddings'):
            self.same_speaker_emb = self.initialize_embeddings(
                'same_speaker', [2, self.feature_size])  # True or False.
            self.genre_emb = self.initialize_embeddings(
                'genre', [self.vocab.genre.size, self.feature_size])
            self.mention_width_emb = self.initialize_embeddings(
                "mention_width", [self.max_mention_width, self.feature_size])
            self.mention_distance_emb = self.initialize_embeddings(
                "mention_distance", [10, self.feature_size])

        word_repls = encoder.word_encoder.word_encode(self.ph.text.word)
        char_repls = encoder.word_encoder.char_encode(self.ph.text.char)
        text_emb, text_outputs, state = encoder.encode(
            [word_repls, char_repls], self.ph.sentence_length)
        self.adv_outputs = text_outputs  # for adversarial MTL, it must have the shape [batch_size]

        with tf.name_scope('candidates_and_mentions'):
            flattened_text_emb, flattened_text_outputs, flattened_sentence_indices = self.flatten_doc_to_sent(
                text_emb, text_outputs, self.ph.sentence_length)

            candidate_starts, candidate_ends, candidate_mention_scores, mention_starts, mention_ends, mention_scores, mention_emb = self.get_mentions(
                flattened_text_emb, flattened_text_outputs,
                flattened_sentence_indices)

            with tf.name_scope('keep_mention_embs'):
                self.pred_mention_emb = self.get_mention_emb(
                    flattened_text_emb, flattened_text_outputs, mention_starts,
                    mention_ends, False)

                # Add dummy to gold mentions for the case that an example has no gold mentions and self.get_mention_emb(...) causes an error.
                dummy = tf.constant([0], dtype=tf.int32)
                self.gold_mention_emb = self.get_mention_emb(
                    flattened_text_emb,
                    flattened_text_outputs,
                    # self.ph.gold_starts,
                    # self.ph.gold_ends,
                    tf.concat([self.ph.gold_starts, dummy], axis=0),
                    tf.concat([self.ph.gold_ends, dummy], axis=0),
                    False)

        with tf.name_scope('antecedents'):
            antecedents, antecedent_scores, antecedent_labels = self.get_antecedents(
                mention_scores, mention_starts, mention_ends, mention_emb,
                self.ph.speaker_ids, self.ph.genre, self.ph.gold_starts,
                self.ph.gold_ends, self.ph.cluster_ids)

        self.outputs = [
            candidate_starts, candidate_ends, candidate_mention_scores,
            mention_starts, mention_ends, antecedents, antecedent_scores
        ]

        with tf.name_scope('loss'):
            loss = self.softmax_loss(antecedent_scores,
                                     antecedent_labels)  # [num_mentions]
            self.loss = tf.reduce_sum(loss)  # []

        with tf.name_scope("Summary"):
            self.summary_loss = tf.placeholder(tf.float32,
                                               shape=[],
                                               name='coref_loss')
Exemplo n.º 19
0
    def article2entries(self, article):
        def qid2entity(qid, article):
            assert qid in article.link
            s_id, (begin, end) = article.link[qid]

            # The offset is the number of words in previous sentences.
            offset = sum([len(sent) for sent in article.text[:s_id]])
            entity = recDotDefaultDict()
            # Replace entity's name with the actual representation in the article.
            entity.raw = ' '.join(article.text[s_id][begin:end + 1])
            entity.position = article.link[qid]
            entity.flat_position = (begin + offset, end + offset)
            return entity

        entry = recDotDefaultDict()
        entry.qid = article.qid

        entry.text.raw = article.text
        entry.text.flat = article.flat_text
        entry.text.word = [self.vocab.word.sent2ids(s) for s in article.text]
        entry.text.char = [self.vocab.char.sent2ids(s) for s in article.text]

        entry.query = qid2entity(article.qid, article)  # (begin, end)

        # Articles which contain triples less than self.min_triples are discarded since they can be incorrect.
        if len(article.triples.subjective.ids) + len(
                article.triples.objective.ids) < self.min_triples:
            return []
        entry.mentions.raw = []
        entry.mentions.flat_position = []

        for t_type in ['subjective', 'objective']:
            entry.triples[t_type] = []
            entry.target[t_type] = [[
                self.vocab.rel.UNK_ID for j in range(self.max_mention_width)
            ] for i in range(article.num_words)]

            for triple_idx, triple in enumerate(
                    article.triples[t_type].ids):  # triple = [subj, rel, obj]
                is_subjective = triple[0] == article.qid
                query_qid, rel_pid, mention_qid = triple if is_subjective else reversed(
                    triple)
                # TODO: 同じメンションがクエリと異なる関係を持つ場合は?
                mention = qid2entity(mention_qid, article)
                #entry.mentions[t_type].raw.append(mention.raw)
                #entry.mentions[t_type].flat_position.append(mention.flat_position)
                entry.mentions.raw.append(mention.raw)
                entry.mentions.flat_position.append(mention.flat_position)

                rel = dotDict({
                    'raw': rel_pid,
                    'name': self.vocab.rel.token2name(rel_pid)
                })

                begin, end = mention.flat_position
                if end - begin < self.max_mention_width:
                    entry.target[t_type][begin][
                        end - begin] = self.vocab.rel.token2id(rel_pid)

                triple = [entry.query, rel, mention
                          ] if is_subjective else [mention, rel, entry.query]
                entry.triples[t_type].append(triple)

        relation_freqs = Counter(flatten(entry.target.subjective))

        # TODO: For now this experiments focus only on subjective relations.
        entry.triples.objective = []
        #####################
        entry.loss_weights_by_label = [1.0 for _ in range(self.vocab.rel.size)]

        entry.num_mentions = len(entry.mentions.flat_position)
        return [entry]
Exemplo n.º 20
0
    def __init__(self, sess, config, encoder, activation=tf.nn.relu):
        super(GraphLinkPrediction, self).__init__(sess, config)
        self.sess = sess
        self.encoder = encoder
        self.activation = activation

        self.is_training = encoder.is_training
        self.keep_prob = 1.0 - tf.to_float(
            self.is_training) * config.dropout_rate
        self.ffnn_size = config.ffnn_size
        self.cnn_filter_widths = config.cnn.filter_widths
        self.cnn_filter_size = config.cnn.filter_size

        # Placeholders
        with tf.name_scope('Placeholder'):
            self.ph = recDotDefaultDict()
            self.ph.text.word = tf.placeholder(
                tf.int32, name='text.word',
                shape=[None, None]) if self.encoder.wbase else None
            self.ph.text.char = tf.placeholder(
                tf.int32, name='text.char',
                shape=[None, None, None]) if self.encoder.cbase else None

            self.ph.subj = tf.placeholder(tf.int32,
                                          name='subj.position',
                                          shape=[None, 2])
            self.ph.obj = tf.placeholder(tf.int32,
                                         name='obj.position',
                                         shape=[None, 2])

            self.ph.rel = dotDict()
            self.ph.rel.word = tf.placeholder(
                tf.int32, name='rel.word',
                shape=[None, None]) if self.encoder.wbase else None
            self.ph.rel.char = tf.placeholder(
                tf.int32, name='rel.char',
                shape=[None, None, None]) if self.encoder.cbase else None
            self.ph.target = tf.placeholder(tf.int32,
                                            name='target',
                                            shape=[None])
            self.sentence_length = tf.count_nonzero(self.ph.text.word, axis=1)

        with tf.name_scope('Encoder'):
            text_emb, encoder_outputs, encoder_state = self.encoder.encode(
                [self.ph.text.word, self.ph.text.char], self.sentence_length)
            self.encoder_outputs = encoder_outputs

        with tf.variable_scope('Subject') as scope:
            mention_starts, mention_ends = tf.unstack(self.ph.subj, axis=1)
            subj_outputs, _ = self.encoder.get_batched_mention_emb(
                text_emb, encoder_outputs, mention_starts, mention_ends)

        with tf.variable_scope('Object') as scope:
            mention_starts, mention_ends = tf.unstack(self.ph.obj, axis=1)
            obj_outputs, _ = self.encoder.get_batched_mention_emb(
                text_emb, encoder_outputs, mention_starts, mention_ends)

        with tf.variable_scope('Relation') as scope:
            # Stop gradient to prevent biased learning to the words used as relation labels.
            rel_words_emb = tf.stop_gradient(
                self.encoder.word_encoder.encode(
                    [self.ph.rel.word, self.ph.rel.char]))
            with tf.name_scope("compose_words"):
                rel_outputs = cnn(rel_words_emb, self.cnn_filter_widths,
                                  self.cnn_filter_size)

        with tf.variable_scope('Inference'):
            score_outputs = self.inference(subj_outputs, rel_outputs,
                                           obj_outputs)  # [batch_size, 1]
            self.outputs = tf.round(
                tf.reshape(score_outputs,
                           [shape(score_outputs, 0)]))  # [batch_size]
        with tf.name_scope("Loss"):
            self.losses = self.cross_entropy(score_outputs, self.ph.target)
            self.loss = tf.reduce_mean(self.losses)
Exemplo n.º 21
0
 def qid2position(qid, article):
     assert qid in article.link
     begin, end = article.link[qid]
     entity = recDotDefaultDict()
     entity.raw = article.text[begin:end + 1]
     entity.position = (begin, end)