コード例 #1
0
ファイル: k8s.py プロジェクト: cruelfate/odin
    def from_dict(cls, dict_value: Dict) -> 'Job':
        """Create a `Job` from some dict read from JSON/YAML

        :param dict_value: a Dictionary
        :returns: A job instance based on data inside the dictionary.
        """
        mounts = dict_value.get('mount', dict_value.get('mounts'))
        mounts = [
            Volume(m['path'], m['name'], m['claim']) for m in listify(mounts)
        ] if mounts is not None else None
        secrets = dict_value.get('secret', dict_value.get('secrets'))
        secrets = [populate_secret(s)
                   for s in listify(secrets)] if secrets is not None else None
        config_maps = dict_value.get('config_map',
                                     dict_value.get('config_maps'))
        config_maps = [populate_config_map(cm) for cm in listify(config_maps)
                       ] if config_maps is not None else None

        return Task(
            dict_value['name'],
            dict_value['image'],
            dict_value['command'],
            dict_value.get('args', []),
            mounts,
            secrets,
            config_maps,
            dict_value.get('num_gpus', 0),
            dict_value.get('pull_policy', 'IfNotPresent'),
            dict_value.get('node_selector'),
            dict_value.get('resource_type', "Pod"),
            dict_value.get('num_workers', 1),
            dict_value.get('inputs'),
            dict_value.get('outputs'),
        )
コード例 #2
0
 def list_results(self, task, param_dict, user, metric, sort, event_type):
     if event_type is None or event_type == 'None':
         event_type = 'test_events'
     metrics = [x for x in listify(metric) if x.strip()]
     users = [x for x in listify(user) if x.strip()]
     if users:
         param_dict.update({'username': users})
     coll = self.db[task]
     query = self._update_query({}, **param_dict)
     all_results = list(coll.find(query))
     if not all_results:
         return BackendError(message='no information available for {} in task database [{}]'
                             .format(param_dict, task))
     experiments = mongo_to_experiment_set(task, all_results, event_type=event_type, metrics=metrics)
     if type(experiments) == BackendError:
         return experiments
     if sort is None or sort == 'None':
         return experiments
     else:
         if event_type == 'test_events':
             if sort in METRICS_SORT_ASCENDING:
                 return experiments.sort(sort, reverse=False)
             else:
                 return experiments.sort(sort)
         else:
             return BackendError(message='experiments can only be sorted when event_type=test_events')
コード例 #3
0
 def list_results(self, task, param_dict, user, metric, sort, event_type):
     session = self.Session()
     data_experiments = []
     if event_type is None or event_type == 'None':
         event_type = 'test_events'
     metrics = [x for x in listify(metric) if x.strip()]
     users = [x for x in listify(user) if x.strip()]
     if not param_dict:
         hits = session.query(SqlExperiment).filter(SqlExperiment.task == task)
     else:
         hits = session.query(SqlExperiment).filter(SqlExperiment.task == task)
         for prop, value in param_dict.items():
             hits = hits.filter(getattr(SqlExperiment, prop) == value)
     if users:
         hits = hits.filter(SqlExperiment.username.in_(users))
     if hits.first() is None:
         return BackendError('No results in {} database for {}'.format(task, param_dict))
     for exp in hits:
         data_experiment = self.sql_result_to_data_experiment(exp, event_type, metrics)
         if type(data_experiment) is BackendError:
             return data_experiment
         else:
             data_experiments.append(data_experiment)
     experiment_set = self.get_data_experiment_set(data_experiments)
     if sort is None or sort == 'None':
         return experiment_set
     else:
         if event_type == 'test_events':
             if sort in METRICS_SORT_ASCENDING:
                 return experiment_set.sort(sort, reverse=False)
             else:
                 return experiment_set.sort(sort)
         else:
             return BackendError(message='experiments can only be sorted when event_type=test_events')
コード例 #4
0
 def experiment_details(self, user, metric, sort, task, event_type, sha1,
                        n):
     session = self.Session()
     results = []
     metrics = listify(metric)
     users = listify(user)
     metrics_to_add = [metrics[0]] if len(metrics) == 1 else []
     phase = self.event2phase(event_type)
     hits = session.query(Experiment).filter(Experiment.sha1 == sha1). \
         filter(Experiment.task == task)
     for exp in hits:
         for event in exp.events:
             if event.phase == phase:
                 result = [
                     exp.id, exp.username, exp.label, exp.dataset, exp.sha1,
                     exp.date
                 ]
                 for m in self._get_filtered_metrics(
                         event.metrics, metrics):
                     result += [m.value]
                     if m.label not in metrics_to_add:
                         metrics_to_add += [m.label]
                 results.append(result)
     cols = ['id', 'username', 'label', 'dataset', 'sha1', 'date'
             ] + metrics_to_add
     result_frame = pd.DataFrame(results, columns=cols)
     return df_experimental_details(result_frame, sha1, users, sort, metric,
                                    n)
コード例 #5
0
 def __init__(self, **kwargs):
     super().__init__(**kwargs)
     self.field = kwargs.get('fields', kwargs.get('field', 'text'))
     self.label = kwargs.get('label', 'label')
     self.emit_begin_toks = listify(
         kwargs.get('emit_begin_tok', [Offsets.VALUES[Offsets.PAD]]))
     self.emit_end_toks = listify(
         kwargs.get('emit_end_tok', [Offsets.VALUES[Offsets.PAD]]))
コード例 #6
0
ファイル: backend.py プロジェクト: dpressel/baseline
 def experiment_details(self, user, metric, sort, task, event_type, sha1, n):
     metrics = listify(metric)
     coll = self.db[task]
     users = listify(user)
     query = self._update_query({}, username=users, sha1=sha1)
     projection = self._update_projection(event_type=event_type)
     result_frame = self._generate_data_frame(coll, metrics=metrics, query=query, projection=projection, event_type=event_type)
     return df_experimental_details(result_frame, sha1, users, sort, metric, n)
コード例 #7
0
 def __init__(self, **kwargs):
     super().__init__(**kwargs)
     self.field = kwargs.get('fields', kwargs.get('field'))
     self.primary_feature = kwargs.get('primary_feature', 'text')
     self.emit_begin_toks = listify(
         kwargs.get('emit_begin_tok', [Offsets.VALUES[Offsets.PAD]]))
     self.emit_end_toks = listify(
         kwargs.get('emit_end_tok', [Offsets.VALUES[Offsets.PAD]]))
     self.apply_all_subwords = kwargs.get('apply_all_subwords', True)
コード例 #8
0
 def __init__(self, **kwargs):
     super().__init__(kwargs.get('transform_fn'))
     self.max_seen = 128
     self.tokenizer = WordpieceTokenizer(
         self.read_vocab(kwargs.get('vocab_file')))
     self.mxlen = kwargs.get('mxlen', -1)
     self.dtype = kwargs.get('dtype', 'int')
     self._special_tokens = {"[CLS]", "<unk>", "<EOS>"}
     self.emit_begin_toks = listify(kwargs.get('emit_begin_tok', ['[CLS]']))
     self.emit_end_toks = listify(kwargs.get('emit_end_tok', ['[SEP]']))
コード例 #9
0
ファイル: core.py プロジェクト: wenshuoliu/odin
def create_graph(  # pylint: disable=too-many-nested-blocks,too-many-branches
        task_list: List[Dict],
        external_inputs: Dict = {}) -> Graph:
    """Convert task list into a graph.

    There is a link between a chores if a chore references another in
    any of it's inputs or if the chore is listed in a key called `depends`
    (`depends` is used for forced control flow when not a explicit input dep)

    :param task_list: A list of dictionaries describing tasks
    :param external_inputs: A dictionary of external inputs
    :raises ValueError: If a task name contains a `.` or there is a phantom dependency
    :returns: A DAG
    """
    graph: Graph = defaultdict(set)
    name2idx = {
        task.get('_name', task.get('name')): i
        for i, task in enumerate(task_list)
    }
    idx2name = {i: k for k, i in name2idx.items()}
    for name in name2idx:
        if '.' in name:
            raise ValueError(f"Names cannot contain `.` found {name} ")

    for dst, task in enumerate(task_list):
        if DEPENDENCY_KEY in task:
            for src in listify(task[DEPENDENCY_KEY]):
                src = src[1:] if src.startswith('^') else src
                if src not in name2idx:
                    raise ValueError(
                        f"Dependency `{src}` of node `{idx2name[dst]}` not found in graph."
                    )
                graph[name2idx[src]].add(dst)
        for values in task.values():
            values = listify(values)
            for value in values:
                if is_reference(value):
                    lookups = parse_reference(value)
                    src = lookups[0]
                    if src not in external_inputs:
                        if src not in name2idx:
                            raise ValueError(
                                f"Dependency `{src}` of node `{idx2name[dst]}` not found in graph."
                            )
                        graph[name2idx[src]].add(dst)
                    else:
                        LOGGER.info(
                            "No dependency required in this graph from %s to %s",
                            dst, src)
    for dst in range(len(task_list)):
        if dst not in graph:
            graph[dst] = set()
    return graph
コード例 #10
0
ファイル: reader.py プロジェクト: DevSinghSachan/baseline
    def build_vocab(self, files, **kwargs):
        """Take a directory (as a string), or an array of files and build a vocabulary

        Take in a directory or an array of individual files (as a list).  If the argument is
        a string, it may be a directory, in which case, all files in the directory will be loaded
        to form a vocabulary.

        :param files: Either a directory (str), or an array of individual files
        :return:
        """
        vocab_file = kwargs.get('vocab_file')
        label_file = kwargs.get('label_file')
        if vocab_file is not None and label_file is not None:
            _vocab_allowed(self.vectorizers)
            vocab = _build_vocab_for_col(0, listify(vocab_file),
                                         self.vectorizers)
            labels = Counter(chain(*_read_from_col(0, listify(label_file))))
            self.label2index = {l: i for i, l in enumerate(labels)}
            return vocab, self.get_labels()

        label_idx = len(self.label2index)
        if isinstance(files, six.string_types):
            if os.path.isdir(files):
                base = files
                files = filter(
                    os.path.isfile,
                    [os.path.join(base, x) for x in os.listdir(base)])
            else:
                files = [files]
        vocab = {k: Counter() for k in self.vectorizers.keys()}

        for file_name in files:
            if file_name is None:
                continue
            with codecs.open(file_name, encoding='utf-8', mode='r') as f:
                for il, line in enumerate(f):
                    label, text = TSVSeqLabelReader.label_and_sentence(
                        line, self.clean_fn)
                    if len(text) == 0:
                        continue

                    for k, vectorizer in self.vectorizers.items():
                        vocab_file = vectorizer.count(text)
                        vocab[k].update(vocab_file)

                    if label not in self.label2index:
                        self.label2index[label] = label_idx
                        label_idx += 1

        vocab = _filter_vocab(vocab, kwargs.get('min_f', {}))

        return vocab, self.get_labels()
コード例 #11
0
 def experiment_details(self, user, metric, sort, task, event_type, sha1,
                        n):
     metrics = listify(metric)
     coll = self.db[task]
     users = listify(user)
     query = self._update_query({}, username=users, sha1=sha1)
     projection = self._update_projection(event_type=event_type)
     result_frame = self._generate_data_frame(coll,
                                              metrics=metrics,
                                              query=query,
                                              projection=projection,
                                              event_type=event_type)
     return df_experimental_details(result_frame, sha1, users, sort, metric,
                                    n)
コード例 #12
0
 def __init__(self, **kwargs):
     """Loads a BPE tokenizer"""
     super().__init__(kwargs.get('transform_fn'))
     self.max_seen = 128
     self.model_file = kwargs.get('model_file')
     self.vocab_file = kwargs.get('vocab_file')
     self.tokenizer = SavableFastBPE(self.model_file, self.vocab_file)
     self.mxlen = kwargs.get('mxlen', -1)
     self._vocab = {
         k: i
         for i, k in enumerate(self.read_vocab(self.vocab_file))
     }
     self.emit_begin_toks = listify(kwargs.get('emit_begin_tok', []))
     self.emit_end_toks = listify(kwargs.get('emit_end_tok', []))
     self._special_tokens = {"[CLS]", "<unk>", "<EOS>"}
コード例 #13
0
ファイル: reader.py プロジェクト: dpressel/baseline
    def build_vocab(self, files, **kwargs):
        """Take a directory (as a string), or an array of files and build a vocabulary

        Take in a directory or an array of individual files (as a list).  If the argument is
        a string, it may be a directory, in which case, all files in the directory will be loaded
        to form a vocabulary.

        :param files: Either a directory (str), or an array of individual files
        :return:
        """
        vocab_file = kwargs.get('vocab_file')
        label_file = kwargs.get('label_file')
        if vocab_file is not None and label_file is not None:
            _vocab_allowed(self.vectorizers)
            vocab = _build_vocab_for_col(0, listify(vocab_file), self.vectorizers)
            labels = Counter(chain(*_read_from_col(0, listify(label_file))))
            self.label2index = {l: i for i, l in enumerate(labels)}
            return vocab, self.get_labels()

        label_idx = len(self.label2index)
        if isinstance(files, six.string_types):
            if os.path.isdir(files):
                base = files
                files = filter(os.path.isfile, [os.path.join(base, x) for x in os.listdir(base)])
            else:
                files = [files]
        vocab = {k: Counter() for k in self.vectorizers.keys()}

        for file_name in files:
            if file_name is None:
                continue
            with codecs.open(file_name, encoding='utf-8', mode='r') as f:
                for il, line in enumerate(f):
                    label, text = TSVSeqLabelReader.label_and_sentence(line, self.clean_fn)
                    if len(text) == 0:
                        continue

                    for k, vectorizer in self.vectorizers.items():
                        vocab_file = vectorizer.count(text)
                        vocab[k].update(vocab_file)

                    if label not in self.label2index:
                        self.label2index[label] = label_idx
                        label_idx += 1

        vocab = _filter_vocab(vocab, kwargs.get('min_f', {}))

        return vocab, self.get_labels()
コード例 #14
0
def generate_export_task(  # pylint: disable=too-many-locals
    template: Dict,
    odin_image: str,
    claim: str,
    models: List[str],
    task: str,
    dataset_name: str,
    depends: List[str],
    metric: str = 'acc',
    export_policy: Optional[str] = None,
    pull_policy: str = ALWAYS,
) -> Dict:
    """Generate a task (an element in the task list for the pipeline) that is a export job.

    :param template: The skeleton version of the export task.
    :param odin_image: The name of the odin docker image to use.
    :param claim: The name of the pvc to use.
    :param models: A list of models to evaluate.
    :param task: The name of the task that was used in training.
    :param dataset_name: The name of the dataset used to train on.
    :param depends: The tasks that need to be run before this one
    :param metric: The metric to use when comparing models.
    :param export_policy: How to make a decision if we should export things.
    :param pull_policy: Should k8s repull your containers.

    :returns:
        Dict, The task the represents a export job to run in the odin pipeline.
    """
    if not export_policy:
        return {}
    depends = listify(depends)
    template = deepcopy(template)
    template['image'] = odin_image
    template['mounts'][0]['claim'] = claim
    template['pull_policy'] = pull_policy

    # For each model we trained add it to `args` just after `--models`
    models_idx = template['args'].index('--models') + 1
    template['args'].pop(models_idx)
    for model in (f"${{PIPE_ID}}--{name}" for name in models):
        template['args'].insert(models_idx, model)

    template['args'][template['args'].index('--task') + 1] = task
    template['args'][template['args'].index('--type') + 1] = export_policy
    template['args'][template['args'].index('--metric') + 1] = metric

    template['depends'] = deepcopy(listify(depends))
    return template
コード例 #15
0
ファイル: model.py プロジェクト: baiyuang/baseline
    def stacked(self, pooled, init, **kwargs):
        """Stack 1 or more hidden layers, optionally (forming an MLP)

        :param pooled: The fixed representation of the model
        :param init: The tensorflow initializer
        :param kwargs: See below

        :Keyword Arguments:
        * *hsz* -- (``int``) The number of hidden units (defaults to `100`)

        :return: The final layer
        """

        hszs = listify(kwargs.get('hsz', []))
        if len(hszs) == 0:
            return pooled

        in_layer = pooled
        for i, hsz in enumerate(hszs):
            with tf.variable_scope('fc-{}'.format(i)):
                with tf.contrib.slim.arg_scope([fully_connected],
                                               weights_initializer=init):
                    fc = fully_connected(in_layer,
                                         hsz,
                                         activation_fn=tf.nn.relu)
                    in_layer = tf.nn.dropout(fc, self.pkeep)
        return in_layer
コード例 #16
0
    def build_vocab(self, files, **kwargs):
        label_idx = len(self.label2index)
        files = listify(files)
        vocab = {k: Counter() for k in self.vectorizers.keys()}
        for file_name in files:
            if file_name is None: continue
            with codecs.open(file_name + self.data, encoding='utf-8', mode='r') as data_file:
                with codecs.open(file_name + self.labels, encoding='utf-8', mode='r') as label_file:
                    for d, l in zip(data_file, label_file):
                        if d.strip() == "": continue
                        label = l.rstrip()
                        text = ParallelSeqLabelReader.get_sentence(d, self.clean_fn)
                        if len(text) == 0: continue

                        for k, vectorizer in self.vectorizers.items():
                            vocab_file = vectorizer.count(text)
                            vocab[k].update(vocab_file)

                        if label not in self.label2index:
                            self.label2index[label] = label_idx
                            label_idx += 1

        vocab = _filter_vocab(vocab, kwargs.get('min_f', {}))

        return vocab, self.get_labels()
コード例 #17
0
ファイル: model.py プロジェクト: dpressel/baseline
def register_model(cls, task, name=None):
    """Register a function as a plug-in"""
    if name is None:
        name = cls.__name__

    names = listify(name)

    if task not in BASELINE_MODELS:
        BASELINE_MODELS[task] = {}

    if task not in BASELINE_LOADERS:
        BASELINE_LOADERS[task] = {}

    if hasattr(cls, 'create'):
        def create(*args, **kwargs):
            return cls.create(*args, **kwargs)
    else:
        def create(*args, **kwargs):
            return cls(*args, **kwargs)

    for alias in names:
        if alias in BASELINE_MODELS[task]:
            raise Exception('Error: attempt to re-define previously registered handler {} (old: {}, new: {}) for task {} in registry'.format(alias, BASELINE_MODELS[task], cls, task))

        BASELINE_MODELS[task][alias] = create

        if hasattr(cls, 'load'):
            BASELINE_LOADERS[task][alias] = cls.load
    return cls
コード例 #18
0
ファイル: backend.py プロジェクト: dpressel/baseline
 def get_results(self, task, dataset, event_type, num_exps=None, num_exps_per_config=None, metric=None, sort=None, id=None, label=None):
     session = self.Session()
     results = []
     metrics = listify(metric)
     metrics_to_add = [metrics[0]] if len(metrics) == 1 else []
     phase = self.event2phase(event_type)
     if id is not None:
         hit = session.query(Experiment).get(id)
         if hit is None:
             return None
         hits = [hit]
     elif label is not None:
         hits = session.query(Experiment).filter(Experiment.label == label). \
             filter(Experiment.dataset == dataset). \
             filter(Experiment.task == task)
     else:
         hits = session.query(Experiment).filter(Experiment.dataset == dataset). \
             filter(Experiment.task == task)
     for exp in hits:
         for event in exp.events:
             if event.phase == phase:
                 result = [exp.id, exp.username, exp.label, exp.dataset, exp.sha1, exp.date]
                 for m in self._get_filtered_metrics(event.metrics, metrics):
                     result += [m.value]
                     if m.label not in metrics_to_add:
                         metrics_to_add += [m.label]
                 results.append(result)
     cols = ['id', 'username', 'label', 'dataset', 'sha1', 'date'] + metrics_to_add
     result_frame = pd.DataFrame(results, columns=cols)
     if not result_frame.empty:
         return df_get_results(result_frame, dataset, num_exps, num_exps_per_config, metric, sort)
     return None
コード例 #19
0
ファイル: helpers.py プロジェクト: dpressel/baseline
def df_get_results(result_frame, dataset, num_exps, num_exps_per_config, metric, sort):
    datasets = result_frame.dataset.unique()
    if dataset not in datasets:
        return None
    dsr = result_frame[result_frame.dataset == dataset]
    if dsr.empty:
        return None
    df = pd.DataFrame()
    if num_exps_per_config is not None:
        for gname, rframe in result_frame.groupby("sha1"):
            rframe = rframe.copy()
            rframe['date'] =pd.to_datetime(rframe.date)
            rframe = rframe.sort_values(by='date', ascending=False).head(int(num_exps_per_config))
            df = df.append(rframe)
        result_frame = df

    result_frame = result_frame.drop(["id"], axis=1)
    result_frame = result_frame.groupby("sha1").agg([len, np.mean, np.std, np.min, np.max])\
        .rename(columns={'len': 'num_exps', 'amean': 'mean', 'amin': 'min', 'amax': 'max'})
    metrics = listify(metric)
    if len(metrics) == 1:
        result_frame = result_frame.sort_values([(metrics[0], 'mean')], ascending=sort_ascending(metric))
    if sort:
        result_frame = result_frame.sort_values([(sort, 'mean')], ascending=sort_ascending(metric))
    if result_frame.empty:
        return None
    if num_exps is not None:
        result_frame = result_frame.head(num_exps)
    return result_frame
コード例 #20
0
    def task_summary(self, task, dataset, metric, event_type):
        metrics = listify(metric)

        coll = self.db[task]
        query = self._update_query({}, [], dataset)
        projection = self._update_projection(event_type=event_type)
        result_frame = self._generate_data_frame(coll,
                                                 metrics,
                                                 query,
                                                 projection,
                                                 event_type=event_type)
        if not result_frame.empty:
            datasets = result_frame.dataset.unique()
            if dataset not in datasets:
                return None
            dsr = result_frame[result_frame.dataset == dataset].sort_values(
                metric, ascending=False)
            result = dsr[metric].iloc[0]
            user = dsr.username.iloc[0]
            sha1 = dsr.sha1.iloc[0]
            date = dsr.date.iloc[0]
            summary = "For dataset {}, the best {} is {:0.3f} reported by {} on {}. " \
                      "The sha1 for the config file is {}.".format(dataset, metric, result, user, date, sha1)
            return summary

        return None
コード例 #21
0
 def get_results(self, task, param_dict, reduction_dim, metric, sort, numexp_reduction_dim, event_type):
     metrics = [x for x in listify(metric)]
     if event_type is None or event_type == 'None':
         event_type = 'test_events'
     reduction_dim = reduction_dim if reduction_dim is not None else 'sha1'
     coll = self.db[task]
     if 'dataset' in param_dict.keys():
         value = self.get_related_datasets(task, param_dict['dataset'])
         param_dict['dataset'] = value
     query = self._update_query({}, **param_dict)
     all_results = list(coll.find(query))
     if not all_results:
         return BackendError(message='no information available for {} in task database [{}]'
                             .format(param_dict, task))
     resultset = mongo_to_experiment_set(task, all_results, event_type=event_type, metrics=metrics)
     if type(resultset) is BackendError:
         return resultset
     experiment_aggregate_set = aggregate_results(resultset, reduction_dim, event_type, numexp_reduction_dim,
                                                  param_dict)
     if sort is None or sort == 'None':
         return experiment_aggregate_set
     else:
         if event_type == 'test_events':
             if sort in METRICS_SORT_ASCENDING:
                 return experiment_aggregate_set.sort(sort, reverse=False)
             else:
                 return experiment_aggregate_set.sort(sort)
         else:
             return BackendError(message='experiments can only be sorted when event_type=test_events')
コード例 #22
0
    def build_vocab(self, files, **kwargs):
        label_idx = len(self.label2index)
        files = listify(files)
        vocab = {k: Counter() for k in self.vectorizers.keys()}
        for file_name in files:
            if file_name is None: continue
            with codecs.open(file_name + self.data, encoding='utf-8',
                             mode='r') as data_file:
                with codecs.open(file_name + self.labels,
                                 encoding='utf-8',
                                 mode='r') as label_file:
                    for d, l in zip(data_file, label_file):
                        if d.strip() == "": continue
                        label = l.rstrip()
                        text = ParallelSeqLabelReader.get_sentence(
                            d, self.clean_fn)
                        if len(text) == 0: continue

                        for k, vectorizer in self.vectorizers.items():
                            vocab_file = vectorizer.count(text)
                            vocab[k].update(vocab_file)

                        if label not in self.label2index:
                            self.label2index[label] = label_idx
                            label_idx += 1

        vocab = _filter_vocab(vocab, kwargs.get('min_f', {}))

        return vocab, self.get_labels()
コード例 #23
0
def generate_chore_task(template: Dict,
                        odin_image: str,
                        claim: str,
                        depends: Optional[Union[str, List[str]]],
                        pull_policy: str = ALWAYS) -> Dict:
    """Generate a task (an element in the task list in main.yaml) that is a chore job.

    :param template: The skeleton of a chore task to fill in.
    :param odin_image: The name of the odin docker image to use.
    :param claim: The name of the pvc to use.
    :param depends: Tasks that need to run before this one.
    :param pull_policy: Should k8s repull your containers.

    :returns:
        Dict, The task that will run a single chore file in an odin pipeline.
    """
    template = deepcopy(template)

    template['image'] = odin_image
    template['mount']['claim'] = claim
    template['pull_policy'] = pull_policy

    if depends:
        template['depends'] = deepcopy(listify(depends))

    return template
コード例 #24
0
ファイル: reader.py プロジェクト: sagnik/baseline
    def build_vocab(self, files, **kwargs):
        if _all_predefined_vocabs(self.vectorizers):
            logger.info("Skipping building vocabulary.  All vectorizers have predefined vocabs!")
            return {k: v.vocab for k, v in self.vectorizers.items()}

        vocab_file = kwargs.get('vocab_file')
        if vocab_file is not None:
            _vocab_allowed(self.vectorizers)
            return _build_vocab_for_col(0, listify(vocab_file), self.vectorizers)

        vocabs = {k: Counter() for k in self.vectorizers.keys()}

        for file in files:
            if file is None:
                continue

            with codecs.open(file, encoding='utf-8', mode='r') as f:
                sentences = []
                for line in f:
                    sentences += line.split() + ['<EOS>']
                for k, vectorizer in self.vectorizers.items():
                    vocabs[k].update(vectorizer.count(sentences))

        vocabs = _filter_vocab(vocabs, kwargs.get('min_f', {}))
        return vocabs
コード例 #25
0
 def get_results(self,
                 task,
                 dataset,
                 event_type,
                 num_exps=None,
                 num_exps_per_config=None,
                 metric=None,
                 sort=None):
     session = self.Session()
     results = []
     metrics = listify(metric)
     metrics_to_add = [metrics[0]] if len(metrics) == 1 else []
     phase = self.event2phase(event_type)
     hits = session.query(Experiment).filter(Experiment.dataset == dataset). \
         filter(Experiment.task == task)
     for exp in hits:
         for event in exp.events:
             if event.phase == phase:
                 result = [
                     exp.id, exp.username, exp.label, exp.dataset, exp.sha1,
                     exp.date
                 ]
                 for m in self._get_filtered_metrics(
                         event.metrics, metrics):
                     result += [m.value]
                     if m.label not in metrics_to_add:
                         metrics_to_add += [m.label]
                 results.append(result)
     cols = ['id', 'username', 'label', 'dataset', 'sha1', 'date'
             ] + metrics_to_add
     result_frame = pd.DataFrame(results, columns=cols)
     if not result_frame.empty:
         return df_get_results(result_frame, dataset, num_exps,
                               num_exps_per_config, metric, sort)
     return None
コード例 #26
0
ファイル: k8s.py プロジェクト: cruelfate/odin
    def _generate_configmaps(self, task: Task) -> Optional[List[ConfigMap]]:
        """Generate configmaps based on the requirements of the job.

        Eventually we can support custom configmaps by having the job create
        configmaps from the yaml config. Then this function will combine configmaps
        on the job with these injected configmaps to yield the final full list.

        :param task: The job we are running and want to add configmaps too.
        :type task: Task
        :returns: A list of configmaps or `None`
        :rtype: Optional[List[ConfigMap]]
        """
        configmaps = task.config_maps if task.config_maps is not None else []
        command = listify(task.command)
        if command[0].startswith('odin-chores'):
            try:
                # Check that the ssh-config configmap exists
                _ = self.core_api.read_namespaced_config_map(
                    name='ssh-config', namespace=self.namespace)
                # Inject an ssh_config that will use the ssh key we inject with a secret
                ssh_config = ConfigMap('/etc/ssh/ssh_config', 'ssh-config',
                                       'ssh_config')
                # Inject a known hosts file so it can find our gitlab server.
                known_hosts = ConfigMap('/etc/ssh/ssh_known_hosts',
                                        'ssh-config', 'known_hosts')
                configmaps.extend((ssh_config, known_hosts))
            except client.rest.ApiException:
                pass
        return configmaps if configmaps else None
コード例 #27
0
ファイル: reader.py プロジェクト: DevSinghSachan/baseline
 def build_vocabs(self, files, **kwargs):
     vocab_file = kwargs.get('vocab_file')
     if vocab_file is not None:
         all_vects = self.src_vectorizers.copy()
         all_vects['tgt'] = self.tgt_vectorizer
         _vocab_allowed(all_vects)
         # Only read the file once.
         text = _read_from_col(0, listify(vocab_file))
         src_vocab = _build_vocab_for_col(None,
                                          None,
                                          self.src_vectorizers,
                                          text=text)
         tgt_vocab = _build_vocab_for_col(None,
                                          None,
                                          {'tgt': self.tgt_vectorizer},
                                          text=text)
         return src_vocab, tgt_vocab['tgt']
     src_vocab = _build_vocab_for_col(0,
                                      [f + self.src_suffix for f in files],
                                      self.src_vectorizers)
     tgt_vocab = _build_vocab_for_col(0,
                                      [f + self.tgt_suffix for f in files],
                                      {'tgt': self.tgt_vectorizer})
     min_f = kwargs.get('min_f', {})
     tgt_min_f = {'tgt': min_f.pop('tgt', -1)}
     src_vocab = _filter_vocab(src_vocab, min_f)
     tgt_vocab = _filter_vocab(tgt_vocab, tgt_min_f)
     return src_vocab, tgt_vocab['tgt']
コード例 #28
0
ファイル: model.py プロジェクト: bcmi220/multilingual_srl
def register_model(cls, task, name=None):
    """Register a function as a plug-in"""
    if name is None:
        name = cls.__name__

    names = listify(name)

    if task not in BASELINE_MODELS:
        BASELINE_MODELS[task] = {}

    if task not in BASELINE_LOADERS:
        BASELINE_LOADERS[task] = {}

    if hasattr(cls, 'create'):

        def create(*args, **kwargs):
            return cls.create(*args, **kwargs)
    else:

        def create(*args, **kwargs):
            return cls(*args, **kwargs)

    for alias in names:
        if alias in BASELINE_MODELS[task]:
            raise Exception(
                'Error: attempt to re-define previously registered handler {} (old: {}, new: {}) for task {} in registry'
                .format(alias, BASELINE_MODELS[task], cls, task))

        BASELINE_MODELS[task][alias] = create

        if hasattr(cls, 'load'):
            BASELINE_LOADERS[task][alias] = cls.load
    return cls
コード例 #29
0
def fit(model, ts, vs, es, **kwargs):

    do_early_stopping = bool(kwargs.get('do_early_stopping', True))
    epochs = int(kwargs.get('epochs', 20))
    model_file = get_model_file('tagger', 'dynet', kwargs.get('basedir'))
    conll_output = kwargs.get('conll_output', None)
    txts = kwargs.get('txts', None)

    best_metric = 0
    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'acc')
        early_stopping_cmp, best_metric = get_metric_cmp(
            early_stopping_metric, kwargs.get('early_stopping_metric'))
        patience = kwargs.get('patience', epochs)
        logger.info('Doing early stopping on [%s] with patience [%d]',
                    early_stopping_metric, patience)

    reporting_fns = listify(kwargs.get('reporting', []))
    logger.info('reporting %s', reporting_fns)

    #validation_improvement_fn = kwargs.get('validation_improvement', None)

    after_train_fn = kwargs.get('after_train_fn', None)
    trainer = create_trainer(model, **kwargs)

    last_improved = 0
    for epoch in range(epochs):

        trainer.train(ts, reporting_fns)
        if after_train_fn is not None:
            after_train_fn(model)
        test_metrics = trainer.test(vs, reporting_fns, phase='Valid')

        if do_early_stopping is False:
            model.save(model_file)

        elif early_stopping_cmp(test_metrics[early_stopping_metric],
                                best_metric):
            last_improved = epoch
            best_metric = test_metrics[early_stopping_metric]
            logger.info('New max %.3f', best_metric)
            model.save(model_file)

        elif (epoch - last_improved) > patience:
            logger.info('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        logger.info('Best performance on %s: %.3f at epoch %d',
                    early_stopping_metric, best_metric, last_improved)

    if es is not None:
        logger.info('Reloading best checkpoint')
        model = model.load(model_file)
        trainer = create_trainer(model, **kwargs)
        trainer.test(es,
                     reporting_fns,
                     conll_output=conll_output,
                     txts=txts,
                     phase='Test')
コード例 #30
0
ファイル: model.py プロジェクト: kiennguyen94/baseline
    def stacked(self, pooled, init, **kwargs):
        """Stack 1 or more hidden layers, optionally (forming an MLP)

        :param pooled: The fixed representation of the model
        :param init: The tensorflow initializer
        :param kwargs: See below

        :Keyword Arguments:
        * *hsz* -- (``int``) The number of hidden units (defaults to `100`)

        :return: The final layer
        """

        hszs = listify(kwargs.get('hsz', []))
        if len(hszs) == 0:
            return pooled

        in_layer = pooled
        for i, hsz in enumerate(hszs):
            fc = tf.layers.dense(in_layer,
                                 hsz,
                                 activation=tf.nn.relu,
                                 kernel_initializer=init,
                                 name='fc-{}'.format(i))
            in_layer = tf.layers.dropout(fc,
                                         self.pdrop_value,
                                         training=TRAIN_FLAG(),
                                         name='fc-dropout-{}'.format(i))
        return in_layer
コード例 #31
0
def fit(model, ts, vs, es, **kwargs):
    epochs = int(kwargs['epochs']) if 'epochs' in kwargs else 5
    patience = int(kwargs['patience']) if 'patience' in kwargs else epochs
    conll_file = kwargs.get('conll_file', None)
    txts = kwargs.get('txts', None)
    model_file = kwargs['outfile'] if 'outfile' in kwargs and kwargs[
        'outfile'] is not None else './tagger-model-tf'
    after_train_fn = kwargs[
        'after_train_fn'] if 'after_train_fn' in kwargs else None
    trainer = TaggerTrainerTf(model, **kwargs)
    init = tf.global_variables_initializer()
    model.sess.run(init)
    saver = tf.train.Saver()
    model.save_using(saver)
    do_early_stopping = bool(kwargs.get('do_early_stopping', True))

    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'acc')
        patience = kwargs.get('patience', epochs)
        print('Doing early stopping on [%s] with patience [%d]' %
              (early_stopping_metric, patience))

    reporting_fns = listify(kwargs.get('reporting', basic_reporting))
    print('reporting', reporting_fns)

    max_metric = 0
    last_improved = 0
    for epoch in range(epochs):

        trainer.train(ts, reporting_fns)
        if after_train_fn is not None:
            after_train_fn(model)

        test_metrics = trainer.test(vs, reporting_fns, phase='Valid')

        if do_early_stopping is False:
            trainer.checkpoint()
            model.save(model_file)

        elif test_metrics[early_stopping_metric] > max_metric:
            last_improved = epoch
            max_metric = test_metrics[early_stopping_metric]
            print('New max %.3f' % max_metric)
            trainer.checkpoint()
            model.save(model_file)

        elif (epoch - last_improved) > patience:
            print('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        print('Best performance on max_metric %.3f at epoch %d' %
              (max_metric, last_improved))
    if es is not None:

        trainer.recover_last_checkpoint()
        evaluator = TaggerEvaluatorTf(model)
        test_metrics = evaluator.test(es, conll_file=conll_file, txts=txts)
        for reporting in reporting_fns:
            reporting(test_metrics, 0, 'Test')
コード例 #32
0
ファイル: tasks.py プロジェクト: minsukchang/baseline
    def _configure_reporting(self, reporting, config_file, **kwargs):
        """Configure all `reporting_hooks` specified in the mead settings or overridden at the command line

        :param reporting:
        :param kwargs:
        :return:
        """
        default_reporting = self.mead_settings_config.get(
            'reporting_hooks', {})
        # Add default reporting information to the reporting settings.
        for report_type in default_reporting:
            if report_type in reporting:
                for report_arg, report_val in default_reporting[
                        report_type].items():
                    if report_arg not in reporting[report_type]:
                        reporting[report_type][report_arg] = report_val
        reporting_hooks = list(reporting.keys())
        for settings in reporting.values():
            for module in listify(
                    settings.get('module', settings.get('modules', []))):
                import_user_module(module)

        self.reporting = baseline.create_reporting(
            reporting_hooks, reporting, {
                'config_file': config_file,
                'task': self.__class__.task_name(),
                'base_dir': self.get_basedir()
            })

        self.config_params['train']['reporting'] = [
            x.step for x in self.reporting
        ]
        logging.basicConfig(level=logging.DEBUG)
コード例 #33
0
ファイル: helpers.py プロジェクト: tanthml/baseline
def df_get_results(result_frame, dataset, num_exps, num_exps_per_config,
                   metric, sort):
    datasets = result_frame.dataset.unique()
    if dataset not in datasets:
        return None
    dsr = result_frame[result_frame.dataset == dataset]
    if dsr.empty:
        return None
    df = pd.DataFrame()
    if num_exps_per_config is not None:
        for gname, rframe in result_frame.groupby("sha1"):
            rframe = rframe.copy()
            rframe['date'] = pd.to_datetime(rframe.date)
            rframe = rframe.sort_values(by='date', ascending=False).head(
                int(num_exps_per_config))
            df = df.append(rframe)
        result_frame = df

    result_frame = result_frame.drop(["id"], axis=1)
    result_frame = result_frame.groupby("sha1").agg([len, np.mean, np.std, np.min, np.max])\
        .rename(columns={'len': 'num_exps', 'amean': 'mean', 'amin': 'min', 'amax': 'max'})
    metrics = listify(metric)
    if len(metrics) == 1:
        result_frame = result_frame.sort_values(
            [(metrics[0], 'mean')], ascending=sort_ascending(metric))
    if sort:
        result_frame = result_frame.sort_values(
            [(sort, 'mean')], ascending=sort_ascending(metric))
    if result_frame.empty:
        return None
    if num_exps is not None:
        result_frame = result_frame.head(num_exps)
    return result_frame
コード例 #34
0
 def build_vocabs(self, files, **kwargs):
     if _all_predefined_vocabs(self.src_vectorizers) and isinstance(
             self.tgt_vectorizer, HasPredefinedVocab):
         logger.info(
             "Skipping building vocabulary.  All vectorizers have predefined vocabs!"
         )
         return {k: v.vocab
                 for k, v in self.src_vectorizers.items()}, {
                     'tgt': self.tgt_vectorizer.vocab
                 }
     vocab_file = kwargs.get('vocab_file')
     if vocab_file is not None:
         all_vects = self.src_vectorizers.copy()
         all_vects['tgt'] = self.tgt_vectorizer
         _vocab_allowed(all_vects)
         # Only read the file once
         text = _read_from_col(0, listify(vocab_file))
         src_vocab = _build_vocab_for_col(None,
                                          None,
                                          self.src_vectorizers,
                                          text=text)
         tgt_vocab = _build_vocab_for_col(None,
                                          None,
                                          {'tgt': self.tgt_vectorizer},
                                          text=text)
         return src_vocab, tgt_vocab['tgt']
     src_vocab = _build_vocab_for_col(self.src_col_num, files,
                                      self.src_vectorizers)
     tgt_vocab = _build_vocab_for_col(self.tgt_col_num, files,
                                      {'tgt': self.tgt_vectorizer})
     min_f = kwargs.get('min_f', {})
     tgt_min_f = {'tgt': min_f.pop('tgt', -1)}
     src_vocab = _filter_vocab(src_vocab, min_f)
     tgt_vocab = _filter_vocab(tgt_vocab, tgt_min_f)
     return src_vocab, tgt_vocab['tgt']
コード例 #35
0
ファイル: train.py プロジェクト: switchfootsid/baseline
def fit(model, ts, vs, es, **kwargs):

    do_early_stopping = bool(kwargs.get('do_early_stopping', True))
    epochs = int(kwargs.get('epochs', 20))
    model_file = get_model_file(kwargs, 'tagger', 'pytorch')
    conll_output = kwargs.get('conll_output', None)
    txts = kwargs.get('txts', None)

    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'acc')
        patience = kwargs.get('patience', epochs)
        print('Doing early stopping on [%s] with patience [%d]' %
              (early_stopping_metric, patience))

    reporting_fns = listify(kwargs.get('reporting', basic_reporting))
    print('reporting', reporting_fns)

    #validation_improvement_fn = kwargs.get('validation_improvement', None)

    after_train_fn = kwargs.get('after_train_fn', None)
    trainer = create_trainer(TaggerTrainerPyTorch, model, **kwargs)

    last_improved = 0
    max_metric = 0
    for epoch in range(epochs):

        trainer.train(ts, reporting_fns)
        if after_train_fn is not None:
            after_train_fn(model)
        test_metrics = trainer.test(vs, reporting_fns, phase='Valid')

        if do_early_stopping is False:
            model.save(model_file)

        elif test_metrics[early_stopping_metric] > max_metric:
            #if validation_improvement_fn is not None:
            #    validation_improvement_fn(early_stopping_metric, test_metrics, epoch, max_metric, last_improved)
            last_improved = epoch
            max_metric = test_metrics[early_stopping_metric]
            print('New max %.3f' % max_metric)
            model.save(model_file)

        elif (epoch - last_improved) > patience:
            print('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        print('Best performance on max_metric %.3f at epoch %d' %
              (max_metric, last_improved))

    if es is not None:
        print('Reloading best checkpoint')
        model = torch.load(model_file)
        trainer = create_trainer(TaggerTrainerPyTorch, model, **kwargs)
        trainer.test(es,
                     reporting_fns,
                     conll_output=conll_output,
                     txts=txts,
                     phase='Test')
コード例 #36
0
ファイル: hash.py プロジェクト: wenshuoliu/odin
def hash_args(command: Union[str, List[str]], args: List[str]) -> str:
    """Hash the command and arguments of a container.

    :param command: The command that k8s will give the container.
    :param args: The arguments the container receives.
    :returns: The hash of the container args.
    """
    return sha1("".join(chain(listify(command), args)).encode('utf-8')).hexdigest()
コード例 #37
0
ファイル: model.py プロジェクト: dpressel/baseline
    def encode(self, embedseq, **kwargs):
        nlayers = kwargs.get('layers', 1)
        hsz = int(kwargs['hsz'])
        filts = kwargs.get('wfiltsz', None)
        if filts is None:
            filts = 5

        cnnout = stacked_cnn(embedseq, hsz, self.pdrop_value, nlayers, filts=listify(filts), training=TRAIN_FLAG())
        return cnnout
コード例 #38
0
ファイル: train.py プロジェクト: dpressel/baseline
def fit(model, ts, vs, es, **kwargs):

    do_early_stopping = bool(kwargs.get('do_early_stopping', True))
    epochs = int(kwargs.get('epochs', 20))
    model_file = get_model_file('tagger', 'pytorch', kwargs.get('basedir'))
    conll_output = kwargs.get('conll_output', None)
    txts = kwargs.get('txts', None)

    best_metric = 0
    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'acc')
        early_stopping_cmp, best_metric = get_metric_cmp(early_stopping_metric, kwargs.get('early_stopping_cmp'))
        patience = kwargs.get('patience', epochs)
        logger.info('Doing early stopping on [%s] with patience [%d]', early_stopping_metric, patience)

    reporting_fns = listify(kwargs.get('reporting', []))
    logger.info('reporting %s', reporting_fns)

    #validation_improvement_fn = kwargs.get('validation_improvement', None)

    after_train_fn = kwargs.get('after_train_fn', None)
    trainer = create_trainer(model, **kwargs)

    last_improved = 0
    for epoch in range(epochs):

        trainer.train(ts, reporting_fns)
        if after_train_fn is not None:
            after_train_fn(model)
        test_metrics = trainer.test(vs, reporting_fns, phase='Valid')

        if do_early_stopping is False:
            model.save(model_file)

        elif early_stopping_cmp(test_metrics[early_stopping_metric], best_metric):
            #if validation_improvement_fn is not None:
            #    validation_improvement_fn(early_stopping_metric, test_metrics, epoch, max_metric, last_improved)
            last_improved = epoch
            best_metric = test_metrics[early_stopping_metric]
            logger.info('New best %.3f', best_metric)
            model.save(model_file)


        elif (epoch - last_improved) > patience:
            logger.info('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        logger.info('Best performance on %s: %.3f at epoch %d', early_stopping_metric, best_metric, last_improved)

    if es is not None:
        logger.info('Reloading best checkpoint')
        model = torch.load(model_file)
        trainer = create_trainer(model, **kwargs)
        test_metrics = trainer.test(es, reporting_fns, conll_output=conll_output, txts=txts, phase='Test')
    return test_metrics
コード例 #39
0
ファイル: remote.py プロジェクト: dpressel/baseline
def register_remote(cls, name=None):
    """Register a class as a plug-in"""
    if name is None:
        name = cls.__name__
    names = listify(name)
    for alias in names:
        if alias in BASELINE_REMOTES:
            raise Exception('Error: attempt to re-define previously registered hander {} (old: {}, new: {}) in registry'.format(alias, BASELINE_REMOTES, cls))
        BASELINE_REMOTES[alias] = cls
    return cls
コード例 #40
0
ファイル: backend.py プロジェクト: dpressel/baseline
 def get_results(self, task, dataset, event_type,  num_exps=None,
                 num_exps_per_config=None, metric=None, sort=None, id=None, label=None):
     metrics = listify(metric)
     coll = self.db[task]
     query = self._update_query({}, dataset=dataset, id=id, label=label)
     projection = self._update_projection(event_type=event_type)
     result_frame = self._generate_data_frame(coll, metrics=metrics, query=query, projection=projection, event_type=event_type)
     if not result_frame.empty:
         return df_get_results(result_frame, dataset, num_exps, num_exps_per_config, metric, sort)
     return None
コード例 #41
0
ファイル: backend.py プロジェクト: dpressel/baseline
 def experiment_details(self, user, metric, sort, task, event_type, sha1, n):
     session = self.Session()
     results = []
     metrics = listify(metric)
     users = listify(user)
     metrics_to_add = [metrics[0]] if len(metrics) == 1 else []
     phase = self.event2phase(event_type)
     hits = session.query(Experiment).filter(Experiment.sha1 == sha1). \
         filter(Experiment.task == task)
     for exp in hits:
         for event in exp.events:
             if event.phase == phase:
                 result = [exp.id, exp.username, exp.label, exp.dataset, exp.sha1, exp.date]
                 for m in self._get_filtered_metrics(event.metrics, metrics):
                     result += [m.value]
                     if m.label not in metrics_to_add:
                         metrics_to_add += [m.label]
                 results.append(result)
     cols = ['id', 'username', 'label', 'dataset', 'sha1', 'date'] + metrics_to_add
     result_frame = pd.DataFrame(results, columns=cols)
     return df_experimental_details(result_frame, sha1, users, sort, metric, n)
コード例 #42
0
ファイル: model.py プロジェクト: dpressel/baseline
    def _stacked(self, pooled, **kwargs):
        pdrop = kwargs.get('dropout', 0.5)
        hszs = listify(kwargs.get('hsz', []))
        activation_type = kwargs.get('activation', 'relu')

        if len(hszs) == 0:
            return pooled

        last_layer = pooled
        for i, hsz in enumerate(hszs):
            last_layer = Dense(units=hsz, activation=activation_type)(last_layer)
            last_layer = Dropout(rate=pdrop)(last_layer)
        return last_layer
コード例 #43
0
ファイル: train.py プロジェクト: dpressel/baseline
def fit(model, ts, vs, es=None, **kwargs):

    do_early_stopping = bool(kwargs.get('do_early_stopping', True))
    epochs = int(kwargs.get('epochs', 20))
    model_file = get_model_file('seq2seq', 'pytorch', kwargs.get('basedir'))

    best_metric = 0
    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'bleu')
        early_stopping_cmp, best_metric = get_metric_cmp(early_stopping_metric, kwargs.get('early_stopping_cmp'))
        patience = kwargs.get('patience', epochs)
        logger.info('Doing early stopping on [%s] with patience [%d]', early_stopping_metric, patience)

    reporting_fns = listify(kwargs.get('reporting', []))
    logger.info('reporting %s', reporting_fns)

    after_train_fn = kwargs.get('after_train_fn', None)
    trainer = create_trainer(model, **kwargs)

    last_improved = 0
    for epoch in range(epochs):
        trainer.train(ts, reporting_fns)

        if after_train_fn is not None:
            after_train_fn(model)

        test_metrics = trainer.test(vs, reporting_fns, phase='Valid')

        if do_early_stopping is False:
            model.save(model_file)

        elif early_stopping_cmp(test_metrics[early_stopping_metric], best_metric):
            last_improved = epoch
            best_metric = test_metrics[early_stopping_metric]
            logger.info('New best %.3f', best_metric)
            model.save(model_file)

        elif (epoch - last_improved) > patience:
            logger.info('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        logger.info('Best performance on %s: %.3f at epoch %d', early_stopping_metric, best_metric, last_improved)

    if es is not None:
        model.load(model_file)
        trainer = Seq2SeqTrainerPyTorch(model, **kwargs)
        test_metrics = trainer.test(es, reporting_fns, phase='Test')
    return test_metrics
コード例 #44
0
ファイル: model.py プロジェクト: dpressel/baseline
 def init_stacked(self, input_dim, **kwargs):
     hszs = listify(kwargs.get('hsz', []))
     if len(hszs) == 0:
         self.stacked_layers = None
         return input_dim
     self.stacked_layers = nn.Sequential()
     layers = []
     in_layer_sz = input_dim
     for i, hsz in enumerate(hszs):
         layers.append(nn.Linear(in_layer_sz, hsz))
         layers.append(nn.ReLU())
         layers.append(nn.Dropout(self.pdrop))
         in_layer_sz = hsz
     append2seq(self.stacked_layers, layers)
     return in_layer_sz
コード例 #45
0
ファイル: train.py プロジェクト: dpressel/baseline
def fit(model, ts, vs, es=None, epochs=5, do_early_stopping=True,
        early_stopping_metric='bleu', **kwargs):

    patience = int(kwargs.get('patience', epochs))
    after_train_fn = kwargs.get('after_train_fn', None)

    model_file = get_model_file('seq2seq', 'dy', kwargs.get('basedir'))

    trainer = create_trainer(model, **kwargs)

    best_metric = 0
    if do_early_stopping:
        early_stopping_cmp, best_metric = get_metric_cmp(early_stopping_metric, kwargs.get('early_stopping_cmp'))
        logger.info("Doing early stopping on [%s] with patience [%d]", early_stopping_metric, patience)

    reporting_fns = listify(kwargs.get('reporting', []))
    logger.info('reporting %s', reporting_fns)

    last_improved = 0

    for epoch in range(epochs):
        trainer.train(ts, reporting_fns)
        if after_train_fn is not None:
            after_train_fn(model)

        test_metrics = trainer.test(vs, reporting_fns, phase='Valid')

        if do_early_stopping is False:
            model.save(model_file)

        elif early_stopping_cmp(test_metrics[early_stopping_metric], best_metric):
            last_improved = epoch
            best_metric = test_metrics[early_stopping_metric]
            logger.info("New best %.3f", best_metric)
            model.save(model_file)

        elif (epoch - last_improved) > patience:
            logger.info("Stopping due to persistent failures to improve")
            break

    if do_early_stopping is True:
        logger.info('Best performance on %s: %.3f at epoch %d', early_stopping_metric, best_metric, last_improved)

    if es is not None:
        logger.info('Reloading best checkpoint')
        model = model.load(model_file)
        test_metrics = trainer.test(es, reporting_fns, phase='Test')
    return test_metrics
コード例 #46
0
ファイル: reader.py プロジェクト: dpressel/baseline
 def build_vocabs(self, files, **kwargs):
     vocab_file = kwargs.get('vocab_file')
     if vocab_file is not None:
         all_vects = self.src_vectorizers.copy()
         all_vects['tgt'] = self.tgt_vectorizer
         _vocab_allowed(all_vects)
         # Only read the file once.
         text = _read_from_col(0, listify(vocab_file))
         src_vocab = _build_vocab_for_col(None, None, self.src_vectorizers, text=text)
         tgt_vocab = _build_vocab_for_col(None, None, {'tgt': self.tgt_vectorizer}, text=text)
         return src_vocab, tgt_vocab['tgt']
     src_vocab = _build_vocab_for_col(0, [f + self.src_suffix for f in files], self.src_vectorizers)
     tgt_vocab = _build_vocab_for_col(0, [f + self.tgt_suffix for f in files], {'tgt': self.tgt_vectorizer})
     min_f = kwargs.get('min_f', {})
     tgt_min_f = {'tgt': min_f.pop('tgt', -1)}
     src_vocab = _filter_vocab(src_vocab, min_f)
     tgt_vocab = _filter_vocab(tgt_vocab, tgt_min_f)
     return src_vocab, tgt_vocab['tgt']
コード例 #47
0
ファイル: train.py プロジェクト: dpressel/baseline
def fit(model, ts, vs, es, epochs=20, do_early_stopping=True, early_stopping_metric='acc', **kwargs):
    autobatchsz = kwargs.get('autobatchsz', 1)
    verbose = kwargs.get('verbose', {'print': kwargs.get('verbose_print', False), 'file': kwargs.get('verbose_file', None)})
    model_file = get_model_file('classify', 'dynet', kwargs.get('basedir'))

    best_metric = 0
    if do_early_stopping:
        patience = kwargs.get('patience', epochs)
        early_stopping_cmp, best_metric = get_metric_cmp(early_stopping_metric, kwargs.get('early_stopping_cmp'))
        logger.info('Doing early stopping on [%s] with patience [%d]', early_stopping_metric, patience)

    reporting_fns = listify(kwargs.get('reporting', []))
    logger.info('reporting %s', reporting_fns)

    trainer = create_trainer(model, **kwargs)

    last_improved = 0

    for epoch in range(epochs):
        trainer.train(ts, reporting_fns)
        test_metrics = trainer.test(vs, reporting_fns)

        if do_early_stopping is False:
            model.save(model_file)

        elif early_stopping_cmp(test_metrics[early_stopping_metric], best_metric):
            last_improved = epoch
            best_metric = test_metrics[early_stopping_metric]
            logger.info('New best %.3f', best_metric)
            model.save(model_file)

        elif (epoch - last_improved) > patience:
            logger.info('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        logger.info('Best performance on %s: %.3f at epoch %d', early_stopping_metric, best_metric, last_improved)
    if es is not None:
        logger.info('Reloading best checkpoint')
        model = model.load(model_file)
        test_metrics = trainer.test(es, reporting_fns, phase='Test', verbose=verbose)
    return test_metrics
コード例 #48
0
ファイル: train.py プロジェクト: dpressel/baseline
def create_lr_scheduler(**kwargs):
    """Create a learning rate scheduler.

    :Keyword Arguments:
      * *lr_scheduler_type* `str` or `list` The name of the learning rate scheduler
          if list then the first scheduler should be a warmup scheduler.
    """
    sched_type = kwargs.get('lr_scheduler_type')
    if sched_type is None:
        return None
    sched_type = listify(sched_type)
    if len(sched_type) == 2:
        warm = BASELINE_LR_SCHEDULERS.get(sched_type[0])(**kwargs)
        assert isinstance(warm, WarmupLearningRateScheduler), "First LR Scheduler must be a warmup scheduler."
        rest = BASELINE_LR_SCHEDULERS.get(sched_type[1])(**kwargs)
        return BASELINE_LR_SCHEDULERS.get('composite')(warm=warm, rest=rest, **kwargs)
    Constructor = BASELINE_LR_SCHEDULERS.get(sched_type[0])
    lrs = Constructor(**kwargs)
    logger.info(lrs)
    return lrs
コード例 #49
0
ファイル: model.py プロジェクト: dpressel/baseline
    def init_stacked(self, input_dim, **kwargs):

        hszs = listify(kwargs.get('hsz', []))
        if len(hszs) == 0:
            return input_dim, None

        stacked_layers = []
        isz = input_dim
        for i, hsz in enumerate(hszs):
            stacked_layers.append(Linear(hsz, isz, self.pc))
            stacked_layers.append(dy.rectify)
            stacked_layers.append(self.dropout)
            isz = hsz

        def call_stacked(input_):
            for layer in stacked_layers:
                input_ = layer(input_)
            return input_

        return hsz, call_stacked
コード例 #50
0
ファイル: reader.py プロジェクト: dpressel/baseline
    def build_vocab(self, files, **kwargs):
        vocab_file = kwargs.get('vocab_file')
        if vocab_file is not None:
            _vocab_allowed(self.vectorizers)
            return _build_vocab_for_col(0, listify(vocab_file), self.vectorizers)

        vocabs = {k: Counter() for k in self.vectorizers.keys()}

        for file in files:
            if file is None:
                continue

            with codecs.open(file, encoding='utf-8', mode='r') as f:
                sentences = []
                for line in f:
                    sentences += line.split() + ['<EOS>']
                for k, vectorizer in self.vectorizers.items():
                    vocabs[k].update(vectorizer.count(sentences))

        vocabs = _filter_vocab(vocabs, kwargs.get('min_f', {}))
        return vocabs
コード例 #51
0
ファイル: model.py プロジェクト: dpressel/baseline
    def stacked(self, pooled, init, **kwargs):
        """Stack 1 or more hidden layers, optionally (forming an MLP)

        :param pooled: The fixed representation of the model
        :param init: The tensorflow initializer
        :param kwargs: See below

        :Keyword Arguments:
        * *hsz* -- (``int``) The number of hidden units (defaults to `100`)

        :return: The final layer
        """

        hszs = listify(kwargs.get('hsz', []))
        if len(hszs) == 0:
            return pooled

        in_layer = pooled
        for i, hsz in enumerate(hszs):
            fc = tf.layers.dense(in_layer, hsz, activation=tf.nn.relu, kernel_initializer=init, name='fc-{}'.format(i))
            in_layer = tf.layers.dropout(fc, self.pdrop_value, training=TRAIN_FLAG(), name='fc-dropout-{}'.format(i))
        return in_layer
コード例 #52
0
ファイル: train.py プロジェクト: dpressel/baseline
def fit(model, ts, vs, es=None, **kwargs):
    """
    Train a classifier using Keras
    
    :param model: The model to train
    :param ts: A training data set
    :param vs: A validation data set
    :param es: A test data set, can be None
    :param kwargs:
        See below

    :Keyword Arguments:
        * *do_early_stopping* (``bool``) --
          Stop after evaluation data is no longer improving.  Defaults to True

        * *epochs* (``int``) -- how many epochs.  Default to 20
        * *outfile* -- Model output file, defaults to classifier-model-keras
        * *patience* (``int``) --
           How many epochs where evaluation is no longer improving before we give up
        * *reporting* --
           Callbacks which may be used on reporting updates
        * *optim* --
           What optimizer to use.  Defaults to `adam`
    :return:
    """
    trainer = create_trainer(model, **kwargs)
    do_early_stopping = bool(kwargs.get('do_early_stopping', True))
    epochs = int(kwargs.get('epochs', 20))
    model_file = get_model_file('classify', 'keras', kwargs.get('basedir'))

    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'acc')
        patience = kwargs.get('patience', epochs)
        logger.info('Doing early stopping on [%s] with patience [%d]', early_stopping_metric, patience)

    reporting_fns = listify(kwargs.get('reporting', []))
    logger.info('reporting %s', reporting_fns)

    max_metric = 0
    last_improved = 0

    for epoch in range(epochs):

        trainer.train(ts, reporting_fns)
        test_metrics = trainer.test(vs, reporting_fns)

        if do_early_stopping is False:
            model.save(model_file)

        elif test_metrics[early_stopping_metric] > max_metric:
            last_improved = epoch
            max_metric = test_metrics[early_stopping_metric]
            logger.info('New max %.3f', max_metric)
            model.save(model_file)

        elif (epoch - last_improved) > patience:
            logger.info('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        logger.info('Best performance on max_metric %.3f at epoch %d', max_metric, last_improved)

    if es is not None:
        logger.info('Reloading best checkpoint')
        model = model.load(model_file)
        trainer = ClassifyTrainerKeras(model, **kwargs)
        trainer.test(es, reporting_fns, phase='Test')
コード例 #53
0
ファイル: train.py プロジェクト: dpressel/baseline
def fit(model, ts, vs, es=None, **kwargs):
    epochs = int(kwargs['epochs']) if 'epochs' in kwargs else 5
    patience = int(kwargs['patience']) if 'patience' in kwargs else epochs

    model_file = get_model_file('lm', 'tf', kwargs.get('basedir'))
    after_train_fn = kwargs['after_train_fn'] if 'after_train_fn' in kwargs else None
    trainer = create_trainer(model, **kwargs)
    init = tf.global_variables_initializer()
    feed_dict = {k: v for e in model.embeddings.values() for k, v in e.get_feed_dict().items()}
    model.sess.run(init, feed_dict)
    saver = tf.train.Saver()
    model.set_saver(saver)
    checkpoint = kwargs.get('checkpoint')
    if checkpoint is not None:
        latest = tf.train.latest_checkpoint(checkpoint)
        print('Reloading ' + latest)
        model.saver.restore(model.sess, latest)

    do_early_stopping = bool(kwargs.get('do_early_stopping', True))

    best_metric = 1000
    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'avg_loss')
        early_stopping_cmp, best_metric = get_metric_cmp(early_stopping_metric, kwargs.get('early_stopping_cmp'))
        patience = kwargs.get('patience', epochs)
        logger.info('Doing early stopping on [%s] with patience [%d]', early_stopping_metric, patience)

    reporting_fns = listify(kwargs.get('reporting', []))
    logger.info('reporting %s', reporting_fns)

    last_improved = 0

    for epoch in range(epochs):

        trainer.train(ts, reporting_fns)
        if after_train_fn is not None:
            after_train_fn(model)

        test_metrics = trainer.test(vs, reporting_fns, phase='Valid')

        if do_early_stopping is False:
            trainer.checkpoint()
            trainer.model.save(model_file)

        elif early_stopping_cmp(test_metrics[early_stopping_metric], best_metric):
            last_improved = epoch
            best_metric = test_metrics[early_stopping_metric]
            logger.info('New best %.3f', best_metric)
            trainer.checkpoint()
            trainer.model.save(model_file)

        elif (epoch - last_improved) > patience:
            logger.info('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        logger.info('Best performance on %s: %.3f at epoch %d', early_stopping_metric, best_metric, last_improved)
    if es is not None:
        trainer.recover_last_checkpoint()
        test_metrics = trainer.test(es, reporting_fns, phase='Test')
    return test_metrics
コード例 #54
0
ファイル: sample.py プロジェクト: dpressel/baseline
def add_min_max(example, key, value):
    key = (key,)
    c = value.get('constraints', DEFAULT_CONSTRAINTS.get(key[-1], DEFAULT_CONSTRAINTS['default']))
    example[key] = [listify(c), value['min'], value['max']]
コード例 #55
0
ファイル: sample.py プロジェクト: dpressel/baseline
def add_normal(example, key, value):
    key = (key,)
    c = value.get('constraints', DEFAULT_CONSTRAINTS.get(key[-1], DEFAULT_CONSTRAINTS['default']))
    example[key] = [listify(c), value['mu'], value['sigma']]
コード例 #56
0
ファイル: vectorizers.py プロジェクト: dpressel/baseline
 def __init__(self, **kwargs):
     super(Dict2DVectorizer, self).__init__(**kwargs)
     self.fields = listify(kwargs.get('fields', 'text'))
     self.delim = kwargs.get('token_delim', '@@')
コード例 #57
0
ファイル: train.py プロジェクト: dpressel/baseline
def fit(model, ts, vs, es, **kwargs):
    epochs = int(kwargs['epochs']) if 'epochs' in kwargs else 5
    patience = int(kwargs['patience']) if 'patience' in kwargs else epochs
    conll_output = kwargs.get('conll_output', None)
    span_type = kwargs.get('span_type', 'iob')
    txts = kwargs.get('txts', None)
    model_file = get_model_file('tagger', 'tf', kwargs.get('basedir'))
    after_train_fn = kwargs['after_train_fn'] if 'after_train_fn' in kwargs else None
    trainer = create_trainer(model, **kwargs)
    tables = tf.tables_initializer()
    model.sess.run(tables)
    feed_dict = {k: v for e in model.embeddings.values() for k, v in e.get_feed_dict().items()}
    init = tf.global_variables_initializer()
    model.sess.run(init, feed_dict)
    saver = tf.train.Saver()
    model.save_using(saver)
    checkpoint = kwargs.get('checkpoint')
    if checkpoint is not None:
        latest = tf.train.latest_checkpoint(checkpoint)
        print('Reloading ' + latest)
        model.saver.restore(model.sess, latest)

    do_early_stopping = bool(kwargs.get('do_early_stopping', True))
    verbose = bool(kwargs.get('verbose', False))

    best_metric = 0
    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'acc')
        early_stopping_cmp, best_metric = get_metric_cmp(early_stopping_metric, kwargs.get('early_stopping_cmp'))
        patience = kwargs.get('patience', epochs)
        logger.info('Doing early stopping on [%s] with patience [%d]', early_stopping_metric, patience)

    reporting_fns = listify(kwargs.get('reporting', []))
    logger.info('reporting %s', reporting_fns)

    last_improved = 0
    for epoch in range(epochs):

        trainer.train(ts, reporting_fns)
        if after_train_fn is not None:
            after_train_fn(model)

        test_metrics = trainer.test(vs, reporting_fns, phase='Valid')

        if do_early_stopping is False:
            trainer.checkpoint()
            model.save(model_file)

        elif early_stopping_cmp(test_metrics[early_stopping_metric], best_metric):
            last_improved = epoch
            best_metric = test_metrics[early_stopping_metric]
            logger.info('New best %.3f', best_metric)
            trainer.checkpoint()
            model.save(model_file)

        elif (epoch - last_improved) > patience:
            logger.info('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        logger.info('Best performance on %s: %.3f at epoch %d', early_stopping_metric, best_metric, last_improved)
    if es is not None:

        trainer.recover_last_checkpoint()
        # What to do about overloading this??
        evaluator = TaggerEvaluatorTf(model, span_type, verbose)
        start = time.time()
        test_metrics = evaluator.test(es, conll_output=conll_output, txts=txts)
        duration = time.time() - start
        for reporting in reporting_fns:
            reporting(test_metrics, 0, 'Test')
        trainer.log.debug({'phase': 'Test', 'time': duration})
    return test_metrics
コード例 #58
0
ファイル: train.py プロジェクト: dpressel/baseline
def fit(model, ts, vs, es=None, **kwargs):
    """
    Train a classifier using TensorFlow

    :param model: The model to train
    :param ts: A training data set
    :param vs: A validation data set
    :param es: A test data set, can be None
    :param kwargs:
        See below

    :Keyword Arguments:
        * *do_early_stopping* (``bool``) --
          Stop after evaluation data is no longer improving.  Defaults to True

        * *epochs* (``int``) -- how many epochs.  Default to 20
        * *outfile* -- Model output file, defaults to classifier-model.pyth
        * *patience* --
           How many epochs where evaluation is no longer improving before we give up
        * *reporting* --
           Callbacks which may be used on reporting updates
        * Additional arguments are supported, see :func:`baseline.tf.optimize` for full list
    :return:
    """
    do_early_stopping = bool(kwargs.get('do_early_stopping', True))
    verbose = kwargs.get('verbose', {'console': kwargs.get('verbose_console', False), 'file': kwargs.get('verbose_file', None)})
    epochs = int(kwargs.get('epochs', 20))
    model_file = get_model_file('classify', 'tf', kwargs.get('basedir'))
    ema = True if kwargs.get('ema_decay') is not None else False

    best_metric = 0
    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'acc')
        early_stopping_cmp, best_metric = get_metric_cmp(early_stopping_metric, kwargs.get('early_stopping_cmp'))
        patience = kwargs.get('patience', epochs)
        logger.info('Doing early stopping on [%s] with patience [%d]', early_stopping_metric, patience)

    reporting_fns = listify(kwargs.get('reporting', []))
    logger.info('reporting %s', reporting_fns)

    trainer = create_trainer(model, **kwargs)
    tables = tf.tables_initializer()
    model.sess.run(tables)
    feed_dict = {k: v for e in model.embeddings.values() for k, v in e.get_feed_dict().items()}
    model.sess.run(tf.global_variables_initializer(), feed_dict)
    model.set_saver(tf.train.Saver())
    checkpoint = kwargs.get('checkpoint')
    if checkpoint is not None:
        latest = tf.train.latest_checkpoint(checkpoint)
        print('Reloading ' + latest)
        model.saver.restore(model.sess, latest)


    last_improved = 0

    for epoch in range(epochs):

        trainer.train(ts, reporting_fns)
        test_metrics = trainer.test(vs, reporting_fns, phase='Valid')

        if do_early_stopping is False:
            trainer.checkpoint()
            trainer.model.save(model_file)

        elif early_stopping_cmp(test_metrics[early_stopping_metric], best_metric):
            last_improved = epoch
            best_metric = test_metrics[early_stopping_metric]
            logger.info('New best %.3f', best_metric)
            trainer.checkpoint()
            trainer.model.save(model_file)

        elif (epoch - last_improved) > patience:
            logger.info('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        logger.info('Best performance on %s: %.3f at epoch %d', early_stopping_metric, best_metric, last_improved)

    if es is not None:
        logger.info('Reloading best checkpoint')
        trainer.recover_last_checkpoint()
        test_metrics = trainer.test(es, reporting_fns, phase='Test', verbose=verbose)
    return test_metrics
コード例 #59
0
ファイル: demolib.py プロジェクト: dpressel/baseline
def train(model, ts, vs, es=None, **kwargs):
    """
    Train a classifier using TensorFlow

    :param model: The model to train
    :param ts: A training data set
    :param vs: A validation data set
    :param es: A test data set, can be None
    :param kwargs:
        See below

    :Keyword Arguments:
        * *do_early_stopping* (``bool``) --
          Stop after evaluation data is no longer improving.  Defaults to True

        * *epochs* (``int``) -- how many epochs.  Default to 20
        * *outfile* -- Model output file, defaults to classifier-model.pyth
        * *patience* --
           How many epochs where evaluation is no longer improving before we give up
        * *reporting* --
           Callbacks which may be used on reporting updates
        * Additional arguments are supported, see :func:`baseline.tf.optimize` for full list
    :return:
    """
    n = int(kwargs.get('test_epochs', 5))
    do_early_stopping = bool(kwargs.get('do_early_stopping', True))
    epochs = int(kwargs.get('epochs', 20))
    model_file = get_model_file('classify', 'tf', kwargs.get('basedir'))

    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'acc')
        patience = kwargs.get('patience', epochs)
        print('Doing early stopping on [%s] with patience [%d]' % (early_stopping_metric, patience))

    reporting_fns = listify(kwargs.get('reporting', []))
    print('reporting', reporting_fns)

    trainer = create_trainer(model, **kwargs)
    tables = tf.tables_initializer()
    model.sess.run(tables)
    model.sess.run(tf.global_variables_initializer())
    model.set_saver(tf.train.Saver())

    max_metric = 0
    last_improved = 0

    for epoch in range(epochs):

        trainer.train(ts, reporting_fns)
        test_metrics = trainer.test(vs, reporting_fns, phase='Valid')

        if epoch > 0 and epoch % n == 0 and epoch < epochs - 1:
            print(color('Running test', Colors.GREEN))
            trainer.test(es, reporting_fns, phase='Test')

        if do_early_stopping is False:
            trainer.checkpoint()
            trainer.model.save(model_file)

        elif test_metrics[early_stopping_metric] > max_metric:
            last_improved = epoch
            max_metric = test_metrics[early_stopping_metric]
            print('New max %.3f' % max_metric)
            trainer.checkpoint()
            trainer.model.save(model_file)

        elif (epoch - last_improved) > patience:
            print(color('Stopping due to persistent failures to improve', Colors.RED))
            break

    if do_early_stopping is True:
        print('Best performance on max_metric %.3f at epoch %d' % (max_metric, last_improved))

    if es is not None:
        print(color('Reloading best checkpoint', Colors.GREEN))
        trainer.recover_last_checkpoint()
        trainer.test(es, reporting_fns, phase='Test')
コード例 #60
0
ファイル: train.py プロジェクト: dpressel/baseline
def fit(model, ts, vs, es, **kwargs):
    """
    Train a classifier using PyTorch
    :param model: The model to train
    :param ts: A training data set
    :param vs: A validation data set
    :param es: A test data set, can be None
    :param kwargs: See below

    :Keyword Arguments:
        * *do_early_stopping* (``bool``) -- Stop after eval data is not improving. Default to True
        * *epochs* (``int``) -- how many epochs.  Default to 20
        * *outfile* -- Model output file, defaults to classifier-model.pyth
        * *patience* --
           How many epochs where evaluation is no longer improving before we give up
        * *reporting* --
           Callbacks which may be used on reporting updates
        * *optim* --
           Optimizer to use, defaults to `sgd`
        * *eta, lr* (``float``) --
           Learning rate, defaults to 0.01
        * *mom* (``float``) --
           Momentum (SGD only), defaults to 0.9 if optim is `sgd`
    :return:
    """
    do_early_stopping = bool(kwargs.get('do_early_stopping', True))
    verbose = kwargs.get('verbose', {'console': kwargs.get('verbose_console', False), 'file': kwargs.get('verbose_file', None)})
    epochs = int(kwargs.get('epochs', 20))
    model_file = get_model_file('classify', 'pytorch', kwargs.get('basedir'))

    best_metric = 0
    if do_early_stopping:
        early_stopping_metric = kwargs.get('early_stopping_metric', 'acc')
        early_stopping_cmp, best_metric = get_metric_cmp(early_stopping_metric, kwargs.get('eatly_stopping_cmp'))
        patience = kwargs.get('patience', epochs)
        logger.info('Doing early stopping on [%s] with patience [%d]', early_stopping_metric, patience)

    reporting_fns = listify(kwargs.get('reporting', []))
    logger.info('reporting %s', reporting_fns)


    trainer = create_trainer(model, **kwargs)

    last_improved = 0

    for epoch in range(epochs):
        trainer.train(ts, reporting_fns)
        test_metrics = trainer.test(vs, reporting_fns)

        if do_early_stopping is False:
            model.save(model_file)

        elif early_stopping_cmp(test_metrics[early_stopping_metric], best_metric):
            last_improved = epoch
            best_metric = test_metrics[early_stopping_metric]
            logger.info('New best %.3f', best_metric)
            model.save(model_file)

        elif (epoch - last_improved) > patience:
            logger.info('Stopping due to persistent failures to improve')
            break

    if do_early_stopping is True:
        logger.info('Best performance on %s: %.3f at epoch %d', early_stopping_metric, best_metric, last_improved)

    if es is not None:
        logger.info('Reloading best checkpoint')
        model = torch.load(model_file)
        trainer = create_trainer(model, **kwargs)
        test_metrics = trainer.test(es, reporting_fns, phase='Test', verbose=verbose)
    return test_metrics