def __init__(self, settings=None, **kwargs): if settings is not None: self.is_data_parallel = settings['base']['isDataParallel'] self.is_showmode_info = settings['ui']['base']['isShowModelInfo'] else: self.is_data_parallel = False self.is_showmode_info = False self.settings = settings ch = SettingHandler(settings) self.gpu_ids = ch.get_GPU_ID() self.checkpoints_dir = ch.get_check_points_dir() self.loss_factory = lf.LossFactory(settings) self.optimizer_factory = OptimizerFactory(settings) self.show_model_lst = [] self.model_factory = ModelFactory(settings)
def __init__(self, settings): import aimaker.models.model_factory as mf import aimaker.loss.loss_factory as lf import aimaker.optimizers.optimizer_factory as of self.settings = settings ch = SettingHandler(settings) self.gpu_ids = ch.get_GPU_ID() self.checkpoints_dir = ch.get_check_points_dir() model_factory = mf.ModelFactory(settings) loss_factory = lf.LossFactory(settings) optimizer_factory = of.OptimizerFactory(settings) # for discriminator regularization self.pool_fake_A = ImagePool( int(settings['controllers']['cycleGAN']['imagePoolSize'])) self.pool_fake_B = ImagePool( int(settings['controllers']['cycleGAN']['imagePoolSize'])) name = settings['controllers']['cycleGAN']['generatorModel'] self.netG_A = model_factory.create(name) self.netG_B = model_factory.create(name) if len(self.gpu_ids): self.netG_A = self.netG_A.cuda(self.gpu_ids[0]) self.netG_B = self.netG_B.cuda(self.gpu_ids[0]) name = settings['controllers']['cycleGAN']['discriminatorModel'] self.netD_A = model_factory.create(name) self.netD_B = model_factory.create(name) if len(self.gpu_ids): self.netD_A = self.netD_A.cuda(self.gpu_ids[0]) self.netD_B = self.netD_B.cuda(self.gpu_ids[0]) self.loadModels() self.criterionGAN = loss_factory.create("GANLoss") self.criterionCycle = loss_factory.create( settings['controllers']['cycleGAN']['cycleLoss']) self.criterionIdt = loss_factory.create( settings['controllers']['cycleGAN']['idtLoss']) if len(self.gpu_ids): self.criterionGAN = self.criterionGAN.cuda(self.gpu_ids[0]) self.criterionCycle = self.criterionCycle.cuda(self.gpu_ids[0]) self.criterionIdt = self.criterionIdt.cuda(self.gpu_ids[0]) # initialize optimizers self.optimizer_G = optimizer_factory.create( settings['controllers']['cycleGAN']['generatorOptimizer'])( it.chain(self.netG_A.parameters(), self.netG_B.parameters()), settings) if settings['data']['base']['isTrain']: self.optimizer_D_A = optimizer_factory.create( settings['controllers']['cycleGAN']['D_AOptimizer'])( self.netD_A.parameters(), settings) self.optimizer_D_B = optimizer_factory.create( settings['controllers']['cycleGAN']['D_BOptimizer'])( self.netD_B.parameters(), settings) if settings['ui']['base']['isShowModelInfo']: self.showModel()
class Trainer: def __init__(self, setting_path="settings"): self.settings = settings = EasyDict(load_setting(setting_path)) self.sh = SettingHandler(settings) self.controller = self.sh.get_controller() self.dataset = DatasetFactory(settings).create( settings.data.base.datasetName) self.valid_dataset = None if settings.data.base.isValid: self.valid_dataset = DatasetFactory(settings).create( settings['data']['base']['valid']['datasetName']) self.data_loader = self.dataset.getDataLoader() if settings.data.base.isValid: self.valid_data_loader = self.valid_dataset.getDataLoader() self.sheckpoint_dir = self.sh.get_check_points_dir() self.viz = viewer.TensorBoardXViewer(settings) self.train_monitor = viewer.TrainMonitor(settings) self.n_update_graphs = self.sh.get_update_interval_of_graphs( self.dataset) self.n_update_images = self.sh.get_update_interval_of_images( self.dataset) if settings.data.base.isValid: self.validator = Validator(settings) self.idx_dic = EasyDict({'train': 0, 'test': 1, 'valid': 2}) def _getInfo(self): info = EasyDict() info.current_epoch = 0 info.train = EasyDict({"v_iter": 0, "current_n_iter": 0}) info.test = EasyDict({"v_iter": 0, "current_n_iter": 0}) info.valid = EasyDict({"v_iter": 0, "current_n_iter": 0}) return info def train(self): n_epoch = self.settings['base']['nEpoch'] if self.settings['base']['isView']: self.viz.initGraphs() self.viz.initImages() train_n_iter = len(self.data_loader) if self.settings.data.base.isValid: valid_n_iter = len(self.valid_data_loader) if os.path.exists(self.settings.base.infoPath): info = EasyDict(json.load(open(self.settings.base.infoPath))) else: info = self._getInfo() try: model_save_interval = self.sh.get_model_save_interval() if self.settings.data.base.isValid: model_save_interval_valid = self.sh.get_model_save_interval_for_valid( ) for current_epoch in range(info.current_epoch, n_epoch): info.current_epoch = current_epoch if self.settings['data']['base']['isTrain']: info.mode = "train" print('{}:'.format('train')) self.dataset.set_mode('train') self.dataset.getTransforms() self.viz.setMode('train') self.data_loader = self.dataset.getDataLoader() self.controller.set_mode('train') info = self._learning('train', info, current_epoch, self.data_loader, train_n_iter) if self.settings['data']['base']['isTest']: info.mode = "test" print('{}:'.format('test')) self.dataset.set_mode('test') self.dataset.getTransforms() self.viz.setMode('test') self.data_loader = self.dataset.getDataLoader() self.controller.set_mode('test') info = self._learning('test', info, current_epoch, self.data_loader, train_n_iter) if self.valid_dataset is not None: if self.settings['data']['base']['isValid']: info.mode = "valid" print('{}:'.format('valid')) self.valid_dataset.set_mode('valid') self.valid_dataset.getTransforms() self.viz.setMode('valid') self.valid_data_loader = self.valid_dataset.getDataLoader( ) self.controller.set_mode('valid') info = self._learning('valid', info, current_epoch, self.valid_data_loader, valid_n_iter) if current_epoch != 0 and not current_epoch % model_save_interval_valid: self.controller._save_model( self.controller.get_model(), self.settings['validator']['base'] ['modelPath'], is_fcnt=False) self.validator.upload( ) #self.settings['valid_data']['data']['base']['datasetName']) if not current_epoch % model_save_interval: self.controller.save_models() except: import traceback traceback.print_exc() self.controller.save_models() if self.settings['base']['isView']: self.viz.destructVisdom() def _saveInfo(self, info): with open(self.settings.base.infoPath, 'w') as fp: json.dump(info, fp) def _learning(self, mode, info, current_epoch, data_loader, train_n_iter): n_iter = len(data_loader) ratio = train_n_iter / n_iter v_iter = info[mode].v_iter self.viz.setTotalDataLoaderLength(len(self.data_loader)) is_volatile = False if mode == 'train' else True for current_n_iter, data in enumerate(data_loader): if current_n_iter < info[mode].current_n_iter: continue info[mode].current_n_iter = current_n_iter info[mode].v_iter = v_iter try: self.controller.set_input(data, is_volatile) self.controller.forward() if mode == 'train': self.controller.backward() loss_dic = self.controller.get_losses() if mode == 'valid': self.validator.setLosses(loss_dic) self.train_monitor.setLosses(loss_dic) self.train_monitor.dumpCurrentProgress(current_epoch, current_n_iter, n_iter) if not current_n_iter % self.n_update_graphs: if self.settings['base']['isView']: self.viz.updateGraphs(ratio * v_iter, loss_dic, idx=self.idx_dic[mode]) if not current_n_iter % self.n_update_images: if self.settings['base']['isView']: self.viz.updateImages(self.controller.get_images(), current_n_iter) except KeyboardInterrupt: self._saveInfo(info) sys.exit() break except FileNotFoundError: import traceback traceback.print_exc() except: import traceback traceback.print_exc() break v_iter += 1 #if mode == 'valid': # self.validator.uploadModelIfSOTA(current_epoch) info[mode].v_iter = v_iter info[mode].current_n_iter = 0 # reset self.train_monitor.dumpAverageLossOnEpoch(current_epoch) self.train_monitor.flash() return info