def create_mtcnn_net(self, p_model_path=None, r_model_path=None, o_model_path=None, use_cuda=True): dirname, _ = os.path.split(p_model_path) checkpoint = CheckPoint(dirname) pnet, rnet, onet = None, None, None self.device = torch.device( "cuda:0" if use_cuda and torch.cuda.is_available() else "cpu") if p_model_path is not None: pnet = PNet() pnet_model_state = checkpoint.load_model(p_model_path) pnet = checkpoint.load_state(pnet, pnet_model_state) if (use_cuda): pnet.to(self.device) pnet.eval() if r_model_path is not None: rnet = RNet() rnet_model_state = checkpoint.load_model(r_model_path) rnet = checkpoint.load_state(rnet, rnet_model_state) if (use_cuda): rnet.to(self.device) rnet.eval() if o_model_path is not None: onet = ONet() onet_model_state = checkpoint.load_model(o_model_path) onet = checkpoint.load_state(onet, onet_model_state) if (use_cuda): onet.to(self.device) onet.eval() return pnet, rnet, onet
class ExperimentDesign: def __init__(self, options=None): self.settings = options or Option() self.checkpoint = None self.data_loader = None self.model = None self.optimizer_state = None self.trainer = None self.start_epoch = 0 self.model_analyse = None self.visualize = vs.Visualization(self.settings.save_path) self.logger = vs.Logger(self.settings.save_path) self.test_input = None self.lr_master = None self.prepare() def prepare(self): self._set_gpu() self._set_dataloader() self._set_model() self._set_checkpoint() self._set_parallel() self._set_lr_policy() self._set_trainer() self._draw_net() def _set_gpu(self): # set torch seed # init random seed torch.manual_seed(self.settings.manualSeed) torch.cuda.manual_seed(self.settings.manualSeed) assert self.settings.GPU <= torch.cuda.device_count( ) - 1, "Invalid GPU ID" torch.cuda.set_device(self.settings.GPU) print("|===>Set GPU done!") def _set_dataloader(self): # create data loader self.data_loader = DataLoader(dataset=self.settings.dataset, batch_size=self.settings.batchSize, data_path=self.settings.dataPath, n_threads=self.settings.nThreads, ten_crop=self.settings.tenCrop) print("|===>Set data loader done!") def _set_checkpoint(self): assert self.model is not None, "please create model first" self.checkpoint = CheckPoint(self.settings.save_path) if self.settings.retrain is not None: model_state = self.checkpoint.load_model(self.settings.retrain) self.model = self.checkpoint.load_state(self.model, model_state) if self.settings.resume is not None: model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint( self.settings.resume) self.model = self.checkpoint.load_state(self.model, model_state) # self.start_epoch = epoch # self.optimizer_state = optimizer_state print("|===>Set checkpoint done!") def _set_model(self): if self.settings.dataset == "sphere": if self.settings.netType == "SphereNet": self.model = md.SphereNet( depth=self.settings.depth, num_features=self.settings.featureDim) elif self.settings.netType == "SphereNIN": self.model = md.SphereNIN( num_features=self.settings.featureDim) elif self.settings.netType == "wcSphereNet": self.model = md.wcSphereNet( depth=self.settings.depth, num_features=self.settings.featureDim, rate=self.settings.rate) else: assert False, "use %s data while network is %s" % ( self.settings.dataset, self.settings.netType) else: assert False, "unsupport data set: " + self.settings.dataset print("|===>Set model done!") def _set_parallel(self): self.model = utils.data_parallel(self.model, self.settings.nGPU, self.settings.GPU) def _set_lr_policy(self): self.lr_master = utils.LRPolicy(self.settings.lr, self.settings.nIters, self.settings.lrPolicy) params_dict = { 'gamma': self.settings.gamma, 'step': self.settings.step, 'end_lr': self.settings.endlr, 'decay_rate': self.settings.decayRate } self.lr_master.set_params(params_dict=params_dict) def _set_trainer(self): train_loader, test_loader = self.data_loader.getloader() self.trainer = Trainer(model=self.model, lr_master=self.lr_master, n_epochs=self.settings.nEpochs, n_iters=self.settings.nIters, train_loader=train_loader, test_loader=test_loader, feature_dim=self.settings.featureDim, momentum=self.settings.momentum, weight_decay=self.settings.weightDecay, optimizer_state=self.optimizer_state, logger=self.logger) def _draw_net(self): # visualize model if self.settings.dataset == "sphere": rand_input = torch.randn(1, 3, 112, 96) else: assert False, "invalid data set" rand_input = Variable(rand_input.cuda()) self.test_input = rand_input if self.settings.drawNetwork: rand_output, _ = self.trainer.forward(rand_input) self.visualize.save_network(rand_output) print("|===>Draw network done!") self.visualize.write_settings(self.settings) def pruning(self, run_count=0): net_type = None if self.settings.dataset == "sphere": if self.settings.netType == "wcSphereNet": net_type = "SphereNet" assert net_type is not None, "net_type for prune is NoneType" self.trainer.test() if isinstance(self.model, nn.DataParallel): model = self.model.module else: model = self.model if net_type == "SphereNet": model_prune = prune.SpherePrune(model) model_prune.run() self.trainer.reset_model(model_prune.model) self.model = self.trainer.model self.trainer.test() self.checkpoint.save_model(self.trainer.model, index=run_count, tag="pruning") # analyse model self.model_analyse = utils.ModelAnalyse(self.trainer.model, self.visualize) params_num = self.model_analyse.params_count() self.model_analyse.flops_compute(self.test_input) def fine_tuning(self, run_count=0): # set lr self.settings.lr = 0.01 # 0.1 self.settings.nIters = 12000 # 28000 self.settings.lrPolicy = "multi_step" self.settings.decayRate = 0.1 self.settings.step = [6000] # [16000, 24000] self._set_lr_policy() self.trainer.reset_lr(self.lr_master, self.settings.nIters) # run fine-tuning self.training(run_count, tag="fine-tuning") def retrain(self, run_count=0): self.settings.lr = 0.1 self.settings.nIters = 28000 self.settings.lrPolicy = "multi_step" self.settings.decayRate = 0.1 self.settings.step = [16000, 24000] self._set_lr_policy() self.trainer.reset_lr(self.lr_master, self.settings.nIters) # run retrain self.training(run_count, tag="training") def run(self, run_count=0): """ if run_count == 0: print "|===> training" self.retrain(run_count) else: print "|===> fine-tuning" self.fine_tuning(run_count) """ if run_count >= 1: print("|===> training") self.retrain(run_count) self.trainer.reset_model(self.model) print("|===> fine-tuning") self.fine_tuning(run_count) print("|===> pruning") self.pruning(run_count) # keep margin_linear static layer_count = 0 for layer in self.model.modules(): if isinstance(layer, md.MarginLinear): layer.iteration.fill_(0) layer.margin_type.data.fill_(1) layer.weight.requires_grad = False elif isinstance(layer, nn.Linear): layer.weight.requires_grad = False layer.bias.requires_grad = False elif isinstance(layer, nn.Conv2d): if layer.bias is not None: bias_flag = True else: bias_flag = False new_layer = prune.wcConv2d(layer.weight.size(1), layer.weight.size(0), kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=bias_flag, rate=self.settings.rate) new_layer.weight.data.copy_(layer.weight.data) if layer.bias is not None: new_layer.bias.data.copy_(layer.bias.data) if layer_count == 1: self.model.conv2 = new_layer elif layer_count == 2: self.model.conv3 = new_layer elif layer_count == 3: self.model.conv4 = new_layer layer_count += 1 print(self.model) self.trainer.reset_model(self.model) # assert False def training(self, run_count=0, tag="training"): best_top1 = 100 # start_time = time.time() self.trainer.test() # assert False for epoch in range(self.start_epoch, self.settings.nEpochs): if self.trainer.iteration >= self.trainer.n_iters: break start_epoch = 0 # training and testing train_error, train_loss, train5_error = self.trainer.train( epoch=epoch) acc_mean, acc_std, acc_all = self.trainer.test() test_error_mean = 100 - acc_mean * 100 # write and print result log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % ( epoch, train_error, train_loss, train5_error, acc_mean, acc_std) for acc in acc_all: log_str += "%.4f\t" % acc self.visualize.write_log(log_str) best_flag = False if best_top1 >= test_error_mean: best_top1 = test_error_mean best_flag = True print( colored( "# %d ==>Best Result is: Top1 Error: %f\n" % (run_count, best_top1), "red")) else: print( colored( "# %d ==>Best Result is: Top1 Error: %f\n" % (run_count, best_top1), "blue")) self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer, epoch) if best_flag: self.checkpoint.save_model(self.model, best_flag=best_flag, tag="%s_%d" % (tag, run_count)) if (epoch + 1) % self.settings.drawInterval == 0: self.visualize.draw_curves() for name, value in self.model.named_parameters(): if 'weight' in name: name = name.replace('.', '/') self.logger.histo_summary( name, value.data.cpu().numpy(), run_count * self.settings.nEpochs + epoch + 1) if value.grad is not None: self.logger.histo_summary( name + "/grad", value.grad.data.cpu().numpy(), run_count * self.settings.nEpochs + epoch + 1) # end_time = time.time() if isinstance(self.model, nn.DataParallel): self.model = self.model.module # draw experimental curves self.visualize.draw_curves() # compute cost time # time_interval = end_time - start_time # t_string = "Running Time is: " + \ # str(datetime.timedelta(seconds=time_interval)) + "\n" # print(t_string) # write cost time to file # self.visualize.write_readme(t_string) # save experimental results self.model_analyse = utils.ModelAnalyse(self.trainer.model, self.visualize) self.visualize.write_readme("Best Result of all is: Top1 Error: %f\n" % best_top1) # analyse model params_num = self.model_analyse.params_count() # save analyse result to file self.visualize.write_readme("Number of parameters is: %d" % (params_num)) self.model_analyse.prune_rate() self.model_analyse.flops_compute(self.test_input) return best_top1
class ExperimentDesign(object): """ run experiments with pre-defined pipeline """ def __init__(self): self.settings = Option() self.checkpoint = None self.data_loader = None self.model = None self.trainer = None self.seg_opt_state = None self.fc_opt_state = None self.aux_fc_state = None self.start_epoch = 0 self.model_analyse = None self.visualize = vs.Visualization(self.settings.save_path) self.logger = vs.Logger(self.settings.save_path) self.prepare() def prepare(self): """ preparing experiments """ self._set_gpu() self._set_dataloader() self._set_model() self._set_checkpoint() self._set_trainer() def _set_gpu(self): # set torch seed # init random seed torch.manual_seed(self.settings.manualSeed) torch.cuda.manual_seed(self.settings.manualSeed) assert self.settings.GPU <= torch.cuda.device_count( ) - 1, "Invalid GPU ID" torch.cuda.set_device(self.settings.GPU) cudnn.benchmark = True def _set_dataloader(self): # create data loader self.data_loader = DataLoader(dataset=self.settings.dataset, batch_size=self.settings.batchSize, data_path=self.settings.dataPath, n_threads=self.settings.nThreads, ten_crop=self.settings.tenCrop) def _set_checkpoint(self): assert self.model is not None, "please create model first" self.checkpoint = CheckPoint(self.settings.save_path) if self.settings.retrain is not None: model_state = self.checkpoint.load_model(self.settings.retrain) self.model = self.checkpoint.load_state(self.model, model_state) if self.settings.resume is not None: check_point_params = torch.load(self.settings.resume) model_state = check_point_params["model"] self.seg_opt_state = check_point_params["seg_opt"] self.fc_opt_state = check_point_params["fc_opt"] self.aux_fc_state = check_point_params["aux_fc"] self.model = self.checkpoint.load_state(self.model, model_state) self.start_epoch = 90 def _set_model(self): print("netType:", self.settings.netType) if self.settings.dataset in ["cifar10", "cifar100"]: if self.settings.netType == "DARTSNet": genotype = md.genotypes.DARTS self.model = md.DARTSNet(self.settings.init_channels, self.settings.nClasses, self.settings.layers, self.settings.auxiliary, genotype) elif self.settings.netType == "PreResNet": self.model = md.PreResNet(depth=self.settings.depth, num_classes=self.settings.nClasses, wide_factor=self.settings.wideFactor) elif self.settings.netType == "CifarResNeXt": self.model = md.CifarResNeXt(self.settings.cardinality, self.settings.depth, self.settings.nClasses, self.settings.base_width, self.settings.widen_factor) elif self.settings.netType == "ResNet": self.model = md.ResNet(self.settings.depth, self.settings.nClasses) else: assert False, "use %s data while network is %s" % ( self.settings.dataset, self.settings.netType) else: assert False, "unsupported data set: " + self.settings.dataset def _set_trainer(self): # set lr master lr_master = utils.LRPolicy(self.settings.lr, self.settings.nEpochs, self.settings.lrPolicy) params_dict = { 'power': self.settings.power, 'step': self.settings.step, 'end_lr': self.settings.endlr, 'decay_rate': self.settings.decayRate } lr_master.set_params(params_dict=params_dict) # set trainer train_loader, test_loader = self.data_loader.getloader() self.trainer = Trainer(model=self.model, train_loader=train_loader, test_loader=test_loader, lr_master=lr_master, settings=self.settings, logger=self.logger) if self.seg_opt_state is not None: self.trainer.resume(aux_fc_state=self.aux_fc_state, seg_opt_state=self.seg_opt_state, fc_opt_state=self.fc_opt_state) def _save_best_model(self, model, aux_fc): check_point_params = {} if isinstance(model, nn.DataParallel): check_point_params["model"] = model.module.state_dict() else: check_point_params["model"] = model.state_dict() aux_fc_state = [] for i in range(len(aux_fc)): if isinstance(aux_fc[i], nn.DataParallel): aux_fc_state.append(aux_fc[i].module.state_dict()) else: aux_fc_state.append(aux_fc[i].state_dict()) check_point_params["aux_fc"] = aux_fc_state torch.save( check_point_params, os.path.join(self.checkpoint.save_path, "best_model_with_AuxFC.pth")) def _save_checkpoint(self, model, seg_optimizers, fc_optimizers, aux_fc, index=0): check_point_params = {} if isinstance(model, nn.DataParallel): check_point_params["model"] = model.module.state_dict() else: check_point_params["model"] = model.state_dict() seg_opt_state = [] fc_opt_state = [] aux_fc_state = [] for i in range(len(seg_optimizers)): seg_opt_state.append(seg_optimizers[i].state_dict()) for i in range(len(fc_optimizers)): fc_opt_state.append(fc_optimizers[i].state_dict()) if isinstance(aux_fc[i], nn.DataParallel): aux_fc_state.append(aux_fc[i].module.state_dict()) else: aux_fc_state.append(aux_fc[i].state_dict()) check_point_params["seg_opt"] = seg_opt_state check_point_params["fc_opt"] = fc_opt_state check_point_params["aux_fc"] = aux_fc_state torch.save( check_point_params, os.path.join(self.checkpoint.save_path, "checkpoint_%03d.pth" % index)) def run(self, run_count=0): """ training and testing """ best_top1 = 100 best_top5 = 100 start_time = time.time() if self.settings.resume is not None or self.settings.retrain is not None: self.trainer.test(0) for epoch in range(self.start_epoch, self.settings.nEpochs): self.start_epoch = 0 train_error, train_loss, train5_error = self.trainer.train( epoch=epoch) test_error, test_loss, test5_error = self.trainer.test(epoch=epoch) # write and print result if isinstance(train_error, np.ndarray): log_str = "%d\t" % (epoch) for i in range(len(train_error)): log_str += "%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % ( train_error[i], train_loss[i], test_error[i], test_loss[i], train5_error[i], test5_error[i], ) best_flag = False if best_top1 >= test_error[-1]: best_top1 = test_error[-1] best_top5 = test5_error[-1] best_flag = True else: log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % ( epoch, train_error, train_loss, test_error, test_loss, train5_error, test5_error) best_flag = False if best_top1 >= test_error: best_top1 = test_error best_top5 = test5_error best_flag = True self.visualize.write_log(log_str) if best_flag: self.checkpoint.save_model(self.trainer.model, best_flag=best_flag) self._save_best_model(self.trainer.model, self.trainer.auxfc) print colored( "# %d ==>Best Result is: Top1 Error: %f, Top5 Error: %f\n" % (run_count, best_top1, best_top5), "red") else: print colored( "# %d ==>Best Result is: Top1 Error: %f, Top5 Error: %f\n" % (run_count, best_top1, best_top5), "blue") self._save_checkpoint(self.trainer.model, self.trainer.seg_optimizer, self.trainer.fc_optimizer, self.trainer.auxfc) end_time = time.time() # compute cost time time_interval = end_time - start_time t_string = "Running Time is: " + \ str(datetime.timedelta(seconds=time_interval)) + "\n" print t_string # write cost time to file self.visualize.write_readme(t_string) # save experimental results self.visualize.write_readme( "Best Result of all is: Top1 Error: %f, Top5 Error: %f\n" % (best_top1, best_top5))
class ExperimentDesign: def __init__(self, options=None): self.settings = options or Option() self.checkpoint = None self.train_loader = None self.test_loader = None self.model = None self.optimizer_state = None self.trainer = None self.start_epoch = 0 self.test_input = None self.model_analyse = None os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.visible_devices self.settings.set_save_path() self.logger = self.set_logger() self.settings.paramscheck(self.logger) self.visualize = vs.Visualization(self.settings.save_path, self.logger) self.tensorboard_logger = vs.Logger(self.settings.save_path) self.prepare() def set_logger(self): logger = logging.getLogger('sphereface') file_formatter = logging.Formatter( '%(asctime)s %(levelname)s: %(message)s') console_formatter = logging.Formatter('%(message)s') # file log file_handler = logging.FileHandler( os.path.join(self.settings.save_path, "train_test.log")) file_handler.setFormatter(file_formatter) # console log console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(console_formatter) logger.addHandler(file_handler) logger.addHandler(console_handler) logger.setLevel(logging.INFO) return logger def prepare(self): self._set_gpu() self._set_dataloader() self._set_model() self._set_checkpoint() self._set_trainer() self._draw_net() def _set_gpu(self): # set torch seed # init random seed torch.manual_seed(self.settings.manualSeed) torch.cuda.manual_seed(self.settings.manualSeed) assert self.settings.GPU <= torch.cuda.device_count( ) - 1, "Invalid GPU ID" torch.cuda.set_device(self.settings.GPU) cudnn.benchmark = True def _set_dataloader(self): # create data loader data_loader = DataLoader(dataset=self.settings.dataset, batch_size=self.settings.batchSize, data_path=self.settings.dataPath, n_threads=self.settings.nThreads, ten_crop=self.settings.tenCrop, logger=self.logger) self.train_loader, self.test_loader = data_loader.getloader() def _set_checkpoint(self): assert self.model is not None, "please create model first" self.checkpoint = CheckPoint(self.settings.save_path, self.logger) if self.settings.retrain is not None: model_state = self.checkpoint.load_model(self.settings.retrain) self.model = self.checkpoint.load_state(self.model, model_state) if self.settings.resume is not None: model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint( self.settings.resume) self.model = self.checkpoint.load_state(self.model, model_state) self.start_epoch = epoch self.optimizer_state = optimizer_state def _set_model(self): if self.settings.dataset in ["sphere", "sphere_large"]: self.test_input = Variable(torch.randn(1, 3, 112, 96).cuda()) if self.settings.netType == "SphereNet": self.model = md.SphereNet( depth=self.settings.depth, num_features=self.settings.featureDim) elif self.settings.netType == "SphereNIN": self.model = md.SphereNIN( num_features=self.settings.featureDim) elif self.settings.netType == "SphereMobileNet_v2": self.model = md.SphereMobleNet_v2( num_features=self.settings.featureDim) else: assert False, "use %s data while network is %s" % ( self.settings.dataset, self.settings.netType) else: assert False, "unsupport data set: " + self.settings.dataset def _set_trainer(self): lr_master = utils.LRPolicy(self.settings.lr, self.settings.nEpochs, self.settings.lrPolicy) params_dict = { 'power': self.settings.power, 'step': self.settings.step, 'end_lr': self.settings.endlr, 'decay_rate': self.settings.decayRate } lr_master.set_params(params_dict=params_dict) self.trainer = Trainer(model=self.model, train_loader=self.train_loader, test_loader=self.test_loader, lr_master=lr_master, settings=self.settings, logger=self.logger, tensorboard_logger=self.tensorboard_logger, optimizer_state=self.optimizer_state) def _draw_net(self): # visualize model if self.settings.drawNetwork: rand_output, _ = self.trainer.forward(self.test_input) self.visualize.save_network(rand_output) self.logger.info("|===>Draw network done!") self.visualize.write_settings(self.settings) def _model_analyse(self, model): # analyse model model_analyse = utils.ModelAnalyse(model, self.visualize) params_num = model_analyse.params_count() zero_num = model_analyse.zero_count() zero_rate = zero_num * 1.0 / params_num self.logger.info("zero rate is: {}".format(zero_rate)) # save analyse result to file self.visualize.write_readme( "Number of parameters is: %d, number of zeros is: %d, zero rate is: %f" % (params_num, zero_num, zero_rate)) # model_analyse.flops_compute(self.test_input) model_analyse.madds_compute(self.test_input) def run(self, run_count=0): self.logger.info("|===>Start training") best_top1 = 100 start_time = time.time() # self._model_analyse(self.model) # assert False self.trainer.test() # assert False for epoch in range(self.start_epoch, self.settings.nEpochs): if self.trainer.iteration >= self.settings.nIters: break self.start_epoch = 0 # training and testing train_error, train_loss, train5_error = self.trainer.train( epoch=epoch) acc_mean, acc_std, acc_all = self.trainer.test() test_error_mean = 100 - acc_mean * 100 # write and print result log_str = "{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t".format( epoch, train_error, train_loss, train5_error, acc_mean, acc_std) for acc in acc_all: log_str += "%.4f\t" % acc self.visualize.write_log(log_str) best_flag = False if best_top1 >= test_error_mean: best_top1 = test_error_mean best_flag = True self.logger.info( "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format( run_count, best_top1)) else: self.logger.info( "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format( run_count, best_top1)) self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer, epoch) if best_flag: self.checkpoint.save_model(self.model, best_flag=best_flag) if (epoch + 1) % self.settings.drawInterval == 0: self.visualize.draw_curves() end_time = time.time() if isinstance(self.model, nn.DataParallel): self.model = self.model.module # draw experimental curves self.visualize.draw_curves() # compute cost time time_interval = end_time - start_time t_string = "Running Time is: " + \ str(datetime.timedelta(seconds=time_interval)) + "\n" self.logger.info(t_string) # write cost time to file self.visualize.write_readme(t_string) # analyse model self._model_analyse(self.model) return best_top1