def __init__( self, s_dim, visual_sources, visual_resolution, a_dim, is_continuous, options_num=4, dc=0.01, terminal_mask=False, eps=0.1, epoch=4, pi_beta=1.0e-3, lr=5.0e-4, lambda_=0.95, epsilon=0.2, value_epsilon=0.2, kl_reverse=False, kl_target=0.02, kl_target_cutoff=2, kl_target_earlystop=4, kl_beta=[0.7, 1.3], kl_alpha=1.5, kl_coef=1.0, hidden_units={ 'share': [32, 32], 'q': [32, 32], 'intra_option': [32, 32], 'termination': [32, 32] }, **kwargs): super().__init__(s_dim=s_dim, visual_sources=visual_sources, visual_resolution=visual_resolution, a_dim=a_dim, is_continuous=is_continuous, **kwargs) self.pi_beta = pi_beta self.epoch = epoch self.lambda_ = lambda_ self.epsilon = epsilon self.value_epsilon = value_epsilon self.kl_reverse = kl_reverse self.kl_target = kl_target self.kl_alpha = kl_alpha self.kl_coef = tf.constant(kl_coef, dtype=tf.float32) self.kl_cutoff = kl_target * kl_target_cutoff self.kl_stop = kl_target * kl_target_earlystop self.kl_low = kl_target * kl_beta[0] self.kl_high = kl_target * kl_beta[-1] self.options_num = options_num self.dc = dc self.terminal_mask = terminal_mask self.eps = eps self.net = NetWork(self.feat_dim, self.a_dim, self.options_num, hidden_units, self.is_continuous) if self.is_continuous: self.log_std = tf.Variable(initial_value=-0.5 * np.ones( (self.options_num, self.a_dim), dtype=np.float32), trainable=True) # [P, A] self.net_tv = self.net.trainable_variables + [self.log_std ] + self.other_tv else: self.net_tv = self.net.trainable_variables + self.other_tv self.lr = self.init_lr(lr) self.optimizer = self.init_optimizer(self.lr) self.model_recorder(dict(model=self.net, optimizer=self.optimizer)) self.initialize_data_buffer(data_name_list=[ 's', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done', 'value', 'log_prob', 'beta_adv', 'last_options', 'options' ])
def _net(): return NetWork(self.feat_dim, self.a_dim, hidden_units) self.dueling_net = _net()
def _q_net(): return NetWork(self.feat_dim, self.a_dim, hidden_units)
def _net(): return NetWork(self.feat_dim, self.a_dim, self.quantiles_idx, hidden_units) self.q_net = _net()