def binary_search(model, target): """ Binary search algorithm to determine the threshold :param model: :param target: :param merge_flag: :return: """ # target = 0.70 # threshold = model.get_model().args.threshold step = 0.01 # step_min = 0.0001 status = 1.0 stop = 0.001 counter = 1 max_iter = 100 flops = get_flops(model) params = get_parameters(model) while abs(status - target) > stop and counter <= max_iter: status_old = status # calculate flops and status model.set_parameters() flops_prune = get_flops(model) status = flops_prune / flops params_prune = get_parameters(model) params_compression_ratio = params_prune / params string = 'Iter {:<3}: current step={:1.8f}, current threshold={:2.8f}, status (FLOPs ratio) = {:2.4f}, ' \ 'params ratio = {:2.4f}.\n'\ .format(counter, step, model.pt, status, params_compression_ratio) print(string) if abs(status - target) > stop: # calculate the next step flag = False if counter == 1 else (status_old >= target) == ( status < target) if flag: step /= 2 # calculate the next threshold if status > target: model.pt += step elif status < target: model.pt -= step model.pt = max(model.pt, 0) counter += 1 # deal with the unexpected status if model.pt < 0 or status <= 0: print('Status {} or threshold {} is out of range'.format( status, model.pt)) break else: print( 'The target compression ratio is achieved. The loop is stopped' )
def reset_after_optimization(self): # During the reloading and testing phase, the searched sparse model is already loaded at initialization. # During the training phase, the searched sparse model is just there. if not self.converging and not self.args.test_only: self.model.get_model().reset_after_searching() self.converging = True self.optimizer = utility.make_optimizer_dhp( self.args, self.model, converging=self.converging) self.scheduler = utility.make_scheduler_dhp( self.args, self.optimizer, int(self.args.lr_decay_step.split('+')[1]), converging=self.converging) # print(self.model.get_model()) print(self.model.get_model(), file=self.ckp.log_file) # calculate flops and number of parameters self.flops_prune = get_flops(self.model.get_model()) self.flops_compression_ratio = self.flops_prune / self.flops self.params_prune = get_parameters(self.model.get_model()) self.params_compression_ratio = self.params_prune / self.params # reset tensorboardX summary if not self.args.test_only and self.args.summary: self.writer = SummaryWriter(os.path.join(self.args.dir_save, self.args.save), comment='converging') # get the searching epochs if os.path.exists(os.path.join(self.ckp.dir, 'epochs.pt')): self.epochs_searching = torch.load( os.path.join(self.ckp.dir, 'epochs.pt'))
def reset_after_searching(self): # Phase 1 & 3, model reset here. # PHase 2 & 4, model reset at initialization # In Phase 1 & 3, the optimizer and scheduler are reset. # In Phase 2, the optimizer and scheduler is not used. # In Phase 4, the optimizer and scheduler is already set during the initialization of the trainer. # during the converging stage, self.converging =True. Do not need to set lr_adjust_flag in make_optimizer_hinge # and make_scheduler_hinge. if not self.converging and not self.args.test_only: self.model.get_model().reset_after_searching() self.converging = True del self.optimizer, self.scheduler torch.cuda.empty_cache() decay = self.args.decay if len(self.args.decay.split('+')) == 1 else self.args.decay.split('+')[1] self.optimizer = utility.make_optimizer_dhp(self.args, self.model, converging=self.converging) self.scheduler = utility.make_scheduler_dhp(self.args, self.optimizer, decay, converging=self.converging) self.flops_prune = get_flops(self.model.get_model()) self.flops_compression_ratio = self.flops_prune / self.flops self.params_prune = get_parameters(self.model.get_model()) self.params_compression_ratio = self.params_prune / self.params if not self.args.test_only and self.args.summary: self.writer = SummaryWriter(os.path.join(self.args.dir_save, self.args.save), comment='converging') if os.path.exists(os.path.join(self.ckp.dir, 'epochs.pt')): self.epochs_searching = torch.load(os.path.join(self.ckp.dir, 'epochs.pt'))
def __init__(self, args, loader, my_model, my_loss, ckp=None, writer=None, converging=False, model_teacher=None): self.args = args self.ckp = ckp self.loader_train = loader.loader_train self.loader_test = loader.loader_test self.model = my_model self.model_teacher = model_teacher self.loss = my_loss self.writer = writer self.loss_mse = nn.MSELoss( ) if self.args.distillation_inter == 'kd' else None if args.data_train.find('CIFAR') >= 0: self.input_dim = (3, 32, 32) elif args.data_train.find('Tiny_ImageNet') >= 0: self.input_dim = (3, 64, 64) else: self.input_dim = (3, 224, 224) set_output_dimension(self.model.get_model(), self.input_dim) self.flops = get_flops(self.model.get_model()) self.flops_prune = self.flops # at initialization, no pruning is conducted. self.flops_compression_ratio = self.flops_prune / self.flops self.params = get_parameters(self.model.get_model()) self.params_prune = self.params self.params_compression_ratio = self.params_prune / self.params self.flops_ratio_log = [] self.params_ratio_log = [] self.converging = converging self.ckp.write_log( '\nThe computation complexity and number of parameters of the current network is as follows.' '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format( self.flops / 10.**9, self.params / 10.**3)) self.flops_another = get_model_flops(self.model.get_model(), self.input_dim, False) self.ckp.write_log( 'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated ' 'Flops are the same.\n'.format(self.flops_another / 10.**9)) self.optimizer = utility.make_optimizer_dhp(args, self.model, ckp=ckp, converging=converging) self.scheduler = utility.make_scheduler_dhp(args, self.optimizer, args.decay.split('+')[0], converging=converging) self.device = torch.device('cpu' if args.cpu else 'cuda') if args.model.find('INQ') >= 0: self.inq_steps = args.inq_steps else: self.inq_steps = None
def __init__(self, args, loader, my_model, my_loss, ckp, writer=None, converging=False): self.args = args self.scale = args.scale self.ckp = ckp self.loader_train = loader.loader_train self.loader_test = loader.loader_test self.model = my_model self.loss = my_loss self.writer = writer self.optimizer = utility.make_optimizer_dhp(args, self.model, ckp, converging=converging) self.scheduler = utility.make_scheduler_dhp( args, self.optimizer, int(args.lr_decay_step.split('+')[0]), converging=converging) if self.args.model.lower().find( 'unet') >= 0 or self.args.model.lower().find('dncnn') >= 0: self.input_dim = (1, args.input_dim, args.input_dim) else: self.input_dim = (3, args.input_dim, args.input_dim) # embed() set_output_dimension(self.model.get_model(), self.input_dim) self.flops = get_flops(self.model.get_model()) self.flops_prune = self.flops # at initialization, no pruning is conducted. self.flops_compression_ratio = self.flops_prune / self.flops self.params = get_parameters(self.model.get_model()) self.params_prune = self.params self.params_compression_ratio = self.params_prune / self.params self.flops_ratio_log = [] self.params_ratio_log = [] self.converging = converging self.ckp.write_log( '\nThe computation complexity and number of parameters of the current network is as follows.' '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format( self.flops / 10.**9, self.params / 10.**3)) self.flops_another = get_model_flops(self.model.get_model(), self.input_dim, False) self.ckp.write_log( 'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated ' 'Flops are the same.\n'.format(self.flops_another / 10.**9)) self.error_last = 1e8
def __init__(self, args, loader, my_model, my_loss, ckp, writer=None): self.args = args self.ckp = ckp self.loader_train = loader.loader_train self.loader_test = loader.loader_test self.model = my_model self.loss = my_loss self.writer = writer if args.data_train.find('CIFAR') >= 0: self.input_dim = (3, 32, 32) elif args.data_train.find('Tiny_ImageNet') >= 0: self.input_dim = (3, 64, 64) else: self.input_dim = (3, 224, 224) set_output_dimension(self.model.get_model(), self.input_dim) self.flops = get_flops(self.model.get_model()) self.params = get_parameters(self.model.get_model()) self.ckp.write_log( '\nThe computation complexity and number of parameters of the current network is as follows.' '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format( self.flops / 10.**9, self.params / 10.**3)) self.flops_another = get_model_flops(self.model.get_model(), self.input_dim, False) self.ckp.write_log( 'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated ' 'Flops are the same.\n'.format(self.flops_another / 10.**9)) self.optimizer = utility.make_optimizer(args, self.model, ckp=ckp) self.scheduler = utility.make_scheduler(args, self.optimizer) self.device = torch.device('cpu' if args.cpu else 'cuda') if args.model.find('INQ') >= 0: self.inq_steps = args.inq_steps else: self.inq_steps = None
# Step 3: Pruning -> prune the derived sparse model and prepare the trainer instance for finetuning or testing # ================================================================================================================== t.reset_after_optimization() if args.print_model: print(t.model.get_model()) print(t.model.get_model(), file=checkpoint.log_file) # ================================================================================================================== # Step 4: Continue the training / Testing -> continue to train the pruned model to have a higher accuracy. # ================================================================================================================== while not t.terminate(): t.train() t.test() set_output_dimension(model.get_model(), t.input_dim) flops = get_flops(model.get_model()) params = get_parameters(model.get_model()) print( '\nThe computation complexity and number of parameters of the current network is as follows.' '\nFlops: {:.4f} [G]\tParams {:.2f} [k]\n'.format( flops / 10.**9, params / 10.**3)) if args.summary: t.writer.close() if args.print_model: print(t.model.get_model()) print(t.model.get_model(), file=checkpoint.log_file) # for m in t.model.parameters(): # print(m.shape) checkpoint.done()
def train(self): self.loss.step() epoch = self.scheduler.last_epoch + 1 learning_rate = self.scheduler.get_lr()[0] idx_scale = self.args.scale if not self.converging: stage = 'Searching Stage' else: stage = 'Finetuning Stage (Searching Epoch {})'.format( self.epochs_searching) self.ckp.write_log('\n[Epoch {}]\tLearning rate: {:.2e}\t{}'.format( epoch, Decimal(learning_rate), stage)) self.loss.start_log() self.model.train() timer_data, timer_model = utility.timer(), utility.timer() for batch, (lr, hr, _) in enumerate(self.loader_train): # if batch <= 1200: lr, hr = self.prepare([lr, hr]) timer_data.hold() timer_model.tic() self.optimizer.zero_grad() sr = self.model(idx_scale, lr) loss = self.loss(sr, hr) if loss.item() < self.args.skip_threshold * self.error_last: # Adam loss.backward() self.optimizer.step() # proximal operator if not self.converging: self.model.get_model().proximal_operator(learning_rate) # check the compression ratio if (batch + 1) % self.args.compression_check_frequency == 0: # set the channels of the potential pruned model self.model.get_model().set_parameters() # update the flops and number of parameters self.flops_prune = get_flops(self.model.get_model()) self.flops_compression_ratio = self.flops_prune / self.flops self.params_prune = get_parameters( self.model.get_model()) self.params_compression_ratio = self.params_prune / self.params self.flops_ratio_log.append( self.flops_compression_ratio) self.params_ratio_log.append( self.params_compression_ratio) if self.terminate(): break if (batch + 1) % 1000 == 0: self.model.get_model().latent_vector_distribution( epoch, batch + 1, self.ckp.dir) self.model.get_model().per_layer_compression_ratio( epoch, batch + 1, self.ckp.dir) else: print('Skip this batch {}! (Loss: {}) (Threshold: {})'.format( batch + 1, loss.item(), self.args.skip_threshold * self.error_last)) timer_model.hold() if (batch + 1) % self.args.print_every == 0: self.ckp.write_log( '[{}/{}]\t{}\t{:.3f}+{:.3f}s' '\tFlops Ratio: {:.2f}% = {:.4f} G / {:.4f} G' '\tParams Ratio: {:.2f}% = {:.2f} k / {:.2f} k'.format( (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), self.loss.display_loss(batch), timer_model.release(), timer_data.release(), self.flops_compression_ratio * 100, self.flops_prune / 10.**9, self.flops / 10.**9, self.params_compression_ratio * 100, self.params_prune / 10.**3, self.params / 10.**3)) timer_data.tic() # else: # break self.loss.end_log(len(self.loader_train)) self.error_last = self.loss.log[-1, -1] # self.error_last = loss self.scheduler.step()
def train(self): epoch, lr = self.start_epoch() self.model.begin( epoch, self.ckp ) #TODO: investigate why not using self.model.train() directly self.loss.start_log() timer_data, timer_model = utility.timer(), utility.timer() n_samples = 0 for batch, (img, label) in enumerate(self.loader_train): # embed() if (self.args.data_train == 'ImageNet' or self.args.model.lower() == 'efficientnet_hh') and not self.converging: if self.args.model == 'ResNet_ImageNet_HH' or self.args.model == 'RegNet_ImageNet_HH': divider = 4 else: divider = 2 print('Divider is {}'.format(divider)) batch_size = img.shape[0] // divider img = img[:batch_size] label = label[:batch_size] # embed() img, label = self.prepare(img, label) n_samples += img.size(0) timer_data.hold() timer_model.tic() self.optimizer.zero_grad() prediction = self.model(img) # embed() if (not self.converging and self.args.distillation_stage == 'c') or \ (self.converging and not self.args.distillation_final): loss, _ = self.loss(prediction, label) else: with torch.no_grad(): prediction_teacher = self.model_teacher(img) if not self.args.distillation_inter: prediction = [prediction] prediction_teacher = [prediction_teacher] loss, _ = self.loss(prediction[0], label) if self.args.distillation_final == 'kd': loss_distill_final = distillation(prediction[0], prediction_teacher[0], T=4) loss = 0.4 * loss_distill_final + 0.6 * loss elif self.args.distillation_inter == 'sp': loss_distill_final = similarity_preserving( prediction[0], prediction_teacher[0]) * 3000 loss = loss_distill_final + loss if self.args.distillation_inter == 'kd': loss_distill_inter = 0 for p, pt in zip(prediction[1], prediction_teacher[1]): loss_distill_inter += self.loss_mse(p, pt) # embed() loss_distill_inter = loss_distill_inter / len( prediction[1]) * self.args.distill_beta loss = loss_distill_inter + loss elif self.args.distillation_inter == 'sp': loss_distill_inter = 0 for p, pt in zip(prediction[1], prediction_teacher[1]): loss_distill_inter += similarity_preserving(p, pt) loss_distill_inter = loss_distill_inter / len( prediction[1]) * 3000 * self.args.distill_beta # loss_distill_inter = similarity_preserving(prediction[1], prediction_teacher[1]) loss = loss_distill_inter + loss # else: self.args.distillation_inter == '', do nothing here # SGD loss.backward() self.optimizer.step() if not self.converging and self.args.use_prox: # if epoch > 5: # proximal operator self.model.get_model().proximal_operator(lr) if (batch + 1) % self.args.compression_check_frequency == 0: self.model.get_model().set_parameters() self.flops_prune = get_flops(self.model.get_model()) self.flops_compression_ratio = self.flops_prune / self.flops self.params_prune = get_parameters(self.model.get_model()) self.params_compression_ratio = self.params_prune / self.params self.flops_ratio_log.append(self.flops_compression_ratio) self.params_ratio_log.append(self.params_compression_ratio) if self.terminate(): break if (batch + 1) % 300 == 0: self.model.get_model().latent_vector_distribution( epoch, batch + 1, self.ckp.dir) self.model.get_model().per_layer_compression_ratio( epoch, batch + 1, self.ckp.dir) timer_model.hold() if (batch + 1) % self.args.print_every == 0: s = '{}/{} ({:.0f}%)\tNLL: {:.3f} Top1: {:.2f} / Top5: {:.2f}\t'.format( n_samples, len(self.loader_train.dataset), 100.0 * n_samples / len(self.loader_train.dataset), *(self.loss.log_train[-1, :] / n_samples)) if self.converging or (not self.converging and self.args.distillation_stage == 's'): if self.args.distillation_final: s += 'DFinal: {:.3f} '.format(loss_distill_final) if self.args.distillation_inter: s += 'DInter: {:.3f}'.format(loss_distill_inter) if self.args.distillation_final or self.args.distillation_inter: s += '\t' s += 'Time: {:.1f}+{:.1f}s\t'.format(timer_model.release(), timer_data.release()) if hasattr(self, 'flops_compression_ratio') and hasattr( self, 'params_compression_ratio'): s += 'Flops: {:.2f}% = {:.4f} [G] / {:.4f} [G]\t' \ 'Params: {:.2f}% = {:.2f} [k] / {:.2f} [k]'.format( self.flops_compression_ratio * 100, self.flops_prune / 10. ** 9, self.flops / 10. ** 9, self.params_compression_ratio * 100, self.params_prune / 10. ** 3, self.params / 10. ** 3) self.ckp.write_log(s) if self.args.summary: if (batch + 1) % 50 == 0: for name, param in self.model.named_parameters(): if name.find('features') >= 0 and name.find( 'weight') >= 0: self.writer.add_scalar( 'data/' + name, param.clone().cpu().data.abs().mean().numpy(), 1000 * (epoch - 1) + batch) if param.grad is not None: self.writer.add_scalar( 'data/' + name + '_grad', param.grad.clone().cpu().data.abs().mean(). numpy(), 1000 * (epoch - 1) + batch) if (batch + 1) == 500: for name, param in self.model.named_parameters(): if name.find('features') >= 0 and name.find( 'weight') >= 0: self.writer.add_histogram( name, param.clone().cpu().data.numpy(), 1000 * (epoch - 1) + batch) if param.grad is not None: self.writer.add_histogram( name + '_grad', param.grad.clone().cpu().data.numpy(), 1000 * (epoch - 1) + batch) timer_data.tic() if not self.converging and epoch == self.args.epochs_grad and batch == 1: break self.model.log(self.ckp) # TODO: why this is used? self.loss.end_log(len(self.loader_train.dataset))
def train(self): epoch, lr = self.start_epoch() self.model.begin(epoch, self.ckp) #TODO: investigate why not using self.model.train() directly self.loss.start_log() timer_data, timer_model = utility.timer(), utility.timer() n_samples = 0 for batch, (img, label) in enumerate(self.loader_train): img, label = self.prepare(img, label) n_samples += img.size(0) timer_data.hold() timer_model.tic() self.optimizer.zero_grad() prediction = self.model(img) loss, _ = self.loss(prediction, label) # SGD loss.backward() self.optimizer.step() # proximal operator if not self.converging: self.model.get_model().proximal_operator(lr) if (batch + 1) % self.args.compression_check_frequency == 0: self.model.get_model().set_parameters() self.flops_prune = get_flops(self.model.get_model()) self.flops_compression_ratio = self.flops_prune / self.flops self.params_prune = get_parameters(self.model.get_model()) self.params_compression_ratio = self.params_prune / self.params self.flops_ratio_log.append(self.flops_compression_ratio) self.params_ratio_log.append(self.params_compression_ratio) # if self.terminate(): # break if (batch + 1) % 300 == 0: self.model.get_model().latent_vector_distribution(epoch, batch + 1, self.ckp.dir) self.model.get_model().per_layer_compression_ratio(epoch, batch + 1, self.ckp.dir) timer_model.hold() if (batch + 1) % self.args.print_every == 0: self.ckp.write_log('{}/{} ({:.0f}%)\t' 'NLL: {:.3f}\tTop1: {:.2f} / Top5: {:.2f}\t' 'Time: {:.1f}+{:.1f}s\t' 'Flops Ratio: {:.2f}% = {:.4f} [G] / {:.4f} [G]\t' 'Params Ratio: {:.2f}% = {:.2f} [k] / {:.2f} [k]'.format( n_samples, len(self.loader_train.dataset), 100.0 * n_samples / len(self.loader_train.dataset), *(self.loss.log_train[-1, :] / n_samples), timer_model.release(), timer_data.release(), self.flops_compression_ratio * 100, self.flops_prune / 10. ** 9, self.flops / 10. ** 9, self.params_compression_ratio * 100, self.params_prune / 10. ** 3, self.params / 10. ** 3)) if not self.converging and self.terminate(): break if self.args.summary: if (batch + 1) % 50 == 0: for name, param in self.model.named_parameters(): if name.find('features') >= 0 and name.find('weight') >= 0: self.writer.add_scalar('data/' + name, param.clone().cpu().data.abs().mean().numpy(), 1000 * (epoch - 1) + batch) if param.grad is not None: self.writer.add_scalar('data/' + name + '_grad', param.grad.clone().cpu().data.abs().mean().numpy(), 1000 * (epoch - 1) + batch) if (batch + 1) == 500: for name, param in self.model.named_parameters(): if name.find('features') >= 0 and name.find('weight') >= 0: self.writer.add_histogram(name, param.clone().cpu().data.numpy(), 1000 * (epoch - 1) + batch) if param.grad is not None: self.writer.add_histogram(name + '_grad', param.grad.clone().cpu().data.numpy(), 1000 * (epoch - 1) + batch) timer_data.tic() self.model.log(self.ckp) # TODO: why this is used? self.loss.end_log(len(self.loader_train.dataset))
# Step 3: Pruning -> prune the derived sparse model and prepare the trainer instance for finetuning or testing # ================================================================================================================== t.reset_after_searching() if args.print_model: print(t.model.get_model()) print(t.model.get_model(), file=checkpoint.log_file) # ================================================================================================================== # Step 4: Fintuning/Testing -> finetune the pruned model to have a higher accuracy. # ================================================================================================================== while not t.terminate(): t.train() t.test() set_output_dimension(network_model.get_model(), t.input_dim) flops = get_flops(network_model.get_model()) params = get_parameters(network_model.get_model()) print( '\nThe computation complexity and number of parameters of the current network is as follows.' '\nFlops: {:.4f} [G]\tParams {:.2f} [k]\n'.format( flops / 10.**9, params / 10.**3)) if args.summary: t.writer.close() if args.print_model: print(t.model.get_model()) print(t.model.get_model(), file=checkpoint.log_file) checkpoint.done()