def _create_cc(self, x, W, gy, hint, fwd_W, e=Engine()): # Create primitive descriptor cc_d = create_backward_desc(ip_backdata.desc, x, W, gy) cc_pd = ip_backdata.primitive_desc(cc_d, e, hint) # Transform inputs self.W = fwd_W self.gy = array(gy, m.memory.nc, e) gx = linear_bd_op(cc_pd, self.gy, self.W, self.dag) # # Prepare output mdarray # gx = mdarray(cc_pd.diff_src_primitive_desc()) # dag = self.dag_ # # Reorder if must # gy_m = reorder_if_must(self.gy.memory, cc_pd.diff_dst_primitive_desc(), dag) # W_m = reorder_if_must(self.W.memory, cc_pd.weights_primitive_desc(), dag) # dag.push_back(ip_backdata.inner_product_backward_data(cc_pd, # at(gy_m), at(W_m), gx.memory)) # self.gy_m = gy_m # self.W_m = W_m self._hint = hint self.outputs = gx,
def _create_cc(self, x, gy, hint, e=Engine()): if x.ndim == 2: fmt = m.memory.nc else: fmt = m.memory.nchw x = array(x, fmt, e) gy = array(gy, fmt, e) diff_pd = gy.memory.get_primitive_desc() outputs = CC.reorder_if_must(x, diff_pd, e, self.dag) if len(outputs) == 2: x, self.itm_arr = outputs[:2] else: x = outputs[0] mem_pd = x.memory.get_primitive_desc() cc_d = eltwise_backward.desc(eltwise_relu, diff_pd.desc(), mem_pd.desc(), 0.0, 0.0) cc_pd = eltwise_backward.primitive_desc(cc_d, e, hint) # gx = mdarray(cc_pd.diff_src_primitive_desc()) # print("gx.format=", m.get_fmt(cc_pd.diff_src_primitive_desc())) gx = gy self.dag.push_back(eltwise_backward.eltwise_backward(cc_pd, at(x.memory), at(gy.memory), gx.memory)) self.x = x self.gy = gy self._hint = hint self.outputs = gx,
def __init__(self, inputs, pos=(0, 0), e=Engine()): x = inputs[0] super(ReLUForward, self).__init__() if self.new: self._create_cc(x, e) else: self._reuse_cc(x)
def __init__(self, inputs, grad_outputs, hint, pos=(0, 0), e=Engine()): x = inputs[0] gy = grad_outputs[0] super(ReLUBackward, self).__init__() if self.new: self._create_cc(x, gy, hint, e) else: self._reuse_cc(x, gy)
def __init__(self, inputs, pos=(0, 0), e=Engine()): super(LinearForward, self).__init__() x = inputs[0] W = inputs[1] b = inputs[2] if len(inputs) == 3 else None self.argc = len(inputs) if self.new: self._create_cc(x, W, b, e) else: self._reuse_cc(x, W, b, e)
def __init__(self, inputs, grad_outputs, hint, pos=(0, 0), e=Engine()): super(LinearBackwardWeighs, self).__init__() x = inputs[0] gy = grad_outputs[0] self.argc = len(inputs) if self.new: W = inputs[1] b = inputs[2] if self.argc == 3 else None self._create_cc(x, W, b, gy, hint, e) else: self._reuse_cc(x, gy)
def warray(w): fmt = None if w.ndim == 1: fmt = m.memory.x elif w.ndim == 2: fmt = m.memory.oi elif w.ndim == 4: fmt = m.memory.oihw else: raise NotImplementedError if w.dtype != numpy.float32: raise NotImplementedError e = Engine() return mdarray(w, fmt, e)
def __init__(self, inputs, stride=1, pad=0, outsize=None, cover_all=False, hint=dummy_hint, pos=(0, 0), e=Engine()): x, gy = inputs[:2] if self.new: self._create_cc(x, gy, stride, pad, outsize, cover_all, hint, e) else: self._reuse_cc(x, gy)
def __init__(self, inputs, grad_outputs, hint, fwd_W, pos=(0, 0), e=Engine()): super(LinearBackwardData, self).__init__() W = inputs[1] gy = grad_outputs[0] self.argc = len(inputs) if self.new: x = inputs[0] self._create_cc(x, W, gy, hint, fwd_W, e) else: self._reuse_cc(W, gy)
def __init__(self, inputs, stride=1, pad=0, cover_all=False, pos=(0, 0), e=Engine()): x = inputs[0] W = inputs[1] b = inputs[2] if len(inputs) == 3 else None if self.new: self._create_cc(x, W, b, stride, pad, cover_all, e) self.num_inputs = len(inputs) else: self._reuse_cc(x, W, b)
def create_dummy_hint(): """ Create a dummy hint To create a convolution backward primitive, one needs a forward primitive as a hint. Though there is no use of it in actual implementations. A dummy hint can be a wordaround of this situation. There would be a interface requires no hint in the furture. """ x_md = m.desc((128, 3, 227, 227), m.memory.f32, m.memory.any) W_md = m.desc((96, 3, 11, 11), m.memory.f32, m.memory.any) o_md = m.desc((128, 96, 55, 55), m.memory.f32, m.memory.any) dummy_d = conv_forward.desc(forward, convolution_direct, x_md, W_md, o_md, (4, 4), (0, 0), (0, 0), zero) return conv_forward.primitive_desc(dummy_d, Engine())
def _create_cc(self, x, W, b, e=Engine()): y_d = m.desc((x.shape[0], W.shape[0]), m.memory.f32, m.memory.any) # Create primitive_desc from any cc_d = create_forward_desc(ip_forward.desc, y_d, x, W, b) cc_pd = ip_forward.primitive_desc(cc_d, e) # Transform inputs self.x = array(x, _x_format(x.ndim), e) w_mpd = cc_pd.weights_primitive_desc() self.usr_w = array(W, _W_format(W.ndim), e) outputs = CC.reorder_if_must(self.usr_w, w_mpd, e, self.dag) if len(outputs) == 2: self.W, self.itm_arr = outputs[:2] else: self.W = outputs[0] if b is not None: self.b = array(b, m.memory.x, e) y = linear_f_op(cc_pd, self.x, self.W, self.b, self.dag) else: y = linear_f_op(cc_pd, self.x, self.W, self.dag) # Prepare output # y = mdarray(cc_pd.dst_primitive_desc()) # dag = self.dag_ # # Reorder if must # x_m = reorder_if_must(self.x.memory, # cc_pd.src_primitive_desc(), dag) # W_m = reorder_if_must(self.W.memory, # cc_pd.weights_primitive_desc(), dag) # if b is None: # dag.push_back(ip_forward.inner_product_forward(cc_pd, # at(x_m), at(W_m), y.memory)) # else: # dag.push_back(ip_forward.inner_product_forward(cc_pd, # at(x_m), at(W_m), at(self.b.memory), y.memory)) # self.x_m = x_m # self.W_m = W_m self._hint = cc_pd self.outputs = y,
def _create_cc(self, x, e=Engine()): if x.ndim == 2: fmt = m.memory.nc elif x.ndim == 4: fmt = m.memory.nchw x = array(x, fmt, e) mem_pd = x.memory.get_primitive_desc() cc_d = eltwise_forward.desc( forward, eltwise_relu, mem_pd.desc(), 0.0, 0.0) cc_pd = eltwise_forward.primitive_desc(cc_d, e) y = mdarray(cc_pd.dst_primitive_desc()) self.x = x self.dag.push_back(eltwise_forward.eltwise_forward(cc_pd, at(x.memory), y.memory)) self._hint = cc_pd self.outputs = y,
def w_tensor(W): """Convert the input to an weight tensor of MKL-DNN Paramters --------- W : object support buffer protocol """ if W.ndim == 1: fmt = m.memory.x elif W.ndim == 2: fmt = m.memory.oi elif W.ndim == 4: fmt = m.memory.oihw else: raise NotImplementedError if W.dtype != numpy.float32: raise NotImplementedError return mdarray(W, fmt, Engine())