예제 #1
0
class XPCtlReporting(EpochReportingHook):
    def __init__(self, **kwargs):
        super(XPCtlReporting, self).__init__(**kwargs)
        # throw exception if the next three can't be read from kwargs
        self.cred = read_config_file_or_json(kwargs['cred'])
        self.label = kwargs.get('label', None)
        self.exp_config = read_config_file_or_json(kwargs['config_file'])
        self.task = kwargs['task']
        self.print_fn = print
        self.username = kwargs.get('user', getpass.getuser())
        self.hostname = kwargs.get('host', socket.gethostname())
        self.checkpoint_base = None
        self.checkpoint_store = kwargs.get('checkpoint_store', '/data/model-checkpoints')
        self.save_model = kwargs.get('save_model', False) # optionally save the model

        self.repo = ExperimentRepo().create_repo(**self.cred)
        self.log = []

    def _step(self, metrics, tick, phase, tick_type, **kwargs):
        """Write intermediate results to a logging memory object that ll be pushed to the xpctl repo

        :param metrics: A map of metrics to scores
        :param tick: The time (resolution defined by `tick_type`)
        :param phase: The phase of training (`Train`, `Valid`, `Test`)
        :param tick_type: The resolution of tick (`STEP`, `EPOCH`)
        :return:
        """
        msg = {'tick_type': tick_type, 'tick': tick, 'phase': phase}
        for k, v in metrics.items():
            msg[k] = v
        self.log.append(msg)

    def done(self):
        """Write the log to the xpctl database"""
        if self.save_model:
           self.backend = self.exp_config.get('backend', 'default')
           backends = {'default': 'tf', 'tensorflow': 'tf', 'pytorch': 'pyt'}
           self.checkpoint_base = self._search_checkpoint_base(self.task, backends[self.backend])

        self.repo.put_result(self.task, self.exp_config, self.log,
                            checkpoint_base=self.checkpoint_base,
                            checkpoint_store=self.checkpoint_store,
                            print_fn=self.print_fn,
                            hostname=self.hostname,
                            username=self.username,
                            label=self.label)

    @staticmethod
    def _search_checkpoint_base(task, backend):
        """Finds if the checkpoint exists as a zip file or a bunch of files."""
        zip = "{}-model-{}-{}.zip".format(task, backend, os.getpid())
        non_zip = "{}-model-{}-{}".format(task, backend, os.getpid())
        print(zip)
        if os.path.exists(zip):
            return zip
        elif os.path.exists(".graph".format(non_zip)):
            return non_zip
        return None
예제 #2
0
    def __init__(self, **kwargs):
        super(XPCtlReporting, self).__init__(**kwargs)
        # throw exception if the next three can't be read from kwargs
        self.cred = read_config_file_or_json(kwargs['cred'])
        self.label = kwargs.get('label', None)
        self.exp_config = read_config_file_or_json(kwargs['config_file'])
        self.task = kwargs['task']
        self.print_fn = print
        self.username = kwargs.get('user', getpass.getuser())
        self.hostname = kwargs.get('host', socket.gethostname())
        self.checkpoint_base = None
        self.checkpoint_store = kwargs.get('checkpoint_store',
                                           '/data/model-checkpoints')
        self.save_model = kwargs.get('save_model',
                                     False)  # optionally save the model

        self.repo = ExperimentRepo().create_repo(**self.cred)
        self.log = []
예제 #3
0
파일: cli.py 프로젝트: ZhenyueQin/baseline
    def get():
        if RepoManager.central_repo is None:
            RepoManager.central_repo = ExperimentRepo.create_repo(
                RepoManager.dbtype, RepoManager.dbhost, RepoManager.dbport,
                RepoManager.dbuser, RepoManager.dbpass)

        if RepoManager.central_repo is not None:
            click.echo(
                "db {} connection successful with [host]: {}, [port]: {}".
                format(RepoManager.dbtype, RepoManager.dbhost,
                       RepoManager.dbport))
            return RepoManager.central_repo
        click.echo("db connection unsuccessful, aborting")
        sys.exit(1)
예제 #4
0
파일: report.py 프로젝트: dpressel/baseline
 def put_result(self, label):
     # Wait to create the experiment repo until after the fork
     if self.repo is None:
         try:
             self.repo = ExperimentRepo.create_repo(**self.xpctl_config)
         except Exception as e:
             return str(e)
     loc = os.path.join(label.exp, label.sha1, label.name)
     config_loc = os.path.join(loc, 'config.json')
     config = read_config_file(config_loc)
     task = config.get('task')
     log_loc = glob.glob(os.path.join(loc, 'reporting-*.log'))[0]
     logs = read_logs(log_loc)
     return str(self.repo.put_result(task, config, logs, print_fn=dummy_print, label=self.name))
예제 #5
0
파일: cli.py 프로젝트: dpressel/baseline
    def get():
        if RepoManager.central_repo is None:
            RepoManager.central_repo = ExperimentRepo.create_repo(RepoManager.dbhost,
                                                                  RepoManager.dbport,
                                                                  RepoManager.dbuser,
                                                                  RepoManager.dbpass,
                                                                  RepoManager.dbtype)

        if RepoManager.central_repo is not None:
            click.echo("db {} connection successful with [host]: {}, [port]: {}".format(RepoManager.dbtype,
                                                                                        RepoManager.dbhost,
                                                                                        RepoManager.dbport))
            return RepoManager.central_repo
        click.echo("db connection unsuccessful, aborting")
        sys.exit(1)
예제 #6
0
    def __init__(self, **kwargs):
        super(XPCtlReporting, self).__init__(**kwargs)
        # throw exception if the next three can't be read from kwargs
        self.cred = read_config_file_or_json(kwargs['cred'])
        self.label = kwargs.get('label', None)
        self.exp_config = read_config_file_or_json(kwargs['config_file'])
        self.task = kwargs['task']
        self.print_fn = print
        self.username = kwargs.get('user', getpass.getuser())
        self.hostname = kwargs.get('host', socket.gethostname())
        self.checkpoint_base = None
        self.checkpoint_store = kwargs.get('checkpoint_store', '/data/model-checkpoints')
        self.save_model = kwargs.get('save_model', False) # optionally save the model

        self.repo = ExperimentRepo().create_repo(**self.cred)
        self.log = []
예제 #7
0
파일: report.py 프로젝트: wxiaopei/baseline
 def put_result(self, label):
     # Wait to create the experiment repo until after the fork
     if self.repo is None:
         try:
             self.repo = ExperimentRepo.create_repo(**self.xpctl_config)
         except Exception as e:
             return str(e)
     loc = os.path.join(label.exp, label.sha1, label.name)
     config_loc = os.path.join(loc, 'config.json')
     config = read_config_file(config_loc)
     task = config.get('task')
     log_loc = glob.glob(os.path.join(loc, 'reporting-*.log'))[0]
     logs = read_logs(log_loc)
     return str(
         self.repo.put_result(task,
                              config,
                              logs,
                              print_fn=dummy_print,
                              label=self.name))
예제 #8
0
class XPCtlReporting(EpochReportingHook):
    def __init__(self, **kwargs):
        super(XPCtlReporting, self).__init__(**kwargs)
        # throw exception if the next three can't be read from kwargs
        self.cred = read_config_file_or_json(kwargs['cred'])
        self.label = kwargs.get('label', None)
        self.exp_config = read_config_file_or_json(kwargs['config_file'])
        self.task = kwargs['task']
        self.print_fn = print
        self.username = kwargs.get('user', getpass.getuser())
        self.hostname = kwargs.get('host', socket.gethostname())
        self.checkpoint_base = None
        self.checkpoint_store = kwargs.get('checkpoint_store',
                                           '/data/model-checkpoints')
        self.save_model = kwargs.get('save_model',
                                     False)  # optionally save the model

        self.repo = ExperimentRepo().create_repo(**self.cred)
        self.log = []

    def _step(self, metrics, tick, phase, tick_type, **kwargs):
        """Write intermediate results to a logging memory object that ll be pushed to the xpctl repo

        :param metrics: A map of metrics to scores
        :param tick: The time (resolution defined by `tick_type`)
        :param phase: The phase of training (`Train`, `Valid`, `Test`)
        :param tick_type: The resolution of tick (`STEP`, `EPOCH`)
        :return:
        """
        msg = {'tick_type': tick_type, 'tick': tick, 'phase': phase}
        for k, v in metrics.items():
            msg[k] = v
        self.log.append(msg)

    def done(self):
        """Write the log to the xpctl database"""
        if self.save_model:
            self.backend = self.exp_config.get('backend', 'default')
            backends = {'default': 'tf', 'tensorflow': 'tf', 'pytorch': 'pyt'}
            self.checkpoint_base = self._search_checkpoint_base(
                self.task, backends[self.backend])

        self.repo.put_result(self.task,
                             self.exp_config,
                             self.log,
                             checkpoint_base=self.checkpoint_base,
                             checkpoint_store=self.checkpoint_store,
                             print_fn=self.print_fn,
                             hostname=self.hostname,
                             username=self.username,
                             label=self.label)

    @staticmethod
    def _search_checkpoint_base(task, backend):
        """Finds if the checkpoint exists as a zip file or a bunch of files."""
        zip = "{}-model-{}-{}.zip".format(task, backend, os.getpid())
        non_zip = "{}-model-{}-{}".format(task, backend, os.getpid())
        print(zip)
        if os.path.exists(zip):
            return zip
        elif os.path.exists(".graph".format(non_zip)):
            return non_zip
        return None