import random import logging import numpy as np import math from baseline.utils import exporter __all__ = [] export = exporter(__all__) logger = logging.getLogger('baseline') @export class DataFeed: """Data collection that, when iterated, produces an epoch of data This class manages producing a dataset to the trainer, by iterating an epoch and producing a single step at a time. The data can be shuffled per epoch, if requested, otherwise it is returned in the order of the dateset """ def __init__(self): self.steps = 0 self.shuffle = False def _batch(self, i): pass def __getitem__(self, i): return self._batch(i) def __iter__(self): shuffle = np.random.permutation(np.arange(self.steps)) if self.shuffle else np.arange(self.steps)
import math import time import platform from enum import Enum from pprint import pformat from collections import defaultdict from multiprocessing.managers import BaseManager import numpy as np from baseline.utils import export as exporter from baseline.utils import import_user_module, optional_params, register from mead.utils import hash_config from hpctl.utils import Label __all__ = [] export = exporter(__all__) RESULTS = {} @export @optional_params def register_results(cls, name=None): return register(cls, RESULTS, name, 'results') @six.python_2_unicode_compatible class States(Enum): DONE = '\u2714' KILLED = '\u2717' RUNNING = '\u21bb' WAITING = '\u231b'