class EmdrosApplication:
    def __init__(self,
                 options=None,
                 title='EmdrosApplication',
                 DO_REF=True,
                 DO_OUT=True,
                 DO_LBL=True):
        if options is None:
            options = Options()
        kernel_cfg_name = options.get('kernel')
        if kernel_cfg_name is not None:
            kernel_cfg_name = 'emdros_application.syscfg.' + re.sub(
                '\.py[c]?$', '', kernel_cfg_name)
        else:
            kernel_cfg_name = syscfg.config.DEFAULT_KERNEL
        import importlib
        kernel = importlib.import_module(kernel_cfg_name)
        self.kernel = kernel
        #kernel = __import__(kernel_cfg_name)

        self.title = title
        self.DO_REF = DO_REF
        self.DO_OUT = DO_OUT
        self.DO_LBL = DO_LBL
        self.title = title

        self.spinner = Spinner()

        if options is None:
            self.options = Options(
                addPathAndExt('options', kernel.CFG_DIR, kernel.CFG_EXT))
        else:
            self.options = options

        self.cfg = self.configure(self.options, kernel)
        self.modeCfgs = self.setupModeConfigurations(kernel)

        self.mql = MQLEngine(database=self.database,
                             usr=self.usr,
                             pwd=self.pwd,
                             be=self.backend,
                             domainfile=self.domainqueryfile,
                             domain=self.domain,
                             VERBOSE=self.VERBOSE,
                             verbose=self.verbose,
                             test=self.test,
                             outstream=self.outstream,
                             errstream=self.errstream,
                             kernel=kernel)

        if self.DO_OUT or self.DO_REF:
            self.ref, self.out = self.setupOutput(kernel=kernel)

        if self.DO_LBL:
            self.lbl = self.setupLabelManagers(options.args, kernel=kernel)

        if self.options.get('gui'):
            self.gui = GUI(title=title, app=self)
            self.gui.mainloop()
        else:
            self.gui = None

    # overload this with something useful
    def next(self):
        if self.gui:
            self.gui.writeln('EmdrosApplication says: "next"')

    # overload this with something useful
    def quit(self):
        exit()

    def configure(self, options, kernel):

        # CFG_FILE is initially defined in the kernel configuration
        # program option -C (--CFG) overrides it

        cfg_name = options.get('CFG')
        #if cfg_name is not None:
        #CFG_FILE = cfg_name
        #else:
        if cfg_name is None:
            cfg_name = kernel.DEFAULT_KERNEL_CFG
        self.cfg = Configuration(cfg_name, kernel=kernel)

        # settings in local_cnf override settings in CONFIG_FILE
        local_cfg = options.get('cfg')
        if local_cfg is not None:
            local_cfg = Configuration(local_cfg, kernel=kernel)
            self.cfg.overrideWith(local_cfg)

        # settings in command line options override settings in
        # CONFIG_FILE and local_cnf

        self.cfg.overrideWith(options.toConfiguration(kernel=kernel))

        ### set global variables

        self.auto = self.cfg.get('auto')
        self.verbose = self.cfg.get('verbose')
        self.VERBOSE = self.cfg.get('VERBOSE')

        backends = {
            'NO_BACKEND': EmdrosPy.kBackendNone,
            'POSTGRES': EmdrosPy.kPostgreSQL,
            'MYSQL': EmdrosPy.kMySQL,
            'SQL2': EmdrosPy.kSQLite2,
            'SQL3': EmdrosPy.kSQLite3
        }

        if self.cfg.get('backend') is not None:
            self.backend = backends[self.cfg.get('backend')]
        else:
            try:
                self.backend = backends[kernel.DEFAULT_BACKEND]
            except:
                exitOnError('no backend defined in configuration files')

        self.database = self.cfg.get('database')
        if self.cfg.get('usr'):
            self.usr = self.cfg.get('usr')
        else:
            self.usr = DEFAULT_USR
        if self.cfg.get('pwd'):
            self.pwd = self.cfg.get('pwd')
        else:
            self.pwd = DEFAULT_PWD

        if self.cfg.has_key('modes'):
            self.modes = self.cfg.get_multi('modes')  # has proper order
        else:
            self.modes = self.options.validVals('mode')
        self.mode = self.cfg.get('mode')

        if self.cfg.has_key('formats'):
            self.formats = self.cfg.get_multi('formats')
        else:
            self.formats = self.options.validVals('format')
            self.cfg.set_multi('formats', self.formats)
        self.format = self.cfg.get('format')
        self.json = self.cfg.get('jsonoverride')

        self.domain = self.cfg.get('domain')
        self.domainqueryfile = self.cfg.get('domainqueryfile')

        self.test = self.cfg.get('test')

        self.outstream = eval(self.cfg.get('stdout'))
        self.errstream = eval(self.cfg.get('stderr'))

        return self.cfg

    def setupModeConfigurations(self, kernel):
        modeCfgs = {}
        for m in self.modes:
            if self.cfg.has_key('modecfg', m):
                cfgname = self.cfg.get('modecfg', m)
                cfg = Configuration(cfgname,
                                    kernel=kernel,
                                    verbose=self.VERBOSE)
                if cfg is not None:
                    self.Vmsg('cfg for %s mode: %s' % (m, cfgname))
                modeCfgs[m] = cfg
            else:
                warning(
                    "no key 'modecfg' defined for mode %s in configuration file"
                    % m)
        return modeCfgs

    def setupLabelManagers(self, opt_otypes, kernel=None):
        self.lblManagers = {}
        for m in self.modes:
            if m == self.mode:
                opt_o = opt_otypes
            else:
                opt_o = []
            try:
                self.lblManagers[m] = LabelManager(mql=self.mql,
                                                   mode=m,
                                                   globalcfg=self.cfg,
                                                   local_cfg=self.modeCfgs[m],
                                                   opt_otypes=opt_o,
                                                   kernel=kernel)
            except:
                exitOnError()
        if self.options.get('sync'):
            if self.options.get('auto'):
                mlist = self.modes
            else:
                mlist = userCheckList(self.modes,
                                      question="toggle modes to update")

            if mlist == []:
                writeln('no labels will be updated', outstream=self.errstream)
            else:
                for m in mlist:
                    self.updateLabels(m)

        return self.lblManagers[self.mode]

    def setupOutput(self, kernel):
        if self.DO_REF:
            ref = ReferenceManager(self.mql, self.cfg, kernel=kernel)
        else:
            ref = None

        if self.DO_OUT:
            out = OutputManager(self.mql,
                                self.cfg,
                                format=self.format,
                                json=self.json,
                                outstream=self.outstream,
                                msgstream=self.errstream,
                                verbose=self.verbose,
                                VERBOSE=self.VERBOSE,
                                kernel=kernel)
            if out is None:
                writeln('output manager based on %s could not be initiated' %
                        self.json)
        else:
            out = None
        return ref, out

    def Vmsg(self, msg):
        if self.VERBOSE:
            writeln(msg, outstream=self.errstream)

    def vmsg(self, msg):
        if self.verbose or self.VERBOSE:
            writeln(msg, outstream=self.errstream)

    def setMode(self, mode):
        if mode not in self.modes:
            raise Exception('wrong value %s in EmdrosApplication.setMode(): %s' % \
              mode, repr(self.modes))
        else:
            self.mode = mode
            self.mql.mode = mode
            if self.DO_OUT and self.out is not None:
                self.out.mode = mode
            if self.DO_LBL and self.lblManagers is not None:
                self.lbl = self.lblManagers[mode]

    def setFormat(self, format):
        if format not in self.formats:
            raise Exception('wrong value %s in EmdrosApplication.setMode(): %s' % \
              format, repr(self.formats))
        else:
            self.format = format
            if self.out is not None:
                self.out.format = format

    def updateLabels(self, mode=None, forced=False, mql=None, objTypes=None):
        if mode is None: mode = self.mode
        if self.lblManagers is None or self.lblManagers[mode] is None:
            warning('no label manager for %s existing' % mode)
        else:
            if mql is None: mql = self.mql
            self.lblManagers[mode].updateAll(forced=forced,
                                             mql=mql,
                                             objTypes=objTypes)

    def write(self, msg='', outstream=None):
        if self.gui:
            self.gui.write(msg, outstream=outstream)
        else:
            write(msg, outstream=outstream)

    def writeln(self, msg='', outstream=None):
        if self.gui:
            self.gui.writeln(msg)
        else:
            writeln(msg, outstream=outstream)

    def userMultiLineInput(self, msg=None):
        if self.gui:
            return self.gui.userMultiLineInput()
        else:
            return userMultiLineInput(msg)
Exemplo n.º 2
0
def train(config_file):
    configuration = Configuration(config_file)

    # Instantiate model
    model_module, model_name = configuration.get('model', 'module'), configuration.get('model', 'name')
    model = instantiate(model_module, model_name)

    # Instantiate loss
    loss_module, loss_name = configuration.get('loss', 'module'), configuration.get('loss', 'name')
    loss = instantiate(loss_module, loss_name)

    # Instantiate optimizer
    optimizer_module, optimizer_name = configuration.get('optimizer', 'module'), configuration.get('optimizer', 'name')
    optimizer = instantiate(optimizer_module, optimizer_name)

    # Instantiate experiment
    experiment_module = configuration.get('experiment', 'module')
    experiment_name = configuration.get('experiment', 'name')
    experiment = instantiate(experiment_module, experiment_name)
    cluster_method = configuration.get('experiment', 'method')
    k = int(configuration.get('Cluster parameters', 'nb_clusters'))

    # Initialize experiment
    experiment = experiment(model, cluster_method, k, loss, optimizer, configuration)

    # Instantiate metrics
    train_metrics = instantiate_metrics(configuration.get_section('train metrics'), experiment, 'train')
    val_metrics = None
    if configuration.has_section('val metrics'):
        val_metrics = instantiate_metrics(configuration.get_section('val metrics'), experiment, 'val')
    cluster_metrics = instantiate_metrics(configuration.get_section('cluster metrics'), experiment, 'cluster')

    # Get train dataloader parameters
    transform = transforms.Compose([transforms.ToTensor()])
    train_batch_size = int(configuration.get('train loader', 'batch size'))
    train_shuffle = bool(configuration.get('train loader', 'shuffle'))
    train_num_workers = int(configuration.get('train loader', 'num workers'))
    pin_memory = torch.cuda.is_available()

    # Initialize train dataloader
    train_dataloader_module = configuration.get('train loader', 'module')
    train_dataloader_name = configuration.get('train loader', 'name')
    train_skip = int(configuration.get('train loader', 'skip'))
    train_split = configuration.get('train loader', 'split')
    train_dataset = instantiate(train_dataloader_module, train_dataloader_name)

    # Get valid dataloader parameters
    val_batch_size = int(configuration.get('valid loader', 'batch size'))
    val_num_workers = int(configuration.get('valid loader', 'num workers'))

    # Initialize valid dataloader
    valid_dataloader_module = configuration.get('valid loader', 'module')
    valid_dataloader_name = configuration.get('valid loader', 'name')
    val_skip = int(configuration.get('valid loader', 'skip'))
    val_split = configuration.get('valid loader', 'split')
    valid_dataset = instantiate(valid_dataloader_module, valid_dataloader_name)

    # Obtaining train & valid dataloaders
    train_dataset = train_dataset(split=train_split, skip=train_skip, flattened=False, transform=transform)
    valid_dataset = valid_dataset(split=val_split, skip=val_skip, flattened=False, transform=transform)

    train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=train_shuffle,
                                  num_workers=train_num_workers, pin_memory=pin_memory)
    val_dataloader = DataLoader(valid_dataset, batch_size=val_batch_size, shuffle=False,
                                num_workers=val_num_workers, pin_memory=pin_memory)
    test_dataloader = None
    if configuration.has_section('test dataloader'):
        # Get test dataloader parameters
        file_path = configuration.get('test loader', 'file path')
        transform = None
        batch_size = int(configuration.get('test loader', 'batch size'))
        shuffle = bool(configuration.get('test loader', 'shuffle'))
        num_workers = int(configuration.get('test loader', 'num_workers'))
        pin_memory = torch.cuda.is_available()

        # Initialize test dataloader
        test_dataloader_module = configuration.get('test loader', 'module')
        test_dataloader_name = configuration.get('test loader', 'name')
        test_dataset = instantiate(test_dataloader_module, test_dataloader_name)
        test_dataset = test_dataset(file_path, transform=transform)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle,
                                     num_workers=num_workers, pin_memory=pin_memory)

    # Setting up the environment
    metrics = {'train': train_metrics, 'val': val_metrics, 'cluster': cluster_metrics}

    experiment.train_and_validate(train_dataloader,
                                  val_dataloader,
                                  metrics,
                                  k)