torch.manual_seed(opts.seed) env = create_atari_env(opts.env_name, opts) trained_model = ActorCritic(env.observation_space.shape[0], env.action_space) # load a pre-trained model according to the ft-setting if opts.ft_setting == 'full-ft': if opts.env_name == 'BreakoutDeterministic-v4': fname = './agent/trained_model/breakout/11000.pth.tar' elif opts.env_name == 'PongDeterministic-v4': fname = './agent/trained_model/pong/4000.pth.tar' else: sys.exit('Only support Break or Pong') if os.path.isfile(fname): checkpoint = torch.load(fname) trained_model.load_state_dict(checkpoint['state_dict']) for param in trained_model.parameters(): param.requires_grad = True print(f"{fname}\n Model was loaded successfully") max_iter = config['max_iter'] config['vgg_model_path'] = opts.output_path # Setup model and data loader trainer = RL_VAEGAN(config) trainer.cuda() # Setup output folders output_directory = opts.output_path + '/output/' + opts.env_name checkpoint_directory, result_directory = prepare_sub_folder(output_directory)
class ActorCriticAgent(): def __init__(self): # if gpu is to be used device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = device self.policy_net = ActorCritic().to(device).double() self.target_net = ActorCritic().to(device).double() self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() self.lr = 1e-5 self.optimizer = optim.Adam([ { "params": self.policy_net.head_a_m.parameters() }, { "params": self.policy_net.head_a_t.parameters() }, { "params": self.policy_net.fc.parameters() }, ], lr=self.lr) self.optimizer2 = optim.Adam([ { "params": self.policy_net.head_v.parameters() }, { "params": self.policy_net.fc.parameters() }, ], lr=self.lr) self.memory = ReplayMemory(100000) def preprocess(self, state: RobotState): return torch.tensor([state.scan]).double() # 1x2xseq def run_AC(self, tensor_state): with torch.no_grad(): self.target_net.eval() a_m, a_t, v = self.target_net(tensor_state.to(self.device)) a_m = a_m.cpu().numpy()[0] # left, ahead, right a_t = a_t.cpu().numpy()[0] # turn left, stay, right return a_m, a_t def decode_action(self, a_m, a_t, state, mode): if mode == "max_probability": a_m = np.argmax(a_m) a_t = np.argmax(a_t) elif mode == "sample": #a_m += 0.01 a_m /= a_m.sum() a_m = np.random.choice(range(3), p=a_m) #a_t += 0.01 a_t /= a_t.sum() a_t = np.random.choice(range(3), p=a_t) action = Action() if a_m == 0: # left action.v_n = -1.0 elif a_m == 1: # ahead action.v_t = +1.0 elif a_m == 2: # right action.v_n = +1.0 if a_t == 0: # left action.angular = +1.0 elif a_t == 1: # stay action.angular = 0.0 elif a_t == 2: # right action.angular = -1.0 if state.detect: action.shoot = +1.0 else: action.shoot = 0.0 return action def select_action(self, state, mode): tensor_state = self.preprocess(state).to(self.device) a_m, a_t = self.run_AC(tensor_state) action = self.decode_action(a_m, a_t, state, mode) return action def push(self, state, next_state, action, reward): self.memory.push(state, action, next_state, reward) def make_state_map(self, state): return torch.cat(state, dim=0).double() # batchx2xseq def sample_memory(self, is_test=False): device = self.device transitions = self.memory.sample(BATCH_SIZE, is_test) # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for # detailed explanation). This converts batch-array of Transitions # to Transition of batch-arrays. batch = Transition(*zip(*transitions)) #state_batch = torch.cat(batch.state).to(device) #next_state_batch = torch.cat(batch.next_state).to(device) state_batch = self.make_state_map(batch.state).double() next_state_batch = self.make_state_map(batch.next_state).double() action_batch = torch.tensor(batch.action).double() reward_batch = torch.tensor(batch.reward).double() return state_batch, action_batch, reward_batch, next_state_batch def optimize_once(self, data): state_batch, action_batch, reward_batch, next_state_batch = data device = self.device state_batch = state_batch.to(device) action_batch = action_batch.to(device) reward_batch = reward_batch.to(device) next_state_batch = next_state_batch.to(device) self.policy_net.train() state_batch = Variable(state_batch, requires_grad=True) a_m, a_t, value_eval = self.policy_net(state_batch) # batch, 1, 10, 16 ### Critic ### td_error = reward_batch - value_eval loss = nn.MSELoss()(value_eval, reward_batch) self.optimizer2.zero_grad() loss.backward(retain_graph=True) self.optimizer2.step() ### Actor ### #prob = x.gather(1, (action_batch[:,0:1]*32).long()) * y.gather(1, (action_batch[:,1:2]*20).long()) prob_m = a_m.gather(1, action_batch[:, 0].long()) prob_t = a_t.gather(1, action_batch[:, 1].long()) log_prob = torch.log(prob_m * prob_t + 1e-6) exp_v = torch.mean(log_prob * td_error.detach()) loss = -exp_v + F.smooth_l1_loss(value_eval, reward_batch) self.optimizer.zero_grad() loss.backward() # for param in self.model.parameters(): # if param.grad is not None: #param.grad.data.clamp_(-1, 1) self.optimizer.step() return loss.item() def optimize_online(self): if len(self.memory) < BATCH_SIZE: return data = self.sample_memory() loss = self.optimize_once(data) return loss def test_model(self): if len(self.memory) < BATCH_SIZE: return state_batch, action_batch, reward_batch, next_state_batch = self.sample_memory( True) device = self.device state_batch = state_batch.to(device) action_batch = action_batch.to(device) reward_batch = reward_batch.to(device) next_state_batch = next_state_batch.to(device) with torch.no_grad(): self.target_net.eval() a_m, a_t, value_eval = self.target_net( state_batch) # batch, 1, 10, 16 ### Critic ### td_error = reward_batch - value_eval ### Actor ### #prob = x.gather(1, (action_batch[:,0:1]*32).long()) * y.gather(1, (action_batch[:,1:2]*20).long()) prob_m = a_m.gather(1, action_batch[:, 0:1].long()) prob_t = a_t.gather(1, action_batch[:, 1:2].long()) log_prob = torch.log(prob_m * prob_t + 1e-6) exp_v = torch.mean(log_prob * td_error.detach()) loss = -exp_v return loss.item() def save_model(self, file_path): torch.save(self.policy_net.state_dict(), file_path) def save_memory(self, file_path): torch.save(self.memory, file_path) def load_model(self, file_path): self.policy_net.load_state_dict( torch.load(file_path, map_location=self.device)) # FIXME 开场直接加载已获得参数作为经验 # self.update_target_net() def load_memory(self, file_path): self.memory = torch.load(file_path) def optimize_offline(self, num_epoch): def batch_state_map(transitions): batch = Transition(*zip(*transitions)) state_batch = self.make_state_map(batch.state).double() next_state_batch = self.make_state_map(batch.next_state).double() action_batch = torch.tensor(batch.action).double() reward_batch = torch.tensor(batch.reward).double() return state_batch, action_batch, reward_batch, next_state_batch dataloader = DataLoader(self.memory.main_memory, batch_size=BATCH_SIZE, shuffle=True, collate_fn=batch_state_map, num_workers=0, pin_memory=True) device = self.device for epoch in range(num_epoch): #print("Train epoch: [{}/{}]".format(epoch, num_epoch)) for data in (dataloader): loss = self.optimize_once(data) #loss = self.test_model() #print("Test loss: {}".format(loss)) return loss def decay_LR(self, decay): self.lr *= decay self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr) def update_target_net(self): self.target_net.load_state_dict(self.policy_net.state_dict())
def train(rank, args, shared_model, counter, lock, optimizer=None): print('Train with A3C') torch.manual_seed(args.seed + rank) env = create_atari_env(args.env_name, args) env.seed(args.seed + rank) model = ActorCritic(env.observation_space.shape[0], env.action_space) if optimizer is None: optimizer = optim.Adam(shared_model.parameters(), lr=args.lr) model.train() output_directory = 'outputs/' + args.env_name checkpoint_directory, result_directory = prepare_sub_folder( output_directory) print(f'checkpoint directory {checkpoint_directory}') time.sleep(10) state = env.reset() state = torch.from_numpy(state) done = True episode_length = 0 total_step = 0 rewards_ep = [] policy_loss_ep = [] value_loss_ep = [] for epoch in range(100000000): # Sync with the shared model model.load_state_dict(shared_model.state_dict()) values = [] log_probs = [] rewards = [] entropies = [] # for step in range(args.num_steps): is_Terminal = False while not is_Terminal: episode_length += 1 total_step += 1 value, logit = model(state.unsqueeze(0)) prob = F.softmax(logit, dim=-1) log_prob = F.log_softmax(logit, dim=-1) entropy = -(log_prob * prob).sum(1, keepdim=True) entropies.append(entropy) action = prob.multinomial(num_samples=1).detach() log_prob = log_prob.gather(1, action) state, reward, done, _ = env.step(action.numpy()) done = done or episode_length >= args.max_episode_length reward = max(min(reward, 1), -1) with lock: counter.value += 1 if done: # print(episode_length) print( f'epoch {epoch} - steps {total_step} - total rewards {np.sum(rewards) + reward}' ) total_step = 1 episode_length = 0 state = env.reset() state = torch.from_numpy(state) values.append(value) log_probs.append(log_prob) rewards.append(reward) if done: rewards_ep.append(np.sum(rewards)) is_Terminal = True # break R = torch.zeros(1, 1) if not done: value, _ = model(state.unsqueeze(0)) R = value.detach() values.append(R) policy_loss = 0 value_loss = 0 gae = torch.zeros(1, 1) for i in reversed(range(len(rewards))): R = args.gamma * R + rewards[i] advantage = R - values[i] value_loss = value_loss + 0.5 * advantage.pow(2) # Generalized Advantage Estimataion delta_t = rewards[i] + args.gamma * \ values[i + 1] - values[i] gae = gae * args.gamma * args.tau + delta_t policy_loss = policy_loss - \ log_probs[i] * gae.detach() - args.entropy_coef * entropies[i] optimizer.zero_grad() policy_loss_ep.append(policy_loss.detach().numpy()[0, 0]) value_loss_ep.append(value_loss.detach().numpy()[0, 0]) (policy_loss + args.value_loss_coef * value_loss).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) ensure_shared_grads(model, shared_model) optimizer.step() if epoch % 1000 == 0: torch.save({'state_dict': model.state_dict()}, checkpoint_directory + '/' + str(epoch) + ".pth.tar") with open(result_directory + '/' + str(epoch) + '_rewards.pkl', 'wb') as f: pickle.dump(rewards_ep, f) with open(result_directory + '/' + str(epoch) + '_policy_loss.pkl', 'wb') as f: pickle.dump(policy_loss_ep, f) with open(result_directory + '/' + str(epoch) + '_value_loss.pkl', 'wb') as f: pickle.dump(value_loss_ep, f) if episode_length >= 10000000: break torch.save({ 'state_dict': model.state_dict(), }, checkpoint_directory + '/Last' + ".pth.tar")