Exemple #1
0
 def __setstate__(self, state):
     Serializable.__setstate__(self, state['init_args'])
     self.set_params(state['network_params'])
     [
         obs_filter.set_params(params)
         for obs_filter, params in zip(self.obs_filters, state['filter'])
     ]
Exemple #2
0
    def __init__(self,
                 *args,
                 init_std=1.,
                 min_std=1e-6,
                 cell_type='lstm',
                 **kwargs):
        # store the init args for serialization and call the super constructors
        Serializable.quick_init(self, locals())
        Policy.__init__(self, *args, **kwargs)

        self.min_log_std = np.log(min_std)
        self.init_log_std = np.log(init_std)

        self.init_policy = None
        self.policy_params = None
        self.obs_var = None
        self.mean_var = None
        self.log_std_var = None
        self.action_var = None
        self._dist = None
        self._hidden_state = None
        self.recurrent = True
        self._cell_type = cell_type

        self.build_graph()
        self._zero_hidden = self.cell.zero_state(1, tf.float32)
Exemple #3
0
    def __init__(self,
                 obs_dim,
                 action_dim,
                 name='policy',
                 hidden_sizes=(32, 32),
                 learn_std=True,
                 hidden_nonlinearity='tanh',
                 output_nonlinearity=None,
                 **kwargs):
        Serializable.quick_init(self, locals())

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.name = name

        self.hidden_sizes = hidden_sizes
        self.learn_std = learn_std
        self.hidden_nonlinearity = self._activations[hidden_nonlinearity]
        self.output_nonlinearity = self._activations[output_nonlinearity]

        self._dist = None
        self.policy_params = None
        self._assign_ops = None
        self._assign_phs = None
        self.policy_params_keys = None
        self.policy_params_ph = None
Exemple #4
0
    def __init__(self, obs_dim, action_dim, name='np_policy', **kwargs):
        Serializable.quick_init(self, locals())

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.name = name

        self._dist = None
        self.policy_params = None
        self.policy_params_batch = None
        self._num_deltas = None
        self.obs_filters = [Filter((self.obs_dim, ))]
Exemple #5
0
 def __getstate__(self):
     state = {
         'init_args': Serializable.__getstate__(self),
         'network_params': self.get_params(),
         'filter':
         [obs_filter.get_params() for obs_filter in self.obs_filters],
     }
     return state
Exemple #6
0
    def __init__(self,
                 obs_dim,
                 action_dim,
                 name='np_policy',
                 hidden_sizes=(64, 64),
                 hidden_nonlinearity='tanh',
                 output_nonlinearity=None,
                 normalization='first',
                 **kwargs):
        Serializable.quick_init(self, locals())
        NpPolicy.__init__(self, obs_dim, action_dim, name, **kwargs)

        assert normalization in ['all', 'first', None, 'none']

        self.obs_filter = MeanStdFilter(shape=(obs_dim, ))
        self.hidden_nonlinearity = self._activations[hidden_nonlinearity]
        self.output_nonlinearity = self._activations[output_nonlinearity]
        self.hidden_sizes = hidden_sizes
        self.policy_params = OrderedDict()

        self.obs_filters = []
        prev_size = obs_dim
        for i, hidden_size in enumerate(hidden_sizes):
            W = np.zeros((hidden_size, prev_size), dtype=np.float64)
            b = np.zeros((hidden_size, ))

            self.policy_params['W_%d' % i] = W
            self.policy_params['b_%d' % i] = b

            if normalization == 'all' or (normalization == 'first' and i == 0):
                self.obs_filters.append(MeanStdFilter(shape=(prev_size, )))
            else:
                self.obs_filters.append(Filter(shape=(prev_size, )))

            prev_size = hidden_size

        if normalization == 'all' or (normalization == 'first'
                                      and len(hidden_sizes) == 0):
            self.obs_filters.append(MeanStdFilter(shape=(prev_size, )))
        else:
            self.obs_filters.append(Filter(shape=(prev_size, )))

        W = np.zeros((action_dim, prev_size), dtype=np.float64)
        b = np.zeros((action_dim, ))
        self.policy_params['W_out'] = W
        self.policy_params['b_out'] = b
Exemple #7
0
    def __init__(self,
                 *args,
                 squashed=False,
                 init_std=1.,
                 min_std=1e-6,
                 **kwargs):
        # store the init args for serialization and call the super constructors
        Serializable.quick_init(self, locals())
        Policy.__init__(self, *args, **kwargs)

        self.min_log_std = np.log(min_std)
        self.init_log_std = np.log(init_std)

        self.init_policy = None
        self.policy_params = None
        self.obs_var = None
        self.mean_var = None
        self.log_std_var = None
        self.action_var = None
        self._dist = None
        self.squashed = squashed

        self.build_graph()
Exemple #8
0
 def __setstate__(self, state):
     Serializable.__setstate__(self, state['init_args'])
     # tf.get_default_session().run(tf.global_variables_initializer())
     self.set_params(state['network_params'])
Exemple #9
0
 def __getstate__(self):
     state = {
         'init_args': Serializable.__getstate__(self),
         'network_params': self.get_param_values()
     }
     return state