示例#1
0
    def __init__(self, settings=None, **kwargs):

        if settings is not None:
            self.is_data_parallel = settings['base']['isDataParallel']
            self.is_showmode_info = settings['ui']['base']['isShowModelInfo']
        else:
            self.is_data_parallel = False
            self.is_showmode_info = False

        self.settings = settings
        ch = SettingHandler(settings)
        self.gpu_ids = ch.get_GPU_ID()
        self.checkpoints_dir = ch.get_check_points_dir()

        self.loss_factory = lf.LossFactory(settings)
        self.optimizer_factory = OptimizerFactory(settings)
        self.show_model_lst = []

        self.model_factory = ModelFactory(settings)
示例#2
0
    def __init__(self, settings):
        import aimaker.models.model_factory as mf
        import aimaker.loss.loss_factory as lf
        import aimaker.optimizers.optimizer_factory as of

        self.settings = settings
        ch = SettingHandler(settings)

        self.gpu_ids = ch.get_GPU_ID()
        self.checkpoints_dir = ch.get_check_points_dir()
        model_factory = mf.ModelFactory(settings)
        loss_factory = lf.LossFactory(settings)
        optimizer_factory = of.OptimizerFactory(settings)

        # for discriminator regularization
        self.pool_fake_A = ImagePool(
            int(settings['controllers']['cycleGAN']['imagePoolSize']))
        self.pool_fake_B = ImagePool(
            int(settings['controllers']['cycleGAN']['imagePoolSize']))

        name = settings['controllers']['cycleGAN']['generatorModel']
        self.netG_A = model_factory.create(name)
        self.netG_B = model_factory.create(name)
        if len(self.gpu_ids):
            self.netG_A = self.netG_A.cuda(self.gpu_ids[0])
            self.netG_B = self.netG_B.cuda(self.gpu_ids[0])

            name = settings['controllers']['cycleGAN']['discriminatorModel']
            self.netD_A = model_factory.create(name)
            self.netD_B = model_factory.create(name)
            if len(self.gpu_ids):
                self.netD_A = self.netD_A.cuda(self.gpu_ids[0])
                self.netD_B = self.netD_B.cuda(self.gpu_ids[0])

        self.loadModels()

        self.criterionGAN = loss_factory.create("GANLoss")
        self.criterionCycle = loss_factory.create(
            settings['controllers']['cycleGAN']['cycleLoss'])
        self.criterionIdt = loss_factory.create(
            settings['controllers']['cycleGAN']['idtLoss'])
        if len(self.gpu_ids):
            self.criterionGAN = self.criterionGAN.cuda(self.gpu_ids[0])
            self.criterionCycle = self.criterionCycle.cuda(self.gpu_ids[0])
            self.criterionIdt = self.criterionIdt.cuda(self.gpu_ids[0])

            # initialize optimizers
        self.optimizer_G = optimizer_factory.create(
            settings['controllers']['cycleGAN']['generatorOptimizer'])(
                it.chain(self.netG_A.parameters(),
                         self.netG_B.parameters()), settings)

        if settings['data']['base']['isTrain']:
            self.optimizer_D_A = optimizer_factory.create(
                settings['controllers']['cycleGAN']['D_AOptimizer'])(
                    self.netD_A.parameters(), settings)
            self.optimizer_D_B = optimizer_factory.create(
                settings['controllers']['cycleGAN']['D_BOptimizer'])(
                    self.netD_B.parameters(), settings)

        if settings['ui']['base']['isShowModelInfo']:
            self.showModel()
示例#3
0
class Trainer:
    def __init__(self, setting_path="settings"):
        self.settings = settings = EasyDict(load_setting(setting_path))
        self.sh = SettingHandler(settings)
        self.controller = self.sh.get_controller()
        self.dataset = DatasetFactory(settings).create(
            settings.data.base.datasetName)
        self.valid_dataset = None
        if settings.data.base.isValid:
            self.valid_dataset = DatasetFactory(settings).create(
                settings['data']['base']['valid']['datasetName'])
        self.data_loader = self.dataset.getDataLoader()
        if settings.data.base.isValid:
            self.valid_data_loader = self.valid_dataset.getDataLoader()

        self.sheckpoint_dir = self.sh.get_check_points_dir()
        self.viz = viewer.TensorBoardXViewer(settings)
        self.train_monitor = viewer.TrainMonitor(settings)

        self.n_update_graphs = self.sh.get_update_interval_of_graphs(
            self.dataset)
        self.n_update_images = self.sh.get_update_interval_of_images(
            self.dataset)
        if settings.data.base.isValid:
            self.validator = Validator(settings)
        self.idx_dic = EasyDict({'train': 0, 'test': 1, 'valid': 2})

    def _getInfo(self):
        info = EasyDict()
        info.current_epoch = 0
        info.train = EasyDict({"v_iter": 0, "current_n_iter": 0})
        info.test = EasyDict({"v_iter": 0, "current_n_iter": 0})
        info.valid = EasyDict({"v_iter": 0, "current_n_iter": 0})
        return info

    def train(self):
        n_epoch = self.settings['base']['nEpoch']
        if self.settings['base']['isView']:
            self.viz.initGraphs()
            self.viz.initImages()

        train_n_iter = len(self.data_loader)
        if self.settings.data.base.isValid:
            valid_n_iter = len(self.valid_data_loader)

        if os.path.exists(self.settings.base.infoPath):
            info = EasyDict(json.load(open(self.settings.base.infoPath)))
        else:
            info = self._getInfo()

        try:
            model_save_interval = self.sh.get_model_save_interval()
            if self.settings.data.base.isValid:
                model_save_interval_valid = self.sh.get_model_save_interval_for_valid(
                )
            for current_epoch in range(info.current_epoch, n_epoch):
                info.current_epoch = current_epoch
                if self.settings['data']['base']['isTrain']:
                    info.mode = "train"
                    print('{}:'.format('train'))
                    self.dataset.set_mode('train')
                    self.dataset.getTransforms()
                    self.viz.setMode('train')
                    self.data_loader = self.dataset.getDataLoader()
                    self.controller.set_mode('train')
                    info = self._learning('train', info, current_epoch,
                                          self.data_loader, train_n_iter)
                if self.settings['data']['base']['isTest']:
                    info.mode = "test"
                    print('{}:'.format('test'))
                    self.dataset.set_mode('test')
                    self.dataset.getTransforms()
                    self.viz.setMode('test')
                    self.data_loader = self.dataset.getDataLoader()
                    self.controller.set_mode('test')
                    info = self._learning('test', info, current_epoch,
                                          self.data_loader, train_n_iter)
                if self.valid_dataset is not None:
                    if self.settings['data']['base']['isValid']:
                        info.mode = "valid"
                        print('{}:'.format('valid'))
                        self.valid_dataset.set_mode('valid')
                        self.valid_dataset.getTransforms()
                        self.viz.setMode('valid')
                        self.valid_data_loader = self.valid_dataset.getDataLoader(
                        )
                        self.controller.set_mode('valid')
                        info = self._learning('valid', info, current_epoch,
                                              self.valid_data_loader,
                                              valid_n_iter)
                        if current_epoch != 0 and not current_epoch % model_save_interval_valid:
                            self.controller._save_model(
                                self.controller.get_model(),
                                self.settings['validator']['base']
                                ['modelPath'],
                                is_fcnt=False)
                            self.validator.upload(
                            )  #self.settings['valid_data']['data']['base']['datasetName'])

                if not current_epoch % model_save_interval:
                    self.controller.save_models()
        except:
            import traceback
            traceback.print_exc()
        self.controller.save_models()
        if self.settings['base']['isView']:
            self.viz.destructVisdom()

    def _saveInfo(self, info):
        with open(self.settings.base.infoPath, 'w') as fp:
            json.dump(info, fp)

    def _learning(self, mode, info, current_epoch, data_loader, train_n_iter):
        n_iter = len(data_loader)
        ratio = train_n_iter / n_iter
        v_iter = info[mode].v_iter
        self.viz.setTotalDataLoaderLength(len(self.data_loader))
        is_volatile = False if mode == 'train' else True
        for current_n_iter, data in enumerate(data_loader):
            if current_n_iter < info[mode].current_n_iter:
                continue
            info[mode].current_n_iter = current_n_iter
            info[mode].v_iter = v_iter
            try:
                self.controller.set_input(data, is_volatile)
                self.controller.forward()
                if mode == 'train':
                    self.controller.backward()
                loss_dic = self.controller.get_losses()
                if mode == 'valid':
                    self.validator.setLosses(loss_dic)
                self.train_monitor.setLosses(loss_dic)
                self.train_monitor.dumpCurrentProgress(current_epoch,
                                                       current_n_iter, n_iter)

                if not current_n_iter % self.n_update_graphs:
                    if self.settings['base']['isView']:
                        self.viz.updateGraphs(ratio * v_iter,
                                              loss_dic,
                                              idx=self.idx_dic[mode])
                if not current_n_iter % self.n_update_images:
                    if self.settings['base']['isView']:
                        self.viz.updateImages(self.controller.get_images(),
                                              current_n_iter)
            except KeyboardInterrupt:
                self._saveInfo(info)
                sys.exit()
                break
            except FileNotFoundError:
                import traceback
                traceback.print_exc()
            except:
                import traceback
                traceback.print_exc()
                break
            v_iter += 1
        #if mode == 'valid':
        #    self.validator.uploadModelIfSOTA(current_epoch)
        info[mode].v_iter = v_iter
        info[mode].current_n_iter = 0  # reset
        self.train_monitor.dumpAverageLossOnEpoch(current_epoch)
        self.train_monitor.flash()
        return info