class ExperimentDesign: def __init__(self, options=None): self.settings = options or Option() self.checkpoint = None self.data_loader = None self.model = None self.optimizer_state = None self.trainer = None self.start_epoch = 0 self.model_analyse = None self.visualize = vs.Visualization(self.settings.save_path) self.logger = vs.Logger(self.settings.save_path) self.test_input = None self.lr_master = None self.prepare() def prepare(self): self._set_gpu() self._set_dataloader() self._set_model() self._set_checkpoint() self._set_parallel() self._set_lr_policy() self._set_trainer() self._draw_net() def _set_gpu(self): # set torch seed # init random seed torch.manual_seed(self.settings.manualSeed) torch.cuda.manual_seed(self.settings.manualSeed) assert self.settings.GPU <= torch.cuda.device_count( ) - 1, "Invalid GPU ID" torch.cuda.set_device(self.settings.GPU) print("|===>Set GPU done!") def _set_dataloader(self): # create data loader self.data_loader = DataLoader(dataset=self.settings.dataset, batch_size=self.settings.batchSize, data_path=self.settings.dataPath, n_threads=self.settings.nThreads, ten_crop=self.settings.tenCrop) print("|===>Set data loader done!") def _set_checkpoint(self): assert self.model is not None, "please create model first" self.checkpoint = CheckPoint(self.settings.save_path) if self.settings.retrain is not None: model_state = self.checkpoint.load_model(self.settings.retrain) self.model = self.checkpoint.load_state(self.model, model_state) if self.settings.resume is not None: model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint( self.settings.resume) self.model = self.checkpoint.load_state(self.model, model_state) # self.start_epoch = epoch # self.optimizer_state = optimizer_state print("|===>Set checkpoint done!") def _set_model(self): if self.settings.dataset == "sphere": if self.settings.netType == "SphereNet": self.model = md.SphereNet( depth=self.settings.depth, num_features=self.settings.featureDim) elif self.settings.netType == "SphereNIN": self.model = md.SphereNIN( num_features=self.settings.featureDim) elif self.settings.netType == "wcSphereNet": self.model = md.wcSphereNet( depth=self.settings.depth, num_features=self.settings.featureDim, rate=self.settings.rate) else: assert False, "use %s data while network is %s" % ( self.settings.dataset, self.settings.netType) else: assert False, "unsupport data set: " + self.settings.dataset print("|===>Set model done!") def _set_parallel(self): self.model = utils.data_parallel(self.model, self.settings.nGPU, self.settings.GPU) def _set_lr_policy(self): self.lr_master = utils.LRPolicy(self.settings.lr, self.settings.nIters, self.settings.lrPolicy) params_dict = { 'gamma': self.settings.gamma, 'step': self.settings.step, 'end_lr': self.settings.endlr, 'decay_rate': self.settings.decayRate } self.lr_master.set_params(params_dict=params_dict) def _set_trainer(self): train_loader, test_loader = self.data_loader.getloader() self.trainer = Trainer(model=self.model, lr_master=self.lr_master, n_epochs=self.settings.nEpochs, n_iters=self.settings.nIters, train_loader=train_loader, test_loader=test_loader, feature_dim=self.settings.featureDim, momentum=self.settings.momentum, weight_decay=self.settings.weightDecay, optimizer_state=self.optimizer_state, logger=self.logger) def _draw_net(self): # visualize model if self.settings.dataset == "sphere": rand_input = torch.randn(1, 3, 112, 96) else: assert False, "invalid data set" rand_input = Variable(rand_input.cuda()) self.test_input = rand_input if self.settings.drawNetwork: rand_output, _ = self.trainer.forward(rand_input) self.visualize.save_network(rand_output) print("|===>Draw network done!") self.visualize.write_settings(self.settings) def pruning(self, run_count=0): net_type = None if self.settings.dataset == "sphere": if self.settings.netType == "wcSphereNet": net_type = "SphereNet" assert net_type is not None, "net_type for prune is NoneType" self.trainer.test() if isinstance(self.model, nn.DataParallel): model = self.model.module else: model = self.model if net_type == "SphereNet": model_prune = prune.SpherePrune(model) model_prune.run() self.trainer.reset_model(model_prune.model) self.model = self.trainer.model self.trainer.test() self.checkpoint.save_model(self.trainer.model, index=run_count, tag="pruning") # analyse model self.model_analyse = utils.ModelAnalyse(self.trainer.model, self.visualize) params_num = self.model_analyse.params_count() self.model_analyse.flops_compute(self.test_input) def fine_tuning(self, run_count=0): # set lr self.settings.lr = 0.01 # 0.1 self.settings.nIters = 12000 # 28000 self.settings.lrPolicy = "multi_step" self.settings.decayRate = 0.1 self.settings.step = [6000] # [16000, 24000] self._set_lr_policy() self.trainer.reset_lr(self.lr_master, self.settings.nIters) # run fine-tuning self.training(run_count, tag="fine-tuning") def retrain(self, run_count=0): self.settings.lr = 0.1 self.settings.nIters = 28000 self.settings.lrPolicy = "multi_step" self.settings.decayRate = 0.1 self.settings.step = [16000, 24000] self._set_lr_policy() self.trainer.reset_lr(self.lr_master, self.settings.nIters) # run retrain self.training(run_count, tag="training") def run(self, run_count=0): """ if run_count == 0: print "|===> training" self.retrain(run_count) else: print "|===> fine-tuning" self.fine_tuning(run_count) """ if run_count >= 1: print("|===> training") self.retrain(run_count) self.trainer.reset_model(self.model) print("|===> fine-tuning") self.fine_tuning(run_count) print("|===> pruning") self.pruning(run_count) # keep margin_linear static layer_count = 0 for layer in self.model.modules(): if isinstance(layer, md.MarginLinear): layer.iteration.fill_(0) layer.margin_type.data.fill_(1) layer.weight.requires_grad = False elif isinstance(layer, nn.Linear): layer.weight.requires_grad = False layer.bias.requires_grad = False elif isinstance(layer, nn.Conv2d): if layer.bias is not None: bias_flag = True else: bias_flag = False new_layer = prune.wcConv2d(layer.weight.size(1), layer.weight.size(0), kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=bias_flag, rate=self.settings.rate) new_layer.weight.data.copy_(layer.weight.data) if layer.bias is not None: new_layer.bias.data.copy_(layer.bias.data) if layer_count == 1: self.model.conv2 = new_layer elif layer_count == 2: self.model.conv3 = new_layer elif layer_count == 3: self.model.conv4 = new_layer layer_count += 1 print(self.model) self.trainer.reset_model(self.model) # assert False def training(self, run_count=0, tag="training"): best_top1 = 100 # start_time = time.time() self.trainer.test() # assert False for epoch in range(self.start_epoch, self.settings.nEpochs): if self.trainer.iteration >= self.trainer.n_iters: break start_epoch = 0 # training and testing train_error, train_loss, train5_error = self.trainer.train( epoch=epoch) acc_mean, acc_std, acc_all = self.trainer.test() test_error_mean = 100 - acc_mean * 100 # write and print result log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % ( epoch, train_error, train_loss, train5_error, acc_mean, acc_std) for acc in acc_all: log_str += "%.4f\t" % acc self.visualize.write_log(log_str) best_flag = False if best_top1 >= test_error_mean: best_top1 = test_error_mean best_flag = True print( colored( "# %d ==>Best Result is: Top1 Error: %f\n" % (run_count, best_top1), "red")) else: print( colored( "# %d ==>Best Result is: Top1 Error: %f\n" % (run_count, best_top1), "blue")) self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer, epoch) if best_flag: self.checkpoint.save_model(self.model, best_flag=best_flag, tag="%s_%d" % (tag, run_count)) if (epoch + 1) % self.settings.drawInterval == 0: self.visualize.draw_curves() for name, value in self.model.named_parameters(): if 'weight' in name: name = name.replace('.', '/') self.logger.histo_summary( name, value.data.cpu().numpy(), run_count * self.settings.nEpochs + epoch + 1) if value.grad is not None: self.logger.histo_summary( name + "/grad", value.grad.data.cpu().numpy(), run_count * self.settings.nEpochs + epoch + 1) # end_time = time.time() if isinstance(self.model, nn.DataParallel): self.model = self.model.module # draw experimental curves self.visualize.draw_curves() # compute cost time # time_interval = end_time - start_time # t_string = "Running Time is: " + \ # str(datetime.timedelta(seconds=time_interval)) + "\n" # print(t_string) # write cost time to file # self.visualize.write_readme(t_string) # save experimental results self.model_analyse = utils.ModelAnalyse(self.trainer.model, self.visualize) self.visualize.write_readme("Best Result of all is: Top1 Error: %f\n" % best_top1) # analyse model params_num = self.model_analyse.params_count() # save analyse result to file self.visualize.write_readme("Number of parameters is: %d" % (params_num)) self.model_analyse.prune_rate() self.model_analyse.flops_compute(self.test_input) return best_top1
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.step, gamma=0.1) # Set trainer logger = Logger(config.save_path) trainer = AlexNetTrainer(config.lr, train_loader, valid_loader, model, optimizer, scheduler, logger, device) print(model) time_start = time.time() for epoch in range(1, config.nEpochs + 1): cls_loss_, accuracy, accuracy_valid = trainer.train(epoch) checkpoint.save_model(model, index=epoch) time_end = time.time() print(time_end - time_start) # #计算样本量 # imagedb.count_num_sample() # # #记录字符集 # imagedb.write_char_set() # # #随机抽取部分字符为模型使用 # imagedb.random_select() # # #加载数据 # train_dataset, valid_dataset = imagedb.load_data()
device = torch.device("cuda" if use_cuda else "cpu") torch.backends.cudnn.benchmark = True # Set dataloader kwargs = {'num_workers': config.nThreads, 'pin_memory': True} if use_cuda else {} transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) train_loader = torch.utils.data.DataLoader( FaceDataset(config.annoPath, transform=transform, is_train=True), batch_size=config.batchSize, shuffle=True, **kwargs) # Set model model = ONet() model = model.to(device) # Set checkpoint checkpoint = CheckPoint(config.save_path) # Set optimizer optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.step, gamma=0.1) # Set trainer logger = Logger(config.save_path) trainer = ONetTrainer(config.lr, train_loader, model, optimizer, scheduler, logger, device) for epoch in range(1, config.nEpochs + 1): trainer.train(epoch) checkpoint.save_model(model, index=epoch)
class ExperimentDesign(object): """ run experiments with pre-defined pipeline """ def __init__(self): self.settings = Option() self.checkpoint = None self.data_loader = None self.model = None self.trainer = None self.seg_opt_state = None self.fc_opt_state = None self.aux_fc_state = None self.start_epoch = 0 self.model_analyse = None self.visualize = vs.Visualization(self.settings.save_path) self.logger = vs.Logger(self.settings.save_path) self.prepare() def prepare(self): """ preparing experiments """ self._set_gpu() self._set_dataloader() self._set_model() self._set_checkpoint() self._set_trainer() def _set_gpu(self): # set torch seed # init random seed torch.manual_seed(self.settings.manualSeed) torch.cuda.manual_seed(self.settings.manualSeed) assert self.settings.GPU <= torch.cuda.device_count( ) - 1, "Invalid GPU ID" torch.cuda.set_device(self.settings.GPU) cudnn.benchmark = True def _set_dataloader(self): # create data loader self.data_loader = DataLoader(dataset=self.settings.dataset, batch_size=self.settings.batchSize, data_path=self.settings.dataPath, n_threads=self.settings.nThreads, ten_crop=self.settings.tenCrop) def _set_checkpoint(self): assert self.model is not None, "please create model first" self.checkpoint = CheckPoint(self.settings.save_path) if self.settings.retrain is not None: model_state = self.checkpoint.load_model(self.settings.retrain) self.model = self.checkpoint.load_state(self.model, model_state) if self.settings.resume is not None: check_point_params = torch.load(self.settings.resume) model_state = check_point_params["model"] self.seg_opt_state = check_point_params["seg_opt"] self.fc_opt_state = check_point_params["fc_opt"] self.aux_fc_state = check_point_params["aux_fc"] self.model = self.checkpoint.load_state(self.model, model_state) self.start_epoch = 90 def _set_model(self): print("netType:", self.settings.netType) if self.settings.dataset in ["cifar10", "cifar100"]: if self.settings.netType == "DARTSNet": genotype = md.genotypes.DARTS self.model = md.DARTSNet(self.settings.init_channels, self.settings.nClasses, self.settings.layers, self.settings.auxiliary, genotype) elif self.settings.netType == "PreResNet": self.model = md.PreResNet(depth=self.settings.depth, num_classes=self.settings.nClasses, wide_factor=self.settings.wideFactor) elif self.settings.netType == "CifarResNeXt": self.model = md.CifarResNeXt(self.settings.cardinality, self.settings.depth, self.settings.nClasses, self.settings.base_width, self.settings.widen_factor) elif self.settings.netType == "ResNet": self.model = md.ResNet(self.settings.depth, self.settings.nClasses) else: assert False, "use %s data while network is %s" % ( self.settings.dataset, self.settings.netType) else: assert False, "unsupported data set: " + self.settings.dataset def _set_trainer(self): # set lr master lr_master = utils.LRPolicy(self.settings.lr, self.settings.nEpochs, self.settings.lrPolicy) params_dict = { 'power': self.settings.power, 'step': self.settings.step, 'end_lr': self.settings.endlr, 'decay_rate': self.settings.decayRate } lr_master.set_params(params_dict=params_dict) # set trainer train_loader, test_loader = self.data_loader.getloader() self.trainer = Trainer(model=self.model, train_loader=train_loader, test_loader=test_loader, lr_master=lr_master, settings=self.settings, logger=self.logger) if self.seg_opt_state is not None: self.trainer.resume(aux_fc_state=self.aux_fc_state, seg_opt_state=self.seg_opt_state, fc_opt_state=self.fc_opt_state) def _save_best_model(self, model, aux_fc): check_point_params = {} if isinstance(model, nn.DataParallel): check_point_params["model"] = model.module.state_dict() else: check_point_params["model"] = model.state_dict() aux_fc_state = [] for i in range(len(aux_fc)): if isinstance(aux_fc[i], nn.DataParallel): aux_fc_state.append(aux_fc[i].module.state_dict()) else: aux_fc_state.append(aux_fc[i].state_dict()) check_point_params["aux_fc"] = aux_fc_state torch.save( check_point_params, os.path.join(self.checkpoint.save_path, "best_model_with_AuxFC.pth")) def _save_checkpoint(self, model, seg_optimizers, fc_optimizers, aux_fc, index=0): check_point_params = {} if isinstance(model, nn.DataParallel): check_point_params["model"] = model.module.state_dict() else: check_point_params["model"] = model.state_dict() seg_opt_state = [] fc_opt_state = [] aux_fc_state = [] for i in range(len(seg_optimizers)): seg_opt_state.append(seg_optimizers[i].state_dict()) for i in range(len(fc_optimizers)): fc_opt_state.append(fc_optimizers[i].state_dict()) if isinstance(aux_fc[i], nn.DataParallel): aux_fc_state.append(aux_fc[i].module.state_dict()) else: aux_fc_state.append(aux_fc[i].state_dict()) check_point_params["seg_opt"] = seg_opt_state check_point_params["fc_opt"] = fc_opt_state check_point_params["aux_fc"] = aux_fc_state torch.save( check_point_params, os.path.join(self.checkpoint.save_path, "checkpoint_%03d.pth" % index)) def run(self, run_count=0): """ training and testing """ best_top1 = 100 best_top5 = 100 start_time = time.time() if self.settings.resume is not None or self.settings.retrain is not None: self.trainer.test(0) for epoch in range(self.start_epoch, self.settings.nEpochs): self.start_epoch = 0 train_error, train_loss, train5_error = self.trainer.train( epoch=epoch) test_error, test_loss, test5_error = self.trainer.test(epoch=epoch) # write and print result if isinstance(train_error, np.ndarray): log_str = "%d\t" % (epoch) for i in range(len(train_error)): log_str += "%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % ( train_error[i], train_loss[i], test_error[i], test_loss[i], train5_error[i], test5_error[i], ) best_flag = False if best_top1 >= test_error[-1]: best_top1 = test_error[-1] best_top5 = test5_error[-1] best_flag = True else: log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % ( epoch, train_error, train_loss, test_error, test_loss, train5_error, test5_error) best_flag = False if best_top1 >= test_error: best_top1 = test_error best_top5 = test5_error best_flag = True self.visualize.write_log(log_str) if best_flag: self.checkpoint.save_model(self.trainer.model, best_flag=best_flag) self._save_best_model(self.trainer.model, self.trainer.auxfc) print colored( "# %d ==>Best Result is: Top1 Error: %f, Top5 Error: %f\n" % (run_count, best_top1, best_top5), "red") else: print colored( "# %d ==>Best Result is: Top1 Error: %f, Top5 Error: %f\n" % (run_count, best_top1, best_top5), "blue") self._save_checkpoint(self.trainer.model, self.trainer.seg_optimizer, self.trainer.fc_optimizer, self.trainer.auxfc) end_time = time.time() # compute cost time time_interval = end_time - start_time t_string = "Running Time is: " + \ str(datetime.timedelta(seconds=time_interval)) + "\n" print t_string # write cost time to file self.visualize.write_readme(t_string) # save experimental results self.visualize.write_readme( "Best Result of all is: Top1 Error: %f, Top5 Error: %f\n" % (best_top1, best_top5))
train_loader = torch.utils.data.DataLoader(FaceDataset(train_config.annoPath, transform=transform, is_train=True), batch_size=train_config.batchSize, shuffle=True, **kwargs) # Set model model = ONet(config.NUM_LANDMARKS) model = model.to(device) # Set checkpoint checkpoint = CheckPoint(train_config.save_path) # Set optimizer optimizer = torch.optim.Adam(model.parameters(), lr=train_config.lr) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=train_config.step, gamma=0.1) # Set trainer logger = Logger(train_config.save_path) trainer = ONetTrainer(train_config.lr, train_loader, model, optimizer, scheduler, logger, device) for epoch in range(1, train_config.nEpochs + 1): trainer.train(epoch) checkpoint.save_model(model, index=epoch, tag=str(config.NUM_LANDMARKS) + '_landmarks')
class ExperimentDesign: def __init__(self, options=None): self.settings = options or Option() self.checkpoint = None self.train_loader = None self.test_loader = None self.model = None self.optimizer_state = None self.trainer = None self.start_epoch = 0 self.test_input = None self.model_analyse = None os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.visible_devices self.settings.set_save_path() self.logger = self.set_logger() self.settings.paramscheck(self.logger) self.visualize = vs.Visualization(self.settings.save_path, self.logger) self.tensorboard_logger = vs.Logger(self.settings.save_path) self.prepare() def set_logger(self): logger = logging.getLogger('sphereface') file_formatter = logging.Formatter( '%(asctime)s %(levelname)s: %(message)s') console_formatter = logging.Formatter('%(message)s') # file log file_handler = logging.FileHandler( os.path.join(self.settings.save_path, "train_test.log")) file_handler.setFormatter(file_formatter) # console log console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(console_formatter) logger.addHandler(file_handler) logger.addHandler(console_handler) logger.setLevel(logging.INFO) return logger def prepare(self): self._set_gpu() self._set_dataloader() self._set_model() self._set_checkpoint() self._set_trainer() self._draw_net() def _set_gpu(self): # set torch seed # init random seed torch.manual_seed(self.settings.manualSeed) torch.cuda.manual_seed(self.settings.manualSeed) assert self.settings.GPU <= torch.cuda.device_count( ) - 1, "Invalid GPU ID" torch.cuda.set_device(self.settings.GPU) cudnn.benchmark = True def _set_dataloader(self): # create data loader data_loader = DataLoader(dataset=self.settings.dataset, batch_size=self.settings.batchSize, data_path=self.settings.dataPath, n_threads=self.settings.nThreads, ten_crop=self.settings.tenCrop, logger=self.logger) self.train_loader, self.test_loader = data_loader.getloader() def _set_checkpoint(self): assert self.model is not None, "please create model first" self.checkpoint = CheckPoint(self.settings.save_path, self.logger) if self.settings.retrain is not None: model_state = self.checkpoint.load_model(self.settings.retrain) self.model = self.checkpoint.load_state(self.model, model_state) if self.settings.resume is not None: model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint( self.settings.resume) self.model = self.checkpoint.load_state(self.model, model_state) self.start_epoch = epoch self.optimizer_state = optimizer_state def _set_model(self): if self.settings.dataset in ["sphere", "sphere_large"]: self.test_input = Variable(torch.randn(1, 3, 112, 96).cuda()) if self.settings.netType == "SphereNet": self.model = md.SphereNet( depth=self.settings.depth, num_features=self.settings.featureDim) elif self.settings.netType == "SphereNIN": self.model = md.SphereNIN( num_features=self.settings.featureDim) elif self.settings.netType == "SphereMobileNet_v2": self.model = md.SphereMobleNet_v2( num_features=self.settings.featureDim) else: assert False, "use %s data while network is %s" % ( self.settings.dataset, self.settings.netType) else: assert False, "unsupport data set: " + self.settings.dataset def _set_trainer(self): lr_master = utils.LRPolicy(self.settings.lr, self.settings.nEpochs, self.settings.lrPolicy) params_dict = { 'power': self.settings.power, 'step': self.settings.step, 'end_lr': self.settings.endlr, 'decay_rate': self.settings.decayRate } lr_master.set_params(params_dict=params_dict) self.trainer = Trainer(model=self.model, train_loader=self.train_loader, test_loader=self.test_loader, lr_master=lr_master, settings=self.settings, logger=self.logger, tensorboard_logger=self.tensorboard_logger, optimizer_state=self.optimizer_state) def _draw_net(self): # visualize model if self.settings.drawNetwork: rand_output, _ = self.trainer.forward(self.test_input) self.visualize.save_network(rand_output) self.logger.info("|===>Draw network done!") self.visualize.write_settings(self.settings) def _model_analyse(self, model): # analyse model model_analyse = utils.ModelAnalyse(model, self.visualize) params_num = model_analyse.params_count() zero_num = model_analyse.zero_count() zero_rate = zero_num * 1.0 / params_num self.logger.info("zero rate is: {}".format(zero_rate)) # save analyse result to file self.visualize.write_readme( "Number of parameters is: %d, number of zeros is: %d, zero rate is: %f" % (params_num, zero_num, zero_rate)) # model_analyse.flops_compute(self.test_input) model_analyse.madds_compute(self.test_input) def run(self, run_count=0): self.logger.info("|===>Start training") best_top1 = 100 start_time = time.time() # self._model_analyse(self.model) # assert False self.trainer.test() # assert False for epoch in range(self.start_epoch, self.settings.nEpochs): if self.trainer.iteration >= self.settings.nIters: break self.start_epoch = 0 # training and testing train_error, train_loss, train5_error = self.trainer.train( epoch=epoch) acc_mean, acc_std, acc_all = self.trainer.test() test_error_mean = 100 - acc_mean * 100 # write and print result log_str = "{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t".format( epoch, train_error, train_loss, train5_error, acc_mean, acc_std) for acc in acc_all: log_str += "%.4f\t" % acc self.visualize.write_log(log_str) best_flag = False if best_top1 >= test_error_mean: best_top1 = test_error_mean best_flag = True self.logger.info( "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format( run_count, best_top1)) else: self.logger.info( "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format( run_count, best_top1)) self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer, epoch) if best_flag: self.checkpoint.save_model(self.model, best_flag=best_flag) if (epoch + 1) % self.settings.drawInterval == 0: self.visualize.draw_curves() end_time = time.time() if isinstance(self.model, nn.DataParallel): self.model = self.model.module # draw experimental curves self.visualize.draw_curves() # compute cost time time_interval = end_time - start_time t_string = "Running Time is: " + \ str(datetime.timedelta(seconds=time_interval)) + "\n" self.logger.info(t_string) # write cost time to file self.visualize.write_readme(t_string) # analyse model self._model_analyse(self.model) return best_top1
class Experiment(object): """ run experiments with pre-defined pipeline """ def __init__(self, options=None, conf_path=None): self.settings = options or Option(conf_path) self.checkpoint = None self.train_loader = None self.val_loader = None self.ori_model = None self.pruned_model = None self.segment_wise_trainer = None self.aux_fc_state = None self.aux_fc_opt_state = None self.seg_opt_state = None self.current_pivot_index = None self.is_segment_wise_finetune = False self.is_channel_selection = False self.epoch = 0 self.feature_cache_origin = {} self.feature_cache_pruned = {} os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.gpu self.settings.set_save_path() self.write_settings() self.logger = self.set_logger() self.tensorboard_logger = TensorboardLogger(self.settings.save_path) self.prepare() def write_settings(self): """ save experimental settings to a file """ with open(os.path.join(self.settings.save_path, "settings.log"), "w") as f: for k, v in self.settings.__dict__.items(): f.write(str(k) + ": " + str(v) + "\n") def set_logger(self): """ initialize logger """ logger = logging.getLogger('channel_selection') file_formatter = logging.Formatter( '%(asctime)s %(levelname)s: %(message)s') console_formatter = logging.Formatter('%(message)s') # file log file_handler = logging.FileHandler( os.path.join(self.settings.save_path, "train_test.log")) file_handler.setFormatter(file_formatter) # console log console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(console_formatter) logger.addHandler(file_handler) logger.addHandler(console_handler) logger.setLevel(logging.INFO) return logger def prepare(self): """ preparing experiments """ self._set_gpu() self._set_dataloader() self._set_model() self._cal_pivot() self._set_checkpoint() self._set_trainier() def _set_gpu(self): """ initialize the seed of random number generator """ # set torch seed # init random seed torch.manual_seed(self.settings.seed) torch.cuda.manual_seed(self.settings.seed) torch.cuda.set_device(0) cudnn.benchmark = True def _set_dataloader(self): """ create train loader and validation loader for channel pruning """ if self.settings.dataset == 'cifar10': data_root = os.path.join(self.settings.data_path, "cifar") norm_mean = [0.49139968, 0.48215827, 0.44653124] norm_std = [0.24703233, 0.24348505, 0.26158768] train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std) ]) val_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std) ]) train_dataset = datasets.CIFAR10(root=data_root, train=True, transform=train_transform, download=True) val_dataset = datasets.CIFAR10(root=data_root, train=False, transform=val_transform) self.train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=self.settings.batch_size, shuffle=True, pin_memory=True, num_workers=self.settings.n_threads) self.val_loader = torch.utils.data.DataLoader( dataset=val_dataset, batch_size=self.settings.batch_size, shuffle=False, pin_memory=True, num_workers=self.settings.n_threads) elif self.settings.dataset == 'imagenet': dataset_path = os.path.join(self.settings.data_path, "imagenet") traindir = os.path.join(dataset_path, "train") valdir = os.path.join(dataset_path, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.train_loader = torch.utils.data.DataLoader( datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])), batch_size=self.settings.batch_size, shuffle=True, num_workers=self.settings.n_threads, pin_memory=True) self.val_loader = torch.utils.data.DataLoader( datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=self.settings.batch_size, shuffle=False, num_workers=self.settings.n_threads, pin_memory=True) def _set_trainier(self): """ initialize segment-wise trainer trainer """ # initialize segment-wise trainer self.segment_wise_trainer = SegmentWiseTrainer( ori_model=self.ori_model, pruned_model=self.pruned_model, train_loader=self.train_loader, val_loader=self.val_loader, settings=self.settings, logger=self.logger, tensorboard_logger=self.tensorboard_logger) if self.aux_fc_state is not None: self.segment_wise_trainer.update_aux_fc(self.aux_fc_state, self.aux_fc_opt_state, self.seg_opt_state) def _set_model(self): """ get model """ if self.settings.dataset in ["cifar10", "cifar100"]: if self.settings.net_type == "preresnet": self.ori_model = md.PreResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) self.pruned_model = md.PreResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) else: assert False, "use {} data while network is {}".format( self.settings.dataset, self.settings.net_type) elif self.settings.dataset in ["imagenet"]: if self.settings.net_type == "resnet": self.ori_model = md.ResNet(depth=self.settings.depth, num_classes=self.settings.n_classes) self.pruned_model = md.ResNet( depth=self.settings.depth, num_classes=self.settings.n_classes) else: assert False, "use {} data while network is {}".format( self.settings.dataset, self.settings.net_type) else: assert False, "unsupported data set: {}".format( self.settings.dataset) def _set_checkpoint(self): """ load pre-trained model or resume checkpoint """ assert self.ori_model is not None and self.pruned_model is not None, "please create model first" self.checkpoint = CheckPoint(self.settings.save_path, self.logger) self._load_retrain() self._load_resume() def _load_retrain(self): """ load pre-trained model """ if self.settings.retrain is not None: check_point_params = torch.load(self.settings.retrain) if "ori_model" not in check_point_params: model_state = check_point_params self.ori_model = self.checkpoint.load_state( self.ori_model, model_state) self.pruned_model = self.checkpoint.load_state( self.pruned_model, model_state) self.logger.info("|===>load restrain file: {}".format( self.settings.retrain)) else: ori_model_state = check_point_params["ori_model"] pruned_model_state = check_point_params["pruned_model"] # self.current_block_count = check_point_params["current_pivot"] self.aux_fc_state = check_point_params["aux_fc"] # self.replace_layer() self.ori_model = self.checkpoint.load_state( self.ori_model, ori_model_state) self.pruned_model = self.checkpoint.load_state( self.pruned_model, pruned_model_state) self.logger.info("|===>load pre-trained model: {}".format( self.settings.retrain)) def _load_resume(self): """ load resume checkpoint """ if self.settings.resume is not None: check_point_params = torch.load(self.settings.resume) ori_model_state = check_point_params["ori_model"] pruned_model_state = check_point_params["pruned_model"] self.aux_fc_state = check_point_params["aux_fc"] self.aux_fc_opt_state = check_point_params["aux_fc_opt"] self.seg_opt_state = check_point_params["seg_opt"] self.current_pivot_index = check_point_params["current_pivot"] self.is_segment_wise_finetune = check_point_params[ "segment_wise_finetune"] self.is_channel_selection = check_point_params["channel_selection"] self.epoch = check_point_params["epoch"] self.epoch = self.settings.segment_wise_n_epochs self.current_block_count = check_point_params[ "current_block_count"] if self.is_channel_selection or \ (self.is_segment_wise_finetune and self.current_pivot_index > self.settings.pivot_set[0]): self.replace_layer() self.ori_model = self.checkpoint.load_state( self.ori_model, ori_model_state) self.pruned_model = self.checkpoint.load_state( self.pruned_model, pruned_model_state) self.logger.info("|===>load resume file: {}".format( self.settings.resume)) def _cal_pivot(self): """ calculate the inserted layer for additional loss """ self.num_segments = self.settings.n_losses + 1 num_block_per_segment = ( block_num[self.settings.net_type + str(self.settings.depth)] // self.num_segments) + 1 pivot_set = [] for i in range(self.num_segments - 1): pivot_set.append(num_block_per_segment * (i + 1)) self.settings.pivot_set = pivot_set self.logger.info("pivot set: {}".format(pivot_set)) def segment_wise_fine_tune(self, index): """ conduct segment-wise fine-tuning :param index: segment index """ best_top1 = 100 best_top5 = 100 start_epoch = 0 if self.is_segment_wise_finetune and self.epoch != 0: start_epoch = self.epoch + 1 self.epoch = 0 for epoch in range(start_epoch, self.settings.segment_wise_n_epochs): train_error, train_loss, train5_error = self.segment_wise_trainer.train( epoch, index) val_error, val_loss, val5_error = self.segment_wise_trainer.val( epoch) # write and print result if isinstance(train_error, list): best_flag = False if best_top1 >= val_error[-1]: best_top1 = val_error[-1] best_top5 = val5_error[-1] best_flag = True else: best_flag = False if best_top1 >= val_error: best_top1 = val_error best_top5 = val5_error best_flag = True if best_flag: self.checkpoint.save_model( ori_model=self.ori_model, pruned_model=self.pruned_model, aux_fc=self.segment_wise_trainer.aux_fc, current_pivot=self.current_pivot_index, segment_wise_finetune=True, index=index) self.logger.info( "|===>Best Result is: Top1 Error: {:f}, Top5 Error: {:f}\n". format(best_top1, best_top5)) self.logger.info( "|===>Best Result is: Top1 Accuracy: {:f}, Top5 Accuracy: {:f}\n" .format(100 - best_top1, 100 - best_top5)) if self.settings.dataset in ["imagenet"]: self.checkpoint.save_checkpoint( ori_model=self.ori_model, pruned_model=self.pruned_model, aux_fc=self.segment_wise_trainer.aux_fc, aux_fc_opt=self.segment_wise_trainer.fc_optimizer, seg_opt=self.segment_wise_trainer.seg_optimizer, current_pivot=self.current_pivot_index, segment_wise_finetune=True, index=index, epoch=epoch) else: self.checkpoint.save_checkpoint( ori_model=self.ori_model, pruned_model=self.pruned_model, aux_fc=self.segment_wise_trainer.aux_fc, aux_fc_opt=self.segment_wise_trainer.fc_optimizer, seg_opt=self.segment_wise_trainer.seg_optimizer, current_pivot=self.current_pivot_index, segment_wise_finetune=True, index=index) def replace_layer(self): """ Replace the convolutional layer to mask convolutional layer """ block_count = 0 if self.settings.net_type in ["preresnet", "resnet"]: for module in self.pruned_model.modules(): if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)): block_count += 1 layer = module.conv2 if block_count <= self.current_block_count and not isinstance( layer, MaskConv2d): temp_conv = MaskConv2d(in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) module.conv2 = temp_conv if isinstance(module, Bottleneck): layer = module.conv3 if block_count <= self.current_block_count and not isinstance( layer, MaskConv2d): temp_conv = MaskConv2d( in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) module.conv3 = temp_conv def channel_selection(self): """ conduct channel selection """ # get testing error self.segment_wise_trainer.val(0) time_start = time.time() restart_index = None # find restart segment index if self.current_pivot_index: if self.current_pivot_index in self.settings.pivot_set: restart_index = self.settings.pivot_set.index( self.current_pivot_index) else: restart_index = len(self.settings.pivot_set) for index in range(self.num_segments): if restart_index is not None: if index < restart_index: continue elif index == restart_index: if self.is_channel_selection and self.current_block_count == self.current_pivot_index: self.is_channel_selection = False continue if index == self.num_segments - 1: self.current_pivot_index = self.segment_wise_trainer.final_block_count else: self.current_pivot_index = self.settings.pivot_set[index] # fine tune the network with additional loss and final loss if (not self.is_segment_wise_finetune and not self.is_channel_selection) or \ (self.is_segment_wise_finetune and self.epoch != self.settings.segment_wise_n_epochs - 1): self.segment_wise_fine_tune(index) else: self.is_segment_wise_finetune = False # load best model best_model_path = os.path.join( self.checkpoint.save_path, 'model_{:0>3d}_swft.pth'.format(index)) check_point_params = torch.load(best_model_path) ori_model_state = check_point_params["ori_model"] pruned_model_state = check_point_params["pruned_model"] aux_fc_state = check_point_params["aux_fc"] self.ori_model = self.checkpoint.load_state( self.ori_model, ori_model_state) self.pruned_model = self.checkpoint.load_state( self.pruned_model, pruned_model_state) self.segment_wise_trainer.update_model(self.ori_model, self.pruned_model, aux_fc_state) # replace the baseline model if index == 0: if self.settings.net_type in ['preresnet']: self.ori_model.conv = copy.deepcopy(self.pruned_model.conv) for ori_module, pruned_module in zip( self.ori_model.modules(), self.pruned_model.modules()): if isinstance(ori_module, PreBasicBlock): ori_module.bn1 = copy.deepcopy(pruned_module.bn1) ori_module.bn2 = copy.deepcopy(pruned_module.bn2) ori_module.conv1 = copy.deepcopy( pruned_module.conv1) ori_module.conv2 = copy.deepcopy( pruned_module.conv2) if ori_module.downsample is not None: ori_module.downsample = copy.deepcopy( pruned_module.downsample) self.ori_model.bn = copy.deepcopy(self.pruned_model.bn) self.ori_model.fc = copy.deepcopy(self.pruned_model.fc) elif self.settings.net_type in ['resnet']: self.ori_model.conv1 = copy.deepcopy( self.pruned_model.conv) self.ori_model.bn1 = copy.deepcopy(self.pruned_model.bn1) for ori_module, pruned_module in zip( self.ori_model.modules(), self.pruned_model.modules()): if isinstance(ori_module, BasicBlock): ori_module.conv1 = copy.deepcopy( pruned_module.conv1) ori_module.conv2 = copy.deepcopy( pruned_module.conv2) ori_module.bn1 = copy.deepcopy(pruned_module.bn1) ori_module.bn2 = copy.deepcopy(pruned_module.bn2) if ori_module.downsample is not None: ori_module.downsample = copy.deepcopy( pruned_module.downsample) if isinstance(ori_module, Bottleneck): ori_module.conv1 = copy.deepcopy( pruned_module.conv1) ori_module.conv2 = copy.deepcopy( pruned_module.conv2) ori_module.conv3 = copy.deepcopy( pruned_module.conv3) ori_module.bn1 = copy.deepcopy(pruned_module.bn1) ori_module.bn2 = copy.deepcopy(pruned_module.bn2) ori_module.bn3 = copy.deepcopy(pruned_module.bn3) if ori_module.downsample is not None: ori_module.downsample = copy.deepcopy( pruned_module.downsample) self.ori_model.fc = copy.deepcopy(self.pruned_model.fc) aux_fc_state = [] for i in range(len(self.segment_wise_trainer.aux_fc)): if isinstance(self.segment_wise_trainer.aux_fc[i], nn.DataParallel): temp_state = self.segment_wise_trainer.aux_fc[ i].module.state_dict() else: temp_state = self.segment_wise_trainer.aux_fc[ i].state_dict() aux_fc_state.append(temp_state) self.segment_wise_trainer.update_model(self.ori_model, self.pruned_model, aux_fc_state) self.segment_wise_trainer.val(0) # conduct channel selection # contains [0:index] segments net_origin_list = [] net_pruned_list = [] for j in range(index + 1): net_origin_list += utils.model2list( self.segment_wise_trainer.ori_segments[j]) net_pruned_list += utils.model2list( self.segment_wise_trainer.pruned_segments[j]) net_origin = nn.Sequential(*net_origin_list) net_pruned = nn.Sequential(*net_pruned_list) self._seg_channel_selection( net_origin=net_origin, net_pruned=net_pruned, aux_fc=self.segment_wise_trainer.aux_fc[index], pivot_index=self.current_pivot_index, index=index) # update optimizer aux_fc_state = [] for i in range(len(self.segment_wise_trainer.aux_fc)): if isinstance(self.segment_wise_trainer.aux_fc[i], nn.DataParallel): temp_state = self.segment_wise_trainer.aux_fc[ i].module.state_dict() else: temp_state = self.segment_wise_trainer.aux_fc[ i].state_dict() aux_fc_state.append(temp_state) self.segment_wise_trainer.update_model(self.ori_model, self.pruned_model, aux_fc_state) self.checkpoint.save_checkpoint( self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.fc_optimizer, self.segment_wise_trainer.seg_optimizer, self.current_pivot_index, channel_selection=True, index=index, block_count=self.current_pivot_index) self.logger.info(self.ori_model) self.logger.info(self.pruned_model) self.segment_wise_trainer.val(0) self.current_pivot_index = None self.checkpoint.save_model(self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.final_block_count, index=self.num_segments) time_interval = time.time() - time_start log_str = "cost time: {}".format( str(datetime.timedelta(seconds=time_interval))) self.logger.info(log_str) def _hook_origin_feature(self, module, input, output): gpu_id = str(output.get_device()) self.feature_cache_origin[gpu_id] = output def _hook_pruned_feature(self, module, input, output): gpu_id = str(output.get_device()) self.feature_cache_pruned[gpu_id] = output @staticmethod def _concat_gpu_data(data): data_cat = data["0"] for i in range(1, len(data)): data_cat = torch.cat((data_cat, data[str(i)].cuda(0))) return data_cat def _layer_channel_selection(self, net_origin, net_pruned, aux_fc, module, block_count, layer_name="conv2"): """ conduct channel selection for module :param net_origin: original network segments :param net_pruned: pruned network segments :param aux_fc: auxiliary fully-connected layer :param module: the module need to be pruned :param block_count: current block no. :param layer_name: the name of layer need to be pruned """ self.logger.info( "|===>layer-wise channel selection: block-{}-{}".format( block_count, layer_name)) # layer-wise channel selection if layer_name == "conv2": layer = module.conv2 elif layer_name == "conv3": layer = module.conv3 else: assert False, "unsupport layer: {}".format(layer_name) if not isinstance(layer, MaskConv2d): temp_conv = MaskConv2d(in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=(layer.bias is not None)) temp_conv.weight.data.copy_(layer.weight.data) if layer.bias is not None: temp_conv.bias.data.copy_(layer.bias.data) temp_conv.pruned_weight.data.fill_(0) temp_conv.d.fill_(0) if layer_name == "conv2": module.conv2 = temp_conv elif layer_name == "conv3": module.conv3 = temp_conv layer = temp_conv # define criterion criterion_mse = nn.MSELoss().cuda() criterion_softmax = nn.CrossEntropyLoss().cuda() # register hook if layer_name == "conv2": hook_origin = net_origin[block_count].conv2.register_forward_hook( self._hook_origin_feature) hook_pruned = module.conv2.register_forward_hook( self._hook_pruned_feature) elif layer_name == "conv3": hook_origin = net_origin[block_count].conv3.register_forward_hook( self._hook_origin_feature) hook_pruned = module.conv3.register_forward_hook( self._hook_pruned_feature) net_origin_parallel = utils.data_parallel(net_origin, self.settings.n_gpus) net_pruned_parallel = utils.data_parallel(net_pruned, self.settings.n_gpus) # avoid computing the gradient for params in net_origin_parallel.parameters(): params.requires_grad = False for params in net_pruned_parallel.parameters(): params.requires_grad = False net_origin_parallel.eval() net_pruned_parallel.eval() layer.pruned_weight.requires_grad = True aux_fc.cuda() logger_counter = 0 record_time = utils.AverageMeter() for channel in range(layer.in_channels): if layer.d.eq(0).sum() <= math.floor( layer.in_channels * self.settings.pruning_rate): break time_start = time.time() cum_grad = None record_selection_mse_loss = utils.AverageMeter() record_selection_softmax_loss = utils.AverageMeter() record_selection_loss = utils.AverageMeter() img_count = 0 for i, (images, labels) in enumerate(self.train_loader): images = images.cuda() labels = labels.cuda() net_origin_parallel(images) output = net_pruned_parallel(images) softmax_loss = criterion_softmax(aux_fc(output), labels) origin_feature = self._concat_gpu_data( self.feature_cache_origin) self.feature_cache_origin = {} pruned_feature = self._concat_gpu_data( self.feature_cache_pruned) self.feature_cache_pruned = {} mse_loss = criterion_mse(pruned_feature, origin_feature) loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight loss.backward() record_selection_loss.update(loss.item(), images.size(0)) record_selection_mse_loss.update(mse_loss.item(), images.size(0)) record_selection_softmax_loss.update(softmax_loss.item(), images.size(0)) if cum_grad is None: cum_grad = layer.pruned_weight.grad.data.clone() else: cum_grad.add_(layer.pruned_weight.grad.data) layer.pruned_weight.grad = None img_count += images.size(0) if self.settings.max_samples != -1 and img_count >= self.settings.max_samples: break # write tensorboard log self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_LossAll".format(block_count, layer_name), value=record_selection_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_MSELoss".format(block_count, layer_name), value=record_selection_mse_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="S-block-{}_{}_SoftmaxLoss".format(block_count, layer_name), value=record_selection_softmax_loss.avg, step=logger_counter) cum_grad.abs_() # calculate gradient F norm grad_fnorm = cum_grad.mul(cum_grad).sum((2, 3)).sqrt().sum(0) # find grad_fnorm with maximum absolute gradient while True: _, max_index = torch.topk(grad_fnorm, 1) if layer.d[max_index[0]] == 0: layer.d[max_index[0]] = 1 layer.pruned_weight.data[:, max_index[ 0], :, :] = layer.weight[:, max_index[0], :, :].data.clone( ) break else: grad_fnorm[max_index[0]] = -1 # fine-tune average meter record_finetune_softmax_loss = utils.AverageMeter() record_finetune_mse_loss = utils.AverageMeter() record_finetune_loss = utils.AverageMeter() record_finetune_top1_error = utils.AverageMeter() record_finetune_top5_error = utils.AverageMeter() # define optimizer params_list = [] params_list.append({ "params": layer.pruned_weight, "lr": self.settings.layer_wise_lr }) if layer.bias is not None: layer.bias.requires_grad = True params_list.append({"params": layer.bias, "lr": 0.001}) optimizer = torch.optim.SGD( params=params_list, weight_decay=self.settings.weight_decay, momentum=self.settings.momentum, nesterov=True) img_count = 0 for epoch in range(1): for i, (images, labels) in enumerate(self.train_loader): images = images.cuda() labels = labels.cuda() features = net_pruned_parallel(images) net_origin_parallel(images) output = aux_fc(features) softmax_loss = criterion_softmax(output, labels) origin_feature = self._concat_gpu_data( self.feature_cache_origin) self.feature_cache_origin = {} pruned_feature = self._concat_gpu_data( self.feature_cache_pruned) self.feature_cache_pruned = {} mse_loss = criterion_mse(pruned_feature, origin_feature) top1_error, _, top5_error = utils.compute_singlecrop( outputs=output, labels=labels, loss=softmax_loss, top5_flag=True, mean_flag=True) # update parameters optimizer.zero_grad() loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight loss.backward() torch.nn.utils.clip_grad_norm_(layer.parameters(), max_norm=10.0) layer.pruned_weight.grad.data.mul_( layer.d.unsqueeze(0).unsqueeze(2).unsqueeze( 3).expand_as(layer.pruned_weight)) optimizer.step() # update record info record_finetune_softmax_loss.update( softmax_loss.item(), images.size(0)) record_finetune_mse_loss.update(mse_loss.item(), images.size(0)) record_finetune_loss.update(loss.item(), images.size(0)) record_finetune_top1_error.update(top1_error, images.size(0)) record_finetune_top5_error.update(top5_error, images.size(0)) img_count += images.size(0) if self.settings.max_samples != -1 and img_count >= self.settings.max_samples: break layer.pruned_weight.grad = None if layer.bias is not None: layer.bias.requires_grad = False self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_SoftmaxLoss".format(block_count, layer_name), value=record_finetune_softmax_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Loss".format(block_count, layer_name), value=record_finetune_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_MSELoss".format(block_count, layer_name), value=record_finetune_mse_loss.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Top1Error".format(block_count, layer_name), value=record_finetune_top1_error.avg, step=logger_counter) self.tensorboard_logger.scalar_summary( tag="F-block-{}_{}_Top5Error".format(block_count, layer_name), value=record_finetune_top5_error.avg, step=logger_counter) # write log information to file self._write_log( dir_name=os.path.join(self.settings.save_path, "log"), file_name="log_block-{:0>2d}_{}.txt".format( block_count, layer_name), log_str= "{:d}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t\n". format(int(layer.d.sum()), record_selection_loss.avg, record_selection_mse_loss.avg, record_selection_softmax_loss.avg, record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg, record_finetune_top1_error.avg, record_finetune_top5_error.avg)) log_str = "Block-{:0>2d}-{}\t#channels: [{:0>4d}|{:0>4d}]\t".format( block_count, layer_name, int(layer.d.sum()), layer.d.size(0)) log_str += "[selection]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format( record_selection_loss.avg, record_selection_mse_loss.avg, record_selection_softmax_loss.avg) log_str += "[fine-tuning]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format( record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg) log_str += "top1error: {:4f}\ttop5error: {:4f}".format( record_finetune_top1_error.avg, record_finetune_top5_error.avg) self.logger.info(log_str) logger_counter += 1 time_interval = time.time() - time_start record_time.update(time_interval) for params in net_origin_parallel.parameters(): params.requires_grad = True for params in net_pruned_parallel.parameters(): params.requires_grad = True # remove hook hook_origin.remove() hook_pruned.remove() log_str = "|===>Select channel from block-{:d}_{}: time_total:{} time_avg: {}".format( block_count, layer_name, str(datetime.timedelta(seconds=record_time.sum)), str(datetime.timedelta(seconds=record_time.avg))) self.logger.info(log_str) log_str = "|===>fine-tuning result: loss: {:f}, mse_loss: {:f}, softmax_loss: {:f}, top1error: {:f} top5error: {:f}".format( record_finetune_loss.avg, record_finetune_mse_loss.avg, record_finetune_softmax_loss.avg, record_finetune_top1_error.avg, record_finetune_top5_error.avg) self.logger.info(log_str) self.logger.info("|===>remove hook") @staticmethod def _write_log(dir_name, file_name, log_str): """ Write log to file :param dir_name: the path of directory :param file_name: the name of the saved file :param log_str: the string that need to be saved """ if not os.path.isdir(dir_name): os.mkdir(dir_name) with open(os.path.join(dir_name, file_name), "a+") as f: f.write(log_str) def _seg_channel_selection(self, net_origin, net_pruned, aux_fc, pivot_index, index): """ conduct segment channel selection :param net_origin: original network segments :param net_pruned: pruned network segments :param aux_fc: auxiliary fully-connected layer :param pivot_index: the layer index of the additional loss :param index: the index of segment :return: """ block_count = 0 if self.settings.net_type in ["preresnet", "resnet"]: for module in net_pruned.modules(): if isinstance(module, (PreBasicBlock, BasicBlock)): block_count += 1 # We will not prune the pruned blocks again if not isinstance(module.conv2, MaskConv2d): self._layer_channel_selection(net_origin=net_origin, net_pruned=net_pruned, aux_fc=aux_fc, module=module, block_count=block_count, layer_name="conv2") self.logger.info("|===>checking layer type: {}".format( type(module.conv2))) self.checkpoint.save_model( self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, pivot_index, channel_selection=True, index=index, block_count=block_count) self.checkpoint.save_checkpoint( self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.fc_optimizer, self.segment_wise_trainer.seg_optimizer, pivot_index, channel_selection=True, index=index, block_count=block_count) elif isinstance(module, Bottleneck): block_count += 1 if not isinstance(module.conv2, MaskConv2d): self._layer_channel_selection(net_origin=net_origin, net_pruned=net_pruned, aux_fc=aux_fc, module=module, block_count=block_count, layer_name="conv2") if not isinstance(module.conv3, MaskConv2d): self._layer_channel_selection(net_origin=net_origin, net_pruned=net_pruned, aux_fc=aux_fc, module=module, block_count=block_count, layer_name="conv3") self.checkpoint.save_model( self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, pivot_index, channel_selection=True, index=index, block_count=block_count) self.checkpoint.save_checkpoint( self.ori_model, self.pruned_model, self.segment_wise_trainer.aux_fc, self.segment_wise_trainer.fc_optimizer, self.segment_wise_trainer.seg_optimizer, pivot_index, channel_selection=True, index=index, block_count=block_count)