コード例 #1
0
    def getAgent(self, collector):
        Agent = getAgent(self.exp.agent)

        params = merge(self.params, {'gamma': self.getGamma()})
        self.agent = Agent(self.features, self.actions, params, self.seed,
                           collector)
        return self.agent
コード例 #2
0
    def test_merge(self):
        # base functionality
        d1 = {
            'a': [1, 2, 3],
            'b': False,
            'c': {
                'aa': [4, 5, 6],
            },
        }

        d2 = {
            'b': True,
            'd': 22,
        }

        got = merge(d1, d2)
        expected = {
            'a': [1, 2, 3],
            'b': True,
            'c': {
                'aa': [4, 5, 6],
            },
            'd': 22,
        }

        self.assertDictEqual(got, expected)
コード例 #3
0
ファイル: Slurm.py プロジェクト: ajjacobs/PyExpUtils
def buildParallel(executable: str,
                  tasks: Iterator[Any],
                  opts: Dict[str, Any],
                  parallelOpts: Dict[str, Any] = {}):
    nodes = opts.get('nodes-per-process', 1)
    threads = opts.get('threads-per-process', 1)
    return Parallel.build(
        merge(
            {
                'executable':
                f'srun -N{nodes} -n{threads} --exclusive {executable}',
                'tasks': tasks,
                'cores': opts['ntasks'],
                'delay':
                0.5,  # because srun interacts with the scheduler, a slight delay helps prevent intermittent errors
            },
            parallelOpts))
コード例 #4
0
    def interpolateSavePath(self, idx, permute='metaParameters', key=None):
        if key is None:
            config = getConfig()
            key = config.save_path

        params = pick(self.getPermutation(idx, permute), permute)
        param_string = hyphenatedStringify(params)

        run = self.getRun(idx, permute)

        special_keys = {
            'params': param_string,
            'run': str(run),
            'name': self.getExperimentName()
        }
        d = merge(self.__dict__, special_keys)

        return interpolate(key, d)
コード例 #5
0
    def __init__(self, features: int, actions: int, params: Dict, seed: int,
                 collector: Collector):
        super().__init__(features, actions, params, seed, collector)

        self.h_grad: bool = self.params.get('h_grad', False)

        self.h = self.value_net.addOutput(actions,
                                          grad=self.h_grad,
                                          bias=True,
                                          initial_value=0)
        # re-copy this since we changed the policy net
        self.initializeTargetNet()

        # create a second optimizer specific to h
        alpha = self.optimizer_params['alpha']
        eta = self.optimizer_params.get('eta', 1.0)
        h_optimizer_params = merge(self.optimizer_params,
                                   {'alpha': alpha * eta})

        learnables = self.h.parameters()
        self.h_optimizer = deserializeOptimizer(learnables, h_optimizer_params)

        self.h_weight = self.h.weight
        self.h_bias_weight = self.h.bias
コード例 #6
0
    def getPermutation(self, idx, keys='metaParameters', Model=None):
        sweeps = self.permutable(keys)
        permutation = getParameterPermutation(sweeps, idx)
        d = merge(self._d, permutation)

        return Model(d) if Model else d
コード例 #7
0
 def __init__(self, params, rng=None):
     super().__init__(merge(params, { 'scale_output': False }), rng=rng)