def _create_cc(self, x, gy, hint, e=Engine()): if x.ndim == 2: fmt = m.memory.nc else: fmt = m.memory.nchw x = array(x, fmt, e) gy = array(gy, fmt, e) diff_pd = gy.memory.get_primitive_desc() outputs = CC.reorder_if_must(x, diff_pd, e, self.dag) if len(outputs) == 2: x, self.itm_arr = outputs[:2] else: x = outputs[0] mem_pd = x.memory.get_primitive_desc() cc_d = eltwise_backward.desc(eltwise_relu, diff_pd.desc(), mem_pd.desc(), 0.0, 0.0) cc_pd = eltwise_backward.primitive_desc(cc_d, e, hint) # gx = mdarray(cc_pd.diff_src_primitive_desc()) # print("gx.format=", m.get_fmt(cc_pd.diff_src_primitive_desc())) gx = gy self.dag.push_back(eltwise_backward.eltwise_backward(cc_pd, at(x.memory), at(gy.memory), gx.memory)) self.x = x self.gy = gy self._hint = hint self.outputs = gx,
def _create_cc(self, x, gy, hint, y, ws, ksize, stride, pad, cover_all, e): self.ksize = ksize self.stride = stride self.pad = pad self.cover_all = cover_all self.x = array(x, m.memory.nchw, e) gy = array(gy, m.memory.nchw, e) if self.alg_kind is pooling_max: gy_md = y.memory.get_primitive_desc().desc() else: gy_md = gy.memory.get_primitive_desc().desc() gx_md = m.desc(x.shape, m.memory.f32, m.memory.any) # x_md = self.x.memory.get_primitive_desc().desc() n, c, h, w = x.shape sy, sx = _pair(stride) kh, kw = _pair(ksize) p_upper, p_left = _pair(pad) yh = conv.get_conv_outsize(h, kh, sy, p_upper, cover_all=cover_all) assert yh > 0, 'Height in the output should be positive.' yw = conv.get_conv_outsize(w, kw, sx, p_left, cover_all=cover_all) assert yw > 0, 'Width in the output should be positive.' p_down = sy * (yh - 1) + kh - h - p_upper p_right = sx * (yw - 1) + kw - w - p_left cc_d = pooling_backward.desc(self.alg_kind, gx_md, gy_md, stride, ksize, (p_upper, p_left), (p_down, p_right), zero) cc_pd = pooling_backward.primitive_desc(cc_d, e, hint) gx = mdarray(cc_pd.diff_src_primitive_desc()) if self.alg_kind is pooling_max: # For max pooling reorder y if needed outputs = reorder_if_must(gy, y.memory.get_primitive_desc(), e, self.dag_) if len(outputs) == 2: self.reordered_gy, self.itm_arr = outputs[:2] else: self.reordered_gy = outputs[0] self.dag_.push_back( pooling_backward.pooling_backward( cc_pd, at(self.reordered_gy.memory), at(ws.memory), gx.memory)) else: # There is no workspace for average pooling self.dag_.push_back( pooling_backward.pooling_backward(cc_pd, at(gy.memory), gx.memory)) self._hint = hint self.gy = gy self.outputs = gx,
def mkl_sum(xs, func=None): e = Engine() xarrays = () # prevent the obj from gc xs_arrays = () # prevent the obj from gc itm_arr = None # prvent the obj from gc xs_mpdl = m.mpd_list() xs_pl = () scales = m.vectord() pl = primitive_list() for i in range(len(xs)): xarray = array(xs[i], _x_format(xs[i].ndim), e) xmpd = xarray.memory.get_primitive_desc() if i == 0: xmpd_best = xmpd else: if m.get_fmt(xmpd) > m.get_fmt(xmpd_best): xmpd_best = xmpd xs_arrays += (xarray,) for x in xs_arrays: outputs = reorder_if_must(x, xmpd_best, e, pl) if len(outputs) == 2: xarray, itm_arr = outputs[:2] else: xarray = outputs[0] xarrays += (xarray,) scales.push_back(1.0) xs_mpdl.push_back(xarray.memory.get_primitive_desc()) xs_pl += (at(xarray.memory), ) cc_pd = sum.primitive_desc(scales, xs_mpdl) if func is not None and hasattr(func, 'hint'): # this is only used for grad accumulate currently cc = ComputeComplex.get_bd_cc(func.hint, pos=(func.rank, func.fanout)) if cc is not None: y = cc.gy else: y = mdarray(cc_pd.dst_primitive_desc()) else: y = mdarray(cc_pd.dst_primitive_desc()) pl.push_back(sum.sum(cc_pd, xs_pl, y.memory)) s = Stream() s.submit(pl) s.wait() return y
def _create_cc(self, x, W, b, e=Engine()): y_d = m.desc((x.shape[0], W.shape[0]), m.memory.f32, m.memory.any) # Create primitive_desc from any cc_d = create_forward_desc(ip_forward.desc, y_d, x, W, b) cc_pd = ip_forward.primitive_desc(cc_d, e) # Transform inputs self.x = array(x, _x_format(x.ndim), e) w_mpd = cc_pd.weights_primitive_desc() self.usr_w = array(W, _W_format(W.ndim), e) outputs = CC.reorder_if_must(self.usr_w, w_mpd, e, self.dag) if len(outputs) == 2: self.W, self.itm_arr = outputs[:2] else: self.W = outputs[0] if b is not None: self.b = array(b, m.memory.x, e) y = linear_f_op(cc_pd, self.x, self.W, self.b, self.dag) else: y = linear_f_op(cc_pd, self.x, self.W, self.dag) # Prepare output # y = mdarray(cc_pd.dst_primitive_desc()) # dag = self.dag_ # # Reorder if must # x_m = reorder_if_must(self.x.memory, # cc_pd.src_primitive_desc(), dag) # W_m = reorder_if_must(self.W.memory, # cc_pd.weights_primitive_desc(), dag) # if b is None: # dag.push_back(ip_forward.inner_product_forward(cc_pd, # at(x_m), at(W_m), y.memory)) # else: # dag.push_back(ip_forward.inner_product_forward(cc_pd, # at(x_m), at(W_m), at(self.b.memory), y.memory)) # self.x_m = x_m # self.W_m = W_m self._hint = cc_pd self.outputs = y,
def _create_cc(self, inputs, fwd_x, gy, hint, flags, eps, mean, var, e): self.train = configuration.config.train self.flags = flags self.eps = eps x, gamma, beta = inputs[:3] # self.x = array(x, m.memory.nchw, e) self.x = fwd_x x_mpd = self.x.memory.get_primitive_desc() x_md = x_mpd.desc() gy = array(gy, m.memory.nchw, e) outputs = reorder_if_must(gy, x_mpd, e, self.dag_) if len(outputs) == 2: self.gy_src = gy gy, self.itm_arr = outputs[:2] else: self.gy_src = gy gy = outputs[0] gy_md = gy.memory.get_primitive_desc().desc() cc_d = bn_backward.desc(backward, gy_md, x_md, eps, flags) cc_pd = bn_backward.primitive_desc(cc_d, e, hint) gx = mdarray(self.x.memory.get_primitive_desc(), gy.memory) if flags & use_scale_shift: w = numpy.concatenate((gamma, beta), axis=0).reshape((2, -1)) self.w = array(w, m.memory.nc, e) self.mean = array(mean, m.memory.x, e) self.var = array(var, m.memory.x, e) self.gw = mdarray(cc_pd.diff_weights_primitive_desc()) bwd_p = bn_backward.batch_normalization_backward( cc_pd, at(self.x.memory), at(self.mean.memory), at(self.var.memory), at(gy.memory), at(self.w.memory), gx.memory, self.gw.memory) else: bwd_p = bn_backward.batch_normalization_backward( cc_pd, at(self.x.memory), at(self.mean.memory), at(self.var.memory), at(gy.memory), gx.memory) self.dag_.push_back(bwd_p) self._hint = hint self.gy = gy self.outputs = gx, self.gw
def _create_cc(self, inputs, e): x0, x1 = inputs[:2] xs_mpdl = m.mpd_list() xs_pl = () scales = m.vectord() self.x0 = x0 self.x1 = x1 self.x1_reordered = reorder_if_must(x1, x0.memory.get_primitive_desc(), e, self.dag_)[0] scales.push_back(1.0) scales.push_back(1.0) xs_mpdl.push_back(x0.memory.get_primitive_desc()) xs_mpdl.push_back(self.x1_reordered.memory.get_primitive_desc()) cc_pd = sum.primitive_desc(scales, xs_mpdl) xs_pl = (at(x0.memory), at(self.x1_reordered.memory)) y = mdarray(cc_pd.dst_primitive_desc()) self.dag_.push_back(sum.sum(cc_pd, xs_pl, y.memory)) self.outputs = y,
def _create_cc(self, x, W, b, stride, pad, cover_all, e): super(ConvolutionForward, self).__init__() g = conv.conv_geometry(x.shape, W.shape, stride, pad, cover_all) y_d = m.desc(g.out_shape, m.memory.f32, m.memory.any) # Create primitive_desc from any cc_d = create_forward_desc(conv_forward.desc, y_d, (x, W, b), g.geometry) cc_pd = conv_forward.primitive_desc(cc_d, e) w_mpd = cc_pd.weights_primitive_desc() self.usr_w = array(W, m.memory.oihw, e) outputs = CC.reorder_if_must(self.usr_w, w_mpd, e, self.dag) if len(outputs) == 2: self.W, self.itm_arr = outputs[:2] else: self.W = outputs[0] # Record weight reorder primitive hint if self.usr_w is not self.W: wro = WeightReorderOptimization() wro.reorder = self.dag.size() - 1 wro.optimized = False self.weight_reorder_opt = wro else: self.weight_reorder_opt = None self.x = array(x, m.memory.nchw, e) if b is not None: self.b = array(b, m.memory.x, e) if b is None: y = conv_f_op(cc_pd, self.x, self.W, self.dag) else: y = conv_f_op(cc_pd, self.x, self.W, self.b, self.dag) self._hint = cc_pd self.outputs = y,
def _create_cc(self, inputs, eps, mean, var, e): self.eps = eps self.mean = None self.var = None self.w = None self.train = configuration.config.train x, gamma, beta = inputs[:3] fmt_desired = m.get_desired_format(x.shape[1]) x = array(x, m.memory.nchw, e) # x = array(x, fmt_desired, e) assert x.dtype == numpy.dtype('float32') x_desired_md = m.desc(x.shape, m.memory.f32, fmt_desired) x_desired_mpd = m.primitive_desc(x_desired_md, e) outputs = reorder_if_must(x, x_desired_mpd, e, self.dag_) if len(outputs) == 2: self.x, self.itm_arr = outputs[:2] self.x_src = x else: self.x = outputs[0] self.x_src = x w = numpy.concatenate((gamma, beta), axis=0).reshape((2, -1)) self.numpy_w = w self.w = array(w, m.memory.nc, e) scale_shift = True self.flags = use_scale_shift if mean is None: fwd_prop_kind = forward_training global_stats = False else: fwd_prop_kind = forward_scoring self.flags |= use_global_stats global_stats = True self.mean = array(mean, m.memory.x, e) self.var = array(var, m.memory.x, e) x_md = self.x.memory.get_primitive_desc().desc() cc_d = bn_forward.desc(fwd_prop_kind, x_md, eps, self.flags) cc_pd = bn_forward.primitive_desc(cc_d, e) y = mdarray(cc_pd.dst_primitive_desc()) # TODO reorder weight # if scale_shift is True: # w = mdarray(cc_pd.weights_primitive_desc()) if scale_shift is True and global_stats is False: self.mean = mdarray(cc_pd.mean_primitive_desc()) self.var = mdarray(cc_pd.variance_primitive_desc()) if (not configuration.config.train) and (not global_stats): if scale_shift is True: bnf = bn_forward.batch_normalization_forward( cc_pd, at(self.x.memory), at(self.w.memory), y.memory) else: bnf = bn_forward.batch_normalization_forward( cc_pd, at(self.x.memory), y.memory) elif global_stats is True: if scale_shift is True: bnf = bn_forward.batch_normalization_forward( cc_pd, at(self.x.memory), at(self.mean.memory), at(self.var.memory), at(self.w.memory), y.memory) else: bnf = bn_forward.batch_normalization_forward( cc_pd, at(self.x.memory), self.mean.memory, self.var.memory, y.memory) else: if scale_shift is True: bnf = bn_forward.batch_normalization_forward( cc_pd, at(self.x.memory), at(self.w.memory), y.memory, self.mean.memory, self.var.memory) else: bnf = bn_forward.batch_normalization_forward( cc_pd, at(self.x.memory), y.memory, self.mean.memory, self.var.memory) self.dag_.push_back(bnf) self._hint = cc_pd self.outputs = y, self.flags, self.mean, self.var