def reset(self, param: ParamDict): assert self.policy_net is not None, "Error: you should call init before reset !" policy = param.require("policy state_dict") if "fixed policy" in param: self.is_fixed = param.require("fixed policy") self.policy_net.load_state_dict(policy)
def __init__(self, config: ParamDict, environment: Environment, policy: Policy, filter_op: Filter): seed, gpu = config.require("seed", "gpu") # replay buffer self._replay_buffer = [] self._batch_size = 0 # device and seed self.device = decide_device(gpu) self.seed = seed random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # policy which will be copied to child thread before each roll-out self._policy = deepcopy(policy) # environment which will be copied to child thread and not inited in main thread self._environment = deepcopy(environment) # filter which will be copied to child thread and also be kept in main thread self._filter = deepcopy(filter_op) self._filter.init() self._filter.to_device(self.device) self._policy.init() self._policy.to_device(self.device)
def broadcast(self, config: ParamDict): policy_state, filter_state, max_step, self._batch_size, fixed_env, fixed_policy, fixed_filter = \ config.require("policy state dict", "filter state dict", "trajectory max step", "batch size", "fixed environment", "fixed policy", "fixed filter") self._replay_buffer = [] policy_state["fixed policy"] = fixed_policy filter_state["fixed filter"] = fixed_filter cmd = ParamDict({"trajectory max step": max_step, "fixed environment": fixed_env, "filter state dict": filter_state}) assert self._sync_signal.value < 1, "Last sync event not finished due to some error, some sub-proc maybe died, abort" # tell sub-process to reset self._sync_signal.value = len(self._policy_proc) + len(self._environment_proc) # sync net parameters with self._policy_lock: for _ in range(len(self._policy_proc)): self._param_pipe.send(policy_state) # wait for all agents' ready feedback while self._sync_signal.value > 0: sleep(0.01) # sending commands with self._environment_lock: for _ in range(self._batch_size): self._control_pipe.send(cmd)
def bc_step(config: ParamDict, policy: Policy, demo): lr_init, lr_factor, l2_reg, bc_method, batch_sz, i_iter = \ config.require("lr", "lr factor", "l2 reg", "bc method", "batch size", "current training iter") states, actions = demo # ---- annealing on learning rate ---- # lr = max(lr_init + lr_factor * i_iter, 1.e-8) optimizer = Adam(policy.policy_net.parameters(), weight_decay=l2_reg, lr=lr) # ---- define BC from demonstrations ---- # total_len = states.size(0) idx = torch.randperm(total_len, device=policy.device) err = 0. for i_b in range(int(total_len // batch_sz) + 1): idx_b = idx[i_b * batch_sz:(i_b + 1) * batch_sz] s_b = states[idx_b] a_b = actions[idx_b] optimizer.zero_grad() a_mean_pred, a_logvar_pred = policy.policy_net(s_b) bc_loss = mse_loss(a_mean_pred + 0. * a_logvar_pred, a_b) err += bc_loss.item() * s_b.size(0) / total_len bc_loss.backward() optimizer.step() return err
def __init__(self, config: ParamDict, environment: Environment, policy: Policy, filter_op: Filter): threads, gpu = config.require("threads", "gpu") threads_gpu = config["gpu threads"] if "gpu threads" in config else 2 super(Agent_async, self).__init__(config, environment, policy, filter_op) # sync signal, -1: terminate, 0: normal running, >0 restart and waiting for parameter update self._sync_signal = Value('i', 0) # environment sub-process list self._environment_proc = [] # policy sub-process list self._policy_proc = [] # used for synchronize policy parameters self._param_pipe = None self._policy_lock = Lock() # used for synchronize roll-out commands self._control_pipe = None self._environment_lock = Lock() step_pipe = [] cmd_pipe_child, cmd_pipe_parent = Pipe(duplex=True) param_pipe_child, param_pipe_parent = Pipe(duplex=False) self._control_pipe = cmd_pipe_parent self._param_pipe = param_pipe_parent for i_envs in range(threads): child_name = f"environment_{i_envs}" step_pipe_pi, step_pipe_env = Pipe(duplex=True) step_lock = Lock() worker_cfg = ParamDict({ "seed": self.seed + 1024 + i_envs, "gpu": gpu }) child = Process(target=Agent_async._environment_worker, name=child_name, args=(worker_cfg, cmd_pipe_child, step_pipe_env, self._environment_lock, step_lock, self._sync_signal, deepcopy(environment), deepcopy(filter_op))) self._environment_proc.append(child) step_pipe.append((step_pipe_pi, step_lock)) child.start() for i_policies in range(threads_gpu): child_name = f"policy_{i_policies}" worker_cfg = ParamDict({ "seed": self.seed + 2048 + i_policies, "gpu": gpu }) child = Process(target=Agent_async._policy_worker, name=child_name, args=(worker_cfg, param_pipe_child, step_pipe, self._policy_lock, self._sync_signal, deepcopy(policy))) self._policy_proc.append(child) child.start() sleep(5)
def init(self, config: ParamDict): env_name, tag, short, seed = config.require("env name", "tag", "short", "seed") self.default_name = f"{tag}-{env_name}-{short}-{seed}" self._log_dir = log_dir(config) self.loggerx = SummaryWriter(log_dir=self._log_dir) self._epoch_train = 0 self._epoch_val = 0
def trpo_step(config: ParamDict, batch: StepDictList, policy: PolicyWithValue): max_kl, damping, l2_reg = config.require("max kl", "damping", "l2 reg") states, actions, advantages, returns = get_tensor(batch, policy.device) """update critic""" update_value_net(policy.value_net, states, returns, l2_reg) """update policy""" with torch.no_grad(): fixed_log_probs = policy.policy_net.get_log_prob(states, actions).detach() """define the loss function for TRPO""" def get_loss(volatile=False): if volatile: with torch.no_grad(): log_probs = policy.policy_net.get_log_prob(states, actions) else: log_probs = policy.policy_net.get_log_prob(states, actions) action_loss = -advantages * torch.exp(log_probs - fixed_log_probs) return action_loss.mean() """define Hessian*vector for KL""" def Fvp(v): kl = policy.policy_net.get_kl(states).mean() grads = torch.autograd.grad(kl, policy.policy_net.parameters(), create_graph=True) flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) kl_v = (flat_grad_kl * v).sum() grads = torch.autograd.grad(kl_v, policy.policy_net.parameters()) flat_grad_grad_kl = torch.cat( [grad.contiguous().view(-1) for grad in grads]).detach() return flat_grad_grad_kl + v * damping for _ in range(2): loss = get_loss() grads = torch.autograd.grad(loss, policy.policy_net.parameters()) loss_grad = torch.cat([grad.view(-1) for grad in grads]).detach() stepdir = cg(Fvp, -loss_grad, nsteps=10) shs = 0.5 * (stepdir.dot(Fvp(stepdir))) lm = (max_kl / shs).sqrt_() fullstep = stepdir * lm expected_improve = -loss_grad.dot(fullstep) prev_params = policy.policy_net.get_flat_params().detach() success, new_params = line_search(policy.policy_net, get_loss, prev_params, fullstep, expected_improve) policy.policy_net.set_flat_params(new_params) if not success: return False return True
def broadcast(self, config: ParamDict): policy_state, filter_state, self._max_step, self._batch_size, self._fixed_env, fixed_policy, fixed_filter = \ config.require("policy state dict", "filter state dict", "trajectory max step", "batch size", "fixed environment", "fixed policy", "fixed filter") self._replay_buffer = [] policy_state["fixed policy"] = fixed_policy filter_state["fixed filter"] = fixed_filter self._filter.reset(filter_state) self._policy.reset(policy_state)
def ppo_step(config: ParamDict, batch: SampleBatch, policy: PolicyWithValue): lr, l2_reg, clip_epsilon, policy_iter, i_iter, max_iter, mini_batch_sz = \ config.require("lr", "l2 reg", "clip eps", "optimize policy epochs", "current training iter", "max iter", "optimize batch size") lam_entropy = 0. states, actions, advantages, returns = get_tensor(batch, policy.device) lr_mult = max(1.0 - i_iter / max_iter, 0.) clip_epsilon = clip_epsilon * lr_mult optimizer_policy = Adam(policy.policy_net.parameters(), lr=lr * lr_mult, weight_decay=l2_reg) optimizer_value = Adam(policy.value_net.parameters(), lr=lr * lr_mult, weight_decay=l2_reg) with torch.no_grad(): fixed_log_probs = policy.policy_net.get_log_prob(states, actions).detach() for _ in range(policy_iter): inds = torch.randperm(states.size(0)) """perform mini-batch PPO update""" for i_b in range(inds.size(0) // mini_batch_sz): slc = slice(i_b * mini_batch_sz, (i_b + 1) * mini_batch_sz) states_i = states[slc] actions_i = actions[slc] returns_i = returns[slc] advantages_i = advantages[slc] log_probs_i = fixed_log_probs[slc] """update critic""" for _ in range(1): value_loss = F.mse_loss(policy.value_net(states_i), returns_i) optimizer_value.zero_grad() value_loss.backward() torch.nn.utils.clip_grad_norm_(policy.value_net.parameters(), 0.5) optimizer_value.step() """update policy""" log_probs, entropy = policy.policy_net.get_log_prob_entropy( states_i, actions_i) ratio = (log_probs - log_probs_i).clamp_max(15.).exp() surr1 = ratio * advantages_i surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages_i policy_surr = -torch.min( surr1, surr2).mean() - entropy.mean() * lam_entropy optimizer_policy.zero_grad() policy_surr.backward() torch.nn.utils.clip_grad_norm_(policy.policy_net.parameters(), 0.5) optimizer_policy.step()
def __init__(self, config_dict: ParamDict, env_info: dict): # require check failed mainly because you have not passed config to environment first, # which will fill-in action dim and state dim fields assert "state dim" in env_info and "action dim" in env_info,\ f"Error: Key 'state dim' or 'action dim' not in env_info, which only contains {env_info.keys()}" super(PolicyOnly, self).__init__() activation = config_dict.require("activation") self.policy_net = None # TODO: do some actual work here self._action_dim = env_info["action dim"] self._state_dim = env_info["state dim"] self.is_fixed = False self._activation = activation
def ppo_step(config: ParamDict, replay_memory: StepDictList, policy: PolicyWithValue): lr, l2_reg, clip_epsilon, policy_iter, i_iter, max_iter, mini_batch_sz = \ config.require("lr", "l2 reg", "clip eps", "optimize policy epochs", "current training iter", "max iter", "optimize batch size") lam_entropy = 0.0 states, actions, advantages, returns = get_tensor(policy, replay_memory, policy.device) """update critic""" update_value_net(policy.value_net, states, returns, l2_reg) """update policy""" lr_mult = max(1.0 - i_iter / max_iter, 0.) clip_epsilon = clip_epsilon * lr_mult optimizer = Adam(policy.policy_net.parameters(), lr=lr * lr_mult, weight_decay=l2_reg) with torch.no_grad(): fixed_log_probs = policy.policy_net.get_log_prob(states, actions).detach() inds = torch.arange(states.size(0)) for _ in range(policy_iter): np.random.shuffle(inds) """perform mini-batch PPO update""" for i_b in range(int(np.ceil(states.size(0) / mini_batch_sz))): ind = inds[i_b * mini_batch_sz:min((i_b + 1) * mini_batch_sz, inds.size(0))] log_probs, entropy = policy.policy_net.get_log_prob_entropy( states[ind], actions[ind]) ratio = torch.exp(log_probs - fixed_log_probs[ind]) surr1 = ratio * advantages[ind] surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages[ind] policy_surr = -torch.min(surr1, surr2).mean() policy_surr -= entropy.mean() * lam_entropy optimizer.zero_grad() policy_surr.backward() torch.nn.utils.clip_grad_norm_(policy.policy_net.parameters(), 0.5) optimizer.step()
def verify(self, config: ParamDict): policy_state, filter_state, max_step, max_iter, display, fixed_env, fixed_policy, fixed_filter =\ config.require("policy state dict", "filter state dict", "trajectory max step", "max iter", "display", "fixed environment", "fixed policy", "fixed filter") policy_state["fixed policy"] = fixed_policy filter_state["fixed filter"] = fixed_filter self._environment.init(display=display) self._policy.reset(policy_state) self._filter.reset(filter_state) r_sum = [] for i in range(max_iter): current_step = self._environment.reset(random=not fixed_env) # sampling r_sum.append(0.) i_step = 0 for _ in range(max_step): current_step = self._filter.operate_currentStep(current_step) last_step = self._policy.step([current_step])[0] last_step, current_step, done = self._environment.step( last_step) record_step = self._filter.operate_recordStep(last_step) r_sum[-1] += record_step['r'] i_step += 1 if done: break print(f"Verification iter {i}: reward={r_sum[-1]}; step={i_step}") # finalization self._environment.finalize() r_sum = np.asarray(r_sum, dtype=np.float32) r_sum_mean = r_sum.mean() r_sum_std = r_sum.std() r_sum_max = r_sum.max() r_sum_min = r_sum.min() print( f"Verification done with reward sum: avg={r_sum_mean}, std={r_sum_std}, min={r_sum_min}, max={r_sum_max}" ) return r_sum_mean, r_sum_std, r_sum_min, r_sum_max
def verify(self, config: ParamDict, environment: Environment): max_iter, max_step, random = \ config.require("verify iter", "verify max step", "verify random") r_sum = [0.] for _ in range(max_iter): current_step = environment.reset(random=random) # sampling for _ in range(max_step): last_step = self._policy.step([current_step])[0] last_step, current_step, done = environment.step(last_step) r_sum[-1] += last_step['r'] if done: break # finalization r_sum = torch.as_tensor(r_sum, dtype=torch.float32, device=self.device) r_sum_mean = r_sum.mean() r_sum_std = r_sum.std() r_sum_max = r_sum.max() r_sum_min = r_sum.min() print( f"Verification done with reward sum: avg={r_sum_mean}, std={r_sum_std}, min={r_sum_min}, max={r_sum_max}" ) return r_sum_mean, r_sum_std, r_sum_min, r_sum_max
def _policy_worker(setups: ParamDict, pipe_param, pipe_steps, read_lock, sync_signal, policy): gpu, seed = setups.require("gpu", "seed") device = decide_device(gpu) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) policy.to_device(device) policy.init() # -1: syncing, 0: waiting for state local_state = 0 max_batchsz = 8 def _get_piped_data(pipe, lock): with lock: if pipe.poll(): return pipe.recv() else: return None while sync_signal.value >= 0: # check sync counter for sync event, and waiting for new parameters if sync_signal.value > 0: # receive sync signal, reset all workspace settings, decrease sync counter, # and set state machine to -1 for not init again for _pipe, _lock in pipe_steps: while _get_piped_data(_pipe, _lock) is not None: pass if local_state >= 0: _policy_state = _get_piped_data(pipe_param, read_lock) if _policy_state is not None: # set new parameters policy.reset(_policy_state) with sync_signal.get_lock(): sync_signal.value -= 1 local_state = -1 else: sleep(0.01) # if sync ends, tell state machine to recover from syncing state, and reset environment elif sync_signal.value == 0 and local_state == -1: local_state = 0 # waiting for states (states are list of dicts) elif sync_signal.value == 0 and local_state == 0: idx = [] data = [] for i, (_pipe, _lock) in enumerate(pipe_steps): if len(idx) >= max_batchsz: break _steps = _get_piped_data(_pipe, _lock) if _steps is not None: data.append(_steps) idx.append(i) if len(idx) > 0: # prepare for data batch with torch.no_grad(): data = policy.step(data) # send back actions for i, d in zip(idx, data): with pipe_steps[i][1]: pipe_steps[i][0].send(d) else: sleep(0.00001) # finalization policy.finalize() pipe_param.close() for _pipe, _lock in pipe_steps: _pipe.close() print("Policy sub-process exited")
def mcpo_step(config: ParamDict, batch: SampleBatch, policy: PolicyWithValue, demo=None): global mmd_cache max_kl, damping, l2_reg, bc_method, d_init, d_factor, d_max, i_iter = \ config.require("max kl", "damping", "l2 reg", "bc method", "constraint", "constraint factor", "constraint max", "current training iter") states, actions, advantages, returns, demo_states, demo_actions = \ get_tensor(batch, demo, policy.device) # ---- annealing on constraint tolerance ---- # d = min(d_init + (d_factor * i_iter)**2, d_max) # ---- update critic ---- # update_value_net(policy.value_net, states, returns, l2_reg) # ---- update policy ---- # with torch.no_grad(): fixed_log_probs = policy.policy_net.get_log_prob(states, actions).detach() # ---- define the reward loss function for MCPO ---- # def RL_loss(): log_probs = policy.policy_net.get_log_prob(states, actions) action_loss = -advantages * torch.exp(log_probs - fixed_log_probs) return action_loss.mean() # ---- define the reward loss function for MCPO ---- # def Dkl(): kl = policy.policy_net.get_kl(states) return kl.mean() # ---- define MMD constraint from demonstrations ---- # # ---- if use oversample, recommended value is 5 ---- # def Dmmd(oversample=0): from_demo = torch.cat((demo_states, demo_actions), dim=1) # here uses distance between (s_e, a_e) and (s_e, pi_fix(s_e)) instead of (s, a) if not exact a, _, var = policy.policy_net(demo_states) if oversample > 0: sample_a = Normal(torch.zeros_like(a), torch.ones_like(a)).sample( (oversample, )) * var + a sample_s = demo_states.expand(oversample, -1, -1) from_policy = torch.cat( (sample_s, sample_a), dim=2).view(-1, demo_states.size(-1) + a.size(-1)) else: from_policy = torch.cat((demo_states, a + 0. * var), dim=1) return mmd(from_policy, from_demo, mmd_cache) def Dl2(): import torch.nn.functional as F # here we use the mean mean, logstd = policy.policy_net(demo_states) return F.mse_loss(mean + 0. * logstd, demo_actions) # ---- define BC from demonstrations ---- # if bc_method == "l2": Dc = Dl2 else: Dc = Dmmd def DO_BC(): assert demo_states is not None, "I should not arrive here with demos == None" dist = Dc() if dist > d: policy_optimizer = torch.optim.Adam(policy.policy_net.parameters()) #print(f"Debug: Constraint not meet, refining tile it satisfies {dist} < {d}") for _ in range(500): policy_optimizer.zero_grad() dist.backward() policy_optimizer.step() dist = Dc() if dist < d: break #print(f"Debug: BC margin={d - dist}") else: print(f"Debug: constraint meet, {dist.item()} < {d}") x = policy.policy_net.get_flat_params().detach() return x # ---- define grad funcs ---- # def Hvp_f(v, damping=damping): kl = Dkl() grads = torch.autograd.grad(kl, policy.policy_net.parameters(), create_graph=True) flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) kl_v = (flat_grad_kl * v.detach()).sum() grads = torch.autograd.grad(kl_v, policy.policy_net.parameters()) flat_grad_grad_kl = torch.cat([grad.view(-1) for grad in grads]).detach() return flat_grad_grad_kl + v * damping f0 = RL_loss() grads = torch.autograd.grad(f0, policy.policy_net.parameters()) g = torch.cat([grad.view(-1) for grad in grads]).detach() f0.detach_() if demo_states is None: d_value = -1.e7 c = 1.e7 b = torch.zeros_like(g) else: d_value = Dc() c = (d - d_value).detach() grads = torch.autograd.grad(d_value, policy.policy_net.parameters()) b = torch.cat([grad.view(-1) for grad in grads]).detach() # ---- update policy net with CG-LineSearch algorithm---- # d_theta, line_search_check_range, case = mcpo_optim( g, b, c, Hvp_f, max_kl, DO_BC) if torch.isnan(d_theta).any(): if torch.isnan(b).any(): print("b is NaN when Dc={}. Rejecting this step!".format(d_value)) else: print("net parameter is NaN. Rejecting this step!") success = False elif line_search_check_range is not None: expected_df = g @ d_theta success, new_params = line_search(policy.policy_net, expected_df, f0, RL_loss, Dc, d, d_theta, line_search_check_range) policy.policy_net.set_flat_params(new_params) else: # here d_theta is from BC, so skip line search procedure success = True return (case if success else -1), d_value
def reset(self, param: ParamDict): self.mean, self.errsum, self.n_step, self.is_fixed =\ param.require("zfilter mean", "zfilter errsum", "zfilter n_step", "fixed filter")
def _environment_worker(setups: ParamDict, pipe_cmd, pipe_step, read_lock, step_lock, sync_signal, environment, filter_op): gpu, seed = setups.require("gpu", "seed") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) environment.init(display=False) filter_op.to_device(torch.device("cpu")) filter_op.init() # -1: syncing, 0: waiting for command, 1: waiting for action local_state = 0 step_buffer = [] cmd = None def _get_piped_data(pipe, lock): with lock: if pipe.poll(0.001): return pipe.recv() else: return None while sync_signal.value >= 0: # check sync counter for sync event if sync_signal.value > 0 and local_state >= 0: # receive sync signal, reset all workspace settings, decrease sync counter, # and set state machine to -1 for not init again while _get_piped_data(pipe_cmd, read_lock) is not None: pass while _get_piped_data(pipe_step, step_lock) is not None: pass step_buffer.clear() with sync_signal.get_lock(): sync_signal.value -= 1 local_state = -1 # if sync ends, tell state machine to recover from syncing state, and reset environment elif sync_signal.value == 0 and local_state == -1: local_state = 0 # idle and waiting for new command elif sync_signal.value == 0 and local_state == 0: cmd = _get_piped_data(pipe_cmd, read_lock) if cmd is not None: step_buffer.clear() cmd.require("fixed environment", "trajectory max step") current_step = environment.reset(random=not cmd["fixed environment"]) filter_op.reset(cmd["filter state dict"]) policy_step = filter_op.operate_currentStep(current_step) with step_lock: pipe_step.send(policy_step) local_state = 1 # waiting for action elif sync_signal.value == 0 and local_state == 1: last_step = _get_piped_data(pipe_step, step_lock) if last_step is not None: last_step, current_step, done = environment.step(last_step) record_step = filter_op.operate_recordStep(last_step) step_buffer.append(record_step) if len(step_buffer) >= cmd["trajectory max step"] or done: traj = filter_op.operate_stepList(step_buffer, done=done) with read_lock: pipe_cmd.send(traj) local_state = 0 else: policy_step = filter_op.operate_currentStep(current_step) with step_lock: pipe_step.send(policy_step) # finalization environment.finalize() filter_op.finalize() pipe_cmd.close() pipe_step.close() print("Environment sub-process exited")