#Set model model = LeNet_5() model = model.to(device) #Set checkpoint checkpoint = CheckPoint(config.save_path) #Set optimizer optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) scheduler = None #Set trainer logger = Logger(config.save_path) trainer = LeNet_5Trainer(config.lr, train_loader, valid_x, valid_y, model, optimizer, scheduler, logger, device) epoch_dict = 1 model_dict, optimizer_dict, epoch_dict = checkpoint.load_checkpoint( os.path.join(checkpoint.save_path, 'checkpoint_005.pth')) model.load_state_dict(model_dict) optimizer.load_state_dict(optimizer_dict) for epoch in range(epoch_dict, config.nEpochs + 1): cls_loss_, accuracy = trainer.train(epoch) checkpoint.save_checkpoint(model, optimizer, epoch=epoch, index=epoch, tag="123123")
class ExperimentDesign: def __init__(self, options=None): self.settings = options or Option() self.checkpoint = None self.data_loader = None self.model = None self.optimizer_state = None self.trainer = None self.start_epoch = 0 self.model_analyse = None self.visualize = vs.Visualization(self.settings.save_path) self.logger = vs.Logger(self.settings.save_path) self.test_input = None self.lr_master = None self.prepare() def prepare(self): self._set_gpu() self._set_dataloader() self._set_model() self._set_checkpoint() self._set_parallel() self._set_lr_policy() self._set_trainer() self._draw_net() def _set_gpu(self): # set torch seed # init random seed torch.manual_seed(self.settings.manualSeed) torch.cuda.manual_seed(self.settings.manualSeed) assert self.settings.GPU <= torch.cuda.device_count( ) - 1, "Invalid GPU ID" torch.cuda.set_device(self.settings.GPU) print("|===>Set GPU done!") def _set_dataloader(self): # create data loader self.data_loader = DataLoader(dataset=self.settings.dataset, batch_size=self.settings.batchSize, data_path=self.settings.dataPath, n_threads=self.settings.nThreads, ten_crop=self.settings.tenCrop) print("|===>Set data loader done!") def _set_checkpoint(self): assert self.model is not None, "please create model first" self.checkpoint = CheckPoint(self.settings.save_path) if self.settings.retrain is not None: model_state = self.checkpoint.load_model(self.settings.retrain) self.model = self.checkpoint.load_state(self.model, model_state) if self.settings.resume is not None: model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint( self.settings.resume) self.model = self.checkpoint.load_state(self.model, model_state) # self.start_epoch = epoch # self.optimizer_state = optimizer_state print("|===>Set checkpoint done!") def _set_model(self): if self.settings.dataset == "sphere": if self.settings.netType == "SphereNet": self.model = md.SphereNet( depth=self.settings.depth, num_features=self.settings.featureDim) elif self.settings.netType == "SphereNIN": self.model = md.SphereNIN( num_features=self.settings.featureDim) elif self.settings.netType == "wcSphereNet": self.model = md.wcSphereNet( depth=self.settings.depth, num_features=self.settings.featureDim, rate=self.settings.rate) else: assert False, "use %s data while network is %s" % ( self.settings.dataset, self.settings.netType) else: assert False, "unsupport data set: " + self.settings.dataset print("|===>Set model done!") def _set_parallel(self): self.model = utils.data_parallel(self.model, self.settings.nGPU, self.settings.GPU) def _set_lr_policy(self): self.lr_master = utils.LRPolicy(self.settings.lr, self.settings.nIters, self.settings.lrPolicy) params_dict = { 'gamma': self.settings.gamma, 'step': self.settings.step, 'end_lr': self.settings.endlr, 'decay_rate': self.settings.decayRate } self.lr_master.set_params(params_dict=params_dict) def _set_trainer(self): train_loader, test_loader = self.data_loader.getloader() self.trainer = Trainer(model=self.model, lr_master=self.lr_master, n_epochs=self.settings.nEpochs, n_iters=self.settings.nIters, train_loader=train_loader, test_loader=test_loader, feature_dim=self.settings.featureDim, momentum=self.settings.momentum, weight_decay=self.settings.weightDecay, optimizer_state=self.optimizer_state, logger=self.logger) def _draw_net(self): # visualize model if self.settings.dataset == "sphere": rand_input = torch.randn(1, 3, 112, 96) else: assert False, "invalid data set" rand_input = Variable(rand_input.cuda()) self.test_input = rand_input if self.settings.drawNetwork: rand_output, _ = self.trainer.forward(rand_input) self.visualize.save_network(rand_output) print("|===>Draw network done!") self.visualize.write_settings(self.settings) def pruning(self, run_count=0): net_type = None if self.settings.dataset == "sphere": if self.settings.netType == "wcSphereNet": net_type = "SphereNet" assert net_type is not None, "net_type for prune is NoneType" self.trainer.test() if isinstance(self.model, nn.DataParallel): model = self.model.module else: model = self.model if net_type == "SphereNet": model_prune = prune.SpherePrune(model) model_prune.run() self.trainer.reset_model(model_prune.model) self.model = self.trainer.model self.trainer.test() self.checkpoint.save_model(self.trainer.model, index=run_count, tag="pruning") # analyse model self.model_analyse = utils.ModelAnalyse(self.trainer.model, self.visualize) params_num = self.model_analyse.params_count() self.model_analyse.flops_compute(self.test_input) def fine_tuning(self, run_count=0): # set lr self.settings.lr = 0.01 # 0.1 self.settings.nIters = 12000 # 28000 self.settings.lrPolicy = "multi_step" self.settings.decayRate = 0.1 self.settings.step = [6000] # [16000, 24000] self._set_lr_policy() self.trainer.reset_lr(self.lr_master, self.settings.nIters) # run fine-tuning self.training(run_count, tag="fine-tuning") def retrain(self, run_count=0): self.settings.lr = 0.1 self.settings.nIters = 28000 self.settings.lrPolicy = "multi_step" self.settings.decayRate = 0.1 self.settings.step = [16000, 24000] self._set_lr_policy() self.trainer.reset_lr(self.lr_master, self.settings.nIters) # run retrain self.training(run_count, tag="training") def run(self, run_count=0): """ if run_count == 0: print "|===> training" self.retrain(run_count) else: print "|===> fine-tuning" self.fine_tuning(run_count) """ if run_count >= 1: print("|===> training") self.retrain(run_count) self.trainer.reset_model(self.model) print("|===> fine-tuning") self.fine_tuning(run_count) print("|===> pruning") self.pruning(run_count) # keep margin_linear static layer_count = 0 for layer in self.model.modules(): if isinstance(layer, md.MarginLinear): layer.iteration.fill_(0) layer.margin_type.data.fill_(1) layer.weight.requires_grad = False elif isinstance(layer, nn.Linear): layer.weight.requires_grad = False layer.bias.requires_grad = False elif isinstance(layer, nn.Conv2d): if layer.bias is not None: bias_flag = True else: bias_flag = False new_layer = prune.wcConv2d(layer.weight.size(1), layer.weight.size(0), kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=bias_flag, rate=self.settings.rate) new_layer.weight.data.copy_(layer.weight.data) if layer.bias is not None: new_layer.bias.data.copy_(layer.bias.data) if layer_count == 1: self.model.conv2 = new_layer elif layer_count == 2: self.model.conv3 = new_layer elif layer_count == 3: self.model.conv4 = new_layer layer_count += 1 print(self.model) self.trainer.reset_model(self.model) # assert False def training(self, run_count=0, tag="training"): best_top1 = 100 # start_time = time.time() self.trainer.test() # assert False for epoch in range(self.start_epoch, self.settings.nEpochs): if self.trainer.iteration >= self.trainer.n_iters: break start_epoch = 0 # training and testing train_error, train_loss, train5_error = self.trainer.train( epoch=epoch) acc_mean, acc_std, acc_all = self.trainer.test() test_error_mean = 100 - acc_mean * 100 # write and print result log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % ( epoch, train_error, train_loss, train5_error, acc_mean, acc_std) for acc in acc_all: log_str += "%.4f\t" % acc self.visualize.write_log(log_str) best_flag = False if best_top1 >= test_error_mean: best_top1 = test_error_mean best_flag = True print( colored( "# %d ==>Best Result is: Top1 Error: %f\n" % (run_count, best_top1), "red")) else: print( colored( "# %d ==>Best Result is: Top1 Error: %f\n" % (run_count, best_top1), "blue")) self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer, epoch) if best_flag: self.checkpoint.save_model(self.model, best_flag=best_flag, tag="%s_%d" % (tag, run_count)) if (epoch + 1) % self.settings.drawInterval == 0: self.visualize.draw_curves() for name, value in self.model.named_parameters(): if 'weight' in name: name = name.replace('.', '/') self.logger.histo_summary( name, value.data.cpu().numpy(), run_count * self.settings.nEpochs + epoch + 1) if value.grad is not None: self.logger.histo_summary( name + "/grad", value.grad.data.cpu().numpy(), run_count * self.settings.nEpochs + epoch + 1) # end_time = time.time() if isinstance(self.model, nn.DataParallel): self.model = self.model.module # draw experimental curves self.visualize.draw_curves() # compute cost time # time_interval = end_time - start_time # t_string = "Running Time is: " + \ # str(datetime.timedelta(seconds=time_interval)) + "\n" # print(t_string) # write cost time to file # self.visualize.write_readme(t_string) # save experimental results self.model_analyse = utils.ModelAnalyse(self.trainer.model, self.visualize) self.visualize.write_readme("Best Result of all is: Top1 Error: %f\n" % best_top1) # analyse model params_num = self.model_analyse.params_count() # save analyse result to file self.visualize.write_readme("Number of parameters is: %d" % (params_num)) self.model_analyse.prune_rate() self.model_analyse.flops_compute(self.test_input) return best_top1
class ExperimentDesign: def __init__(self, options=None): self.settings = options or Option() self.checkpoint = None self.train_loader = None self.test_loader = None self.model = None self.optimizer_state = None self.trainer = None self.start_epoch = 0 self.test_input = None self.model_analyse = None os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.visible_devices self.settings.set_save_path() self.logger = self.set_logger() self.settings.paramscheck(self.logger) self.visualize = vs.Visualization(self.settings.save_path, self.logger) self.tensorboard_logger = vs.Logger(self.settings.save_path) self.prepare() def set_logger(self): logger = logging.getLogger('sphereface') file_formatter = logging.Formatter( '%(asctime)s %(levelname)s: %(message)s') console_formatter = logging.Formatter('%(message)s') # file log file_handler = logging.FileHandler( os.path.join(self.settings.save_path, "train_test.log")) file_handler.setFormatter(file_formatter) # console log console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(console_formatter) logger.addHandler(file_handler) logger.addHandler(console_handler) logger.setLevel(logging.INFO) return logger def prepare(self): self._set_gpu() self._set_dataloader() self._set_model() self._set_checkpoint() self._set_trainer() self._draw_net() def _set_gpu(self): # set torch seed # init random seed torch.manual_seed(self.settings.manualSeed) torch.cuda.manual_seed(self.settings.manualSeed) assert self.settings.GPU <= torch.cuda.device_count( ) - 1, "Invalid GPU ID" torch.cuda.set_device(self.settings.GPU) cudnn.benchmark = True def _set_dataloader(self): # create data loader data_loader = DataLoader(dataset=self.settings.dataset, batch_size=self.settings.batchSize, data_path=self.settings.dataPath, n_threads=self.settings.nThreads, ten_crop=self.settings.tenCrop, logger=self.logger) self.train_loader, self.test_loader = data_loader.getloader() def _set_checkpoint(self): assert self.model is not None, "please create model first" self.checkpoint = CheckPoint(self.settings.save_path, self.logger) if self.settings.retrain is not None: model_state = self.checkpoint.load_model(self.settings.retrain) self.model = self.checkpoint.load_state(self.model, model_state) if self.settings.resume is not None: model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint( self.settings.resume) self.model = self.checkpoint.load_state(self.model, model_state) self.start_epoch = epoch self.optimizer_state = optimizer_state def _set_model(self): if self.settings.dataset in ["sphere", "sphere_large"]: self.test_input = Variable(torch.randn(1, 3, 112, 96).cuda()) if self.settings.netType == "SphereNet": self.model = md.SphereNet( depth=self.settings.depth, num_features=self.settings.featureDim) elif self.settings.netType == "SphereNIN": self.model = md.SphereNIN( num_features=self.settings.featureDim) elif self.settings.netType == "SphereMobileNet_v2": self.model = md.SphereMobleNet_v2( num_features=self.settings.featureDim) else: assert False, "use %s data while network is %s" % ( self.settings.dataset, self.settings.netType) else: assert False, "unsupport data set: " + self.settings.dataset def _set_trainer(self): lr_master = utils.LRPolicy(self.settings.lr, self.settings.nEpochs, self.settings.lrPolicy) params_dict = { 'power': self.settings.power, 'step': self.settings.step, 'end_lr': self.settings.endlr, 'decay_rate': self.settings.decayRate } lr_master.set_params(params_dict=params_dict) self.trainer = Trainer(model=self.model, train_loader=self.train_loader, test_loader=self.test_loader, lr_master=lr_master, settings=self.settings, logger=self.logger, tensorboard_logger=self.tensorboard_logger, optimizer_state=self.optimizer_state) def _draw_net(self): # visualize model if self.settings.drawNetwork: rand_output, _ = self.trainer.forward(self.test_input) self.visualize.save_network(rand_output) self.logger.info("|===>Draw network done!") self.visualize.write_settings(self.settings) def _model_analyse(self, model): # analyse model model_analyse = utils.ModelAnalyse(model, self.visualize) params_num = model_analyse.params_count() zero_num = model_analyse.zero_count() zero_rate = zero_num * 1.0 / params_num self.logger.info("zero rate is: {}".format(zero_rate)) # save analyse result to file self.visualize.write_readme( "Number of parameters is: %d, number of zeros is: %d, zero rate is: %f" % (params_num, zero_num, zero_rate)) # model_analyse.flops_compute(self.test_input) model_analyse.madds_compute(self.test_input) def run(self, run_count=0): self.logger.info("|===>Start training") best_top1 = 100 start_time = time.time() # self._model_analyse(self.model) # assert False self.trainer.test() # assert False for epoch in range(self.start_epoch, self.settings.nEpochs): if self.trainer.iteration >= self.settings.nIters: break self.start_epoch = 0 # training and testing train_error, train_loss, train5_error = self.trainer.train( epoch=epoch) acc_mean, acc_std, acc_all = self.trainer.test() test_error_mean = 100 - acc_mean * 100 # write and print result log_str = "{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t".format( epoch, train_error, train_loss, train5_error, acc_mean, acc_std) for acc in acc_all: log_str += "%.4f\t" % acc self.visualize.write_log(log_str) best_flag = False if best_top1 >= test_error_mean: best_top1 = test_error_mean best_flag = True self.logger.info( "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format( run_count, best_top1)) else: self.logger.info( "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format( run_count, best_top1)) self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer, epoch) if best_flag: self.checkpoint.save_model(self.model, best_flag=best_flag) if (epoch + 1) % self.settings.drawInterval == 0: self.visualize.draw_curves() end_time = time.time() if isinstance(self.model, nn.DataParallel): self.model = self.model.module # draw experimental curves self.visualize.draw_curves() # compute cost time time_interval = end_time - start_time t_string = "Running Time is: " + \ str(datetime.timedelta(seconds=time_interval)) + "\n" self.logger.info(t_string) # write cost time to file self.visualize.write_readme(t_string) # analyse model self._model_analyse(self.model) return best_top1