def __init__(self, args): """Initialize Experiment. :param args: experiment's arguments, taken from a subcommand. """ args = type(args)(**args.__dict__) args.__dict__.pop("__argp__", None) args.__dict__.pop("__argv__", None) args.__dict__.pop("__cls__", None) self.args = args # This is the time where state is created! self.state = State(self.name, self.__class__.__name__, args, self.hyperparams) workdir = os.path.join(self.project_workdir(args), self.name) super(Experiment, self).__init__(workdir) self.mkdirp(self.logdir) logger.info("Initializing experiment with name: {}".format(self.name)) self.logging = getLogger(self.__class__.__name__) # TODO move to State object self.tracking = None if args.tracking is not None: self.tracking = Wandb(args.tracking.key, args.tracking.entity, project=self.__class__.__name__, name=self.name, config=flatten_ns_to_dict(self.state.hyperparams), id=self.hash(), dir=self.workdir, )
def __init__(self, name: str, model: Module, spec=None, ema=0, **opt_options): self.name = name self._model = model self._avg_model = model self.training = False self.optimizer = None if model is None or spec is None: return logger.debug("Trainable '%s': Create optimizer for module '%s'", self.name, self._model.name) self.optimizer = build_optimizer(self._model.parameters(), spec, **opt_options) State().register_optimizer(self.name, self.optimizer) self.ema = None if ema: assert (1 > ema > 0) self.ema = ema self._avg_model = copy.deepcopy(self._model) self._avg_model.name += "-ema" logger.debug("Trainable '%s': Create ema module '%s'", self.name, self._avg_model.name) State().register_module(self._avg_model) self._avg_model.eval() for param in self._avg_model.parameters(): param.requires_grad_(False)
def finalize_init(self): logger.debug("Finalize module '%s'", self.name) winit = functools.partial( weights_init, nonlinearity=getattr(self, 'nonlinearity', None), output_nonlinearity=getattr(self, 'output_nonlinearity', None)) self.apply(winit) State().register_module(self) self.to(device=State().device) self.eval()
def estimate_metric(self, p, q, critic): px, py = _get_x_y(p) qx, qy = _get_x_y(q) bs = px.size(0) size = px.size() epsilon = torch.rand(bs, 1, device=State().device) inter = px.view(bs, -1) * epsilon + qx.view(bs, -1) * (1 - epsilon) # Interpolate labels as well ?? gradient, _ = get_gradient_wrt(critic, (inter.view(*size), py)) penalty = gradient.view(gradient.size(0), -1).norm(p=2, dim=-1).sub(1).pow(2).mean() return -penalty
def __init__(self, root, num_threads=1, download=False, load=True, splits=(1, ), **options): self.state = State() self.root = os.path.join(os.path.expanduser(root)) self.root = os.path.join(self.root, self.__class__.__name__) assert (num_threads >= 0) self.num_threads = num_threads self.splits = splits self.options = options self.n_epochs = 0 if download is True and self.check_exists(self.root) is not True: self.download(self.root) self._data = [] if load is True: self.load()
class Experiment(nauka.exp.Experiment, ExperimentInterface): @classmethod def project_workdir(cls, args): if args.workdir: return args.workdir else: return os.path.join(args.basedir, cls.__name__) @classmethod def project_logdir(cls, args): pworkdir = cls.project_workdir(args) return os.path.join(pworkdir, "logs") def __init__(self, args): """Initialize Experiment. :param args: experiment's arguments, taken from a subcommand. """ args = type(args)(**args.__dict__) args.__dict__.pop("__argp__", None) args.__dict__.pop("__argv__", None) args.__dict__.pop("__cls__", None) self.args = args # This is the time where state is created! self.state = State(self.name, self.__class__.__name__, args, self.hyperparams) workdir = os.path.join(self.project_workdir(args), self.name) super(Experiment, self).__init__(workdir) self.mkdirp(self.logdir) logger.info("Initializing experiment with name: {}".format(self.name)) self.logging = getLogger(self.__class__.__name__) # TODO move to State object self.tracking = None if args.tracking is not None: self.tracking = Wandb(args.tracking.key, args.tracking.entity, project=self.__class__.__name__, name=self.name, config=flatten_ns_to_dict(self.state.hyperparams), id=self.hash(), dir=self.workdir, ) @property def hyperparams(self): args = copy.deepcopy(self.args) del args.verbosity del args.workdir del args.basedir del args.datadir del args.tmpdir del args.name del args.cuda del args.tracking del args.fastdebug return args @property def info(self): return self.state.info @property def device(self): return self.state.device @property def iter(self): return self.state.info.iter @iter.setter def iter(self, iter_): self.state.info.iter = iter_ # TODO Substitute with progress bar sys.stdout.write("{}/{}\r".format(self.iter, self.args.train_iters)) sys.stdout.flush() self.logging.debug("========= Iter: %d/%d ===========", self.iter, self.args.train_iters) @property def inter(self): return self.state.info.inter @inter.setter def inter(self, inter_): self.state.info.inter = inter_ @property def datadir(self): """Returns the root directory where datasets reside.""" return self.args.datadir @property def workdir(self): return self.workDir @property def logdir(self): """Return the directory where experiment log will reside.""" return os.path.join(self.workdir, "logs") @property def exitcode(self): return 0 if self.is_done else 1 def hash(self, length=32): hparams = nested_ns_to_dict(self.hyperparams) hparams['mingle'] = self.args.name s = _pbkdf2(length, json.dumps(hparams, sort_keys=True), salt="hyperparameters", rounds=100001) return binascii.hexlify(s).decode('utf-8', errors='strict') # TODO summarize function def log(self, **kwargs): self.logging.debug("Iter=%d: %s", self.iter, kwargs) if self.tracking: self.tracking.log(kwargs, step=self.iter) def dump(self, path): """Dump state to the directory `path` When invoked by the snapshot machinery, `path/` may be assumed to already exist. The state must be saved under that directory, but the contents of that directory and any hierarchy underneath it are completely freeform, except that the subdirectory `path/.experiment` must not be touched. When invoked by the snapshot machinery, the path's basename as given by os.path.basename(path) will be the number this snapshot will be be assigned, and it is equal to self.nextSnapshotNum. """ self.state.dump(path) return self def load(self, path): """Load state from given `path`. Restore the experiment to a state as close as possible to the one the experiment was in at the moment of the dump() that generated the checkpoint with the given `path`. """ self.state.load(path) return self def fromScratch(self): """Start a fresh experiment, from scratch.""" password = "******".format(self.args.seed) PRNG.seed(password) if self.tracking: self.tracking.init() self.define() if self.tracking: self.tracking.watch(self.state.modules) return self def fromSnapshot(self, path): """Start an experiment from a snapshot. Most likely, this method will invoke self.load(path) at an opportune time in its implementation. Returns `self`. """ return self.load(path).fromScratch() def interval(self): """An interval is defined as the computation- and time-span between two snapshots. Hard Requirements ----------------- - By definition, one may not invoke snapshot() within an interval. - Corollary: The work done by an interval is either fully recorded or not recorded at all. For reproducibility purposes, all PRNGs are reseeded at the beginning of every interval. """ password = "******".format(self.args.seed, self.inter) PRNG.seed(password) self.execute() self.inter += 1 return self def run(self): """Run by intervals until experiment completion.""" try: self.state.log_setting() while not self.is_done: self.interval().snapshot().purge() except KeyboardInterrupt: pass return self
def state(self, state_): if not State().is_cuda: return self.torch.cuda.set_rng_state(state_, device=State().device)
def state(self): if not State().is_cuda: return return self.torch.cuda.get_rng_state(device=State().device)
def seed(self, password): if not State().is_cuda: return seed = self.get_random_state(password, salt="torch.cuda") self.torch.cuda.manual_seed(seed)
class AbstractDataset(object, metaclass=ABCMeta): def __init__(self, root, num_threads=1, download=False, load=True, splits=(1, ), **options): self.state = State() self.root = os.path.join(os.path.expanduser(root)) self.root = os.path.join(self.root, self.__class__.__name__) assert (num_threads >= 0) self.num_threads = num_threads self.splits = splits self.options = options self.n_epochs = 0 if download is True and self.check_exists(self.root) is not True: self.download(self.root) self._data = [] if load is True: self.load() def download(self, root): pass def check_exists(self, root): return True @abstractmethod def prepare(self, root, **options): """Return a `torch.utils.data.Dataset` implementation.""" pass def transform(self, batch): return batch @property def data(self): if not self._data or \ any(not isinstance(ds, torch.utils.data.Dataset) for ds in self._data): raise ValueError("Call `load` method first.") return self._data @property def N(self): return self.n_data def load(self): if self.check_exists(self.root) is False: raise RuntimeError(self.__class__.__name__ + ' not found.' + ' You can use download=True to download it') self._data = self.prepare(self.root, **self.options) self.n_data = len(self._data) self.splits = prepare_splits(self.splits, self.n_data) self.n_splits = len(self.splits) self._data = torch.utils.data.random_split(self._data, self.splits) def build_loader(self, batch_size, sampler, split=0): """Return a `torch.utils.data.DataLoader` interface.""" return torch.utils.data.DataLoader( dataset=self.data[split], batch_size=batch_size, sampler=sampler, # drop_last=True, num_workers=self.num_threads, pin_memory=self.state.is_cuda, worker_init_fn=_worker_init_fn, ) def _fetch(self, loader, stream): batch = next(loader) if self.state.is_cuda: with torch.cuda.stream(stream): batch = [x.cuda(non_blocking=True) for x in batch] batch = self.transform(batch) else: batch = self.transform(batch) return batch def infinite_sampler(self, name: str, batch_size: int, split=0, resume=True): fetch_stream = None if self.state.is_cuda: fetch_stream = torch.cuda.Stream() if resume is True: batches_seen = self.state.samplers(name) else: batches_seen = 0 sampler = MultiEpochSampler(self.data[split], batches_seen, batch_size) loader = iter(self.build_loader(batch_size, sampler, split=split)) next_batch = self._fetch(loader, fetch_stream) while True: if self.state.is_cuda: torch.cuda.current_stream().wait_stream(fetch_stream) current_batch = next_batch if self.state.is_cuda: for x in current_batch: x.record_stream(torch.cuda.current_stream()) next_batch = self._fetch(loader, fetch_stream) if resume: self.state._samplers[name] += 1 # TODO FIXME yield current_batch