def _rnn_getter(f, attr): if not attr.bidirectional: raise NotImplementedError() use_cudnn = (len(f.parameters) == 1 ) # CNTK has only 1 big fat parameter when using cudnn if use_cudnn: gates = _get_rnn_gates(attr.op_type) fw_Wt, fw_Ht, bw_Wt, bw_Ht, fw_b1, fw_b2, bw_b1, bw_b2 = np.split( f.parameters[0].value.reshape(-1), _get_cudnn_rnn_splitter(attr)) return cstk.RnnArgs(fw_W=_adjust_lstm_gate_order( fw_Wt.reshape(gates * attr.hidden_dim, -1).transpose()), fw_H=_adjust_lstm_gate_order( fw_Ht.reshape(gates * attr.hidden_dim, -1).transpose()), fw_b=_adjust_lstm_gate_order(fw_b1 + fw_b2), bw_W=_adjust_lstm_gate_order( bw_Wt.reshape(gates * attr.hidden_dim, -1).transpose()), bw_H=_adjust_lstm_gate_order( bw_Ht.reshape(gates * attr.hidden_dim, -1).transpose()), bw_b=_adjust_lstm_gate_order(bw_b1 + bw_b2)) else: param = _get_birnn_param(f) return cstk.RnnArgs(fw_W=_parameter_getter(param.fw_W), fw_H=_parameter_getter(param.fw_H), fw_b=_parameter_getter(param.fw_b), bw_W=_parameter_getter(param.bw_W), bw_H=_parameter_getter(param.bw_H), bw_b=_parameter_getter(param.bw_b))
def _get(scope, attr): if not attr.bidirectional: raise NotImplementedError() fw_M, fw_b, bw_M, bw_b = _rnn_trainable_in_scope(scope) fw_W, fw_H = np.split(_trainable_getter(sess)(fw_M), [attr.input_dim]) fw_b = _adjust_forget_bias(_trainable_getter(sess)(fw_b), attr.hidden_dim, attr.forget_bias) bw_W, bw_H = np.split(_trainable_getter(sess)(bw_M), [attr.input_dim]) bw_b = _adjust_forget_bias(_trainable_getter(sess)(bw_b), attr.hidden_dim, attr.forget_bias) return cstk.RnnArgs(fw_W=fw_W, fw_H=fw_H, fw_b=fw_b, bw_W=bw_W, bw_H=bw_H, bw_b=bw_b)
def _get_birnn_param(f): if f.root_function.op_name != 'Splice': raise NotImplementedError() # assuming forward/backward cell first/second input to Splice fw = f.root_function.inputs[0].owner bw = f.root_function.inputs[1].owner return cstk.RnnArgs(fw_W=find_func_param(fw, name='W'), fw_H=find_func_param(fw, name='H'), fw_b=find_func_param(fw, name='b'), bw_W=find_func_param(bw, name='W'), bw_H=find_func_param(bw, name='H'), bw_b=find_func_param(bw, name='b'))
def test_cntk_cudnn(): try: import tensorflow has_tensorflow = True except: has_tensorflow = False if has_tensorflow: tf_baseline_lstm() else: cntk_baseline_lstm() import cntk as C import cntk.contrib.crosstalk.crosstalk_cntk as crct ci = crct.instance input_var = C.sequence.input_variable(shape=(in_dim)) data = {input_var: data_cntk} ci.set_data(data) ci.set_workdir(workdir) W = C.parameter(( -1, dim, ), init=C.glorot_uniform()) cudnn_fwbw = C.optimized_rnnstack(input_var, W, dim, 1, bidirectional=True, recurrent_op='lstm') ci.watch(cudnn_fwbw, 'cntk_birnn_cudnn', var_type=cstk.RnnAttr, attr=cstk.RnnAttr(bidirectional=True, op_type='lstm', input_dim=in_dim, hidden_dim=dim, forget_bias=0)) ci.watch(cudnn_fwbw, 'cntk_birnn_cudnn_out') ci.assign('cntk_birnn_cudnn', load=True, load_name='birnn') assert ci.compare('cntk_birnn_cudnn_out', compare_name='birnn_out', rtol=1e-4, atol=1e-6) ci.fetch('cntk_birnn_cudnn', save=True) ci.assign('cntk_birnn_cudnn', load=True) assert ci.compare('cntk_birnn_cudnn_out', compare_name='birnn_out', rtol=1e-4, atol=1e-6) # test assign with value num_gates = 4 ci.assign('cntk_birnn_cudnn', value=cstk.RnnArgs( fw_W=np.random.random( (in_dim, num_gates * dim)).astype(np.float32), fw_H=np.random.random( (dim, num_gates * dim)).astype(np.float32), fw_b=np.random.random( (num_gates * dim, )).astype(np.float32), bw_W=np.random.random( (in_dim, num_gates * dim)).astype(np.float32), bw_H=np.random.random( (dim, num_gates * dim)).astype(np.float32), bw_b=np.random.random( (num_gates * dim, )).astype(np.float32))) ci.reset()