def __init__(self, input_size, output_size, num_env, num_step, gamma, lam=0.95, learning_rate=1e-4, ent_coef=0.01, clip_grad_norm=0.5, epoch=3, batch_size=128, ppo_eps=0.1, update_proportion=0.25, use_gae=True, use_cuda=False, use_noisy_net=False, hidden_dim=512): self.model = CnnActorCriticNetwork(input_size, output_size, use_noisy_net) self.num_env = num_env self.output_size = output_size self.input_size = input_size self.num_step = num_step self.gamma = gamma self.lam = lam self.epoch = epoch self.batch_size = batch_size self.use_gae = use_gae self.ent_coef = ent_coef self.ppo_eps = ppo_eps self.clip_grad_norm = clip_grad_norm self.update_proportion = update_proportion self.device = torch.device('cuda' if use_cuda else 'cpu') self.netG = NetG(z_dim=hidden_dim) #(input_size, z_dim=hidden_dim) self.netD = NetD(z_dim=1) self.netG.apply(weights_init) self.netD.apply(weights_init) self.optimizer_policy = optim.Adam(list(self.model.parameters()), lr=learning_rate) self.optimizer_G = optim.Adam(list(self.netG.parameters()), lr=learning_rate, betas=(0.5, 0.999)) self.optimizer_D = optim.Adam(list(self.netD.parameters()), lr=learning_rate, betas=(0.5, 0.999)) self.netG = self.netG.to(self.device) self.netD = self.netD.to(self.device) self.model = self.model.to(self.device)
def __init__( self, input_size, action_size, num_env, num_step, gamma, lam=0.95, learning_rate=1e-4, ent_coef=0.01, clip_grad_norm=0.5, epoch=3, batch_size=128, ppo_eps=0.1, update_proportion=0.25, use_gae=True, use_cuda=False, use_noisy_net=False, device=None, ): self.model = CnnActorCriticNetwork(input_size, action_size, use_noisy_net) self.num_env = num_env self.action_size = action_size self.input_size = input_size self.num_step = num_step self.gamma = gamma self.lam = lam self.epoch = epoch self.batch_size = batch_size self.use_gae = use_gae self.ent_coef = ent_coef self.ppo_eps = ppo_eps self.clip_grad_norm = clip_grad_norm self.update_proportion = update_proportion self.device = device if device is not None else torch.device( 'cuda' if use_cuda else 'cpu') self.rnd = RNDModel(input_size, action_size) self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.rnd.predictor.parameters()), lr=learning_rate) self.rnd = self.rnd.to(self.device) self.model = self.model.to(self.device) # varianse matrix self.action_std_eye = torch.eye(self.action_size).to(self.device) self.action_std_eye.requires_grad = False
def __init__(self, input_size, output_size, num_env, num_step, gamma, lam=0.95, learning_rate=1e-4, ent_coef=0.01, clip_grad_norm=0.5, epoch=3, batch_size=128, ppo_eps=0.1, eta=0.01, use_gae=True, use_cuda=False, use_noisy_net=False, gpu=None): # TODO difference between model and icm? self.model = CnnActorCriticNetwork(input_size, output_size, use_noisy_net) self.num_env = num_env self.output_size = output_size self.input_size = input_size self.num_step = num_step self.gamma = gamma self.lam = lam self.epoch = epoch self.batch_size = batch_size self.use_gae = use_gae self.ent_coef = ent_coef self.eta = eta self.ppo_eps = ppo_eps self.clip_grad_norm = clip_grad_norm self.device = gpu self.icm = ICMModel(input_size, output_size, gpu) self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.icm.parameters()), lr=learning_rate) self.icm = self.icm.cuda(gpu) self.model = self.model.cuda(gpu) self.icm = nn.parallel.DistributedDataParallel(self.icm, device_ids=[gpu]) self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[gpu])
def __init__(self, input_size, output_size, num_env, num_step, gamma, history_size=4, lam=0.95, learning_rate=1e-4, ent_coef=0.01, clip_grad_norm=0.5, epoch=3, batch_size=128, ppo_eps=0.1, update_proportion=0.25, use_gae=True, use_cuda=False, use_noisy_net=False, hidden_dim=512): self.model = CnnActorCriticNetwork(input_size, output_size, use_noisy_net, history_size) self.num_env = num_env self.output_size = output_size self.input_size = input_size self.num_step = num_step self.gamma = gamma self.lam = lam self.epoch = epoch self.batch_size = batch_size self.use_gae = use_gae self.ent_coef = ent_coef self.ppo_eps = ppo_eps self.clip_grad_norm = clip_grad_norm self.update_proportion = update_proportion self.device = torch.device('cuda' if use_cuda else 'cpu') self.vae = VAE(input_size, z_dim=hidden_dim) self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.vae.parameters()), lr=learning_rate) self.vae = self.vae.to(self.device) self.model = self.model.to(self.device)
def __init__(self, input_size, output_size, num_env, num_step, gamma, lam=0.95, learning_rate=1e-4, ent_coef=0.01, clip_grad_norm=0.5, epoch=3, batch_size=128, ppo_eps=0.1, eta=0.01, use_gae=True, use_cuda=False, use_noisy_net=False): self.model = CnnActorCriticNetwork(input_size, output_size, use_noisy_net) self.num_env = num_env self.output_size = output_size self.input_size = input_size self.num_step = num_step self.gamma = gamma self.lam = lam self.epoch = epoch self.batch_size = batch_size self.use_gae = use_gae self.ent_coef = ent_coef self.eta = eta self.ppo_eps = ppo_eps self.clip_grad_norm = clip_grad_norm self.device = torch.device('cuda' if use_cuda else 'cpu') self.icm = ICMModel(input_size, output_size, use_cuda) self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.icm.parameters()), lr=learning_rate) self.icm = self.icm.to(self.device) self.model = self.model.to(self.device)
def main(): args = get_args() device = torch.device('cuda' if args.cuda else 'cpu') env = gym.make(args.env_name) input_size = env.observation_space.shape # 4 output_size = env.action_space.n # 2 if 'Breakout' in args.env_name: output_size -= 1 env.close() is_render = True model_path = os.path.join(args.save_dir, args.env_name + '.model') if not os.path.exists(model_path): print("Model file not found") return num_worker = 1 sticky_action = False model = CnnActorCriticNetwork(input_size, output_size, args.use_noisy_net) model = model.to(device) if args.cuda: model.load_state_dict(torch.load(model_path)) else: model.load_state_dict(torch.load(model_path, map_location='cpu')) parent_conn, child_conn = Pipe() work = AtariEnvironment( args.env_name, is_render, 0, child_conn, sticky_action=sticky_action, p=args.sticky_action_prob, max_episode_steps=args.max_episode_steps) work.start() # states = np.zeros([num_worker, 4, 84, 84]) states = torch.zeros(num_worker, 4, 84, 84) while True: actions = get_action(model, device, torch.div(states, 255.)) parent_conn.send(actions) next_states = [] next_state, reward, done, real_done, log_reward = parent_conn.recv() next_states.append(next_state) states = torch.from_numpy(np.stack(next_states)) states = states.type(torch.FloatTensor)
def main(): args = get_args() device = torch.device('cuda' if args.cuda else 'cpu') seed = np.random.randint(0, 100) env = ObstacleTowerEnv('../ObstacleTower/obstacletower', worker_id=seed, retro=True, config={'total-floors': 12}, greyscale=True, timeout_wait=300) env._flattener = ActionFlattener([2, 3, 2, 1]) env._action_space = env._flattener.action_space input_size = env.observation_space.shape # 4 output_size = env.action_space.n # 2 env.close() is_render = False if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model_path = os.path.join(args.save_dir, 'main.model') predictor_path = os.path.join(args.save_dir, 'main.pred') target_path = os.path.join(args.save_dir, 'main.target') writer = SummaryWriter()#log_dir=args.log_dir) discounted_reward = RewardForwardFilter(args.ext_gamma) model = CnnActorCriticNetwork(input_size, output_size, args.use_noisy_net) rnd = RNDModel(input_size, output_size) model = model.to(device) rnd = rnd.to(device) optimizer = optim.Adam(list(model.parameters()) + list(rnd.predictor.parameters()), lr=args.lr) if args.load_model: "Loading model..." if args.cuda: model.load_state_dict(torch.load(model_path)) else: model.load_state_dict(torch.load(model_path, map_location='cpu')) works = [] parent_conns = [] child_conns = [] for idx in range(args.num_worker): parent_conn, child_conn = Pipe() work = AtariEnvironment( args.env_name, is_render, idx, child_conn, sticky_action=args.sticky_action, p=args.sticky_action_prob, max_episode_steps=args.max_episode_steps) work.start() works.append(work) parent_conns.append(parent_conn) child_conns.append(child_conn) states = np.zeros([args.num_worker, 4, 84, 84]) sample_env_index = 0 # Sample Environment index to log sample_episode = 0 sample_rall = 0 sample_step = 0 sample_i_rall = 0 global_update = 0 global_step = 0 print("Load RMS =", args.load_rms) if args.load_rms: print("Loading RMS values for observation and reward normalization") with open('reward_rms.pkl', 'rb') as f: reward_rms = dill.load(f) with open('obs_rms.pkl', 'rb') as f: obs_rms = dill.load(f) else: reward_rms = RunningMeanStd() obs_rms = RunningMeanStd(shape=(1, 1, 84, 84)) # normalize observation print('Initializing observation normalization...') next_obs = [] for step in range(args.num_step * args.pre_obs_norm_steps): actions = np.random.randint(0, output_size, size=(args.num_worker,)) for parent_conn, action in zip(parent_conns, actions): parent_conn.send(action) for parent_conn in parent_conns: next_state, reward, done, realdone, log_reward = parent_conn.recv() next_obs.append(next_state[3, :, :].reshape([1, 84, 84])) if len(next_obs) % (args.num_step * args.num_worker) == 0: next_obs = np.stack(next_obs) obs_rms.update(next_obs) next_obs = [] with open('reward_rms.pkl', 'wb') as f: dill.dump(reward_rms, f) with open('obs_rms.pkl', 'wb') as f: dill.dump(obs_rms, f) print('Training...') while True: total_state, total_reward, total_done, total_next_state, total_action, total_int_reward, total_next_obs, total_ext_values, total_int_values, total_action_probs = [], [], [], [], [], [], [], [], [], [] global_step += (args.num_worker * args.num_step) global_update += 1 # Step 1. n-step rollout for _ in range(args.num_step): actions, value_ext, value_int, action_probs = get_action(model, device, np.float32(states) / 255.) for parent_conn, action in zip(parent_conns, actions): parent_conn.send(action) next_states, rewards, dones, real_dones, log_rewards, next_obs = [], [], [], [], [], [] for parent_conn in parent_conns: next_state, reward, done, real_done, log_reward = parent_conn.recv() next_states.append(next_state) rewards.append(reward) dones.append(done) real_dones.append(real_done) log_rewards.append(log_reward) next_obs.append(next_state[3, :, :].reshape([1, 84, 84])) next_states = np.stack(next_states) rewards = np.hstack(rewards) dones = np.hstack(dones) real_dones = np.hstack(real_dones) next_obs = np.stack(next_obs) # total reward = int reward + ext Reward intrinsic_reward = compute_intrinsic_reward(rnd, device, ((next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5)) intrinsic_reward = np.hstack(intrinsic_reward) sample_i_rall += intrinsic_reward[sample_env_index] total_next_obs.append(next_obs) total_int_reward.append(intrinsic_reward) total_state.append(states) total_reward.append(rewards) total_done.append(dones) total_action.append(actions) total_ext_values.append(value_ext) total_int_values.append(value_int) total_action_probs.append(action_probs) states = next_states[:, :, :, :] sample_rall += log_rewards[sample_env_index] sample_step += 1 if real_dones[sample_env_index]: sample_episode += 1 writer.add_scalar('data/reward_per_epi', sample_rall, sample_episode) writer.add_scalar('data/reward_per_rollout', sample_rall, global_update) writer.add_scalar('data/step', sample_step, sample_episode) sample_rall = 0 sample_step = 0 sample_i_rall = 0 # calculate last next value _, value_ext, value_int, _ = get_action(model, device, np.float32(states) / 255.) total_ext_values.append(value_ext) total_int_values.append(value_int) # -------------------------------------------------- total_state = np.stack(total_state).transpose([1, 0, 2, 3, 4]).reshape([-1, 4, 84, 84]) total_reward = np.stack(total_reward).transpose().clip(-1, 1) total_action = np.stack(total_action).transpose().reshape([-1]) total_done = np.stack(total_done).transpose() total_next_obs = np.stack(total_next_obs).transpose([1, 0, 2, 3, 4]).reshape([-1, 1, 84, 84]) total_ext_values = np.stack(total_ext_values).transpose() total_int_values = np.stack(total_int_values).transpose() total_logging_action_probs = np.vstack(total_action_probs) # Step 2. calculate intrinsic reward # running mean intrinsic reward total_int_reward = np.stack(total_int_reward).transpose() total_reward_per_env = np.array([discounted_reward.update(reward_per_step) for reward_per_step in total_int_reward.T]) mean, std, count = np.mean(total_reward_per_env), np.std(total_reward_per_env), len(total_reward_per_env) reward_rms.update_from_moments(mean, std ** 2, count) # normalize intrinsic reward total_int_reward /= np.sqrt(reward_rms.var) writer.add_scalar('data/int_reward_per_epi', np.sum(total_int_reward) / args.num_worker, sample_episode) writer.add_scalar('data/int_reward_per_rollout', np.sum(total_int_reward) / args.num_worker, global_update) # ------------------------------------------------------------------------------------------- # logging Max action probability writer.add_scalar('data/max_prob', total_logging_action_probs.max(1).mean(), sample_episode) # Step 3. make target and advantage # extrinsic reward calculate ext_target, ext_adv = make_train_data(total_reward, total_done, total_ext_values, args.ext_gamma, args.gae_lambda, args.num_step, args.num_worker, args.use_gae) # intrinsic reward calculate # None Episodic int_target, int_adv = make_train_data(total_int_reward, np.zeros_like(total_int_reward), total_int_values, args.int_gamma, args.gae_lambda, args.num_step, args.num_worker, args.use_gae) # add ext adv and int adv total_adv = int_adv * args.int_coef + ext_adv * args.ext_coef # ----------------------------------------------- # Step 4. update obs normalize param obs_rms.update(total_next_obs) # ----------------------------------------------- # Step 5. Training! train_model(args, device, output_size, model, rnd, optimizer, np.float32(total_state) / 255., ext_target, int_target, total_action, total_adv, ((total_next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5), total_action_probs) if global_step % (args.num_worker * args.num_step * args.save_interval) == 0: print('Now Global Step :{}'.format(global_step)) torch.save(model.state_dict(), model_path) torch.save(rnd.predictor.state_dict(), predictor_path) torch.save(rnd.target.state_dict(), target_path) """ checkpoint_list = np.array([int(re.search(r"\d+(\.\d+)?", x)[0]) for x in glob.glob(os.path.join('trained_models', args.env_name+'*.model'))]) if len(checkpoint_list) == 0: last_checkpoint = -1 else: last_checkpoint = checkpoint_list.max() next_checkpoint = last_checkpoint + 1 print("Latest Checkpoint is #{}, saving checkpoint is #{}.".format(last_checkpoint, next_checkpoint)) incre_model_path = os.path.join(args.save_dir, args.env_name + str(next_checkpoint) + '.model') incre_predictor_path = os.path.join(args.save_dir, args.env_name + str(next_checkpoint) + '.pred') incre_target_path = os.path.join(args.save_dir, args.env_name + str(next_checkpoint) + '.target') with open(incre_model_path, 'wb') as f: torch.save(model.state_dict(), f) with open(incre_predictor_path, 'wb') as f: torch.save(rnd.predictor.state_dict(), f) with open(incre_target_path, 'wb') as f: torch.save(rnd.target.state_dict(), f) """ if args.terminate and (global_step > args.terminate_steps): with open('reward_rms.pkl', 'wb') as f: dill.dump(reward_rms, f) with open('obs_rms.pkl', 'wb') as f: dill.dump(obs_rms, f) break
class ICMAgent(object): def __init__( self, input_size, output_size, num_env, num_step, gamma, lam=0.95, learning_rate=1e-4, ent_coef=0.01, clip_grad_norm=0.5, epoch=3, batch_size=128, ppo_eps=0.1, eta=0.01, use_gae=True, use_cuda=False, use_noisy_net=False): self.model = CnnActorCriticNetwork(input_size, output_size, use_noisy_net) self.num_env = num_env self.output_size = output_size self.input_size = input_size self.num_step = num_step self.gamma = gamma self.lam = lam self.epoch = epoch self.batch_size = batch_size self.use_gae = use_gae self.ent_coef = ent_coef self.eta = eta self.ppo_eps = ppo_eps self.clip_grad_norm = clip_grad_norm self.device = torch.device('cuda' if use_cuda else 'cpu') self.icm = ICMModel(input_size, output_size, use_cuda) self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.icm.parameters()), lr=learning_rate) self.icm = self.icm.to(self.device) self.model = self.model.to(self.device) def get_action(self, state): state = torch.Tensor(state).to(self.device) state = state.float() policy, value = self.model(state) action_prob = F.softmax(policy, dim=-1).data.cpu().numpy() action = self.random_choice_prob_index(action_prob) return action, value.data.cpu().numpy().squeeze(), policy.detach() @staticmethod def random_choice_prob_index(p, axis=1): r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis) return (p.cumsum(axis=axis) > r).argmax(axis=axis) def compute_intrinsic_reward(self, state, next_state, action): state = torch.FloatTensor(state).to(self.device) next_state = torch.FloatTensor(next_state).to(self.device) action = torch.LongTensor(action).to(self.device) action_onehot = torch.FloatTensor( len(action), self.output_size).to( self.device) action_onehot.zero_() action_onehot.scatter_(1, action.view(len(action), -1), 1) real_next_state_feature, pred_next_state_feature, pred_action = self.icm( [state, next_state, action_onehot]) intrinsic_reward = self.eta * F.mse_loss(real_next_state_feature, pred_next_state_feature, reduction='none').mean(-1) return intrinsic_reward.data.cpu().numpy() def train_model(self, s_batch, next_s_batch, target_batch, y_batch, adv_batch, old_policy): s_batch = torch.FloatTensor(s_batch).to(self.device) next_s_batch = torch.FloatTensor(next_s_batch).to(self.device) target_batch = torch.FloatTensor(target_batch).to(self.device) y_batch = torch.LongTensor(y_batch).to(self.device) adv_batch = torch.FloatTensor(adv_batch).to(self.device) sample_range = np.arange(len(s_batch)) ce = nn.CrossEntropyLoss() forward_mse = nn.MSELoss() with torch.no_grad(): policy_old_list = torch.stack(old_policy).permute(1, 0, 2).contiguous().view(-1, self.output_size).to( self.device) m_old = Categorical(F.softmax(policy_old_list, dim=-1)) log_prob_old = m_old.log_prob(y_batch) # ------------------------------------------------------------ for i in range(self.epoch): np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] # -------------------------------------------------------------------------------- # for Curiosity-driven action_onehot = torch.FloatTensor(self.batch_size, self.output_size).to(self.device) action_onehot.zero_() action_onehot.scatter_(1, y_batch[sample_idx].view(-1, 1), 1) real_next_state_feature, pred_next_state_feature, pred_action = self.icm( [s_batch[sample_idx], next_s_batch[sample_idx], action_onehot]) inverse_loss = ce( pred_action, y_batch[sample_idx]) forward_loss = forward_mse( pred_next_state_feature, real_next_state_feature.detach()) # --------------------------------------------------------------------------------- policy, value = self.model(s_batch[sample_idx]) m = Categorical(F.softmax(policy, dim=-1)) log_prob = m.log_prob(y_batch[sample_idx]) ratio = torch.exp(log_prob - log_prob_old[sample_idx]) surr1 = ratio * adv_batch[sample_idx] surr2 = torch.clamp( ratio, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps) * adv_batch[sample_idx] actor_loss = -torch.min(surr1, surr2).mean() critic_loss = F.mse_loss( value.sum(1), target_batch[sample_idx]) entropy = m.entropy().mean() self.optimizer.zero_grad() loss = (actor_loss + 0.5 * critic_loss - 0.001 * entropy) + forward_loss + inverse_loss loss.backward() # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) self.optimizer.step()
class GANAgent(object): def __init__(self, input_size, output_size, num_env, num_step, gamma, lam=0.95, learning_rate=1e-4, ent_coef=0.01, clip_grad_norm=0.5, epoch=3, batch_size=128, ppo_eps=0.1, update_proportion=0.25, use_gae=True, use_cuda=False, use_noisy_net=False, hidden_dim=512): self.model = CnnActorCriticNetwork(input_size, output_size, use_noisy_net) self.num_env = num_env self.output_size = output_size self.input_size = input_size self.num_step = num_step self.gamma = gamma self.lam = lam self.epoch = epoch self.batch_size = batch_size self.use_gae = use_gae self.ent_coef = ent_coef self.ppo_eps = ppo_eps self.clip_grad_norm = clip_grad_norm self.update_proportion = update_proportion self.device = torch.device('cuda' if use_cuda else 'cpu') self.netG = NetG(z_dim=hidden_dim) #(input_size, z_dim=hidden_dim) self.netD = NetD(z_dim=1) self.netG.apply(weights_init) self.netD.apply(weights_init) self.optimizer_policy = optim.Adam(list(self.model.parameters()), lr=learning_rate) self.optimizer_G = optim.Adam(list(self.netG.parameters()), lr=learning_rate, betas=(0.5, 0.999)) self.optimizer_D = optim.Adam(list(self.netD.parameters()), lr=learning_rate, betas=(0.5, 0.999)) self.netG = self.netG.to(self.device) self.netD = self.netD.to(self.device) self.model = self.model.to(self.device) def reconstruct(self, state): state = torch.Tensor(state).to(self.device) state = state.float() reconstructed = self.vae(state.unsqueeze(0))[0].squeeze(0) return reconstructed.detach().cpu().numpy() def get_action(self, state): state = torch.Tensor(state).to(self.device) state = state.float() policy, value_ext, value_int = self.model(state) action_prob = F.softmax(policy, dim=-1).data.cpu().numpy() action = self.random_choice_prob_index(action_prob) return action, value_ext.data.cpu().numpy().squeeze( ), value_int.data.cpu().numpy().squeeze(), policy.detach() @staticmethod def random_choice_prob_index(p, axis=1): r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis) return (p.cumsum(axis=axis) > r).argmax(axis=axis) def compute_intrinsic_reward(self, obs): obs = torch.FloatTensor(obs).to(self.device) #embedding = self.vae.representation(obs) #reconstructed_embedding = self.vae.representation(self.vae(obs)[0]) # why use index[0] reconstructed_img, embedding, reconstructed_embedding = self.netG(obs) intrinsic_reward = (embedding - reconstructed_embedding ).pow(2).sum(1) / 2 # Not use reconstructed loss return intrinsic_reward.detach().cpu().numpy() def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch, adv_batch, next_obs_batch, old_policy): s_batch = torch.FloatTensor(s_batch).to(self.device) target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device) target_int_batch = torch.FloatTensor(target_int_batch).to(self.device) y_batch = torch.LongTensor(y_batch).to(self.device) adv_batch = torch.FloatTensor(adv_batch).to(self.device) next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device) sample_range = np.arange(len(s_batch)) #reconstruction_loss = nn.MSELoss(reduction='none')] l_adv = nn.MSELoss(reduction='none') l_con = nn.L1Loss(reduction='none') l_enc = nn.MSELoss(reduction='none') l_bce = nn.BCELoss(reduction='none') with torch.no_grad(): policy_old_list = torch.stack(old_policy).permute( 1, 0, 2).contiguous().view(-1, self.output_size).to(self.device) m_old = Categorical(F.softmax(policy_old_list, dim=-1)) log_prob_old = m_old.log_prob(y_batch) # ------------------------------------------------------------ #recon_losses = np.array([]) #kld_losses = np.array([]) mean_err_g_adv_per_batch = np.array([]) mean_err_g_con_per_batch = np.array([]) mean_err_g_enc_per_batch = np.array([]) mean_err_d_per_batch = np.array([]) for i in range(self.epoch): np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] # -------------------------------------------------------------------------------- # for generative curiosity (GAN loss) #gen_next_state, mu, logvar = self.vae(next_obs_batch[sample_idx]) ############### netG forward ############################################## gen_next_state, latent_i, latent_o = self.netG( next_obs_batch[sample_idx]) ############### netD forward ############################################## pred_real, feature_real = self.netD(next_obs_batch[sample_idx]) pred_fake, feature_fake = self.netD(gen_next_state) #d = len(gen_next_state.shape) #recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d))) ############### netG backward ############################################# self.optimizer_G.zero_grad() err_g_adv_per_img = l_adv( self.netD(next_obs_batch[sample_idx])[1], self.netD(gen_next_state)[1]).mean( axis=list(range(1, len(feature_real.shape)))) err_g_con_per_img = l_con( next_obs_batch[sample_idx], gen_next_state).mean( axis=list(range(1, len(gen_next_state.shape)))) err_g_enc_per_img = l_enc(latent_i, latent_o).mean(-1) #kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1) # TODO: keep this proportion of experience used for VAE update? # Proportion of experience used for VAE update img_num = len(err_g_con_per_img) mask = torch.rand(img_num).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) mean_err_g_adv = (err_g_adv_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_g_con = (err_g_con_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_g_enc = (err_g_enc_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) # hyperparameter weights: w_adv = 1 w_con = 50 w_enc = 1 mean_err_g = mean_err_g_adv * w_adv +\ mean_err_g_con * w_con +\ mean_err_g_enc * w_enc mean_err_g.backward(retain_graph=True) self.optimizer_G.step() mean_err_g_adv_per_batch = np.append( mean_err_g_adv_per_batch, mean_err_g_adv.detach().cpu().numpy()) mean_err_g_con_per_batch = np.append( mean_err_g_con_per_batch, mean_err_g_con.detach().cpu().numpy()) mean_err_g_enc_per_batch = np.append( mean_err_g_enc_per_batch, mean_err_g_enc.detach().cpu().numpy()) ############## netD backward ############################################## self.optimizer_D.zero_grad() real_label = torch.ones_like(pred_real).to(self.device) fake_label = torch.zeros_like(pred_fake).to(self.device) err_d_real_per_img = l_bce(pred_real, real_label) err_d_fake_per_img = l_bce(pred_fake, fake_label) mean_err_d_real = (err_d_real_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_d_fake = (err_d_fake_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_d = (mean_err_d_real + mean_err_d_fake) / 2 mean_err_d.backward() self.optimizer_D.step() mean_err_d_per_batch = np.append( mean_err_d_per_batch, mean_err_d.detach().cpu().numpy()) if mean_err_d.item() < 1e-5: self.netD.apply(weights_init) print('Reloading net d') ############# policy update ############################################### policy, value_ext, value_int = self.model(s_batch[sample_idx]) m = Categorical(F.softmax(policy, dim=-1)) log_prob = m.log_prob(y_batch[sample_idx]) ratio = torch.exp(log_prob - log_prob_old[sample_idx]) surr1 = ratio * adv_batch[sample_idx] surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps) * adv_batch[sample_idx] actor_loss = -torch.min(surr1, surr2).mean() critic_ext_loss = F.mse_loss(value_ext.sum(1), target_ext_batch[sample_idx]) critic_int_loss = F.mse_loss(value_int.sum(1), target_int_batch[sample_idx]) critic_loss = critic_ext_loss + critic_int_loss entropy = m.entropy().mean() self.optimizer_policy.zero_grad() loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy loss.backward() #global_grad_norm_(list(self.model.parameters())+list(self.vae.parameters())) do we need this step #global_grad_norm_(list(self.model.parameter())) or just norm policy self.optimizer_poilicy.step() return mean_err_g_adv_per_batch, mean_err_g_con_per_batch, mean_err_g_enc_per_batch, mean_err_d_per_batch def train_just_vae(self, s_batch, next_obs_batch): s_batch = torch.FloatTensor(s_batch).to(self.device) next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device) sample_range = np.arange(len(s_batch)) l_adv = nn.MSELoss(reduction='none') l_con = nn.L1Loss(reduction='none') l_enc = nn.MSELoss(reduction='none') l_bce = nn.BCELoss(reduction='none') mean_err_g_adv_per_batch = np.array([]) mean_err_g_con_per_batch = np.array([]) mean_err_g_enc_per_batch = np.array([]) mean_err_d_per_batch = np.array([]) for i in range(self.epoch): np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] ############### netG forward ############################################## gen_next_state, latent_i, latent_o = self.netG( next_obs_batch[sample_idx]) ############### netD forward ############################################## pred_real, feature_real = self.netD(next_obs_batch[sample_idx]) pred_fake, feature_fake = self.netD(gen_next_state) #d = len(gen_next_state.shape) #recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d))) ############### netG backward ############################################# self.optimizer_G.zero_grad() err_g_adv_per_img = l_adv( self.netD(next_obs_batch[sample_idx])[1], self.netD(gen_next_state)[1]).mean( axis=list(range(1, len(feature_real.shape)))) err_g_con_per_img = l_con( next_obs_batch[sample_idx], gen_next_state).mean( axis=list(range(1, len(gen_next_state.shape)))) err_g_enc_per_img = l_enc(latent_i, latent_o).mean(-1) #kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1) # TODO: keep this proportion of experience used for VAE update? # Proportion of experience used for VAE update img_num = len(err_g_con_per_img) mask = torch.rand(img_num).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) mean_err_g_adv = (err_g_adv_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_g_con = (err_g_con_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_g_enc = (err_g_enc_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) # hyperparameter weights: w_adv = 1 w_con = 50 w_enc = 1 mean_err_g = mean_err_g_adv * w_adv +\ mean_err_g_con * w_con +\ mean_err_g_enc * w_enc mean_err_g.backward(retain_graph=True) self.optimizer_G.step() mean_err_g_adv_per_batch = np.append( mean_err_g_adv_per_batch, mean_err_g_adv.detach().cpu().numpy()) mean_err_g_con_per_batch = np.append( mean_err_g_con_per_batch, mean_err_g_con.detach().cpu().numpy()) mean_err_g_enc_per_batch = np.append( mean_err_g_enc_per_batch, mean_err_g_enc.detach().cpu().numpy()) ############## netD backward ############################################## self.optimizer_D.zero_grad() real_label = torch.ones_like(pred_real).to(self.device) fake_label = torch.zeros_like(pred_fake).to(self.device) err_d_real_per_img = l_bce(pred_real, real_label) err_d_fake_per_img = l_bce(pred_fake, fake_label) mean_err_d_real = (err_d_real_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_d_fake = (err_d_fake_per_img * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) mean_err_d = (mean_err_d_real + mean_err_d_fake) / 2 mean_err_d.backward() self.optimizer_D.step() mean_err_d_per_batch = np.append( mean_err_d_per_batch, mean_err_d.detach().cpu().numpy()) return mean_err_g_adv_per_batch, mean_err_g_con_per_batch, mean_err_g_enc_per_batch, mean_err_d_per_batch
class RNDAgent(object): def __init__(self, input_size, output_size, num_env, num_step, gamma, lam=0.95, learning_rate=1e-4, ent_coef=0.01, clip_grad_norm=0.5, epoch=3, batch_size=128, ppo_eps=0.1, update_proportion=0.25, use_gae=True, use_cuda=False, use_noisy_net=False): self.model = CnnActorCriticNetwork(input_size, output_size, use_noisy_net) self.num_env = num_env self.output_size = output_size self.input_size = input_size self.num_step = num_step self.gamma = gamma self.lam = lam self.epoch = epoch self.batch_size = batch_size self.use_gae = use_gae self.ent_coef = ent_coef self.ppo_eps = ppo_eps self.clip_grad_norm = clip_grad_norm self.update_proportion = update_proportion self.device = torch.device('cuda' if use_cuda else 'cpu') self.rnd = RNDModel(input_size, output_size) self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.rnd.predictor.parameters()), lr=learning_rate) self.rnd = self.rnd.to(self.device) self.model = self.model.to(self.device) def get_action(self, state): state = torch.Tensor(state).to(self.device) state = state.float() policy, value_ext, value_int = self.model(state) action_prob = F.softmax(policy, dim=-1).data.cpu().numpy() action = self.random_choice_prob_index(action_prob) return action, value_ext.data.cpu().numpy().squeeze( ), value_int.data.cpu().numpy().squeeze(), policy.detach() @staticmethod def random_choice_prob_index(p, axis=1): r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis) return (p.cumsum(axis=axis) > r).argmax(axis=axis) def compute_intrinsic_reward(self, next_obs): next_obs = torch.FloatTensor(next_obs).to(self.device) target_next_feature = self.rnd.target(next_obs) predict_next_feature = self.rnd.predictor(next_obs) intrinsic_reward = (target_next_feature - predict_next_feature).pow(2).sum(1) / 2 return intrinsic_reward.data.cpu().numpy() def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch, adv_batch, next_obs_batch, old_policy): s_batch = torch.FloatTensor(s_batch).to(self.device) target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device) target_int_batch = torch.FloatTensor(target_int_batch).to(self.device) y_batch = torch.LongTensor(y_batch).to(self.device) adv_batch = torch.FloatTensor(adv_batch).to(self.device) next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device) sample_range = np.arange(len(s_batch)) forward_mse = nn.MSELoss(reduction='none') with torch.no_grad(): policy_old_list = torch.stack(old_policy).permute( 1, 0, 2).contiguous().view(-1, self.output_size).to(self.device) m_old = Categorical(F.softmax(policy_old_list, dim=-1)) log_prob_old = m_old.log_prob(y_batch) # ------------------------------------------------------------ for i in range(self.epoch): np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] # -------------------------------------------------------------------------------- # for Curiosity-driven(Random Network Distillation) predict_next_state_feature, target_next_state_feature = self.rnd( next_obs_batch[sample_idx]) forward_loss = forward_mse( predict_next_state_feature, target_next_state_feature.detach()).mean(-1) # Proportion of exp used for predictor update mask = torch.rand(len(forward_loss)).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) forward_loss = (forward_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) # --------------------------------------------------------------------------------- policy, value_ext, value_int = self.model(s_batch[sample_idx]) m = Categorical(F.softmax(policy, dim=-1)) log_prob = m.log_prob(y_batch[sample_idx]) ratio = torch.exp(log_prob - log_prob_old[sample_idx]) surr1 = ratio * adv_batch[sample_idx] surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps) * adv_batch[sample_idx] actor_loss = -torch.min(surr1, surr2).mean() critic_ext_loss = F.mse_loss(value_ext.sum(1), target_ext_batch[sample_idx]) critic_int_loss = F.mse_loss(value_int.sum(1), target_int_batch[sample_idx]) critic_loss = critic_ext_loss + critic_int_loss entropy = m.entropy().mean() self.optimizer.zero_grad() loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy + forward_loss loss.backward() global_grad_norm_( list(self.model.parameters()) + list(self.rnd.predictor.parameters())) self.optimizer.step()
def main(): args = get_args() device = torch.device('cuda' if args.cuda else 'cpu') env = gym.make(args.env_name) input_size = env.observation_space.shape # 4 output_size = env.action_space.n # 2 if 'Breakout' in args.env_name: output_size -= 1 env.close() is_render = False if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model_path = os.path.join(args.save_dir, args.env_name + '.model') predictor_path = os.path.join(args.save_dir, args.env_name + '.pred') target_path = os.path.join(args.save_dir, args.env_name + '.target') writer = SummaryWriter(log_dir=args.log_dir) reward_rms = RunningMeanStd() obs_rms = RunningMeanStd(shape=(1, 1, 84, 84)) discounted_reward = RewardForwardFilter(args.ext_gamma) model = CnnActorCriticNetwork(input_size, output_size, args.use_noisy_net) rnd = RNDModel(input_size, output_size) model = model.to(device) rnd = rnd.to(device) optimizer = optim.Adam(list(model.parameters()) + list(rnd.predictor.parameters()), lr=args.lr) if args.load_model: if args.cuda: model.load_state_dict(torch.load(model_path)) else: model.load_state_dict(torch.load(model_path, map_location='cpu')) works = [] parent_conns = [] child_conns = [] for idx in range(args.num_worker): parent_conn, child_conn = Pipe() work = AtariEnvironment(args.env_name, is_render, idx, child_conn, sticky_action=args.sticky_action, p=args.sticky_action_prob, max_episode_steps=args.max_episode_steps) work.start() works.append(work) parent_conns.append(parent_conn) child_conns.append(child_conn) states = np.zeros([args.num_worker, 4, 84, 84]) sample_env_index = 0 # Sample Environment index to log sample_episode = 0 sample_rall = 0 sample_step = 0 sample_i_rall = 0 global_update = 0 global_step = 0 # normalize observation print('Initializes observation normalization...') next_obs = [] for step in range(args.num_step * args.pre_obs_norm_steps): actions = np.random.randint(0, output_size, size=(args.num_worker, )) for parent_conn, action in zip(parent_conns, actions): parent_conn.send(action) for parent_conn in parent_conns: next_state, reward, done, realdone, log_reward = parent_conn.recv() next_obs.append(next_state[3, :, :].reshape([1, 84, 84])) if len(next_obs) % (args.num_step * args.num_worker) == 0: next_obs = np.stack(next_obs) obs_rms.update(next_obs) next_obs = [] print('Training...') while True: total_state, total_reward, total_done, total_next_state, total_action, total_int_reward, total_next_obs, total_ext_values, total_int_values, total_action_probs = [], [], [], [], [], [], [], [], [], [] global_step += (args.num_worker * args.num_step) global_update += 1 # Step 1. n-step rollout for _ in range(args.num_step): actions, value_ext, value_int, action_probs = get_action( model, device, np.float32(states) / 255.) for parent_conn, action in zip(parent_conns, actions): parent_conn.send(action) next_states, rewards, dones, real_dones, log_rewards, next_obs = [], [], [], [], [], [] for parent_conn in parent_conns: next_state, reward, done, real_done, log_reward = parent_conn.recv( ) next_states.append(next_state) rewards.append(reward) dones.append(done) real_dones.append(real_done) log_rewards.append(log_reward) next_obs.append(next_state[3, :, :].reshape([1, 84, 84])) next_states = np.stack(next_states) rewards = np.hstack(rewards) dones = np.hstack(dones) real_dones = np.hstack(real_dones) next_obs = np.stack(next_obs) # total reward = int reward + ext Reward intrinsic_reward = compute_intrinsic_reward( rnd, device, ((next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5)) intrinsic_reward = np.hstack(intrinsic_reward) sample_i_rall += intrinsic_reward[sample_env_index] total_next_obs.append(next_obs) total_int_reward.append(intrinsic_reward) total_state.append(states) total_reward.append(rewards) total_done.append(dones) total_action.append(actions) total_ext_values.append(value_ext) total_int_values.append(value_int) total_action_probs.append(action_probs) states = next_states[:, :, :, :] sample_rall += log_rewards[sample_env_index] sample_step += 1 if real_dones[sample_env_index]: sample_episode += 1 writer.add_scalar('data/reward_per_epi', sample_rall, sample_episode) writer.add_scalar('data/reward_per_rollout', sample_rall, global_update) writer.add_scalar('data/step', sample_step, sample_episode) sample_rall = 0 sample_step = 0 sample_i_rall = 0 # calculate last next value _, value_ext, value_int, _ = get_action(model, device, np.float32(states) / 255.) total_ext_values.append(value_ext) total_int_values.append(value_int) # -------------------------------------------------- total_state = np.stack(total_state).transpose([1, 0, 2, 3, 4]).reshape( [-1, 4, 84, 84]) total_reward = np.stack(total_reward).transpose().clip(-1, 1) total_action = np.stack(total_action).transpose().reshape([-1]) total_done = np.stack(total_done).transpose() total_next_obs = np.stack(total_next_obs).transpose( [1, 0, 2, 3, 4]).reshape([-1, 1, 84, 84]) total_ext_values = np.stack(total_ext_values).transpose() total_int_values = np.stack(total_int_values).transpose() total_logging_action_probs = np.vstack(total_action_probs) # Step 2. calculate intrinsic reward # running mean intrinsic reward total_int_reward = np.stack(total_int_reward).transpose() total_reward_per_env = np.array([ discounted_reward.update(reward_per_step) for reward_per_step in total_int_reward.T ]) mean, std, count = np.mean(total_reward_per_env), np.std( total_reward_per_env), len(total_reward_per_env) reward_rms.update_from_moments(mean, std**2, count) # normalize intrinsic reward total_int_reward /= np.sqrt(reward_rms.var) writer.add_scalar('data/int_reward_per_epi', np.sum(total_int_reward) / args.num_worker, sample_episode) writer.add_scalar('data/int_reward_per_rollout', np.sum(total_int_reward) / args.num_worker, global_update) # ------------------------------------------------------------------------------------------- # logging Max action probability writer.add_scalar('data/max_prob', total_logging_action_probs.max(1).mean(), sample_episode) # Step 3. make target and advantage # extrinsic reward calculate ext_target, ext_adv = make_train_data(total_reward, total_done, total_ext_values, args.ext_gamma, args.gae_lambda, args.num_step, args.num_worker, args.use_gae) # intrinsic reward calculate # None Episodic int_target, int_adv = make_train_data(total_int_reward, np.zeros_like(total_int_reward), total_int_values, args.int_gamma, args.gae_lambda, args.num_step, args.num_worker, args.use_gae) # add ext adv and int adv total_adv = int_adv * args.int_coef + ext_adv * args.ext_coef # ----------------------------------------------- # Step 4. update obs normalize param obs_rms.update(total_next_obs) # ----------------------------------------------- # Step 5. Training! train_model(args, device, output_size, model, rnd, optimizer, np.float32(total_state) / 255., ext_target, int_target, total_action, total_adv, ((total_next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5), total_action_probs) if global_step % (args.num_worker * args.num_step * args.save_interval) == 0: print('Now Global Step :{}'.format(global_step)) torch.save(model.state_dict(), model_path) torch.save(rnd.predictor.state_dict(), predictor_path) torch.save(rnd.target.state_dict(), target_path)
class GenerativeAgent(object): def __init__(self, input_size, output_size, num_env, num_step, gamma, history_size=4, lam=0.95, learning_rate=1e-4, ent_coef=0.01, clip_grad_norm=0.5, epoch=3, batch_size=128, ppo_eps=0.1, update_proportion=0.25, use_gae=True, use_cuda=False, use_noisy_net=False, hidden_dim=512): self.model = CnnActorCriticNetwork(input_size, output_size, use_noisy_net, history_size) self.num_env = num_env self.output_size = output_size self.input_size = input_size self.num_step = num_step self.gamma = gamma self.lam = lam self.epoch = epoch self.batch_size = batch_size self.use_gae = use_gae self.ent_coef = ent_coef self.ppo_eps = ppo_eps self.clip_grad_norm = clip_grad_norm self.update_proportion = update_proportion self.device = torch.device('cuda' if use_cuda else 'cpu') self.vae = VAE(input_size, z_dim=hidden_dim) self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.vae.parameters()), lr=learning_rate) self.vae = self.vae.to(self.device) self.model = self.model.to(self.device) def reconstruct(self, state): state = torch.Tensor(state).to(self.device) state = state.float() reconstructed = self.vae(state.unsqueeze(0))[0].squeeze(0) return reconstructed.detach().cpu().numpy() def get_action(self, state): state = torch.Tensor(state).to(self.device) state = state.float() policy, value_ext, value_int = self.model(state) action_prob = F.softmax(policy, dim=-1).data.cpu().numpy() action = self.random_choice_prob_index(action_prob) return action, value_ext.data.cpu().numpy().squeeze( ), value_int.data.cpu().numpy().squeeze(), policy.detach() @staticmethod def random_choice_prob_index(p, axis=1): r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis) return (p.cumsum(axis=axis) > r).argmax(axis=axis) def compute_intrinsic_reward(self, obs): obs = torch.FloatTensor(obs).to(self.device) embedding = self.vae.representation(obs) reconstructed_embedding = self.vae.representation(self.vae(obs)[0]) intrinsic_reward = (embedding - reconstructed_embedding).pow(2).sum(1) / 2 return intrinsic_reward.detach().cpu().numpy() def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch, adv_batch, next_obs_batch, old_policy): s_batch = torch.FloatTensor(s_batch).to(self.device) target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device) target_int_batch = torch.FloatTensor(target_int_batch).to(self.device) y_batch = torch.LongTensor(y_batch).to(self.device) adv_batch = torch.FloatTensor(adv_batch).to(self.device) next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device) sample_range = np.arange(len(s_batch)) reconstruction_loss = nn.MSELoss(reduction='none') with torch.no_grad(): policy_old_list = torch.stack(old_policy).permute( 1, 0, 2).contiguous().view(-1, self.output_size).to(self.device) m_old = Categorical(F.softmax(policy_old_list, dim=-1)) log_prob_old = m_old.log_prob(y_batch) # ------------------------------------------------------------ recon_losses = np.array([]) kld_losses = np.array([]) for i in range(self.epoch): np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] # -------------------------------------------------------------------------------- # for generative curiosity (VAE loss) gen_next_state, mu, logvar = self.vae( next_obs_batch[sample_idx]) d = len(gen_next_state.shape) recon_loss = reconstruction_loss( gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d))) kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1) # TODO: keep this proportion of experience used for VAE update? # Proportion of experience used for VAE update mask = torch.rand(len(recon_loss)).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) recon_loss = (recon_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) kld_loss = (kld_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) recon_losses = np.append(recon_losses, recon_loss.detach().cpu().numpy()) kld_losses = np.append(kld_losses, kld_loss.detach().cpu().numpy()) # --------------------------------------------------------------------------------- policy, value_ext, value_int = self.model(s_batch[sample_idx]) m = Categorical(F.softmax(policy, dim=-1)) log_prob = m.log_prob(y_batch[sample_idx]) ratio = torch.exp(log_prob - log_prob_old[sample_idx]) surr1 = ratio * adv_batch[sample_idx] surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps) * adv_batch[sample_idx] actor_loss = -torch.min(surr1, surr2).mean() critic_ext_loss = F.mse_loss(value_ext.sum(1), target_ext_batch[sample_idx]) critic_int_loss = F.mse_loss(value_int.sum(1), target_int_batch[sample_idx]) critic_loss = critic_ext_loss + critic_int_loss entropy = m.entropy().mean() self.optimizer.zero_grad() loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy + recon_loss + kld_loss loss.backward() global_grad_norm_( list(self.model.parameters()) + list(self.vae.parameters())) self.optimizer.step() return recon_losses, kld_losses def train_just_vae(self, s_batch, next_obs_batch): s_batch = torch.FloatTensor(s_batch).to(self.device) next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device) sample_range = np.arange(len(s_batch)) reconstruction_loss = nn.MSELoss(reduction='none') recon_losses = np.array([]) kld_losses = np.array([]) for i in range(self.epoch): np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] # -------------------------------------------------------------------------------- # for generative curiosity (VAE loss) gen_next_state, mu, logvar = self.vae( next_obs_batch[sample_idx]) d = len(gen_next_state.shape) recon_loss = -1 * pytorch_ssim.ssim(gen_next_state, next_obs_batch[sample_idx], size_average=False) # recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d))) kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1) # TODO: keep this proportion of experience used for VAE update? # Proportion of experience used for VAE update mask = torch.rand(len(recon_loss)).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) recon_loss = (recon_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) kld_loss = (kld_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) recon_losses = np.append(recon_losses, recon_loss.detach().cpu().numpy()) kld_losses = np.append(kld_losses, kld_loss.detach().cpu().numpy()) # --------------------------------------------------------------------------------- self.optimizer.zero_grad() loss = recon_loss + kld_loss loss.backward() global_grad_norm_(list(self.vae.parameters())) self.optimizer.step() return recon_losses, kld_losses