예제 #1
0
def create_model(X):
    # defines few stacked GRUs

    l1 = Recurrence(step_function=GRU(shape=20))(X)
    l2 = Recurrence(step_function=GRU(shape=20))(l1)
    l3 = Dense(shape=1)(l2)

    return l3
예제 #2
0
def test_recurrence():
    inputAxis = Axis('inputAxis')
    stateAxis = Axis('stateAxis')
    InputSequence = SequenceOver[inputAxis]
    StateSequence = SequenceOver[stateAxis]

    # input and expected for both tests below
    x = np.reshape(np.arange(0, 25, dtype=np.float32), (1, 5, 5))
    exp = [[0.239151, 0.239151, 0.239151, 0.239151, 0.239151],
           [0.338713, 0.338713, 0.338713, 0.338713, 0.338713],
           [0.367456, 0.367456, 0.367456, 0.367456, 0.367456],
           [0.375577, 0.375577, 0.375577, 0.375577, 0.375577],
           [0.377891, 0.377891, 0.377891, 0.377891, 0.377891]]

    ####################################################
    # Test 1: Recurrence(): initial state is constant
    ####################################################
    # Note: We cannot use random init of the GRU parameters because random numbers will
    # depend on what previous tests were run. Hence, use a constant (which is not realistic).
    # TODO: Find out how to reset the random generator, then remove the constant init.
    R = Recurrence(GRU(5, init=0.05), go_backwards=False, initial_state=0.1)

    @Function
    @Signature(InputSequence[Tensor[5]])
    def F(x):
        return R(x)

    rt = F(x)
    np.testing.assert_array_almost_equal(
        rt[0], exp, decimal=6, err_msg='Error in Recurrence(GRU()) forward')

    ####################################################
    # Test 2: RecurrenceFrom(): initial state is data input
    ####################################################
    RF = RecurrenceFrom(GRU(5, init=0.05), go_backwards=False)

    @Function
    @Signature(s=StateSequence[Tensor[5]], x=InputSequence[Tensor[5]])
    def FF(s, x):
        return RF(s, x)

    s = np.ones(
        (1, 5, 5)
    ) * 0.1  # we pass the same value as the constant in the previous test to make the result the same
    rt = FF(s, x)
    np.testing.assert_array_almost_equal(
        rt[0],
        exp,
        decimal=6,
        err_msg='Error in RecurrenceFrom(GRU()) forward')
예제 #3
0
def attention_model(context_memory,
                    query_memory,
                    init_status,
                    hidden_dim,
                    att_dim,
                    max_steps=5,
                    init=glorot_uniform()):
    """
  Create the attention model for reasonet
  Args:
    context_memory: Context memory
    query_memory: Query memory
    init_status: Intialize status
    hidden_dim: The dimention of hidden state
    att_dim: The dimention of attention
    max_step: Maxuim number of step to revisit the context memory
  """
    gru = GRU((hidden_dim * 2, ), name='control_status')
    status = init_status
    output = [None] * max_steps * 2
    sum_prob = None
    context_attention_score = attention_score(att_dim,
                                              name='context_attention')
    query_attention_score = attention_score(att_dim, name='query_attention')
    answer_attention_score = attention_score(att_dim,
                                             name='candidate_attention')
    stop_gate = termination_gate(name='terminate_prob')
    prev_stop = 0
    for step in range(max_steps):
        context_attention_weight = context_attention_score(
            status, context_memory)
        query_attention_weight = query_attention_score(status, query_memory)
        context_attention = sequence.reduce_sum(times(context_attention_weight,
                                                      context_memory),
                                                name='C-Att')
        query_attention = sequence.reduce_sum(times(query_attention_weight,
                                                    query_memory),
                                              name='Q-Att')
        attention = ops.splice(query_attention,
                               context_attention,
                               name='att-sp')
        status = gru(status, attention).output
        termination_prob = stop_gate(status)
        ans_attention = answer_attention_score(status, context_memory)
        output[step * 2] = ans_attention
        if step < max_steps - 1:
            stop_prob = prev_stop + ops.log(termination_prob, name='log_stop')
        else:
            stop_prob = prev_stop
        output[step * 2 + 1] = sequence.broadcast_as(
            ops.exp(stop_prob, name='exp_log_stop'),
            output[step * 2],
            name='Stop_{0}'.format(step))
        prev_stop += ops.log(1 - termination_prob, name='log_non_stop')

    final_ans = None
    for step in range(max_steps):
        if final_ans is None:
            final_ans = output[step * 2] * output[step * 2 + 1]
        else:
            final_ans += output[step * 2] * output[step * 2 + 1]
    results = combine(output + [final_ans], name='Attention_func')
    return results