예제 #1
0
    def dropout(cntk_layer, inputs):
        '''
         Setup dropout op with given parameters

        Args:
            cntk_layer (:class:`~cntk.contrib.crosstalkcaffe.unimodel.cntkmodel.CntkLayersDefinition`):
                the layer definition of dropout op
            inputs (list): a list contains all :class:`~cntk.ops.functions.Function` or 
                :class:`~cntk.input`

        Return:
            :func:`~cntk.ops.functions.Function`: instaced cntk dropout op
        '''
        sanitize_output = ops.sanitize_input(inputs[0])
        return ops.dropout(sanitize_output, name=cntk_layer.op_name)
예제 #2
0
    def dropout(cntk_layer, inputs):
        '''
         Setup dropout op with given parameters

        Args:
            cntk_layer (:class:`~cntk.contrib.crosstalkcaffe.unimodel.cntkmodel.CntkLayersDefinition`):
                the layer definition of dropout op
            inputs (list): a list contains all :class:`~cntk.ops.functions.Function` or 
                :class:`~cntk.input`

        Return:
            :func:`~cntk.ops.functions.Function`: instaced cntk dropout op
        '''
        sanitize_output = ops.sanitize_input(inputs[0])
        return ops.dropout(sanitize_output, name=cntk_layer.op_name)
예제 #3
0
def Dropout(prob):
    # expression
    x = Placeholder(name='dropout_arg')
    apply_x = dropout(x, dropout_rate=prob)
    return Block(apply_x, 'Dropout')
예제 #4
0
파일: layers.py 프로젝트: shadrack4292/CNTK
def Dropout(prob):
    # expression
    x = Placeholder(name='dropout_arg')
    apply_x = dropout(x, dropout_rate=prob)
    return Block(apply_x, 'Dropout')
예제 #5
0
def create_model(params: model_params):
    """
  Create ReasoNet model
  Args:
    params (class:`model_params`): The parameters used to create the model
  """
    logger.log(
        "Create model: dropout_rate: {0}, init:{1}, embedding_init: {2}".
        format(params.dropout_rate, params.init, params.embedding_init))
    # Query and Doc/Context/Paragraph inputs to the model
    query_seq_axis = Axis('sourceAxis')
    context_seq_axis = Axis('contextAxis')
    query_sequence = sequence.input(shape=(params.vocab_dim),
                                    is_sparse=True,
                                    sequence_axis=query_seq_axis,
                                    name='query')
    context_sequence = sequence.input(shape=(params.vocab_dim),
                                      is_sparse=True,
                                      sequence_axis=context_seq_axis,
                                      name='context')
    entity_ids_mask = sequence.input(shape=(1, ),
                                     is_sparse=False,
                                     sequence_axis=context_seq_axis,
                                     name='entity_ids_mask')
    # embedding
    if params.embedding_init is None:
        embedding_init = create_random_matrix(params.vocab_dim,
                                              params.embedding_dim)
    else:
        embedding_init = params.embedding_init
    embedding = parameter(shape=(params.vocab_dim, params.embedding_dim),
                          init=None)
    embedding.value = embedding_init
    embedding_matrix = constant(embedding_init,
                                shape=(params.vocab_dim, params.embedding_dim))

    if params.dropout_rate is not None:
        query_embedding = ops.dropout(times(query_sequence, embedding),
                                      params.dropout_rate,
                                      name='query_embedding')
        context_embedding = ops.dropout(times(context_sequence, embedding),
                                        params.dropout_rate,
                                        name='context_embedding')
    else:
        query_embedding = times(query_sequence,
                                embedding,
                                name='query_embedding')
        context_embedding = times(context_sequence,
                                  embedding,
                                  name='context_embedding')

    contextGruW = Parameter(_INFERRED + _as_tuple(params.hidden_dim),
                            init=glorot_uniform(),
                            name='gru_params')
    queryGruW = Parameter(_INFERRED + _as_tuple(params.hidden_dim),
                          init=glorot_uniform(),
                          name='gru_params')

    entity_embedding = ops.times(context_sequence,
                                 embedding_matrix,
                                 name='constant_entity_embedding')
    # Unlike other words in the context, we keep the entity vectors fixed as a random vector so that each vector just means an identifier of different entities in the context and it has no semantic meaning
    full_context_embedding = ops.element_select(entity_ids_mask,
                                                entity_embedding,
                                                context_embedding)
    context_memory = ops.optimized_rnnstack(full_context_embedding,
                                            contextGruW,
                                            params.hidden_dim,
                                            1,
                                            True,
                                            recurrent_op='gru',
                                            name='context_mem')

    query_memory = ops.optimized_rnnstack(query_embedding,
                                          queryGruW,
                                          params.hidden_dim,
                                          1,
                                          True,
                                          recurrent_op='gru',
                                          name='query_mem')
    qfwd = ops.slice(sequence.last(query_memory),
                     -1,
                     0,
                     params.hidden_dim,
                     name='fwd')
    qbwd = ops.slice(sequence.first(query_memory),
                     -1,
                     params.hidden_dim,
                     params.hidden_dim * 2,
                     name='bwd')
    init_status = ops.splice(
        qfwd, qbwd,
        name='Init_Status')  # get last fwd status and first bwd status
    return attention_model(context_memory,
                           query_memory,
                           init_status,
                           params.hidden_dim,
                           params.attention_dim,
                           max_steps=params.max_rl_steps)