コード例 #1
0
def _rnn_getter(f, attr):
    if not attr.bidirectional:
        raise NotImplementedError()
    use_cudnn = (len(f.parameters) == 1
                 )  # CNTK has only 1 big fat parameter when using cudnn
    if use_cudnn:
        gates = _get_rnn_gates(attr.op_type)
        fw_Wt, fw_Ht, bw_Wt, bw_Ht, fw_b1, fw_b2, bw_b1, bw_b2 = np.split(
            f.parameters[0].value.reshape(-1), _get_cudnn_rnn_splitter(attr))
        return cstk.RnnArgs(fw_W=_adjust_lstm_gate_order(
            fw_Wt.reshape(gates * attr.hidden_dim, -1).transpose()),
                            fw_H=_adjust_lstm_gate_order(
                                fw_Ht.reshape(gates * attr.hidden_dim,
                                              -1).transpose()),
                            fw_b=_adjust_lstm_gate_order(fw_b1 + fw_b2),
                            bw_W=_adjust_lstm_gate_order(
                                bw_Wt.reshape(gates * attr.hidden_dim,
                                              -1).transpose()),
                            bw_H=_adjust_lstm_gate_order(
                                bw_Ht.reshape(gates * attr.hidden_dim,
                                              -1).transpose()),
                            bw_b=_adjust_lstm_gate_order(bw_b1 + bw_b2))
    else:
        param = _get_birnn_param(f)
        return cstk.RnnArgs(fw_W=_parameter_getter(param.fw_W),
                            fw_H=_parameter_getter(param.fw_H),
                            fw_b=_parameter_getter(param.fw_b),
                            bw_W=_parameter_getter(param.bw_W),
                            bw_H=_parameter_getter(param.bw_H),
                            bw_b=_parameter_getter(param.bw_b))
コード例 #2
0
 def _get(scope, attr):
     if not attr.bidirectional:
         raise NotImplementedError()
     fw_M, fw_b, bw_M, bw_b = _rnn_trainable_in_scope(scope)
     fw_W, fw_H = np.split(_trainable_getter(sess)(fw_M), [attr.input_dim])
     fw_b = _adjust_forget_bias(_trainable_getter(sess)(fw_b), attr.hidden_dim, attr.forget_bias)
     bw_W, bw_H = np.split(_trainable_getter(sess)(bw_M), [attr.input_dim])
     bw_b  = _adjust_forget_bias(_trainable_getter(sess)(bw_b), attr.hidden_dim, attr.forget_bias)
     return cstk.RnnArgs(fw_W=fw_W, fw_H=fw_H, fw_b=fw_b, bw_W=bw_W, bw_H=bw_H, bw_b=bw_b)
コード例 #3
0
def _get_birnn_param(f):
    if f.root_function.op_name != 'Splice':
        raise NotImplementedError()
    # assuming forward/backward cell first/second input to Splice
    fw = f.root_function.inputs[0].owner
    bw = f.root_function.inputs[1].owner
    return cstk.RnnArgs(fw_W=find_func_param(fw, name='W'),
                        fw_H=find_func_param(fw, name='H'),
                        fw_b=find_func_param(fw, name='b'),
                        bw_W=find_func_param(bw, name='W'),
                        bw_H=find_func_param(bw, name='H'),
                        bw_b=find_func_param(bw, name='b'))
コード例 #4
0
def test_cntk_cudnn():
    try:
        import tensorflow
        has_tensorflow = True
    except:
        has_tensorflow = False

    if has_tensorflow:
        tf_baseline_lstm()
    else:
        cntk_baseline_lstm()

    import cntk as C
    import cntk.contrib.crosstalk.crosstalk_cntk as crct
    ci = crct.instance

    input_var = C.sequence.input_variable(shape=(in_dim))
    data = {input_var: data_cntk}
    ci.set_data(data)
    ci.set_workdir(workdir)

    W = C.parameter((
        -1,
        dim,
    ), init=C.glorot_uniform())
    cudnn_fwbw = C.optimized_rnnstack(input_var,
                                      W,
                                      dim,
                                      1,
                                      bidirectional=True,
                                      recurrent_op='lstm')
    ci.watch(cudnn_fwbw,
             'cntk_birnn_cudnn',
             var_type=cstk.RnnAttr,
             attr=cstk.RnnAttr(bidirectional=True,
                               op_type='lstm',
                               input_dim=in_dim,
                               hidden_dim=dim,
                               forget_bias=0))
    ci.watch(cudnn_fwbw, 'cntk_birnn_cudnn_out')

    ci.assign('cntk_birnn_cudnn', load=True, load_name='birnn')
    assert ci.compare('cntk_birnn_cudnn_out',
                      compare_name='birnn_out',
                      rtol=1e-4,
                      atol=1e-6)

    ci.fetch('cntk_birnn_cudnn', save=True)
    ci.assign('cntk_birnn_cudnn', load=True)
    assert ci.compare('cntk_birnn_cudnn_out',
                      compare_name='birnn_out',
                      rtol=1e-4,
                      atol=1e-6)

    # test assign with value
    num_gates = 4
    ci.assign('cntk_birnn_cudnn',
              value=cstk.RnnArgs(
                  fw_W=np.random.random(
                      (in_dim, num_gates * dim)).astype(np.float32),
                  fw_H=np.random.random(
                      (dim, num_gates * dim)).astype(np.float32),
                  fw_b=np.random.random(
                      (num_gates * dim, )).astype(np.float32),
                  bw_W=np.random.random(
                      (in_dim, num_gates * dim)).astype(np.float32),
                  bw_H=np.random.random(
                      (dim, num_gates * dim)).astype(np.float32),
                  bw_b=np.random.random(
                      (num_gates * dim, )).astype(np.float32)))

    ci.reset()