def forward(self, is_train, req, in_data, out_data, aux): global batch batch += 1 x = in_data[0] gamma = in_data[1] beta = in_data[2] moving_mean = in_data[3] moving_var = in_data[4] new_gamma = in_data[5] new_beta = in_data[6] y_shift_bit = in_data[7] last_shift_bit = in_data[8] #print(batch) if batch % 20 == 0: writer.add_histogram(tag='batch1_input', values=x, bins=np.arange(-10, 10), global_step=batch) #if x.max() > 127 or x.min() < -128: # print(x) y = out_data[0] if is_train: mean = nd.mean(x, axis=(0, 2, 3)) var = nd.array(np.var(x.asnumpy(), axis=(0, 2, 3))) quan_gamma = gamma / (nd.sqrt(var + self.eps)) quan_beta = beta - mean * gamma / nd.sqrt(var + self.eps) # print(quan_gamma) quan_gamma = quan_gamma * (2**last_shift_bit) quan_gamma, quan_beta, gamma_shift_bit = self.int_quantize_double( quan_gamma, quan_beta) y = nd.BatchNorm(x, gamma=nd.ones(shape=moving_var.shape), beta=nd.zeros(shape=moving_mean.shape), moving_mean=nd.zeros(shape=moving_mean.shape), moving_var=nd.ones(shape=moving_var.shape), eps=1e-5, momentum=self.momentum, fix_gamma=True, name=self.name) y, y_shift_bit = self.int_quantize(y) # print('train gamma', quan_gamma) else: # quan_gamma, quan_beta, gamma_shift_bit = self.int_quantize_double(quan_gamma, quan_beta) y = nd.BatchNorm(x, gamma=nd.ones(shape=moving_var.shape), beta=new_beta, moving_mean=nd.zeros(shape=moving_mean.shape), moving_var=nd.ones(shape=moving_var.shape), eps=1e-5, momentum=self.momentum, fix_gamma=True, name=self.name) # y, y_shift_bit = self.int_quantize(y) y = y * (2**y_shift_bit) self.assign(out_data[0], req[0], mx.nd.array(y))
def backward(self, req, out_grad, in_data, out_data, in_grad, aux): dx = in_grad[0] dgamma = in_grad[1] dbeta = in_grad[2] x = in_data[0] gamma = in_data[1] beta = in_data[2] mean = in_data[3] var = in_data[4] new_gamma = in_data[5] new_beta = in_data[6] y_shift_bit = in_data[7] last_shift_bit = in_data[8] y = out_data[0] dy = out_grad[0] mean = nd.mean(x, axis=(0, 2, 3)) var = nd.array(np.var(x.asnumpy(), axis=(0, 2, 3))) quan_gamma = gamma / (nd.sqrt(var + self.eps)) quan_beta = beta - mean * gamma / nd.sqrt(var + self.eps) # quan_gamma = nd.clip(nd.floor(nd.log2(quan_gamma)), a_min=-3, a_max=0) # quan_gamma = 2**(quan_gamma) quan_gamma = quan_gamma * (2**last_shift_bit) # quan_beta, beta_shift_bit = self.int_quantize(quan_beta) quan_gamma, quan_beta, gamma_shift_bit = self.int_quantize_double( quan_gamma, quan_beta) x.attach_grad(), quan_gamma.attach_grad(), quan_beta.attach_grad() # print(quan_gamma) with autograd.record(): y = nd.BatchNorm(x, gamma=quan_gamma, beta=quan_beta, moving_mean=nd.zeros(shape=mean.shape), moving_var=nd.ones(shape=var.shape), eps=self.eps, momentum=self.momentum, fix_gamma=False, name=self.name) y, y_shift_bit = self.int_quantize(y) # print(quan_gamma) dx, dgamma, dbeta = autograd.grad(y, [x, quan_gamma, quan_beta], dy, retain_graph=True) self.assign(in_grad[0], req[0], dx / 2**y_shift_bit) self.assign(in_grad[1], req[0], dgamma / 2**(gamma_shift_bit + last_shift_bit)) self.assign(in_grad[2], req[0], dbeta / 2**gamma_shift_bit) self.assign(in_data[5], req[0], quan_gamma) self.assign(in_data[6], req[0], quan_beta) self.assign(in_data[7], req[0], y_shift_bit)
def forward(self, x): # return fn.BatchNormFn(self.running_mean.data(x.context), # self.running_var.data(x.context), # self.momentum, self.eps)(x, self.bn_weight.data(x.context), # self.bn_bias.data(x.context)) if autograd.is_training(): return nd.BatchNorm(x, gamma=self.bn_weight.data(x.context), beta=self.bn_bias.data(x.context), moving_mean=self.running_mean.data(x.context), moving_var=self.running_var.data(x.context), eps=self.eps, momentum=self.momentum, use_global_stats=False) else: return nd.BatchNorm(x, gamma=self.bn_weight.data(x.context), beta=self.bn_bias.data(x.context), moving_mean=self.running_mean.data(x.context), moving_var=self.running_var.data(x.context), eps=self.eps, momentum=self.momentum, use_global_stats=True)
def merge_batchnorm(self, bnorm, ctx=None): if not isinstance(bnorm, nn.BatchNorm): raise RuntimeError('Cannot merge_batchnorm with type %s.' % type(bnorm)) kwargs = bnorm._kwargs.copy() del kwargs['axis'] gamma = bnorm.gamma.data(ctx=ctx) beta = bnorm.beta.data(ctx=ctx) moving_mean = bnorm.running_mean.data(ctx=ctx) moving_var = bnorm.running_var.data(ctx=ctx) wmod = nd.BatchNorm(data=self.weight.data(ctx=ctx), gamma=gamma, beta=beta.zeros_like(), moving_mean=moving_mean.zeros_like(), moving_var=moving_var, axis=self._weight_axis, **kwargs) self.weight.set_data(wmod) if self.bias is not None: bmod = nd.BatchNorm(data=self.bias.data(ctx=ctx), gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, axis=self._bias_axis, **kwargs) self.bias.set_data(bmod) else: raise NotImplementedError( 'Adding bias to previously bias-less linear layers during BatchNorm-merging is not yet supported.' ) return True
def forward(self, is_train, req, in_data, out_data, aux): x = in_data[0] gamma = in_data[1] beta = in_data[2] moving_mean = in_data[3] moving_var = in_data[4] # print(x.sum()) y = out_data[0] if is_train: mean = nd.mean(x, axis=(0, 2, 3)) var = nd.array(np.var(x.asnumpy(), axis=(0, 2, 3))) #print(moving_mean ,self.momentum, mean) moving_mean = moving_mean * self.momentum + mean * (1 - self.momentum) moving_var = moving_var * self.momentum + var * (1 - self.momentum) self.assign(in_data[3], req[0], moving_mean) self.assign(in_data[4], req[0], moving_var) else: mean = moving_mean var = moving_var quan_gamma = self.quantize(gamma / (nd.sqrt(var + self.eps))) quan_beta = self.quantize(beta - mean * gamma / nd.sqrt(var + self.eps)) y = nd.BatchNorm(x, gamma=quan_gamma, beta=quan_beta, moving_mean=nd.zeros(shape=moving_mean.shape), moving_var=nd.ones(shape=moving_var.shape), eps=self.eps, momentum=self.momentum, fix_gamma=self.fix_gamma, name=self.name) self.assign(out_data[0], req[0], mx.nd.array(y))
def backward(self, req, out_grad, in_data, out_data, in_grad, aux): dx = in_grad[0] dgamma = in_grad[1] dbeta = in_grad[2] x = in_data[0] gamma = in_data[1] beta = in_data[2] y = out_data[0] dy = out_grad[0] mean = nd.mean(x, axis=(0, 2, 3)) var = nd.array(np.var(x.asnumpy(), axis=(0, 2, 3))) quan_gamma = gamma quan_beta = beta x.attach_grad(), gamma.attach_grad(), beta.attach_grad() with autograd.record(): y = nd.BatchNorm(x, gamma=quan_gamma, beta=quan_beta, moving_mean=mean, moving_var=var, eps=self.eps, momentum=self.momentum, fix_gamma=self.fix_gamma, name=self.name) dx, dgamma, dbeta = autograd.grad(y, [x, quan_gamma, quan_beta], dy, retain_graph=True) self.assign(in_grad[0], req[0], dx) self.assign(in_grad[1], req[0], dgamma) self.assign(in_grad[2], req[0], dbeta)