Ejemplo n.º 1
0
def get_default_retrieval_decoder(type_context: TypeContext,
                                  rnn_hidden_size: int,
                                  examples: ExamplesStore, replacer: Replacer,
                                  parser: StringParser,
                                  unparser: AstUnparser) -> TreeDecoder:
    type_vectorizer = vectorizers.TorchDeepEmbed(type_context.get_type_count(),
                                                 rnn_hidden_size)
    rnn_cell = TreeRNNCell(rnn_hidden_size, rnn_hidden_size)
    latent_store = make_latent_store_from_examples(examples, rnn_hidden_size,
                                                   replacer, parser, unparser)
    action_selector = RetrievalActionSelector(latent_store, type_context, 0.25)
    return TreeRNNDecoder(rnn_cell, action_selector, type_vectorizer,
                          type_context)
Ejemplo n.º 2
0
def get_default_nonretrieval_decoder(type_context: TypeContext,
                                     rnn_hidden_size: int) -> TreeDecoder:
    object_vectorizer = vectorizers.TorchDeepEmbed(
        type_context.get_object_count(), rnn_hidden_size)
    ast_embed_size = int(rnn_hidden_size / 2)
    type_vectorizer = vectorizers.TorchDeepEmbed(type_context.get_type_count(),
                                                 ast_embed_size)
    rnn_cell = TreeRNNCellLSTM(ast_embed_size, rnn_hidden_size)
    #rnn_cell = TreeCellOnlyAttn(rnn_hidden_size, rnn_hidden_size)
    #rnn_cell = TreeRNNCellGRU(rnn_hidden_size, rnn_hidden_size)
    action_selector = SimpleActionSelector(
        rnn_cell.output_size,
        objectselector.get_default_object_selector(type_context,
                                                   object_vectorizer),
        type_context)
    return TreeRNNDecoder(rnn_cell, action_selector, type_vectorizer,
                          type_context)
Ejemplo n.º 3
0
 def __init__(self, type_context: TypeContext):
     self._type_to_impl_tensor = [None] * type_context.get_type_count()
     for typ in type_context.get_all_types():
         self._type_to_impl_tensor[typ.ind] = \
             torch.LongTensor([impl.ind for impl in type_context.get_implementations(typ)])