def _make_layer(self, block, planes, blocks, expansion=1, stride=1, groups=1, residual_block=None, dropout=None, mixup=False, BN_momentum = 0.1, conv_options = None, fixup = False): downsample = None out_planes = planes * expansion if stride != 1 or self.inplanes != out_planes: if fixup: downsample = conv1x1(self.inplanes, out_planes, stride, conv_options = conv_options) else: downsample = nn.Sequential( # nn.Conv2d(self.inplanes, out_planes, # kernel_size=1, stride=stride, bias=False), conv1x1(self.inplanes, out_planes, stride, conv_options = conv_options), nn.BatchNorm2d(planes * expansion, BN_momentum), ) if residual_block is not None: residual_block = residual_block(out_planes) layers = [] layers.append(block(self.inplanes, planes, stride, expansion=expansion, downsample=downsample, groups=groups, residual_block=residual_block, dropout=dropout, BN_momentum = BN_momentum, conv_options=conv_options)) self.inplanes = planes * expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, expansion=expansion, groups=groups, residual_block=residual_block, dropout=dropout, BN_momentum = BN_momentum, conv_options= conv_options)) if mixup: layers.append(MixUp()) return nn.Sequential(*layers)
def create_mixup_or_none(alpha, num_classes, comm): from utils.mixup import MixUp # Create different random generators over workers. rng = np.random.RandomState(726 + comm.rank) if alpha > 0: return MixUp(alpha, num_classes, rng) return None
def _make_layer(self, block, planes, blocks, expansion=1, stride=1, groups=1, residual_block=None, dropout=None, mixup=False, dp_type=None, dp_percentage=None, dev='cpu'): downsample = None out_planes = planes * expansion if stride != 1 or self.inplanes != out_planes: downsample = nn.Sequential( nn.Conv2d(self.inplanes, out_planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * expansion), ) if residual_block is not None: residual_block = residual_block(out_planes) layers = [] layers.append( block(self.inplanes, planes, stride, expansion=expansion, downsample=downsample, groups=groups, residual_block=residual_block, dropout=dropout, dropout_type=dp_type, drop_percentage=dp_percentage, dev=dev)) self.inplanes = planes * expansion for i in range(1, blocks): layers.append( block(self.inplanes, planes, expansion=expansion, groups=groups, residual_block=residual_block, dropout=dropout, dropout_type=dp_type, drop_percentage=dp_percentage, dev=dev)) if mixup: layers.append(MixUp()) return nn.Sequential(*layers)
def _step(self, inputs_batch, target_batch, training=False, chunk_batch=1): outputs = [] total_loss = 0 if training: self.optimizer.zero_grad() self.optimizer.update(self.epoch, self.training_steps) for inputs, target in zip(inputs_batch.chunk(chunk_batch, dim=0), target_batch.chunk(chunk_batch, dim=0)): target = target.to(self.device) inputs = inputs.to(self.device, dtype=self.dtype) mixup = None if training: self.optimizer.pre_forward() if self.mixup is not None: input_mixup = MixUp() mixup_modules = [input_mixup] # input mixup mixup_modules += [m for m in self.model.modules() if isinstance(m, MixUp)] mixup = _mixup(mixup_modules, self.mixup, inputs.size(0)) inputs = input_mixup(inputs) # compute output output = self.model(inputs) if mixup is not None: target = mixup.mix_target(target, output.size(-1)) loss = self.criterion(output, target) grad = None if chunk_batch > 1: loss = loss / chunk_batch if isinstance(output, list) or isinstance(output, tuple): output = output[0] outputs.append(output.detach()) if training: self.optimizer.pre_backward() if self.grad_scale is not None: loss = loss * self.grad_scale loss.backward() # accumulate gradient total_loss += float(loss) if training: # post gradient accumulation if self.grad_clip > 0: grad = clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optimizer.step() # SGD step self.training_steps += 1 outputs = torch.cat(outputs, dim=0) return outputs, total_loss, grad
def _make_layer(self, block, planes, blocks, expansion=1, stride=1, groups=1, residual_block=None, dropout=None, mixup=False): CBN = get_bn_folding_module(nn.Conv2d,nn.BatchNorm2d) downsample = None out_planes = planes * expansion if stride != 1 or self.inplanes != out_planes: downsample = nn.Sequential(CBN(self.inplanes, out_planes,kernel_size=1, stride=stride, bias=True)) if residual_block is not None: residual_block = residual_block(out_planes) layers = [] layers.append(block(self.inplanes, planes, stride, expansion=expansion, downsample=downsample, groups=groups, residual_block=residual_block, dropout=dropout)) self.inplanes = planes * expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, expansion=expansion, groups=groups, residual_block=residual_block, dropout=dropout)) if mixup: layers.append(MixUp()) return nn.Sequential(*layers)
def _make_layer(self, block, planes, blocks, expansion=1, stride=1, groups=1, residual_block=None, dropout=None, mixup=False,absorb_bn=False): downsample = None out_planes = planes * expansion if stride != 1 or self.inplanes != out_planes: downsample = nn.Sequential( nn.Conv2d(self.inplanes, out_planes, kernel_size=1, stride=stride, bias=absorb_bn), ) if not absorb_bn: downsample.add_module('1' ,nn.BatchNorm2d(planes * expansion)) if residual_block is not None: residual_block = residual_block(out_planes) layers = [] layers.append(block(self.inplanes, planes, stride, expansion=expansion, downsample=downsample, groups=groups, residual_block=residual_block, dropout=dropout,absorb_bn=absorb_bn)) self.inplanes = planes * expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, expansion=expansion, groups=groups, residual_block=residual_block, dropout=dropout,absorb_bn=absorb_bn)) if mixup: layers.append(MixUp()) return nn.Sequential(*layers)
def _step(self, inputs_batch, target_batch, training=False, average_output=False, chunk_batch=1, scheduled_instructions=None): outputs = [] total_loss = 0 # self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) and (float(self.epoch) % 2 == 0) else False # if scheduled_instructions is None: # self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) else False # else: # self.input_grad_statistics = scheduled_instructions['collect_stat'] meters = {name: {'mean': AverageMeter(), 'std': AverageMeter()} for name in self.module_to_hook.keys()} grad_log_stats = {} if training: self.optimizer.zero_grad() self.optimizer.update(self.epoch, self.training_steps) for i, (inputs, target) in enumerate(zip(inputs_batch.chunk(chunk_batch, dim=0), target_batch.chunk(chunk_batch, dim=0))): self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) and (float(self.epoch) % 2 == 0) and i==0 else False if training: if self.input_grad_statistics: BK_hookF = {} for name, module in self.module_to_hook.items(): # print("name: " + str(name)) BK_hookF[name] = Hook(module, name, True) target = target.to(self.device) inputs = inputs.to(self.device, dtype=self.dtype) mixup = None if training: self.optimizer.pre_forward() if self.mixup is not None or self.cutmix is not None: input_mixup = CutMix() if self.cutmix else MixUp() mix_val = self.mixup or self.cutmix mixup_modules = [input_mixup] # input mixup mixup_modules += [m for m in self.model.modules() if isinstance(m, MixUp)] mixup = _mixup(mixup_modules, mix_val, inputs.size(0)) inputs = input_mixup(inputs) # compute output output = self.model(inputs) if mixup is not None: target = mixup.mix_target(target, output.size(-1)) if average_output: if isinstance(output, list) or isinstance(output, tuple): output = [_average_duplicates(out, target) if out is not None else None for out in output] else: output = _average_duplicates(output, target) loss = self.criterion(output, target) grad = None if chunk_batch > 1: loss = loss / chunk_batch if isinstance(output, list) or isinstance(output, tuple): output = output[0] outputs.append(output.detach()) total_loss += float(loss) if training: if i == 0: self.optimizer.pre_backward() if self.grad_scale is not None: loss = loss * self.grad_scale if self.loss_scale is not None: pass # moran # pdb.set_trace() # loss = loss * self.loss_scale loss.backward() # accumulate gradient ## moran- gather lognorm statistics if self.input_grad_statistics: # pdb.set_trace() # total = reduce((lambda total, item: total+item.alloc), snds) for hook in BK_hookF.values(): meters[hook.name]['mean'].update(float(hook.grad_log_mean), inputs.size(0)) meters[hook.name]['std'].update(float(hook.grad_log_std), inputs.size(0)) # curr_grad_log_stats = {hook.name: {'mean': hook.grad_log_mean, 'std': hook.grad_log_std} for hook in BK_hookF.values()} # pdb.set_trace() if self.input_grad_statistics: for hook in BK_hookF.values(): hook.close() ## moran- end grad_log_stats = {name: {'mean': met['mean'].avg, 'std': met['std'].avg} for name, met in meters.items()} # print("grad_log_stats! loop 1") if training: # post gradient accumulation if self.loss_scale is not None: for p in self.model.parameters(): if p.grad is None: continue # p.grad.data.div_(self.loss_scale) # moran if self.grad_clip > 0: grad = clip_grad_norm_(self.model.parameters(), self.grad_clip) # dont_update_grad = False # for param in self.model.parameters(): # param.grad[torch.isnan(param.grad)] = 0 # if torch.isnan(param.grad).any(): # moran # pdb.set_trace() # dont_update_grad = True # break # if dont_update_grad: # pass # else: # self.optimizer.step() # SGD step self.optimizer.step() # SGD step self.training_steps += 1 outputs = torch.cat(outputs, dim=0) return outputs, total_loss, grad, grad_log_stats
def _step(self, inputs_batch, target_batch, training=False, average_output=False, chunk_batch=1, scheduled_instructions=None, iter=0): outputs = [] total_loss = 0 grad_log_stats = {} # self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) and (float(self.epoch) % 2 == 0) else False # if scheduled_instructions is None: # self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) else False # else: # self.input_grad_statistics = scheduled_instructions['collect_stat'] meters = { name: { 'mean': AverageMeter(), 'std': AverageMeter() } for name in self.module_to_hook.keys() } grad_log_stats = {} # pdb.set_trace() if self.epoch == 0: for name, module in self.module_to_hook.items(): if hasattr(module, 'enable_grad_quantizer'): module.enable_grad_quantizer = False # this is the default start behavior module.fp_x.fill_(self.fp_bits) if training: self.optimizer.zero_grad() self.optimizer.update(self.epoch, self.training_steps) for i, (inputs, target) in enumerate( zip(inputs_batch.chunk(chunk_batch, dim=0), target_batch.chunk(chunk_batch, dim=0))): # pdb.set_trace() # self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) and (float(self.epoch) % 2 == 0) and i==0 else False if training: BK_hookF = {} # if (self.epoch != 0) and (iter == 0): if (self.epoch != 0) and (iter in self.iters_in_fp32): # print("RULE 0") # pdb.set_trace() for name, module in self.module_to_hook.items(): if hasattr(module, 'enable_grad_quantizer'): if iter == self.iters_in_fp32[0]: module.mu.reset() # if self.epoch == 1: module.enable_grad_quantizer = False # this is the default start behavior # else: # module.enable_grad_quantizer = True BK_hookF[name] = Hook(module, name, True, 'collect and quantize') else: BK_hookF[name] = Hook(module, name, True, 'collect only') # elif (self.epoch != 0) and (iter == 1): elif (self.epoch != 0) and (iter in self.iters_wram_up_with_q): # print("RULE 1") # pdb.set_trace() # self.highest_exp_bits = 0 for name, module in self.module_to_hook.items(): if hasattr(module, 'enable_grad_quantizer'): module.enable_grad_quantizer = True # this is the default start behavior # if self.epoch in [1,2,3,4,9,10,20,30,40,50,60,70,78,79,80,81,82,83,84,85,86,90,100,110,120]: # if self.epoch in [1, 2, 3, 4, 30, 31, 32,33, 60, 61, 62, 63, 90,91,92,93]: # if self.epoch in [1,2,3,4, 80,81,82,83]: # 30, 31, 32, 60, 61, 62, 90, 91, 92]: # if self.epoch in [1, 2, 3, 4, 30, 31, 32, 60, 61, 62, 80, 81, 82]: # module.loss_scale.fill_(2 ** (-module.mu-3)) # update loss scale module.loss_scale.fill_( 2**(-module.mu.avg - 3)) # update loss scale # self.highest_exp_bits = module.fp8_grad_args['exp_width'] if module.fp8_grad_args['exp_width'] > self.highest_exp_bits else self.highest_exp_bits # self.highest_exp_bits = 4 # module.enable_grad_quantizer = False # this is the default start behavior # print("name: " + str(name)) BK_hookF[name] = Hook(module, name, True, 'collect and quantize') else: BK_hookF[name] = Hook(module, name, True, 'collect only') for name, module in self.module_to_hook.items(): if hasattr(module, 'enable_grad_quantizer'): if iter == self.iters_wram_up_with_q[0]: module.mu.reset() module.fp8_grad_args = dict( exp_width=self.highest_exp_bits, man_width=int( (module.fp_x - 1 - self.highest_exp_bits)), exp_bias=(2**(self.highest_exp_bits - 1)) - 1, roundingMode=0, lfsrVal=0) # elif (self.epoch != 0) and (iter > 1): elif (self.epoch != 0) and (iter > self.iters_wram_up_with_q[-1]): # print("RULE 2") # pdb.set_trace() # self.highest_exp_bits = 0 for name, module in self.module_to_hook.items(): if hasattr(module, 'enable_grad_quantizer'): # module.enable_grad_quantizer = False # this is the default start behavior module.enable_grad_quantizer = True # this is the default start behavior # if self.epoch in [1,2,3,4, 9, 10, 20, 30, 40, 50, 60, 70, 78, 79, 80, 81, 82, 83, 84, 85, 86, 90, # 100, 110, 120]: # if self.epoch in [1, 2, 3, 4, 79, 80, 81, 82, 83, 84]: # if self.epoch in [1, 2, 3, 4, 30, 31, 32, 33, 60, 61, 62, 63, 90, 91, 92, 93]: # if self.epoch in [1, 30, 31, 32, 60, 61, 62, 90, 91, 92]: # self.highest_exp_bits = module.fp8_grad_args['exp_width'] if module.fp8_grad_args['exp_width'] > self.highest_exp_bits else self.highest_exp_bits # self.highest_exp_bits = 4 # if self.epoch in [1, 2, 3, 4, 80, 81, 82, 83]: # 30, 31, 32, 60, 61, 62, 90, 91, 92]: # if self.epoch in [1, 2, 3, 4, 30, 31, 32, 60, 61, 62, 80, 81, 82]: # module.loss_scale.fill_(2 ** (-module.mu-3)) # update loss scale module.loss_scale.fill_( 2**(-module.mu.avg - 3)) # update loss scale for name, module in self.module_to_hook.items(): if hasattr(module, 'enable_grad_quantizer'): module.fp8_grad_args = dict( exp_width=self.highest_exp_bits, man_width=int( (module.fp_x - 1 - self.highest_exp_bits)), exp_bias=(2**(self.highest_exp_bits - 1)) - 1, roundingMode=0, lfsrVal=0) elif self.epoch == 0: # print("RULE 3") for name, module in self.module_to_hook.items(): BK_hookF[name] = Hook(module, name, True, 'collect only') target = target.to(self.device) inputs = inputs.to(self.device, dtype=self.dtype) mixup = None if training: self.optimizer.pre_forward() if self.mixup is not None or self.cutmix is not None: input_mixup = CutMix() if self.cutmix else MixUp() mix_val = self.mixup or self.cutmix mixup_modules = [input_mixup] # input mixup mixup_modules += [ m for m in self.model.modules() if isinstance(m, MixUp) ] mixup = _mixup(mixup_modules, mix_val, inputs.size(0)) inputs = input_mixup(inputs) # compute output output = self.model(inputs) if mixup is not None: target = mixup.mix_target(target, output.size(-1)) if average_output: if isinstance(output, list) or isinstance(output, tuple): output = [ _average_duplicates(out, target) if out is not None else None for out in output ] else: output = _average_duplicates(output, target) loss = self.criterion(output, target) grad = None if chunk_batch > 1: loss = loss / chunk_batch if isinstance(output, list) or isinstance(output, tuple): output = output[0] outputs.append(output.detach()) total_loss += float(loss) if training: if i == 0: self.optimizer.pre_backward() if self.grad_scale is not None: pass # loss = loss * self.grad_scale if self.loss_scale is not None: pass # moran # pdb.set_trace() # loss = loss * self.loss_scale loss.backward() # accumulate gradient ## moran- gather lognorm statistics if BK_hookF: # pdb.set_trace() # total = reduce((lambda total, item: total+item.alloc), snds) for hook in BK_hookF.values(): meters[hook.name]['mean'].update( float(hook.grad_log_mean), inputs.size(0)) meters[hook.name]['std'].update( float(hook.grad_log_std), inputs.size(0)) # curr_grad_log_stats = {hook.name: {'mean': hook.grad_log_mean, 'std': hook.grad_log_std} for hook in BK_hookF.values()} # pdb.set_trace() grad_log_stats = { name: { 'mean': met['mean'].avg, 'std': met['std'].avg } for name, met in meters.items() } if BK_hookF: for hook in BK_hookF.values(): hook.close() ## moran- end # print("grad_log_stats! loop 1") if training: # post gradient accumulation if self.loss_scale is not None: for p in self.model.parameters(): if p.grad is None: continue # p.grad.data.div_(self.loss_scale) # moran if self.grad_clip > 0: grad = clip_grad_norm_(self.model.parameters(), self.grad_clip) # dont_update_grad = False # for param in self.model.parameters(): # param.grad[torch.isnan(param.grad)] = 0 # if torch.isnan(param.grad).any(): # moran # pdb.set_trace() # dont_update_grad = True # break # if dont_update_grad: # pass # else: # self.optimizer.step() # SGD step if (self.epoch != 0) and (iter == 0): pass else: self.optimizer.step() # SGD step self.training_steps += 1 outputs = torch.cat(outputs, dim=0) return outputs, total_loss, grad, grad_log_stats
def _step(self, inputs_batch, target_batch, training=False, average_output=False, chunk_batch=1): outputs = [] total_loss = 0 if training: self.optimizer.zero_grad() self.optimizer.update(self.epoch, self.training_steps) for i, (inputs, target) in enumerate( zip(inputs_batch.chunk(chunk_batch, dim=0), target_batch.chunk(chunk_batch, dim=0))): target = target.to(self.device) inputs = inputs.to(self.device, dtype=self.dtype) mixup = None if training: self.optimizer.pre_forward() if self.mixup is not None or self.cutmix is not None: input_mixup = CutMix() if self.cutmix else MixUp() mix_val = self.mixup or self.cutmix mixup_modules = [input_mixup] # input mixup mixup_modules += [ m for m in self.model.modules() if isinstance(m, MixUp) ] mixup = _mixup(mixup_modules, mix_val, inputs.size(0)) inputs = input_mixup(inputs) if self.monitor > 3: for mname, m in self.model.named_modules(): if isinstance(m, nn.BatchNorm2d): if self.counter % self.monitor_freq == 0: if self.monitor < 5: self.writer.add_histogram( f"BN_Weights/{mname}", m.weight.view(-1), self.counter) self.writer.add_histogram( f"BN_Biases/{mname}", m.bias.view(-1), self.counter) self.writer.add_scalar(f"BN_weight_L2/{mname}", (m.weight.data**2).sum(), self.counter) self.writer.add_scalar(f"BN_bias_L2/{mname}", (m.bias.data**2).sum(), self.counter) if self.monitor > 5: self.writer.add_scalar( f"BN_Running_Var/{mname}", (m.running_var**2).sum(), self.counter) self.writer.add_scalar( f"BN_Running_Mean/{mname}", (m.running_mean**2).sum(), self.counter) elif isinstance(m, AConv2d): self.writer.add_scalar(f"Conv_weight_L2/{mname}", (m.weight.data**2).sum(), self.counter) ##self.writer.add_scalar(f"Conv_bias_L2/{mname}" ,(m.bias.data**2).sum(), self.counter ) self.counter = self.counter + 1 elif self.monitor > 2: if self.counter % self.monitor_freq == 0: for mname, m in self.model.named_modules(): if isinstance(m, nn.BatchNorm2d): if self.monitor < 5: self.writer.add_histogram( f"BN_Weights/{mname}", m.weight.data.view(-1), self.counter) ##self.writer.add_histogram(f"BN_Biases/{mname}" ,m.bias.data.view(-1) , self.counter ) self.counter = self.counter + 1 # compute output output = self.model(inputs) if mixup is not None: target = mixup.mix_target(target, output.size(-1)) if average_output: if isinstance(output, list) or isinstance(output, tuple): output = [ _average_duplicates(out, target) if out is not None else None for out in output ] else: output = _average_duplicates(output, target) loss = self.criterion(output, target) grad = None if chunk_batch > 1: loss = loss / chunk_batch if isinstance(output, list) or isinstance(output, tuple): output = output[0] outputs.append(output.detach()) total_loss += float(loss) if training: if i == 0: self.optimizer.pre_backward() if self.grad_scale is not None: loss = loss * self.grad_scale if self.loss_scale is not None: loss = loss * self.loss_scale loss.backward() # accumulate gradient if training: # post gradient accumulation if self.loss_scale is not None: for p in self.model.parameters(): if p.grad is None: continue p.grad.data.div_(self.loss_scale) if self.grad_clip > 0: grad = clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optimizer.step() # SGD step self.training_steps += 1 outputs = torch.cat(outputs, dim=0) return outputs, total_loss, grad
def _step(self, inputs_batch, target_batch, training=False, average_output=False, chunk_batch=1): outputs = [] total_loss = 0 if training: self.optimizer.zero_grad() self.optimizer.update(self.epoch, self.training_steps) for i, (inputs, target) in enumerate( zip(inputs_batch.chunk(chunk_batch, dim=0), target_batch.chunk(chunk_batch, dim=0))): target = target.to(self.device) inputs = inputs.to(self.device, dtype=self.dtype) mixup = None if training: self.optimizer.pre_forward() if self.mixup is not None: input_mixup = MixUp() mixup_modules = [input_mixup] # input mixup mixup_modules += [ m for m in self.model.modules() if isinstance(m, MixUp) ] mixup = _mixup(mixup_modules, self.mixup, inputs.size(0)) inputs = input_mixup(inputs) # compute output output = self.model(inputs) if mixup is not None: target = mixup.mix_target(target, output.size(-1)) if average_output: if isinstance(output, list) or isinstance(output, tuple): output = [ _average_duplicates(out, target) if out is not None else None for out in output ] else: output = _average_duplicates(output, target) if isinstance(self.criterion, nn.KLDivLoss): emb = torch.zeros(output.shape) for t in range(target.shape[0]): emb[t] = self.output_embed_fp32[target[t].tolist()] loss = self.criterion(F.log_softmax(output), F.softmax(emb.to(output))) else: loss = self.criterion(output, target) grad = None if chunk_batch > 1: loss = loss / chunk_batch if isinstance(output, list) or isinstance(output, tuple): output = output[0] outputs.append(output.detach()) total_loss += float(loss) if training: if i == 0: self.optimizer.pre_backward() if self.grad_scale is not None: loss = loss * self.grad_scale if self.loss_scale is not None: loss = loss * self.loss_scale loss.backward() # accumulate gradient if self.update_only_th and not self.optimize_rounding: for p in self.model.parameters(): if p.shape[0] == 1000 or p.dim() == 2: p.grad = None if training: # post gradient accumulation if self.loss_scale is not None: for p in self.model.parameters(): if p.grad is None: continue p.grad.data.div_(self.loss_scale) if self.grad_clip > 0: grad = clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optimizer.step() # SGD step self.training_steps += 1 if self.optimize_rounding: sd = self.model.state_dict() for key in sd: if 'quantize_weight' in key and 'range' in key: trange = sd[key] tzp = sd[key.replace('range', 'zero_point')] weights_name = key.replace( 'quantize_weight.running_range', 'weight') #import pdb; pdb.set_trace() t1 = self.fp_state_dict[weights_name.replace( 'module.', '')] t2 = sd[weights_name] new_weight = quant_round_constrain(t1, t2, trange, tzp) sd[weights_name] = new_weight outputs = torch.cat(outputs, dim=0) return outputs, total_loss, grad
def _step(self, inputs_batch, target_batch, training=False, average_output=False, chunk_batch=1): outputs = [] total_loss = 0 if training: self.optimizer.zero_grad() self.optimizer.update(self.epoch, self.training_steps) for i, (inputs, target) in enumerate( zip(inputs_batch.chunk(chunk_batch, dim=0), target_batch.chunk(chunk_batch, dim=0))): target = target.to(self.device) inputs = inputs.to(self.device, dtype=self.dtype) mixup = None if training: self.optimizer.pre_forward() if self.mixup is not None or self.cutmix is not None: input_mixup = CutMix() if self.cutmix else MixUp() mix_val = self.mixup or self.cutmix mixup_modules = [input_mixup] # input mixup mixup_modules += [ m for m in self.model.modules() if isinstance(m, MixUp) ] mixup = _mixup(mixup_modules, mix_val, inputs.size(0)) inputs = input_mixup(inputs) # compute output output = self.model(inputs) if mixup is not None: target = mixup.mix_target(target, output.size(-1)) if average_output: if isinstance(output, list) or isinstance(output, tuple): output = [ _average_duplicates(out, target) if out is not None else None for out in output ] else: output = _average_duplicates(output, target) loss = self.criterion(output, target) grad = None if chunk_batch > 1: loss = loss / chunk_batch if isinstance(output, list) or isinstance(output, tuple): output = output[0] outputs.append(output.detach()) total_loss += float(loss) if training: if i == 0: self.optimizer.pre_backward() if self.grad_scale is not None: loss = loss * self.grad_scale if self.loss_scale is not None: loss = loss * self.loss_scale loss.backward() # accumulate gradient if training: # post gradient accumulation if self.loss_scale is not None: for p in self.model.parameters(): if p.grad is None: continue p.grad.data.div_(self.loss_scale) if self.grad_clip > 0: grad = clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optimizer.step() # SGD step self.training_steps += 1 outputs = torch.cat(outputs, dim=0) return outputs, total_loss, grad
def _step(self, inputs_batch, target_batch, training=False, average_output=False, chunk_batch=1): outputs = [] total_loss = 0 if training: self.optimizer.zero_grad() self.optimizer.update(self.epoch, self.training_steps) for i, (inputs, target) in enumerate( zip(inputs_batch.chunk(chunk_batch, dim=0), target_batch.chunk(chunk_batch, dim=0))): BK_hookF = {} for name, module in self.module_to_hook.items(): # print("name: " + str(name)) BK_hookF[name] = Hook(module, name, True) target = target.to(self.device) inputs = inputs.to(self.device, dtype=self.dtype) mixup = None if training: self.optimizer.pre_forward() if self.mixup is not None or self.cutmix is not None: input_mixup = CutMix() if self.cutmix else MixUp() mix_val = self.mixup or self.cutmix mixup_modules = [input_mixup] # input mixup mixup_modules += [ m for m in self.model.modules() if isinstance(m, MixUp) ] mixup = _mixup(mixup_modules, mix_val, inputs.size(0)) inputs = input_mixup(inputs) # compute output output = self.model(inputs) if mixup is not None: target = mixup.mix_target(target, output.size(-1)) if average_output: if isinstance(output, list) or isinstance(output, tuple): output = [ _average_duplicates(out, target) if out is not None else None for out in output ] else: output = _average_duplicates(output, target) loss = self.criterion(output, target) grad = None if chunk_batch > 1: loss = loss / chunk_batch if isinstance(output, list) or isinstance(output, tuple): output = output[0] outputs.append(output.detach()) total_loss += float(loss) if training: if i == 0: self.optimizer.pre_backward() if self.grad_scale is not None: loss = loss * self.grad_scale if self.loss_scale is not None: loss = loss * self.loss_scale loss.backward() # accumulate gradient ## moran- gather lognorm statistics grad_log_stats = [] if self.input_grad_statistics: # total = reduce((lambda total, item: total+item.alloc), snds) curr_grad_log_stats = [{ 'layer_name': hook.name, 'mean': hook.grad_log_mean, 'std': hook.grad_log_std } for hook in BK_hookF.values()] pdb.set_trace() for hook in BK_hookF.values(): hook.close() ## moran- end if training: # post gradient accumulation if self.loss_scale is not None: for p in self.model.parameters(): if p.grad is None: continue p.grad.data.div_(self.loss_scale) if self.grad_clip > 0: grad = clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optimizer.step() # SGD step self.training_steps += 1 outputs = torch.cat(outputs, dim=0) return outputs, total_loss, grad, grad_log_stats