Пример #1
0
    def __init__(self, args):
        """Initialize Experiment.

        :param args: experiment's arguments, taken from a subcommand.

        """
        args = type(args)(**args.__dict__)
        args.__dict__.pop("__argp__", None)
        args.__dict__.pop("__argv__", None)
        args.__dict__.pop("__cls__", None)
        self.args = args
        # This is the time where state is created!
        self.state = State(self.name, self.__class__.__name__,
                           args, self.hyperparams)
        workdir = os.path.join(self.project_workdir(args), self.name)
        super(Experiment, self).__init__(workdir)
        self.mkdirp(self.logdir)
        logger.info("Initializing experiment with name: {}".format(self.name))
        self.logging = getLogger(self.__class__.__name__)

        # TODO move to State object
        self.tracking = None
        if args.tracking is not None:
            self.tracking = Wandb(args.tracking.key, args.tracking.entity,
                                  project=self.__class__.__name__,
                                  name=self.name,
                                  config=flatten_ns_to_dict(self.state.hyperparams),
                                  id=self.hash(),
                                  dir=self.workdir,
                                  )
Пример #2
0
    def __init__(self,
                 name: str,
                 model: Module,
                 spec=None,
                 ema=0,
                 **opt_options):
        self.name = name
        self._model = model
        self._avg_model = model
        self.training = False
        self.optimizer = None

        if model is None or spec is None:
            return

        logger.debug("Trainable '%s': Create optimizer for module '%s'",
                     self.name, self._model.name)
        self.optimizer = build_optimizer(self._model.parameters(), spec,
                                         **opt_options)
        State().register_optimizer(self.name, self.optimizer)
        self.ema = None
        if ema:
            assert (1 > ema > 0)
            self.ema = ema
            self._avg_model = copy.deepcopy(self._model)
            self._avg_model.name += "-ema"
            logger.debug("Trainable '%s': Create ema module '%s'", self.name,
                         self._avg_model.name)
            State().register_module(self._avg_model)
            self._avg_model.eval()
            for param in self._avg_model.parameters():
                param.requires_grad_(False)
Пример #3
0
 def finalize_init(self):
     logger.debug("Finalize module '%s'", self.name)
     winit = functools.partial(
         weights_init,
         nonlinearity=getattr(self, 'nonlinearity', None),
         output_nonlinearity=getattr(self, 'output_nonlinearity', None))
     self.apply(winit)
     State().register_module(self)
     self.to(device=State().device)
     self.eval()
Пример #4
0
 def estimate_metric(self, p, q, critic):
     px, py = _get_x_y(p)
     qx, qy = _get_x_y(q)
     bs = px.size(0)
     size = px.size()
     epsilon = torch.rand(bs, 1, device=State().device)
     inter = px.view(bs, -1) * epsilon + qx.view(bs, -1) * (1 - epsilon)
     # Interpolate labels as well ??
     gradient, _ = get_gradient_wrt(critic, (inter.view(*size), py))
     penalty = gradient.view(gradient.size(0),
                             -1).norm(p=2, dim=-1).sub(1).pow(2).mean()
     return -penalty
Пример #5
0
    def __init__(self,
                 root,
                 num_threads=1,
                 download=False,
                 load=True,
                 splits=(1, ),
                 **options):
        self.state = State()
        self.root = os.path.join(os.path.expanduser(root))
        self.root = os.path.join(self.root, self.__class__.__name__)
        assert (num_threads >= 0)
        self.num_threads = num_threads
        self.splits = splits
        self.options = options
        self.n_epochs = 0

        if download is True and self.check_exists(self.root) is not True:
            self.download(self.root)

        self._data = []
        if load is True:
            self.load()
Пример #6
0
class Experiment(nauka.exp.Experiment, ExperimentInterface):
    @classmethod
    def project_workdir(cls, args):
        if args.workdir:
            return args.workdir
        else:
            return os.path.join(args.basedir, cls.__name__)

    @classmethod
    def project_logdir(cls, args):
        pworkdir = cls.project_workdir(args)
        return os.path.join(pworkdir, "logs")

    def __init__(self, args):
        """Initialize Experiment.

        :param args: experiment's arguments, taken from a subcommand.

        """
        args = type(args)(**args.__dict__)
        args.__dict__.pop("__argp__", None)
        args.__dict__.pop("__argv__", None)
        args.__dict__.pop("__cls__", None)
        self.args = args
        # This is the time where state is created!
        self.state = State(self.name, self.__class__.__name__,
                           args, self.hyperparams)
        workdir = os.path.join(self.project_workdir(args), self.name)
        super(Experiment, self).__init__(workdir)
        self.mkdirp(self.logdir)
        logger.info("Initializing experiment with name: {}".format(self.name))
        self.logging = getLogger(self.__class__.__name__)

        # TODO move to State object
        self.tracking = None
        if args.tracking is not None:
            self.tracking = Wandb(args.tracking.key, args.tracking.entity,
                                  project=self.__class__.__name__,
                                  name=self.name,
                                  config=flatten_ns_to_dict(self.state.hyperparams),
                                  id=self.hash(),
                                  dir=self.workdir,
                                  )

    @property
    def hyperparams(self):
        args = copy.deepcopy(self.args)
        del args.verbosity
        del args.workdir
        del args.basedir
        del args.datadir
        del args.tmpdir
        del args.name
        del args.cuda
        del args.tracking
        del args.fastdebug
        return args

    @property
    def info(self):
        return self.state.info

    @property
    def device(self):
        return self.state.device

    @property
    def iter(self):
        return self.state.info.iter

    @iter.setter
    def iter(self, iter_):
        self.state.info.iter = iter_
        # TODO Substitute with progress bar
        sys.stdout.write("{}/{}\r".format(self.iter, self.args.train_iters))
        sys.stdout.flush()
        self.logging.debug("=========  Iter: %d/%d  ===========",
                           self.iter, self.args.train_iters)


    @property
    def inter(self):
        return self.state.info.inter

    @inter.setter
    def inter(self, inter_):
        self.state.info.inter = inter_

    @property
    def datadir(self):
        """Returns the root directory where datasets reside."""
        return self.args.datadir

    @property
    def workdir(self):
        return self.workDir

    @property
    def logdir(self):
        """Return the directory where experiment log will reside."""
        return os.path.join(self.workdir, "logs")

    @property
    def exitcode(self):
        return 0 if self.is_done else 1

    def hash(self, length=32):
        hparams = nested_ns_to_dict(self.hyperparams)
        hparams['mingle'] = self.args.name
        s = _pbkdf2(length,
                    json.dumps(hparams, sort_keys=True),
                    salt="hyperparameters", rounds=100001)
        return binascii.hexlify(s).decode('utf-8', errors='strict')

    # TODO summarize function
    def log(self, **kwargs):
        self.logging.debug("Iter=%d: %s", self.iter, kwargs)
        if self.tracking:
            self.tracking.log(kwargs, step=self.iter)

    def dump(self, path):
        """Dump state to the directory `path`

        When invoked by the snapshot machinery, `path/` may be assumed to
        already exist. The state must be saved under that directory, but
        the contents of that directory and any hierarchy underneath it are
        completely freeform, except that the subdirectory `path/.experiment`
        must not be touched.

        When invoked by the snapshot machinery, the path's basename as given
        by os.path.basename(path) will be the number this snapshot will be
        be assigned, and it is equal to self.nextSnapshotNum.

        """
        self.state.dump(path)
        return self

    def load(self, path):
        """Load state from given `path`.

        Restore the experiment to a state as close as possible to the one
        the experiment was in at the moment of the dump() that generated the
        checkpoint with the given `path`.

        """
        self.state.load(path)
        return self

    def fromScratch(self):
        """Start a fresh experiment, from scratch."""
        password = "******".format(self.args.seed)
        PRNG.seed(password)
        if self.tracking:
            self.tracking.init()
        self.define()
        if self.tracking:
            self.tracking.watch(self.state.modules)
        return self

    def fromSnapshot(self, path):
        """Start an experiment from a snapshot.

        Most likely, this method will invoke self.load(path) at an opportune
        time in its implementation.

        Returns `self`.
        """
        return self.load(path).fromScratch()

    def interval(self):
        """An interval is defined as the computation- and time-span between two
        snapshots.

        Hard Requirements
        -----------------
           - By definition, one may not invoke snapshot() within an interval.
           - Corollary: The work done by an interval is either fully recorded
             or not recorded at all.

        For reproducibility purposes, all PRNGs are reseeded at the beginning
        of every interval.
        """
        password = "******".format(self.args.seed,
                                                    self.inter)
        PRNG.seed(password)
        self.execute()
        self.inter += 1
        return self

    def run(self):
        """Run by intervals until experiment completion."""
        try:
            self.state.log_setting()
            while not self.is_done:
                self.interval().snapshot().purge()
        except KeyboardInterrupt:
            pass
        return self
Пример #7
0
 def state(self, state_):
     if not State().is_cuda:
         return
     self.torch.cuda.set_rng_state(state_, device=State().device)
Пример #8
0
 def state(self):
     if not State().is_cuda:
         return
     return self.torch.cuda.get_rng_state(device=State().device)
Пример #9
0
 def seed(self, password):
     if not State().is_cuda:
         return
     seed = self.get_random_state(password, salt="torch.cuda")
     self.torch.cuda.manual_seed(seed)
Пример #10
0
class AbstractDataset(object, metaclass=ABCMeta):
    def __init__(self,
                 root,
                 num_threads=1,
                 download=False,
                 load=True,
                 splits=(1, ),
                 **options):
        self.state = State()
        self.root = os.path.join(os.path.expanduser(root))
        self.root = os.path.join(self.root, self.__class__.__name__)
        assert (num_threads >= 0)
        self.num_threads = num_threads
        self.splits = splits
        self.options = options
        self.n_epochs = 0

        if download is True and self.check_exists(self.root) is not True:
            self.download(self.root)

        self._data = []
        if load is True:
            self.load()

    def download(self, root):
        pass

    def check_exists(self, root):
        return True

    @abstractmethod
    def prepare(self, root, **options):
        """Return a `torch.utils.data.Dataset` implementation."""
        pass

    def transform(self, batch):
        return batch

    @property
    def data(self):
        if not self._data or \
                any(not isinstance(ds, torch.utils.data.Dataset) for ds in self._data):
            raise ValueError("Call `load` method first.")
        return self._data

    @property
    def N(self):
        return self.n_data

    def load(self):
        if self.check_exists(self.root) is False:
            raise RuntimeError(self.__class__.__name__ + ' not found.' +
                               ' You can use download=True to download it')
        self._data = self.prepare(self.root, **self.options)
        self.n_data = len(self._data)
        self.splits = prepare_splits(self.splits, self.n_data)
        self.n_splits = len(self.splits)
        self._data = torch.utils.data.random_split(self._data, self.splits)

    def build_loader(self, batch_size, sampler, split=0):
        """Return a `torch.utils.data.DataLoader` interface."""
        return torch.utils.data.DataLoader(
            dataset=self.data[split],
            batch_size=batch_size,
            sampler=sampler,
            #  drop_last=True,
            num_workers=self.num_threads,
            pin_memory=self.state.is_cuda,
            worker_init_fn=_worker_init_fn,
        )

    def _fetch(self, loader, stream):
        batch = next(loader)
        if self.state.is_cuda:
            with torch.cuda.stream(stream):
                batch = [x.cuda(non_blocking=True) for x in batch]
                batch = self.transform(batch)
        else:
            batch = self.transform(batch)
        return batch

    def infinite_sampler(self,
                         name: str,
                         batch_size: int,
                         split=0,
                         resume=True):
        fetch_stream = None
        if self.state.is_cuda:
            fetch_stream = torch.cuda.Stream()
        if resume is True:
            batches_seen = self.state.samplers(name)
        else:
            batches_seen = 0
        sampler = MultiEpochSampler(self.data[split], batches_seen, batch_size)
        loader = iter(self.build_loader(batch_size, sampler, split=split))
        next_batch = self._fetch(loader, fetch_stream)
        while True:
            if self.state.is_cuda:
                torch.cuda.current_stream().wait_stream(fetch_stream)
            current_batch = next_batch
            if self.state.is_cuda:
                for x in current_batch:
                    x.record_stream(torch.cuda.current_stream())
            next_batch = self._fetch(loader, fetch_stream)
            if resume:
                self.state._samplers[name] += 1  # TODO FIXME
            yield current_batch