Ejemplo n.º 1
0
    def __init__(self, type_content_data):
        super(TreeDecodeModel, self).__init__(type_content_data)
        self.direct_descedent_lstm = YLSTMCell(0)
        if tree_decode_way == tree_decode_2d:
            self.two_dimen_lstm = Y2DLSTMCell(1)
        elif tree_decode_way == tree_decode_embed:
            self.one_dimen_lstm = YLSTMCell(2)

        number_of_tokens = self.type_content_data[all_token_summary][
            TokenHitNum]
        self.linear_token_output_w = tf.Variable(
            random_uniform_variable_initializer(256, 566,
                                                [number_of_tokens, num_units]))
        self.one_hot_token_embedding = tf.Variable(
            random_uniform_variable_initializer(256, 56,
                                                [number_of_tokens, num_units]))
        self.one_hot_token_cell_embedding = tf.Variable(
            random_uniform_variable_initializer(25, 56,
                                                [number_of_tokens, num_units]))
        self.token_embedder = TokenAtomEmbed(self.type_content_data,
                                             self.one_hot_token_embedding)
        self.token_cell_embedder = TokenAtomEmbed(
            self.type_content_data, self.one_hot_token_cell_embedding)

        self.encode_tree = EncodeOneAST(self.type_content_data,
                                        self.token_embedder,
                                        self.token_cell_embedder)
Ejemplo n.º 2
0
    def __init__(self, type_content_data):
        super(SequenceDecodeModel, self).__init__(type_content_data)
        self.token_lstm = YLSTMCell(0)
        number_of_tokens = self.type_content_data[all_token_summary][
            TokenHitNum]
        self.linear_token_output_w = tf.Variable(
            random_uniform_variable_initializer(256, 566,
                                                [number_of_tokens, num_units]))
        self.one_hot_token_embedding = tf.Variable(
            random_uniform_variable_initializer(256, 56,
                                                [number_of_tokens, num_units]))
        self.token_embedder = TokenAtomEmbed(self.type_content_data,
                                             self.one_hot_token_embedding)
        self.token_attention = None
        if decode_attention_way > decode_no_attention:
            self.token_attention = YAttention(10)

        self.dup_token_lstm, self.dup_one_hot_token_embedding, self.dup_token_embedder, self.dup_token_pointer = None, None, None, None
        #     if use_dup_model:
        #       self.dup_token_lstm = YLSTMCell(125)
        #       self.dup_one_hot_token_embedding = tf.Variable(random_uniform_variable_initializer(25, 56, [number_of_tokens, num_units]))
        #       self.dup_token_embedder = TokenAtomEmbed(self.type_content_data, self.dup_one_hot_token_embedding)
        #       self.dup_token_pointer = PointerNetwork(222)

        self.token_decoder = TokenDecoder(self.type_content_data,
                                          self.metrics_index,
                                          self.linear_token_output_w,
                                          self.token_attention,
                                          self.dup_token_pointer)
 def __init__(self, type_content_data):
   super(StatementDupModel, self).__init__(type_content_data)
   number_of_tokens = self.type_content_data[all_token_summary][TokenHitNum]
   self.dup_skeleton_forward_cell_h = tf.Variable(random_uniform_variable_initializer(155, 572, [1, 2, num_units]))
   self.dup_skeleton_backward_cell_h = tf.Variable(random_uniform_variable_initializer(152, 572, [1, 2, num_units]))
   if atom_decode_mode == token_decode:
     self.one_dup_hot_token_embedding = tf.Variable(random_uniform_variable_initializer(252, 226, [number_of_tokens, num_units]))
     self.dup_token_embedder = TokenAtomEmbed(self.type_content_data, self.one_dup_hot_token_embedding)
     self.dup_token_lstm = YLSTMCell(9)
     self.dup_token_pointer = PointerNetwork(655)
     self.integrate_computer = None
     if compute_memory_in_memory_mode:
       self.integrate_computer = Y2DirectLSTMCell(105)
     if compute_token_memory:
       self.dup_mem_nn = NTMOneDirection(800)
       self.dup_forward_token_lstm = YLSTMCell(10)
       self.dup_backward_token_lstm = YLSTMCell(11)
     self.dup_token_decoder = DupTokenDecoder(type_content_data, self.metrics_index, self.dup_token_pointer) 
   else:
     assert False, "Wrong atom_decode_mode"
    def __init__(self, type_content_data):
        super(LinearDupModel, self).__init__(type_content_data)
        number_of_tokens = self.type_content_data[all_token_summary][
            TokenHitNum]

        self.dup_token_lstm = YLSTMCell(125)
        self.dup_one_hot_token_embedding = tf.Variable(
            random_uniform_variable_initializer(25, 56,
                                                [number_of_tokens, num_units]))
        self.dup_token_embedder = TokenAtomEmbed(
            self.type_content_data, self.dup_one_hot_token_embedding)
        self.dup_token_pointer = PointerNetwork(222)

        self.integrate_computer = None
        if compute_memory_in_memory_mode:
            self.integrate_computer = Y2DirectLSTMCell(105)

        self.dup_token_decoder = DupTokenDecoder(self.type_content_data,
                                                 self.metrics_index,
                                                 self.dup_token_pointer)
class StatementDupModel(StatementDecodeModel):
  
  def __init__(self, type_content_data):
    super(StatementDupModel, self).__init__(type_content_data)
    number_of_tokens = self.type_content_data[all_token_summary][TokenHitNum]
    self.dup_skeleton_forward_cell_h = tf.Variable(random_uniform_variable_initializer(155, 572, [1, 2, num_units]))
    self.dup_skeleton_backward_cell_h = tf.Variable(random_uniform_variable_initializer(152, 572, [1, 2, num_units]))
    if atom_decode_mode == token_decode:
      self.one_dup_hot_token_embedding = tf.Variable(random_uniform_variable_initializer(252, 226, [number_of_tokens, num_units]))
      self.dup_token_embedder = TokenAtomEmbed(self.type_content_data, self.one_dup_hot_token_embedding)
      self.dup_token_lstm = YLSTMCell(9)
      self.dup_token_pointer = PointerNetwork(655)
      self.integrate_computer = None
      if compute_memory_in_memory_mode:
        self.integrate_computer = Y2DirectLSTMCell(105)
      if compute_token_memory:
        self.dup_mem_nn = NTMOneDirection(800)
        self.dup_forward_token_lstm = YLSTMCell(10)
        self.dup_backward_token_lstm = YLSTMCell(11)
      self.dup_token_decoder = DupTokenDecoder(type_content_data, self.metrics_index, self.dup_token_pointer) 
    else:
      assert False, "Wrong atom_decode_mode"
    
  def create_in_use_tensors_meta(self):
    result = super(StatementDupModel, self).create_in_use_tensors_meta() + [("dup_loop_forward_cells", tf.TensorShape([None, num_units])), ("dup_loop_forward_hs", tf.TensorShape([None, num_units])), ("dup_loop_backward_cells", tf.TensorShape([None, num_units])), ("dup_loop_backward_hs", tf.TensorShape([None, num_units]))]
    return result
  
  def set_up_field_when_calling(self, one_example, training):
    self.token_info_tensor = one_example[0]
    self.token_info_start_tensor = one_example[1]
    self.token_info_end_tensor = one_example[2]
    self.token_base_model_accuracy = one_example[3]
    self.token_base_model_mrr = one_example[4]
    self.training = training
    
  def stmt_iterate_body(self, i, i_len, *stmt_metrics_tuple):
    stmt_metrics = list(stmt_metrics_tuple)
     
    stmt_start = self.token_info_start_tensor[i]
    stmt_end = self.token_info_end_tensor[i]
     
    '''
    this step ignores the statement with no type content tokens (only with skeleton token). 
    '''
    r_stmt_start = stmt_start
    itearate_tokens_continue = tf.cast(stmt_end >= r_stmt_start, int_type)
    f_res = tf.while_loop(self.itearate_tokens_cond, self.itearate_tokens_body, [tf.constant(0, int_type), itearate_tokens_continue, stmt_start, stmt_end, *stmt_metrics], shape_invariants=[tf.TensorShape(()), tf.TensorShape(()), tf.TensorShape(()), tf.TensorShape(()), *self.metrics_shape], parallel_iterations=1)
    stmt_metrics = f_res[4:]
    return (i + 1, i_len, *stmt_metrics)
  
  def token_iterate_body(self, i, i_len, ini_i, *stmt_metrics_tuple):
    stmt_metrics = list(stmt_metrics_tuple)
    oracle_type_content_en = self.token_info_tensor[0][i]
    oracle_type_content_var = self.token_info_tensor[1][i]
    oracle_type_content_var_relative = self.token_info_tensor[2][i]
    conserved_memory_length = self.token_info_tensor[3][i]
    token_kind = self.token_info_tensor[4][i]
    base_model_accuracy = self.token_base_model_accuracy[i]
    base_model_mrr = self.token_base_model_mrr[i]
    if atom_decode_mode == token_decode:
      stmt_metrics = self.dup_token_decoder.decode_one_token(stmt_metrics, self.training, oracle_type_content_en, oracle_type_content_var, oracle_type_content_var_relative, token_kind, base_model_accuracy, base_model_mrr)
      if compute_token_memory:
        stmt_metrics = one_lstm_step("dup_", stmt_metrics, self.metrics_index, oracle_type_content_en, self.dup_token_lstm, self.dup_token_embedder)
      else:
        stmt_metrics = one_lstm_step_and_update_memory("dup_", stmt_metrics, self.metrics_index, oracle_type_content_en, oracle_type_content_var, conserved_memory_length, self.dup_token_lstm, self.dup_token_embedder, self.integrate_computer)
    else:
      assert False
    return (i + 1, i_len, ini_i, *stmt_metrics)
  
  def itearate_tokens(self, stmt_start, stmt_end, stmt_metrics):
    
#     b_stmt_start = self.token_info_start_tensor[i]
#     stmt_end = self.token_info_end_tensor[i]
    
#     stmt_start = b_stmt_start
    
    skt_use_id = 0
    
    '''
    iterate tokens
    '''
    dup_ini_f_cell = tf.expand_dims(self.dup_skeleton_forward_cell_h[skt_use_id][0], 0)
    dup_ini_f_h = tf.expand_dims(self.dup_skeleton_forward_cell_h[skt_use_id][1], 0)
    dup_ini_b_cell = tf.expand_dims(self.dup_skeleton_backward_cell_h[skt_use_id][0], 0)
    dup_ini_b_h = tf.expand_dims(self.dup_skeleton_backward_cell_h[skt_use_id][1], 0)
      
    '''
    leaf info also means variable info
    '''
    info_length = stmt_end - stmt_start + 1
    token_info = tf.slice(self.token_info_tensor[0], [stmt_start], [info_length])
    leaf_info = tf.slice(self.token_info_tensor[1], [stmt_start], [info_length])
    
    f_res = tf.while_loop(self.token_iterate_cond, self.token_iterate_body, [stmt_start, stmt_end, stmt_start, *stmt_metrics], shape_invariants=[tf.TensorShape(()), tf.TensorShape(()), tf.TensorShape(()), *self.metrics_shape], parallel_iterations=1)
    stmt_metrics = list(f_res[3:])
     
    if compute_token_memory:
      '''
      compute token memory and compute repetition
      '''
      dup_embeds = tf.zeros([0, num_units], float_type)
      e_res = tf.while_loop(self.token_embed_cond, self.token_embed_body, [stmt_start, stmt_end, dup_embeds, *stmt_metrics], shape_invariants=[tf.TensorShape(()), tf.TensorShape(()), tf.TensorShape([None, num_units]), *self.metrics_shape], parallel_iterations=1)
      dup_embeds = e_res[2]
      
      stmt_metrics[self.metrics_index["dup_loop_forward_cells"]] = dup_ini_f_cell
      stmt_metrics[self.metrics_index["dup_loop_forward_hs"]] = dup_ini_f_h
      f_res = tf.while_loop(self.forward_loop_cond, self.forward_loop_body, [0, stmt_end - (stmt_start), dup_embeds, *stmt_metrics], shape_invariants=[tf.TensorShape(()), tf.TensorShape(()), tf.TensorShape([None, num_units]), *self.metrics_shape], parallel_iterations=1)
      stmt_metrics = list(f_res[3:])
      stmt_metrics[self.metrics_index["dup_loop_backward_cells"]] = dup_ini_b_cell
      stmt_metrics[self.metrics_index["dup_loop_backward_hs"]] = dup_ini_b_h
      f_res = tf.while_loop(self.backward_loop_cond, self.backward_loop_body, [0, stmt_end - (stmt_start), dup_embeds, *stmt_metrics], shape_invariants=[tf.TensorShape(()), tf.TensorShape(()), tf.TensorShape([None, num_units]), *self.metrics_shape], parallel_iterations=1)
      stmt_metrics = list(f_res[3:])
      
      dup_discrete_memory_vars, dup_discrete_memory_tokens, dup_discrete_forward_memory_cell, dup_discrete_forward_memory_h = self.dup_mem_nn.compute_variables_in_statement(leaf_info, token_info, stmt_metrics[self.metrics_index["dup_memory_acc_cell"]], stmt_metrics[self.metrics_index["dup_memory_acc_h"]], stmt_metrics[self.metrics_index["dup_loop_forward_cells"]], stmt_metrics[self.metrics_index["dup_loop_forward_hs"]], stmt_metrics[self.metrics_index["dup_loop_backward_cells"]], stmt_metrics[self.metrics_index["dup_loop_backward_hs"]])
      stmt_metrics[self.metrics_index["dup_memory_en"]], stmt_metrics[self.metrics_index["dup_memory_acc_cell"]], stmt_metrics[self.metrics_index["dup_memory_acc_h"]] = self.dup_mem_nn.update_memory_with_variables_in_statement(stmt_metrics[self.metrics_index["dup_memory_en"]], stmt_metrics[self.metrics_index["dup_memory_acc_cell"]], stmt_metrics[self.metrics_index["dup_memory_acc_h"]], dup_discrete_memory_vars, dup_discrete_memory_tokens, dup_discrete_forward_memory_cell, dup_discrete_forward_memory_h)
        
      assert token_memory_mode == only_memory_mode
       
    return stmt_metrics
  
  '''
  build memory
  '''
  
  def token_embed_body(self, i, i_len, dup_embeds, *stmt_metrics_tuple):
    stmt_metrics = list(stmt_metrics_tuple)
    oracle_type_content_en = self.token_info_tensor[0][i]
    dup_e_emebd = self.dup_token_embedder.compute_h(oracle_type_content_en)
    if compute_token_memory:
      oracle_type_content_var = self.token_info_tensor[1][i]
      mem_hs = stmt_metrics[self.metrics_index["dup_memory_acc_h"]]
      use_mem = tf.cast(tf.logical_and(tf.greater(oracle_type_content_var, 0), tf.less(oracle_type_content_var, tf.shape(mem_hs)[0])), int_type)
      r_var = oracle_type_content_var * use_mem
      dup_e_emebd = tf.stack([dup_e_emebd, [mem_hs[r_var]]])[use_mem]
    dup_embeds = tf.concat([dup_embeds, dup_e_emebd], axis=0)
    return (i + 1, i_len, dup_embeds, *stmt_metrics_tuple)
Ejemplo n.º 6
0
class TreeDecodeModel(BasicDecodeModel):
    '''
  the parameter start_nodes must be in the same level and be successive while
  stop_nodes do not have constraints
  '''
    def __init__(self, type_content_data):
        super(TreeDecodeModel, self).__init__(type_content_data)
        self.direct_descedent_lstm = YLSTMCell(0)
        if tree_decode_way == tree_decode_2d:
            self.two_dimen_lstm = Y2DLSTMCell(1)
        elif tree_decode_way == tree_decode_embed:
            self.one_dimen_lstm = YLSTMCell(2)

        number_of_tokens = self.type_content_data[all_token_summary][
            TokenHitNum]
        self.linear_token_output_w = tf.Variable(
            random_uniform_variable_initializer(256, 566,
                                                [number_of_tokens, num_units]))
        self.one_hot_token_embedding = tf.Variable(
            random_uniform_variable_initializer(256, 56,
                                                [number_of_tokens, num_units]))
        self.one_hot_token_cell_embedding = tf.Variable(
            random_uniform_variable_initializer(25, 56,
                                                [number_of_tokens, num_units]))
        self.token_embedder = TokenAtomEmbed(self.type_content_data,
                                             self.one_hot_token_embedding)
        self.token_cell_embedder = TokenAtomEmbed(
            self.type_content_data, self.one_hot_token_cell_embedding)

        self.encode_tree = EncodeOneAST(self.type_content_data,
                                        self.token_embedder,
                                        self.token_cell_embedder)

    def create_in_use_tensors_meta(self):
        result = [("token_accumulated_cell", tf.TensorShape([None,
                                                             num_units])),
                  ("token_accumulated_h", tf.TensorShape([None, num_units]))]
        return result

    def __call__(self, one_example, training=True):
        post_order_node_type_content_en_tensor, post_order_node_child_start_tensor, post_order_node_child_end_tensor, post_order_node_children_tensor = one_example[
            0], one_example[1], one_example[2], one_example[3]
        self.pre_post_order_node_type_content_en_tensor, self.pre_post_order_node_state_tensor, self.pre_post_order_node_post_order_index_tensor, self.pre_post_order_node_parent_grammar_index, self.pre_post_order_node_kind = one_example[
            4], one_example[5], one_example[6], one_example[7], one_example[8]
        _, self.encoded_h, self.encoded_children_cell, self.encoded_children_h = self.encode_tree.get_encoded_embeds(
            post_order_node_type_content_en_tensor,
            post_order_node_child_start_tensor,
            post_order_node_child_end_tensor, post_order_node_children_tensor)
        self.training = training
        ini_metrics = list(
            create_empty_tensorflow_tensors(self.metrics_meta,
                                            self.contingent_parameters,
                                            self.metrics_contingent_index))
        f_res = tf.while_loop(
            self.tree_iterate_cond,
            self.tree_iterate_body, [
                0,
                tf.shape(self.pre_post_order_node_type_content_en_tensor)[-1],
                *ini_metrics
            ],
            shape_invariants=[
                tf.TensorShape(()),
                tf.TensorShape(()), *self.metrics_shape
            ],
            parallel_iterations=1)
        f_res = list(f_res[2:2 + len(self.statistical_metrics_meta)])
        return f_res

    def tree_iterate_cond(self, i, i_len, *_):
        return tf.less(i, i_len)

    def tree_iterate_body(self, i, i_len, *stmt_metrics_tuple):
        stmt_metrics = list(stmt_metrics_tuple)

        en = self.pre_post_order_node_type_content_en_tensor[i]
        state = self.pre_post_order_node_state_tensor[i]
        post_order_index = self.pre_post_order_node_post_order_index_tensor[i]
        grammar_idx = self.pre_post_order_node_parent_grammar_index[i]
        kind = self.pre_post_order_node_kind[i]

        #     if token_accuracy_mode == consider_all_token_accuracy:
        #       t_valid_bool = tf.constant(True, bool_type)
        #     elif token_accuracy_mode == only_consider_token_kind_accuracy:
        #       t_valid_bool = is_in_token_kind_range(kind)
        #     else:
        #       assert False
        #     t_valid_float = tf.cast(t_valid_bool, float_type)
        #     t_valid_int = tf.cast(t_valid_bool, int_type)
        t_valid_float, t_valid_int = is_token_in_consideration(
            en, -1, kind,
            self.type_content_data[all_token_summary][TokenHitNum])
        en_valid_float, en_valid_int = is_en_valid(
            en, self.type_content_data[all_token_summary][TokenHitNum])
        #     en_valid_bool = tf.logical_and(tf.greater(en, 2), tf.less(en, self.type_content_data[all_token_summary][TokenHitNum]))
        #     en_valid_float = tf.cast(en_valid_bool, float_type)
        #     en_valid_int = tf.cast(en_valid_bool, int_type)
        out_use_en = tf.stack([UNK_en, en])[en_valid_int]

        non_leaf_post_bool = tf.equal(state, 2)
        non_leaf_post = tf.cast(non_leaf_post_bool, int_type)
        node_acc_valid = 1 - non_leaf_post
        ''' pre remove '''
        before_remove_length = tf.shape(
            stmt_metrics[self.metrics_index["token_accumulated_cell"]])[0]
        stmt_metrics[self.metrics_index["token_accumulated_cell"]] = tf.slice(
            stmt_metrics[self.metrics_index["token_accumulated_cell"]], [0, 0],
            [before_remove_length - non_leaf_post, num_units])
        stmt_metrics[self.metrics_index["token_accumulated_h"]] = tf.slice(
            stmt_metrics[self.metrics_index["token_accumulated_h"]], [0, 0],
            [before_remove_length - non_leaf_post, num_units])

        cell = tf.convert_to_tensor(
            [stmt_metrics[self.metrics_index["token_accumulated_cell"]][-1]])
        h = tf.convert_to_tensor(
            [stmt_metrics[self.metrics_index["token_accumulated_h"]][-1]])
        p_a_h = h

        if (not self.training) and tree_decode_with_grammar:
            start_idx = self.type_content_data[all_token_grammar_start][
                grammar_idx]
            end_idx = self.type_content_data[all_token_grammar_end][
                grammar_idx]
            ens_range = tf.slice(self.type_content_data[all_token_grammar_ids],
                                 [start_idx], [end_idx - start_idx + 1])
            o_mrr_of_this_node, o_accurate_of_this_node, o_loss_of_this_node = compute_loss_and_accurate_from_linear_with_computed_embeddings_in_limited_range(
                self.training, self.linear_token_output_w, ens_range,
                out_use_en, p_a_h)
        else:
            o_mrr_of_this_node, o_accurate_of_this_node, o_loss_of_this_node = compute_loss_and_accurate_from_linear_with_computed_embeddings(
                self.training, self.linear_token_output_w, out_use_en, p_a_h)

        mrr_of_this_node = tf.stack([0.0, o_mrr_of_this_node])[node_acc_valid]
        accurate_of_this_node = tf.stack(
            [tf.zeros([len(top_ks)], float_type),
             o_accurate_of_this_node])[node_acc_valid]
        loss_of_this_node = tf.stack([0.0,
                                      o_loss_of_this_node])[node_acc_valid]

        count_of_this_node = tf.stack([0, 1])[node_acc_valid]
        r_count = count_of_this_node * t_valid_int
        if ignore_unk_when_computing_accuracy:
            r_count = r_count * en_valid_int

        stmt_metrics[self.metrics_index["token_loss"]] = stmt_metrics[
            self.
            metrics_index["token_loss"]] + loss_of_this_node * en_valid_float
        stmt_metrics[self.metrics_index["token_accurate"]] = stmt_metrics[
            self.metrics_index[
                "token_accurate"]] + accurate_of_this_node * en_valid_float * t_valid_float
        stmt_metrics[
            self.metrics_index["token_mrr"]] = stmt_metrics[self.metrics_index[
                "token_mrr"]] + mrr_of_this_node * en_valid_float * t_valid_float
        stmt_metrics[self.metrics_index["token_count"]] = stmt_metrics[
            self.metrics_index["token_count"]] + r_count
        stmt_metrics[self.metrics_index["all_loss"]] = stmt_metrics[
            self.
            metrics_index["all_loss"]] + loss_of_this_node * en_valid_float
        stmt_metrics[self.metrics_index["all_accurate"]] = stmt_metrics[
            self.metrics_index[
                "all_accurate"]] + accurate_of_this_node * en_valid_float * t_valid_float
        stmt_metrics[
            self.metrics_index["all_mrr"]] = stmt_metrics[self.metrics_index[
                "all_mrr"]] + mrr_of_this_node * en_valid_float * t_valid_float
        stmt_metrics[self.metrics_index["all_count"]] = stmt_metrics[
            self.metrics_index["all_count"]] + r_count

        stmt_metrics[
            self.metrics_index["token_accurate_each_noavg"]] = stmt_metrics[
                self.metrics_index["token_accurate_each_noavg"]].write(
                    stmt_metrics[self.metrics_index[
                        "token_accurate_each_noavg"]].size(),
                    accurate_of_this_node * en_valid_float * t_valid_float)
        stmt_metrics[
            self.metrics_index["token_mrr_each_noavg"]] = stmt_metrics[
                self.metrics_index["token_mrr_each_noavg"]].write(
                    stmt_metrics[
                        self.metrics_index["token_mrr_each_noavg"]].size(),
                    mrr_of_this_node * en_valid_float * t_valid_float)

        stmt_metrics[
            self.metrics_index["token_count_each_int_noavg"]] = stmt_metrics[
                self.metrics_index["token_count_each_int_noavg"]].write(
                    stmt_metrics[self.metrics_index[
                        "token_count_each_int_noavg"]].size(), r_count)
        ''' infer next cell/h '''
        #       p_op = tf.print("loss_of_this_node:", loss_of_this_node, "accurate_of_this_node:", accurate_of_this_node)
        #       with tf.control_dependencies([p_op]):
        en_h = self.token_embedder.compute_h(en)
        _, (next_cell1, next_h1) = self.direct_descedent_lstm(en_h, (cell, h))
        if tree_decode_way == tree_decode_2d:
            next_cell2, next_h2 = self.two_dimen_lstm(
                en_h, cell, h, [self.encoded_children_cell[post_order_index]],
                [self.encoded_children_h[post_order_index]])
        elif tree_decode_way == tree_decode_embed:
            _, (next_cell2, next_h2) = self.one_dimen_lstm(
                tf.expand_dims(self.encoded_h[post_order_index], axis=0),
                (cell, h))
        else:
            print("Unrecognized tree decode mode!")
            assert False
        next_cell = tf.stack([next_cell1, next_cell2])[non_leaf_post]
        next_h = tf.stack([next_h1, next_h2])[non_leaf_post]
        ''' update accumulated cell/h '''
        ''' post remove '''
        should_remove = tf.cast(tf.greater_equal(state, 1), int_type)
        after_remove_length = tf.shape(
            stmt_metrics[self.metrics_index["token_accumulated_cell"]])[0]
        stmt_metrics[self.metrics_index["token_accumulated_cell"]] = tf.slice(
            stmt_metrics[self.metrics_index["token_accumulated_cell"]], [0, 0],
            [after_remove_length - should_remove, num_units])
        stmt_metrics[self.metrics_index["token_accumulated_h"]] = tf.slice(
            stmt_metrics[self.metrics_index["token_accumulated_h"]], [0, 0],
            [after_remove_length - should_remove, num_units])
        ''' concatenate newly inferred '''
        stmt_metrics[self.metrics_index["token_accumulated_cell"]] = tf.concat(
            [
                stmt_metrics[self.metrics_index["token_accumulated_cell"]],
                next_cell
            ],
            axis=0)
        stmt_metrics[self.metrics_index["token_accumulated_h"]] = tf.concat(
            [stmt_metrics[self.metrics_index["token_accumulated_h"]], next_h],
            axis=0)

        return (i + 1, i_len, *stmt_metrics)
Ejemplo n.º 7
0
  def __init__(self, type_content_data):
    super(StatementDecodeModel, self).__init__(type_content_data)
    
    self.ini_metrics = list(create_empty_tensorflow_tensors(self.metrics_meta, self.contingent_parameters, self.metrics_contingent_index))
    
    self.skeleton_forward_cell_h = tf.Variable(random_uniform_variable_initializer(255, 572, [1, 2, num_units]))
    self.skeleton_backward_cell_h = tf.Variable(random_uniform_variable_initializer(252, 572, [1, 2, num_units]))
    
    if compute_token_memory:
      self.mem_nn = NTMOneDirection(500)
      self.forward_token_lstm = YLSTMCell(3)
      self.backward_token_lstm = YLSTMCell(4)
      if compose_tokens_of_a_statement:
        if compose_mode == compose_one_way_lstm and (compose_one_way_lstm_mode == one_way_two_way_compose or compose_one_way_lstm_mode == one_way_three_way_compose):
          self.tokens_merger = Y2DirectLSTMCell(150)
        else:
          self.tokens_merger = EmbedMerger(150)
        if compose_mode == compose_one_way_lstm and compose_one_way_lstm_mode == one_way_three_way_compose:
          self.compose_lstm_cell = Y3DirectLSTMCell(5)
        else:
          self.compose_lstm_cell = YLSTMCell(5)
        if compose_mode == compose_bi_way_lstm:
          self.bi_way_merger = Y2DirectLSTMCell(155)
          self.bi_way_ini_cell = tf.Variable(random_uniform_variable_initializer(855, 55, [1, num_units]))
          self.bi_way_ini_h = tf.Variable(random_uniform_variable_initializer(850, 50, [1, num_units]))
          self.bi_way_lstm = YLSTMCell(52)
     
    self.token_attention = None
    if decode_attention_way > decode_no_attention:
      self.token_attention = YAttention(10)
     
    self.token_lstm = YLSTMCell(0)
    r_token_embedder_mode = token_embedder_mode
     
    number_of_tokens = self.type_content_data[all_token_summary][TokenHitNum]
    number_of_subwords = self.type_content_data[all_token_summary][SwordHitNum]
     
    if atom_decode_mode == token_decode:
      self.linear_token_output_w = tf.Variable(random_uniform_variable_initializer(256, 566, [number_of_tokens, num_units]))
      self.dup_token_embedder, self.dup_token_lstm, self.dup_token_pointer = None, None, None
#       if use_dup_model:
#         self.one_dup_hot_token_embedding = tf.Variable(random_uniform_variable_initializer(252, 226, [number_of_tokens, num_units]))
#         self.dup_token_embedder = TokenAtomEmbed(self.type_content_data, self.one_dup_hot_token_embedding)
#         self.dup_token_lstm = YLSTMCell(9)
#         self.dup_token_pointer = PointerNetwork(655)
#         self.dup_skeleton_forward_cell_h = tf.Variable(random_uniform_variable_initializer(155, 572, [number_of_skeletons, 2, num_units]))
#         self.dup_skeleton_backward_cell_h = tf.Variable(random_uniform_variable_initializer(152, 572, [number_of_skeletons, 2, num_units]))
#         if compute_token_memory:
#           self.dup_mem_nn = NTMOneDirection(800)
#           self.dup_forward_token_lstm = YLSTMCell(10)
#           self.dup_backward_token_lstm = YLSTMCell(11)
      self.token_decoder = TokenDecoder(type_content_data, self.metrics_index, self.linear_token_output_w, self.token_attention, self.dup_token_pointer)
       
    elif atom_decode_mode == sword_decode:
      self.linear_sword_output_w = tf.Variable(random_uniform_variable_initializer(256, 566, [number_of_subwords, num_units]))
      self.sword_lstm = YLSTMCell(12)
      r_token_embedder_mode = swords_compose_mode
      
    else:
      assert False, "Wrong atom_decode_mode"
       
    if r_token_embedder_mode == token_only_mode:
      self.one_hot_token_embedding = tf.Variable(random_uniform_variable_initializer(256, 56, [number_of_tokens, num_units]))
      self.token_embedder = TokenAtomEmbed(self.type_content_data, self.one_hot_token_embedding)
    elif r_token_embedder_mode == swords_compose_mode:
      self.one_hot_sword_embedding = tf.Variable(random_uniform_variable_initializer(256, 56, [number_of_subwords, num_units]))
      self.sword_embedder = SwordAtomEmbed(self.type_content_data, self.one_hot_sword_embedding)
      self.token_embedder = BiLSTMEmbed(self.type_content_data, self.one_hot_sword_embedding)
    else:
      assert False, "Wrong token_embedder_mode"