예제 #1
0
파일: main.py 프로젝트: ColorlessBoy/SAC
def run(args):
    env = gym.make(args.env)

    device = torch.device(args.device)

    # 1. Set some necessary seed.
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    env.seed(args.seed)

    # 2. Create nets.
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.shape[0]
    hidden_sizes = (256, 256)
    ac = ActorCritic(state_size, action_size, hidden_sizes).to(device)
    ac_target = ActorCritic(state_size, action_size, hidden_sizes).to(device)
    hard_update(ac, ac_target)

    # env_sampler = EnvSampler(env, max_episode_step=4000, capacity=1e6)
    env_sampler = EnvSampler2(env, gamma=args.gamma1, capacity=1e6)

    alg = SAC(ac,
              ac_target,
              gamma=args.gamma2,
              alpha=0.2,
              q_lr=1e-3,
              pi_lr=1e-3,
              target_lr=5e-3,
              device=device)

    def get_action(state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        return ac_target.get_action(state)

    def get_mean_action(state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        return ac_target.get_action(state, deterministic=True)

    start_time = time()
    for _ in range(args.start_steps):
        env_sampler.addSample()
    print("Warmup uses {}s.".format(time() - start_time))

    for step in range(1, args.total_steps + 1):
        env_sampler.addSample(get_action)

        if step % args.update_every == 0:
            for _ in range(args.update_every):
                batch = env_sampler.sample(args.batch_size)
                losses = alg.update(*batch)

        if step % args.test_every == 0:
            test_reward = env_sampler.test(get_mean_action)
            yield (step, test_reward, *losses)

    torch.save(ac.pi.state_dict(), './env_{}_pi_net.pth.tar'.format(args.env))
예제 #2
0
def global_test(global_model, device, args, model_type, delay=0.03):
    world = args.world
    stage = args.stage
    env = create_env(world, stage)
    device = device
    state = env.reset()
    state = (env.reset()).to(device, dtype=torch.float)

    state = state.view(1, 1, 80, 80)
    done = True

    if (model_type == "LSTM"):
        model = ActorCritic_LSTM().to(device)
    else:
        model = ActorCritic().to(device)

    model.eval()
    model.load_state_dict(global_model.state_dict())

    while (True):
        if done:
            h_0 = torch.zeros((1, 512), dtype=torch.float)
            c_0 = torch.zeros((1, 512), dtype=torch.float)
        else:
            h_0 = h_0.detach()
            c_0 = c_0.detach()

        h_0 = h_0.to(device)
        c_0 = c_0.to(device)

        env.render()
        p, _, h_0, c_0 = model(state, h_0, c_0)
        policy = F.softmax(p, dim=1)
        action = torch.argmax(policy)

        next_state, _, done, info = env.step(action.item())

        next_state = (next_state).to(device, dtype=torch.float)
        next_state = next_state.view(1, 1, 80, 80)

        state = next_state
        if (done):
            if (info['flag_get']):
                break
            state = env.reset()
            state = state.to(device)
            state = state.view(1, 1, 80, 80)
            model.load_state_dict(global_model.state_dict())
        time.sleep(delay)
    print('Success clear {}-{}'.format(world, stage))
예제 #3
0
    def __init__(self, nb_actions, learning_rate, gamma, hidden_size,
                 model_input_size, entropy_coeff_start, entropy_coeff_end,
                 entropy_coeff_anneal, continuous):

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.num_actions = nb_actions

        self.gamma = gamma

        self.continuous = continuous

        self.learning_rate = learning_rate

        self.entropy_coefficient_start = entropy_coeff_start
        self.entropy_coefficient_end = entropy_coeff_end
        self.entropy_coefficient_anneal = entropy_coeff_anneal

        self.step_no = 0
        if self.continuous:
            self.model = ActorCriticContinuous(hidden_size=hidden_size,
                                               inputs=model_input_size,
                                               outputs=nb_actions).to(
                                                   self.device)
        else:
            self.model = ActorCritic(hidden_size=hidden_size,
                                     inputs=model_input_size,
                                     outputs=nb_actions).to(self.device)

        self.hidden_size = hidden_size
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.learning_rate)

        self.loss_function = torch.nn.MSELoss()

        self.memory = []

        self.ICM = ICM(model_input_size, nb_actions)
        self.ICM.train()
예제 #4
0
파일: worker.py 프로젝트: cyzhung/A3C
    def run(self):
        #self.global_model=self.global_model.to(self.device)
        if(self.args.model_type == "LSTM"):
            self.AC=ActorCritic_LSTM()
        else:
            self.AC=ActorCritic()

        #optimizer_to(self.optimizer,self.device)
        env = create_env(self.world,self.stage)
        state=(env.reset())
        #state=state.reshape(1,1,80,80)
        state=(state).to(self.device,dtype=torch.float)

        #state=self.imageProcess(state) 
        
        i_epoch=self.epoch

        done=True
        while True:
            if done:
                h_0 = torch.zeros((1, 512), dtype=torch.float)
                c_0 = torch.zeros((1, 512), dtype=torch.float)
            else:
                h_0 = h_0.detach()
                c_0 = c_0.detach()
                
            h_0 = h_0.to(self.device)
            c_0 = c_0.to(self.device)

            Timestamp=50
            for i in range((Timestamp)):
                env.render()
                    
                p,value,h_0,c_0=self.AC(state,h_0,c_0)


                
                policy=F.softmax(p,dim=1)
                log_prob=F.log_softmax(p,dim=1)
                entropy=-(policy*log_prob).sum(1,keepdim=True)
                
                m=Categorical(policy)

                action=m.sample()
                next_state, reward, done, info = env.step(action.item())

                #reward=reward/15
                

                #next_state=next_state.view(1,1,80,80)
                next_state=(next_state).to(self.device,dtype=torch.float)
                

                
                #self.states.append(state)
                self.log_probs.append(log_prob[0,action])
                self.rewards.append(reward)
                self.values.append(value)
                self.entropies.append(entropy)
                
                state=next_state
                

                
                if(done):
                    state=(env.reset())
                    #state=state.reshape(1,1,80,80)
                    state=state.to(self.device)
                    #state=self.imageProcess(state)
                    break

            """
            actor_loss=0
            critic_loss=0
            returns=[]
            R=0
            for reward in self.rewards[::-1]:
                R=reward+self.GAMMA*R
                returns.insert(0,R)
            """
            #td=torch.tensor([1],dtype=torch.float).to(device)
            
            R = torch.zeros((1, 1), dtype=torch.float)
            if not done:
                _, R, _, _ = self.AC(state, h_0, c_0)

            R=R.to(self.device)
            actor_loss=0
            critic_loss=0
            entropy_loss=0
            advantage=torch.zeros((1, 1), dtype=torch.float)
            advantage=advantage.to(self.device)
            next_value=R
                
            for log_prob,reward,value,entropy in list(zip(self.log_probs,self.rewards,self.values,self.entropies))[::-1]:
                advantage=advantage*self.GAMMA
                advantage=advantage+reward+self.GAMMA*next_value.detach()-value.detach()
                next_value=value
                actor_loss=actor_loss+(-log_prob*advantage)
                R=R*self.GAMMA+reward
                critic_loss=critic_loss+(R-value)**2/2
                entropy_loss=entropy_loss+entropy

            
            total_loss=actor_loss+critic_loss-0.01*entropy_loss
            
            
            push_and_pull(self.optimizer, self.AC, self.global_model, total_loss)

            #for name, parms in self.C.named_parameters():	
            #print('-->name:', name, '-->grad_requirs:',parms.requires_grad,' -->grad_value:',parms.grad)

            
            if(i_epoch%10==0):
                print(self.name+"\ Episode %d \ Actor loss:%f \ Critic Loss:%f \ Total Loss: %f"%(i_epoch,actor_loss.item(),critic_loss.item(),total_loss.item()))
            
            

            """
            y.append(critic_loss.item())
            x.append(i_epoch)
            plt.plot(x,y) #畫線
            plt.show() #顯示繪製的圖形
            """                    
            i_epoch+=1
            
            del self.log_probs[:]
            del self.rewards[:]
            del self.values[:]
            del self.entropies[:]
            
            if(self.save):
                if(i_epoch%100==0):
                    PATH='./model/{}/A3C_{}_{}.pkl'.format(self.level,self.level,self.args.model_type)
                    torch.save({
                                'epoch': i_epoch,
                                'model_state_dict': self.global_model.state_dict(),
                                'optimizer_state_dict': self.optimizer.state_dict(),
                                'loss': total_loss,
                                'type':self.args.model_type,
                                }, PATH)
            if(i_epoch==Max_epoch):
                return
예제 #5
0
def actor_critic(agent_name,
                 multiple_agents=False,
                 load_agent=False,
                 n_episodes=300,
                 max_t=1000,
                 train_mode=True):
    """ Batch processed the states in a single forward pass with a single neural network
    Params
    ======
        multiple_agents (boolean): boolean for multiple agents
        PER (boolean): 
        n_episodes (int): maximum number of training episodes
        max_t (int): maximum number of timesteps per episode
    """
    start = time.time()
    device = get_device()
    env, env_info, states, state_size, action_size, brain_name, num_agents = initialize_env(
        multiple_agents, train_mode)
    states = torch.from_numpy(states).to(device).float()

    NUM_PROCESSES = num_agents

    # Scores is Episode Rewards
    scores = np.zeros(num_agents)
    scores_window = deque(maxlen=100)
    scores_episode = []

    actor_critic = ActorCritic(state_size, action_size, device).to(device)
    agent = A2C_ACKTR(agent_name,
                      actor_critic,
                      value_loss_coef=CRITIC_DISCOUNT,
                      entropy_coef=ENTROPY_BETA,
                      lr=LEARNING_RATE,
                      eps=EPS,
                      alpha=ALPHA,
                      max_grad_norm=MAX_GRAD_NORM,
                      acktr=False,
                      load_agent=load_agent)

    rollouts = SimpleRolloutStorage(NUM_STEPS, NUM_PROCESSES, state_size,
                                    action_size)
    rollouts.to(device)

    num_updates = NUM_ENV_STEPS // NUM_STEPS // NUM_PROCESSES
    # num_updates = NUM_ENV_STEPS // NUM_STEPS

    print("\n## Loaded environment and agent in {} seconds ##\n".format(
        round((time.time() - start), 2)))

    update_start = time.time()
    timesteps = 0
    episode = 0
    if load_agent != False:
        episode = agent.episode
    while True:
        """CAN INSERT LR DECAY HERE"""
        # if episode == MAX_EPISODES:
        #     return scores_episode

        # Adds noise to agents parameters to encourage exploration
        # agent.add_noise(PARAMETER_NOISE)

        for step in range(NUM_STEPS):
            step_start = time.time()

            # Sample actions
            with torch.no_grad():
                values, actions, action_log_probs, _ = agent.act(states)

            clipped_actions = np.clip(actions.cpu().numpy(), *ACTION_BOUNDS)
            env_info = env.step(actions.cpu().numpy())[
                brain_name]  # send the action to the environment
            next_states = env_info.vector_observations  # get the next state
            rewards = env_info.rewards  # get the reward
            rewards_tensor = np.array(env_info.rewards)
            rewards_tensor[rewards_tensor == 0] = NEGATIVE_REWARD
            rewards_tensor = torch.from_numpy(rewards_tensor).to(
                device).float().unsqueeze(1)
            dones = env_info.local_done
            masks = torch.from_numpy(1 - np.array(dones).astype(int)).to(
                device).float().unsqueeze(1)

            rollouts.insert(states, actions, action_log_probs, values,
                            rewards_tensor, masks, masks)

            next_states = torch.from_numpy(next_states).to(device).float()
            states = next_states
            scores += rewards
            # print(rewards)

            if timesteps % 100:
                print('\rTimestep {}\tScore: {:.2f}\tmin: {:.2f}\tmax: {:.2f}'.
                      format(timesteps, np.mean(scores), np.min(scores),
                             np.max(scores)),
                      end="")

            if np.any(dones):
                print(
                    '\rEpisode {}\tScore: {:.2f}\tAverage Score: {:.2f}\tMin Score: {:.2f}\tMax Score: {:.2f}'
                    .format(episode, score, np.mean(scores_window),
                            np.min(scores), np.max(scores)),
                    end="\n")
                update_csv(agent_name, episode, np.mean(scores_window),
                           np.max(scores))

                if episode % 20 == 0:
                    agent.save_agent(agent_name,
                                     score,
                                     episode,
                                     save_history=True)
                else:
                    agent.save_agent(agent_name, score, episode)

                episode += 1
                scores = np.zeros(num_agents)
                break

            timesteps += 1

        with torch.no_grad():
            next_values, _, _, _ = agent.act(next_states)

        rollouts.compute_returns(next_values, USE_GAE, GAMMA, GAE_LAMBDA)
        agent.update(rollouts)

        score = np.mean(scores)
        scores_window.append(score)  # save most recent score
        scores_episode.append(score)

    return scores_episode
예제 #6
0
파일: test.py 프로젝트: messiest/playground
def test(rank, args, shared_model, counter, device):
    # time.sleep(10.)

    # logging
    log_dir = f'logs/{args.env_name}/{args.model_id}/{args.uuid}/'
    info_logger = setup_logger('info', log_dir, f'info.log')
    result_logger = setup_logger('results', log_dir, f'results.log')

    # torch.manual_seed(args.seed + rank)

    env = create_atari_environment(args.env_name)
    if args.record:
        if not os.path.exists(f'playback/{args.env_name}/'):
            os.makedirs(f'playback/{args.env_name}/{args.model_id}', exist_ok=True)
        env = gym.wrappers.Monitor(env, f'playback/{args.env_name}/{args.model_id}/', force=True)

    # env.seed(args.seed + rank)
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n

    model = ActorCritic(observation_space, action_space)
    if torch.cuda.is_available():
        model.cuda()
    model.eval()

    state = env.reset()
    state = torch.from_numpy(state)
    reward_sum = 0
    done = True
    episode_length = 0
    actions = deque(maxlen=4000)
    start_time = time.time()
    for episode in count():
        episode_length += 1
        # shared model sync
        if done:
            model.load_state_dict(shared_model.state_dict())
            cx = torch.zeros(1, 512)
            hx = torch.zeros(1, 512)

        else:
            cx = cx.data
            hx = hx.data

        with torch.no_grad():
            value, logit, (hx, cx) = model((state.unsqueeze(0), (hx, cx)))

        prob = F.softmax(logit, dim=-1)
        action = prob.max(-1, keepdim=True)[1]

        state, reward, done, info = env.step(action.item())

        reward_sum += reward

        info_log = {
            'id': args.model_id,
            'algorithm': args.algorithm,
            'greedy-eps': args.greedy_eps,
            'episode': episode,
            'total_episodes': counter.value,
            'episode_length': episode_length,
            'reward': reward_sum,
            'done': done,
        }
        info_logger.info(info_log)

        print(f"{emojize(':video_game:', use_aliases=True)} | ", end='\r')

        env.render()

        actions.append(action.item())

        if done:
            t = time.time() - start_time

            print(
                f"{emojize(':video_game:', use_aliases=True)} | " + \
                f"ID: {args.model_id}, " + \
                f"Total Episodes: {counter.value}, " + \
                f"Time: {time.strftime('%H:%M:%S', time.gmtime(t)):^9s}, " + \
                f"FPS: {episode_length/t: 6.2f}, " + \
                f"Reward: {reward_sum: 10.0f}",
                end='\r',
                flush=True,
            )

            result_logger.info(info_log)

            reward_sum = 0
            episode_length = 0
            actions.clear()
            time.sleep(args.reset_delay)
            state = env.reset()

        state = torch.from_numpy(state)
예제 #7
0
from functools import partial
import torch
from torch import autograd, optim
from torch.distributions import Independent, Normal
from torch.distributions.kl import kl_divergence
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from tqdm import tqdm
from env import Env
from hyperparams import BACKTRACK_COEFF, BACKTRACK_ITERS, ON_POLICY_BATCH_SIZE as BATCH_SIZE, CONJUGATE_GRADIENT_ITERS, DAMPING_COEFF, DISCOUNT, HIDDEN_SIZE, KL_LIMIT, LEARNING_RATE, MAX_STEPS, TRACE_DECAY
from models import ActorCritic
from utils import plot

env = Env()
agent = ActorCritic(env.observation_space.shape[0], env.action_space.shape[0],
                    HIDDEN_SIZE)
critic_optimiser = optim.Adam(agent.critic.parameters(), lr=LEARNING_RATE)


def hessian_vector_product(d_kl, x):
    g = parameters_to_vector(
        autograd.grad(d_kl, agent.actor.parameters(), create_graph=True))
    return parameters_to_vector(
        autograd.grad((g * x.detach()).sum(),
                      agent.actor.parameters(),
                      retain_graph=True)) + DAMPING_COEFF * x


def conjugate_gradient(Ax, b):
    x = torch.zeros_like(b)
    r = b - Ax(x)  # Residual
    p = r  # Conjugate vector
예제 #8
0
def train(rank, args, shared_model, opt_ac, can_save, shared_obs_stats):
    best_result = -1000
    torch.manual_seed(args.seed + rank)
    torch.set_default_tensor_type('torch.DoubleTensor')
    num_inputs = args.feature
    num_actions = 9
    last_state = [1] * 48

    if args.render and can_save:
        env = RunEnv(visualize=True)
    else:
        env = RunEnv(visualize=False)

    #running_state = ZFilter((num_inputs,), clip=5)
    #running_reward = ZFilter((1,), demean=False, clip=10)
    episode_lengths = []

    PATH_TO_MODEL = '../models/' + str(args.bh)

    ac_net = ActorCritic(num_inputs, num_actions)

    start_time = time.time()

    for i_episode in count(1):
        memory = Memory()
        ac_net.load_state_dict(shared_model.state_dict())
        ac_net.zero_grad()

        num_steps = 0
        reward_batch = 0
        num_episodes = 0
        #Tot_loss = 0
        #Tot_num =
        while num_steps < args.batch_size:
            #state = env.reset()
            #print(num_steps)
            state = env.reset(difficulty=0)
            last_state = process_observation(state)
            state = process_observation(state)
            last_state, state = transform_observation(last_state, state)

            state = numpy.array(state)
            #global last_state
            #last_state,_ = update_observation(last_state,state)
            #last_state,state = update_observation(last_state,state)
            #print(state.shape[0])
            #print(state[41])
            state = Variable(torch.Tensor(state).unsqueeze(0))
            shared_obs_stats.observes(state)
            state = shared_obs_stats.normalize(state)
            state = state.data[0].numpy()
            #state = running_state(state)

            reward_sum = 0
            #timer = time.time()
            for t in range(10000):  # Don't infinite loop while learning
                #print(t)
                if args.use_sep_pol_val:
                    action = select_action(state)
                else:
                    action = select_action_actor_critic(state, ac_net)
                #print(action)
                action = action.data[0].numpy()
                if numpy.any(numpy.isnan(action)):
                    print(state)
                    print(action)
                    print('ERROR')
                    raise RuntimeError('action NaN problem')
                #print(action)
                #print("------------------------")
                #timer = time.time()

                BB = numpy.append(action, action)
                #print(BB)

                reward = 0
                if args.skip:
                    #env.step(action)
                    _, A, _, _ = env.step(BB)
                    reward += A
                    _, A, _, _ = env.step(BB)
                    reward += A

                next_state, A, done, _ = env.step(BB)
                reward += A
                next_state = process_observation(next_state)
                last_state, next_state = transform_observation(
                    last_state, next_state)

                next_state = numpy.array(next_state)
                reward_sum += reward
                #print('env:')
                #print(time.time()-timer)

                #last_state ,next_state = update_observation(last_state,next_state)
                #next_state = running_state(next_state)
                next_state = Variable(torch.Tensor(next_state).unsqueeze(0))
                shared_obs_stats.observes(next_state)
                next_state = shared_obs_stats.normalize(next_state)
                next_state = next_state.data[0].numpy()
                #print(next_state[41:82])

                mask = 1
                if done:
                    mask = 0

                memory.push(state, np.array([action]), mask, next_state,
                            reward)

                #if args.render:
                #    env.render()
                if done:
                    break

                state = next_state
            num_steps += (t - 1)
            num_episodes += 1

            reward_batch += reward_sum

        reward_batch /= num_episodes
        batch = memory.sample()

        #print('env:')
        #print(time.time()-timer)

        #timer = time.time()
        update_params_actor_critic(batch, args, shared_model, ac_net, opt_ac)
        #print('backpropagate:')
        #print(time.time()-timer)

        epoch = i_episode
        if (i_episode % args.log_interval == 0) and (rank == 0):

            print('TrainEpisode {}\tLast reward: {}\tAverage reward {:.2f}'.
                  format(i_episode, reward_sum, reward_batch))
            if reward_batch > best_result:
                best_result = reward_batch
                save_model(
                    {
                        'epoch': epoch,
                        'bh': args.bh,
                        'state_dict': ac_net.state_dict(),
                        'optimizer': opt_ac,
                        'obs': shared_obs_stats,
                    }, PATH_TO_MODEL, 'best')

            if epoch % 30 == 1:
                save_model(
                    {
                        'epoch': epoch,
                        'bh': args.bh,
                        'state_dict': ac_net.state_dict(),
                        'optimizer': opt_ac,
                        'obs': shared_obs_stats,
                    }, PATH_TO_MODEL, epoch)
예제 #9
0
from functools import partial
import torch
from torch import autograd, optim
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from tqdm import tqdm
from env import Env
from hyperparams import BACKTRACK_COEFF, BACKTRACK_ITERS, ON_POLICY_BATCH_SIZE as BATCH_SIZE, CONJUGATE_GRADIENT_ITERS, DAMPING_COEFF, DISCOUNT, HIDDEN_SIZE, KL_LIMIT, LEARNING_RATE, MAX_STEPS, TRACE_DECAY
from models import ActorCritic
from utils import plot

env = Env()
agent = ActorCritic(HIDDEN_SIZE)
critic_optimiser = optim.Adam(agent.critic.parameters(), lr=LEARNING_RATE)


def hessian_vector_product(d_kl, x):
    g = parameters_to_vector(
        autograd.grad(d_kl, agent.actor.parameters(), create_graph=True))
    return parameters_to_vector(
        autograd.grad((g * x.detach()).sum(),
                      agent.actor.parameters(),
                      retain_graph=True)) + DAMPING_COEFF * x


def conjugate_gradient(Ax, b):
    x = torch.zeros_like(b)
    r = b - Ax(x)  # Residual
    p = r  # Conjugate vector
    r_dot_old = torch.dot(r, r)
예제 #10
0
import torch
from torch import optim
from tqdm import tqdm
from env import Env
from hyperparams import ON_POLICY_BATCH_SIZE as BATCH_SIZE, DISCOUNT, HIDDEN_SIZE, INITIAL_POLICY_LOG_STD_DEV, LEARNING_RATE, MAX_STEPS, TRACE_DECAY, VALUE_EPOCHS
from models import ActorCritic
from utils import plot

env = Env()
agent = ActorCritic(env.observation_space.shape[0],
                    env.action_space.shape[0],
                    HIDDEN_SIZE,
                    initial_policy_log_std_dev=INITIAL_POLICY_LOG_STD_DEV)
actor_optimiser = optim.Adam(agent.actor.parameters(), lr=LEARNING_RATE)
critic_optimiser = optim.Adam(agent.critic.parameters(), lr=LEARNING_RATE)

state, done, total_reward, D = env.reset(), False, 0, []
pbar = tqdm(range(1, MAX_STEPS + 1), unit_scale=1, smoothing=0)
for step in pbar:
    # Collect set of trajectories D by running policy π in the environment
    policy, value = agent(state)
    action = policy.sample()
    log_prob_action = policy.log_prob(action)
    next_state, reward, done = env.step(action)
    total_reward += reward
    D.append({
        'state': state,
        'action': action,
        'reward': torch.tensor([reward]),
        'done': torch.tensor([done], dtype=torch.float32),
        'log_prob_action': log_prob_action,
예제 #11
0
def play(args):
    env = create_mario_env(args.env_name, ACTIONS[args.move_set])

    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n

    model = ActorCritic(observation_space, action_space)

    checkpoint_file = \
        f"{args.env_name}/{args.model_id}_{args.algorithm}_params.tar"
    checkpoint = restore_checkpoint(checkpoint_file)
    assert args.env_name == checkpoint['env'], \
        "This checkpoint is for different environment: {checkpoint['env']}"
    args.model_id = checkpoint['id']

    print(f"Environment: {args.env_name}")
    print(f"      Agent: {args.model_id}")
    model.load_state_dict(checkpoint['model_state_dict'])

    state = env.reset()
    state = torch.from_numpy(state)
    reward_sum = 0
    done = True
    episode_length = 0
    start_time = time.time()
    for step in count():
        episode_length += 1

        # shared model sync
        if done:
            cx = torch.zeros(1, 512)
            hx = torch.zeros(1, 512)

        else:
            cx = cx.data
            hx = hx.data

        with torch.no_grad():
            value, logit, (hx, cx) = model((state.unsqueeze(0), (hx, cx)))

        prob = F.softmax(logit, dim=-1)
        action = prob.max(-1, keepdim=True)[1]

        action_idx = action.item()
        action_out = ACTIONS[args.move_set][action_idx]
        state, reward, done, info = env.step(action_idx)
        reward_sum += reward

        print(
            f"{emojize(':mushroom:')} World {info['world']}-{info['stage']} | {emojize(':video_game:')}: [ {' + '.join(action_out):^13s} ] | ",
            end='\r',
        )

        env.render()

        if done:
            t = time.time() - start_time

            print(
                f"{emojize(':mushroom:')} World {info['world']}-{info['stage']} |" + \
                f" {emojize(':video_game:')}: [ {' + '.join(action_out):^13s} ] | " + \
                f"ID: {args.model_id}, " + \
                f"Time: {time.strftime('%H:%M:%S', time.gmtime(t)):^9s}, " + \
                f"Reward: {reward_sum: 10.2f}, " + \
                f"Progress: {(info['x_pos'] / 3225) * 100: 3.2f}%",
                end='\r',
                flush=True,
            )

            reward_sum = 0
            episode_length = 0
            time.sleep(args.reset_delay)
            state = env.reset()

        state = torch.from_numpy(state)
예제 #12
0
                    default=5,
                    metavar='IE',
                    help='Imitation learning epochs')
parser.add_argument('--imitation-replay-size',
                    type=int,
                    default=1,
                    metavar='IRS',
                    help='Imitation learning trajectory replay size')
args = parser.parse_args()
torch.manual_seed(args.seed)
os.makedirs('results', exist_ok=True)

# Set up environment and models
env = CartPoleEnv()
env.seed(args.seed)
agent = ActorCritic(env.observation_space.shape[0], env.action_space.n,
                    args.hidden_size)
agent_optimiser = optim.RMSprop(agent.parameters(), lr=args.learning_rate)
if args.imitation:
    # Set up expert trajectories dataset
    expert_trajectories = torch.load('expert_trajectories.pth')
    expert_trajectories = {
        k: torch.cat([trajectory[k] for trajectory in expert_trajectories],
                     dim=0)
        for k in expert_trajectories[0].keys()
    }  # Flatten expert trajectories
    expert_trajectories = TransitionDataset(expert_trajectories)
    # Set up discriminator
    if args.imitation in ['AIRL', 'GAIL']:
        if args.imitation == 'AIRL':
            discriminator = AIRLDiscriminator(env.observation_space.shape[0],
                                              env.action_space.n,
예제 #13
0
def main(args):
    print(f" Session ID: {args.uuid}")

    # logging
    log_dir = f'logs/{args.env_name}/{args.model_id}/{args.uuid}/'
    args_logger = setup_logger('args', log_dir, f'args.log')
    env_logger = setup_logger('env', log_dir, f'env.log')

    if args.debug:
        debug.packages()
    os.environ['OMP_NUM_THREADS'] = "1"
    if torch.cuda.is_available():
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        devices = ",".join([str(i) for i in range(torch.cuda.device_count())])
        os.environ["CUDA_VISIBLE_DEVICES"] = devices

    args_logger.info(vars(args))
    env_logger.info(vars(os.environ))

    env = create_atari_environment(args.env_name)

    shared_model = ActorCritic(env.observation_space.shape[0],
                               env.action_space.n)

    if torch.cuda.is_available():
        shared_model = shared_model.cuda()

    shared_model.share_memory()

    optimizer = SharedAdam(shared_model.parameters(), lr=args.lr)
    optimizer.share_memory()

    if args.load_model:  # TODO Load model before initializing optimizer
        checkpoint_file = f"{args.env_name}/{args.model_id}_{args.algorithm}_params.tar"
        checkpoint = restore_checkpoint(checkpoint_file)
        assert args.env_name == checkpoint['env'], \
            "Checkpoint is for different environment"
        args.model_id = checkpoint['id']
        args.start_step = checkpoint['step']
        print("Loading model from checkpoint...")
        print(f"Environment: {args.env_name}")
        print(f"      Agent: {args.model_id}")
        print(f"      Start: Step {args.start_step}")
        shared_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    else:
        print(f"Environment: {args.env_name}")
        print(f"      Agent: {args.model_id}")

    torch.manual_seed(args.seed)

    print(
        FontColor.BLUE + \
        f"CPUs:    {mp.cpu_count(): 3d} | " + \
        f"GPUs: {None if not torch.cuda.is_available() else torch.cuda.device_count()}" + \
        FontColor.END
    )

    processes = []

    counter = mp.Value('i', 0)
    lock = mp.Lock()

    # Queue training processes
    num_processes = args.num_processes
    no_sample = args.non_sample  # count of non-sampling processes

    if args.num_processes > 1:
        num_processes = args.num_processes - 1

    samplers = num_processes - no_sample

    for rank in range(0, num_processes):
        device = 'cpu'
        if torch.cuda.is_available():
            device = 0  # TODO: Need to move to distributed to handle multigpu
        if rank < samplers:  # random action
            p = mp.Process(
                target=train,
                args=(rank, args, shared_model, counter, lock, optimizer,
                      device),
            )
        else:  # best action
            p = mp.Process(
                target=train,
                args=(rank, args, shared_model, counter, lock, optimizer,
                      device, False),
            )
        p.start()
        time.sleep(1.)
        processes.append(p)

    # Queue test process
    p = mp.Process(target=test,
                   args=(args.num_processes, args, shared_model, counter, 0))

    p.start()
    processes.append(p)

    for p in processes:
        p.join()
예제 #14
0
def train(rank,
          args,
          shared_model,
          counter,
          lock,
          optimizer=None,
          device='cpu',
          select_sample=True):
    # torch.manual_seed(args.seed + rank)

    # logging
    log_dir = f'logs/{args.env_name}/{args.model_id}/{args.uuid}/'
    loss_logger = setup_logger('loss', log_dir, f'loss.log')
    # action_logger = setup_logger('actions', log_dir, f'actions.log')

    text_color = FontColor.RED if select_sample else FontColor.GREEN
    print(
        text_color +
        f"Process: {rank: 3d} | {'Sampling' if select_sample else 'Decision'} | Device: {str(device).upper()}",
        FontColor.END)

    env = create_atari_environment(args.env_name)
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n

    # env.seed(args.seed + rank)

    model = ActorCritic(observation_space, action_space)
    if torch.cuda.is_available():
        model = model.cuda()
        model.device = device

    if optimizer is None:
        optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)

    model.train()

    state = env.reset()
    state = torch.from_numpy(state)
    done = True

    for t in count(start=args.start_step):
        if t % args.save_interval == 0 and t > 0:
            save_checkpoint(shared_model, optimizer, args, t)

        # Sync shared model
        model.load_state_dict(shared_model.state_dict())

        if done:
            cx = torch.zeros(1, 512)
            hx = torch.zeros(1, 512)
        else:
            cx = cx.detach()
            hx = hx.detach()

        values = []
        log_probs = []
        rewards = []
        entropies = []

        episode_length = 0
        for step in range(args.num_steps):
            episode_length += 1

            value, logit, (hx, cx) = model((state.unsqueeze(0), (hx, cx)))

            prob = F.softmax(logit, dim=-1)
            log_prob = F.log_softmax(logit, dim=-1)
            entropy = -(log_prob * prob).sum(-1, keepdim=True)
            entropies.append(entropy)

            reason = ''

            if select_sample:
                rand = random.random()
                epsilon = get_epsilon(t)
                if rand < epsilon and args.greedy_eps:
                    action = torch.randint(0, action_space, (1, 1))
                    reason = 'uniform'

                else:
                    action = prob.multinomial(1)
                    reason = 'multinomial'

            else:
                action = prob.max(-1, keepdim=True)[1]
                reason = 'choice'

            # action_logger.info({
            #     'rank': rank,
            #     'action': action.item(),
            #     'reason': reason,
            #     })

            if torch.cuda.is_available():
                action = action.cuda()
                value = value.cuda()

            log_prob = log_prob.gather(-1, action)

            # action_out = ACTIONS[args.move_set][action.item()]

            state, reward, done, info = env.step(action.item())

            done = done or episode_length >= args.max_episode_length
            reward = max(min(reward, 50), -50)  # h/t @ArvindSoma

            with lock:
                counter.value += 1

            if done:
                episode_length = 0
                state = env.reset()

            state = torch.from_numpy(state)
            values.append(value)
            log_probs.append(log_prob)
            rewards.append(reward)

            if done:
                break

        R = torch.zeros(1, 1)
        if not done:
            value, _, _ = model((state.unsqueeze(0), (hx, cx)))
            R = value.data

        values.append(R)

        loss = gae(R, rewards, values, log_probs, entropies, args)

        loss_logger.info({
            'episode': t,
            'rank': rank,
            'sampling': select_sample,
            'loss': loss.item()
        })

        optimizer.zero_grad()

        (loss).backward()

        nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

        ensure_shared_grads(model, shared_model)

        optimizer.step()
예제 #15
0
def main():
    order_book_id_number = 10
    toy_data = create_toy_data(order_book_ids_number=order_book_id_number,
                               feature_number=20,
                               start="2019-05-01",
                               end="2019-12-12",
                               frequency="D")
    env = PortfolioTradingGym(data_df=toy_data,
                              sequence_window=5,
                              add_cash=True)
    env = Numpy(env)
    env = ch.envs.Logger(env, interval=1000)
    env = ch.envs.Torch(env)
    env = ch.envs.Runner(env)

    # create net
    action_size = env.action_space.shape[0]
    number_asset, seq_window, features_number = env.observation_space.shape
    input_size = features_number

    agent = ActorCritic(input_size=input_size,
                        hidden_size=HIDDEN_SIZE,
                        action_size=action_size)
    actor_optimiser = optim.Adam(agent.actor.parameters(), lr=LEARNING_RATE)
    critic_optimiser = optim.Adam(agent.critic.parameters(), lr=LEARNING_RATE)

    replay = ch.ExperienceReplay()

    for step in range(1, MAX_STEPS + 1):
        replay += env.run(agent, episodes=1)

        if len(replay) >= BATCH_SIZE:
            with torch.no_grad():
                advantages = pg.generalized_advantage(DISCOUNT, TRACE_DECAY,
                                                      replay.reward(),
                                                      replay.done(),
                                                      replay.value(),
                                                      torch.zeros(1))
                advantages = ch.normalize(advantages, epsilon=1e-8)
                returns = td.discount(DISCOUNT, replay.reward(), replay.done())
                old_log_probs = replay.log_prob()

            # here is to add readability
            new_values = replay.value()
            new_log_probs = replay.log_prob()
            for epoch in range(PPO_EPOCHS):
                # Recalculate outputs for subsequent iterations
                if epoch > 0:
                    _, infos = agent(replay.state())
                    masses = infos['mass']
                    new_values = infos['value']
                    new_log_probs = masses.log_prob(
                        replay.action()).unsqueeze(-1)

                # Update the policy by maximising the PPO-Clip objective
                policy_loss = ch.algorithms.ppo.policy_loss(
                    new_log_probs,
                    old_log_probs,
                    advantages,
                    clip=PPO_CLIP_RATIO)
                actor_optimiser.zero_grad()
                policy_loss.backward()
                actor_optimiser.step()

                # Fit value function by regression on mean-squared error
                value_loss = ch.algorithms.a2c.state_value_loss(
                    new_values, returns)
                critic_optimiser.zero_grad()
                value_loss.backward()
                critic_optimiser.step()

            replay.empty()
예제 #16
0
def sac(args):
    #set seed if non default is entered
    if args.seed != -1:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

    env, test_env = TorchEnv(args.env_name, args.max_ep_len), TorchEnv(
        args.env_name, args.max_ep_len)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    action_limit = env.action_space.high[0]
    # Create actor-critic module and target networks
    ac = ActorCritic(state_dim,
                     action_dim,
                     action_limit,
                     args.hidden_size,
                     args.gamma,
                     args.alpha,
                     device=args.device)
    ac_targ = deepcopy(ac)
    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False
    # List of parameters for both Q-networks (save this for convenience)
    q_params = itertools.chain(ac.q1.parameters(), ac.q2.parameters())
    # Experience buffer
    buffer = Buffer(state_dim,
                    action_dim,
                    buffer_size=args.buffer_size,
                    device=args.device)
    # Set up optimizers for policy and q-function
    pi_optimizer = Adam(ac.pi.parameters(), lr=args.lr)
    q_optimizer = Adam(q_params, lr=args.lr)

    def update(data):
        # First run one gradient descent step for Q1 and Q2
        q_optimizer.zero_grad()
        loss_q = ac.compute_loss_q(data, ac_targ)
        loss_q.backward()
        q_optimizer.step()

        # Freeze Q-networks so you don't waste computational effort computing gradients for them during the policy learning step.
        for p in q_params:
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        pi_optimizer.zero_grad()
        loss_pi = ac.compute_loss_pi(data)
        loss_pi.backward()
        pi_optimizer.step()

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in q_params:
            p.requires_grad = True

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(args.polyak)
                p_targ.data.add_((1 - args.polyak) * p.data)

    def test_agent(deterministic=True):
        for j in range(args.num_test_episodes):
            o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
            while not (d or (ep_len == args.max_ep_len)):
                # Take deterministic actions at test time
                o, r, d = test_env.step(
                    ac.act(
                        torch.as_tensor(o,
                                        dtype=torch.float32).to(args.device),
                        deterministic))
                ep_ret += r
                ep_len += 1

    # Prepare for interaction with environment
    total_steps = args.steps_per_epoch * args.epochs
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0

    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):
        # Until start_steps have elapsed, randomly sample actions from a uniform distribution for better exploration. Afterwards, use the learned policy.
        if t > args.start_steps:
            a = ac.act(torch.as_tensor(o, dtype=torch.float32).to(args.device))
        else:
            a = env.action_space.sample()
        # Step the env
        o2, r, d = env.step(a)
        if args.render_env:
            env.render()
        ep_ret += r
        ep_len += 1
        # Ignore the "done" signal if it comes from hitting the time horizon (that is, when it's an artificial terminal signal that isn't based on the agent's state)
        d = False if ep_len == args.max_ep_len else d
        # Store experience to replay buffer
        buffer.add(o, a, r, o2, d)
        o = o2
        # End of trajectory handling
        if d or (ep_len == args.max_ep_len):
            print("EPISODE REWARD: ", ep_ret)
            o, ep_ret, ep_len = env.reset(), 0, 0

        # Update handling
        if t >= args.update_after and t % args.update_every == 0:
            batch_generator = buffer.get_train_batches(args.batch_size)
            for j in range(args.update_every):
                #my_batch = my_buffer.get_train_batches(args.batch_size).__next__()
                try:
                    batch = batch_generator.__next__()
                except:
                    batch_generator = buffer.get_train_batches(args.batch_size)
                    batch = batch_generator.__next__()
                update(batch)

        # End of epoch handling
        if (t + 1) % args.steps_per_epoch == 0:
            epoch = (t + 1) // args.steps_per_epoch
            # Test the performance of the deterministic version of the agent.
            test_agent()
예제 #17
0
config = pyglet.gl.Config(double_buffer=True)
window = pyglet.window.Window(config=config)

board_size = 40
offset = 20
width = window.width - offset
height = window.height - offset

board_unit = min(width // board_size, height // board_size)
x1_board = window.width // 2 - (board_size // 2) * board_unit
x2_board = x1_board + board_size * board_unit
y1_board = window.height // 2 - (board_size // 2) * board_unit
y2_board = y1_board + board_size * board_unit

env = Snake(2, board_size=board_size, terminal_step=100)
model = ActorCritic().eval()
model.load_state_dict(torch.load("weights.pt"))
state, invalid = env.reset()
body_coord = torch.nonzero(state[0, 3])
ids = []
dist, val = model(state[0:1], invalid[0:1])
ai_action = dist.sample().item()
q_values = dist.probs.tolist()[0]
batch = pyglet.graphics.Batch()


def draw_board(batch):
    batch.add(2 * 4, pyglet.gl.GL_LINES, None,
              ("v2i",
               (x1_board, y1_board, x2_board, y1_board, x2_board, y1_board,
                x2_board, y2_board, x2_board, y2_board, x1_board, y2_board,
예제 #18
0
board_size = 40
offset = 20
width = window.width - offset
height = window.height - offset

board_unit = min(width // board_size, height // board_size)
x1_board = window.width // 2 - (board_size // 2 + 1) * board_unit
x2_board = x1_board + (board_size + 1) * board_unit
y1_board = window.height // 2 - (board_size // 2 + 1) * board_unit
y2_board = y1_board + (board_size + 1) * board_unit

print(x1_board, x2_board, y1_board, y2_board)

env = TrainEnvSingle()
game = env.game
model = ActorCritic()
model.load_state_dict(torch.load("weights.pt"))
model = model.eval()
state, invalid = env.reset()
dist, value = model(state, invalid)
q_values = dist.probs.tolist()[0]


def take_action(dt):
    pass


def reload_model(dt):
    global model
    model.load_state_dict(torch.load("weights.pt"))
    print("Reloaded model")
예제 #19
0
def main(args):
    # create environment 
    env = gym.make(args.env)
    env.seed(args.seed)
    obs_dim = env.observation_space.shape[0]
    if isinstance(env.action_space, Discrete):
        discrete = True
        act_dim = env.action_space.n
    else:
        discrete = False
        act_dim = env.action_space.shape[0]

    # actor critic 
    ac = ActorCritic(obs_dim, act_dim, discrete).to(args.device)
    print('Number of parameters', count_vars(ac))

    # Set up experience buffer
    steps_per_epoch = int(args.steps_per_epoch)
    buf = PGBuffer(obs_dim, act_dim, discrete, steps_per_epoch, args)
    logs = defaultdict(lambda: [])
    writer = SummaryWriter(args_to_str(args))
    gif_frames = []

    # Set up function for computing policy loss
    def compute_loss_pi(batch):
        obs, act, psi, logp_old = batch['obs'], batch['act'], batch['psi'], batch['logp']
        pi, logp = ac.pi(obs, act)

        # Policy loss
        if args.loss_mode == 'vpg':
            # TODO (Task 2): implement vanilla policy gradient loss
        elif args.loss_mode == 'ppo':
            # TODO (Task 4): implement clipped PPO loss
        else:
            raise Exception('Invalid loss_mode option', args.loss_mode)

        # Useful extra info
        approx_kl = (logp_old - logp).mean().item()
        ent = pi.entropy().mean().item()
        pi_info = dict(kl=approx_kl, ent=ent)

        return loss_pi, pi_info

    # Set up function for computing value loss
    def compute_loss_v(batch):
        obs, ret = batch['obs'], batch['ret']
        v = ac.v(obs)
        # TODO: (Task 2): compute value function loss
        return loss_v

    # Set up optimizers for policy and value function
    pi_optimizer = Adam(ac.pi.parameters(), lr=args.pi_lr)
    vf_optimizer = Adam(ac.v.parameters(), lr=args.v_lr)

    # Set up update function
    def update():
        batch = buf.get()

        # Get loss and info values before update
        pi_l_old, pi_info_old = compute_loss_pi(batch)
        pi_l_old = pi_l_old.item()
        v_l_old = compute_loss_v(batch).item()

        # Policy learning
        for i in range(args.train_pi_iters):
            pi_optimizer.zero_grad()
            loss_pi, pi_info = compute_loss_pi(batch)
            loss_pi.backward()
            pi_optimizer.step()

        # Value function learning
        for i in range(args.train_v_iters):
            vf_optimizer.zero_grad()
            loss_v = compute_loss_v(batch)
            loss_v.backward()
            vf_optimizer.step()

        # Log changes from update
        kl, ent = pi_info['kl'], pi_info_old['ent']
        logs['kl'] += [kl]
        logs['ent'] += [ent]
        logs['loss_v'] += [loss_v.item()]
        logs['loss_pi'] += [loss_pi.item()]

    # Prepare for interaction with environment
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0

    ep_count = 0  # just for logging purpose, number of episodes run
    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(args.epochs):
        for t in range(steps_per_epoch):
            a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32).to(args.device))

            next_o, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1

            # save and log
            buf.store(o, a, r, v, logp)
            if ep_count % 100 == 0:
                frame = env.render(mode='rgb_array')
                # uncomment this line if you want to log to tensorboard (can be memory intensive)
                #gif_frames.append(frame)
                #gif_frames.append(PIL.Image.fromarray(frame).resize([64,64]))  # you can try this downsize version if you are resource constrained
                time.sleep(0.01)
            
            # Update obs (critical!)
            o = next_o

            timeout = ep_len == args.max_ep_len
            terminal = d or timeout
            epoch_ended = t==steps_per_epoch-1

            if terminal or epoch_ended:
                # if trajectory didn't reach terminal state, bootstrap value target
                if timeout or epoch_ended:
                    _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32).to(args.device))
                else:
                    v = 0
                buf.finish_path(v)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logs['ep_ret'] += [ep_ret]
                    logs['ep_len'] += [ep_len]
                    ep_count += 1

                o, ep_ret, ep_len = env.reset(), 0, 0

                # save a video to tensorboard so you can view later
                if len(gif_frames) != 0:
                    vid = np.stack(gif_frames)
                    vid_tensor = vid.transpose(0,3,1,2)[None]
                    writer.add_video('rollout', vid_tensor, epoch, fps=50)
                    gif_frames = []
                    writer.flush()
                    print('wrote video')

        # Perform VPG update!
        update()

        if epoch % 10 == 0:
            vals = {key: np.mean(val) for key, val in logs.items()}
            for key in vals:
                writer.add_scalar(key, vals[key], epoch)
            writer.flush()
            print('Epoch', epoch, vals)
            logs = defaultdict(lambda: [])
예제 #20
0
                    action='store_true',
                    help='force two leg together')
parser.add_argument('--start-epoch', type=int, default=0, help='start-epoch')

if __name__ == '__main__':
    args = parser.parse_args()
    os.environ['OMP_NUM_THREADS'] = '1'
    torch.manual_seed(args.seed)

    num_inputs = args.feature
    num_actions = 18

    traffic_light = TrafficLight()
    counter = Counter()

    ac_net = ActorCritic(num_inputs, num_actions)
    opt_ac = optim.Adam(ac_net.parameters(), lr=args.lr)

    shared_grad_buffers = Shared_grad_buffers(ac_net)
    shared_obs_stats = Shared_obs_stats(num_inputs)

    if args.resume:
        print("=> loading checkpoint ")
        checkpoint = torch.load('../../7.87.t7')
        #checkpoint = torch.load('../../best.t7')
        args.start_epoch = checkpoint['epoch']
        #best_prec1 = checkpoint['best_prec1']
        ac_net.load_state_dict(checkpoint['state_dict'])
        opt_ac.load_state_dict(checkpoint['optimizer'])
        opt_ac.state = defaultdict(dict, opt_ac.state)
        #print(opt_ac)
예제 #21
0

num_inputs = envs.observation_space
num_outputs = envs.action_space

# Hyper-parameters
NB_STEP = 128
UPDATE_EPOCH = 10
MINI_BATCH_SIZE = 512
SIZES = [64]
GAMMA = 0.99
LAMBDA = 0.95
EPSILON = 0.2
REWARD_THRESHOLD = 190

model = ActorCritic(num_inputs, num_outputs, SIZES)

frame_idx = 0
test_rewards = []
#env_render = False

state = envs.reset()
early_stop = False
PATH = "saved_models/model_ppo_pendulum.pt"

while not early_stop:

    log_probs = []
    values = []
    states = []
    actions = []
예제 #22
0
def train(rank, args, traffic_light, counter, shared_model,
          shared_grad_buffers, shared_obs_stats, opt_ac):
    best_result = -1000
    torch.manual_seed(args.seed + rank)
    torch.set_default_tensor_type('torch.DoubleTensor')
    num_inputs = args.feature
    num_actions = 9
    last_state = [0] * 41
    last_v = [0] * 10
    #last_state = numpy.zeros(48)

    env = RunEnv(visualize=False)

    #running_state = ZFilter((num_inputs,), clip=5)
    #running_reward = ZFilter((1,), demean=False, clip=10)
    episode_lengths = []

    PATH_TO_MODEL = '../models/' + str(args.bh)

    ac_net = ActorCritic(num_inputs, num_actions)

    #running_state = ZFilter((num_inputs,), clip=5)

    start_time = time.time()

    for i_episode in range(args.start_epoch + 1, 999999):
        #print(shared_obs_stats.n[0])
        #print('hei')
        #if rank == 0:
        #    print(running_state.rs._n)

        signal_init = traffic_light.get()
        memory = Memory()
        ac_net.load_state_dict(shared_model.state_dict())

        num_steps = 0
        reward_batch = 0
        num_episodes = 0
        #Tot_loss = 0
        #Tot_num =
        while num_steps < args.batch_size:
            #state = env.reset()
            #print(num_steps)
            state = env.reset(difficulty=0)
            #state = numpy.array(state)

            last_state, last_v, state = process_observation(
                last_state, last_v, state)

            state = numpy.array(state)

            #state = running_state(state)

            state = Variable(torch.Tensor(state).unsqueeze(0))
            shared_obs_stats.observes(state)
            state = shared_obs_stats.normalize(state)
            state = state.data[0].numpy()

            #print(state)
            #return

            #print(AA)

            #print(type(AA))
            #print(type(state))
            #print(AA.shape)
            #print(state.shape)

            reward_sum = 0
            #timer = time.time()
            for t in range(10000):  # Don't infinite loop while learning
                #print(t)
                if args.use_sep_pol_val:
                    action = select_action(state)
                else:
                    action = select_action_actor_critic(state, ac_net)
                #print(action)
                action = action.data[0].numpy()
                if numpy.any(numpy.isnan(action)):
                    print(state)
                    print(action)
                    print(ac_net.affine1.weight)
                    print(ac_net.affine1.weight.data)
                    print('ERROR')
                    #action = select_action_actor_critic(state,ac_net)
                    #action = action.data[0].numpy()
                    #state = state + numpy.random.rand(args.feature)*0.001

                    raise RuntimeError('action NaN problem')
                #print(action)
                #print("------------------------")
                #timer = time.time()
                reward = 0
                if args.skip:
                    #env.step(action)
                    _, A, _, _ = env.step(action)
                    reward += A
                    _, A, _, _ = env.step(action)
                    reward += A
                BB = numpy.append(action, action)
                next_state, A, done, _ = env.step(BB)
                reward += A
                #print(next_state)
                #last_state = process_observation(state)
                last_state, last_v, next_state = process_observation(
                    last_state, last_v, next_state)

                next_state = numpy.array(next_state)
                #print(next_state)
                #print(next_state.shape)
                #return
                reward_sum += reward
                #print('env:')
                #print(time.time()-timer)

                #last_state ,next_state = update_observation(last_state,next_state)

                #next_state = running_state(next_state)

                next_state = Variable(torch.Tensor(next_state).unsqueeze(0))
                shared_obs_stats.observes(next_state)
                next_state = shared_obs_stats.normalize(next_state)
                next_state = next_state.data[0].numpy()

                #print(next_state[41:82])

                mask = 1
                if done:
                    mask = 0

                memory.push(state, np.array([action]), mask, next_state,
                            reward)

                #if args.render:
                #    env.render()
                if done:
                    break

                state = next_state
            num_steps += (t - 1)
            num_episodes += 1

            reward_batch += reward_sum

        reward_batch /= num_episodes
        batch = memory.sample()

        #print('env:')
        #print(time.time()-timer)

        #timer = time.time()
        update_params_actor_critic(batch, args, ac_net, opt_ac)
        shared_grad_buffers.add_gradient(ac_net)

        counter.increment()

        epoch = i_episode
        if (i_episode % args.log_interval == 0) and (rank == 0):

            print(
                'TrainEpisode {}\tTime{}\tLast reward: {}\tAverage reward {:.2f}'
                .format(
                    i_episode,
                    time.strftime("%Hh %Mm %Ss",
                                  time.gmtime(time.time() - start_time)),
                    reward_sum, reward_batch))

            epoch = i_episode
            if reward_batch > best_result:
                best_result = reward_batch
                save_model(
                    {
                        'epoch': epoch,
                        'bh': args.bh,
                        'state_dict': shared_model.state_dict(),
                        'optimizer': opt_ac.state_dict(),
                        'obs': shared_obs_stats,
                    }, PATH_TO_MODEL, 'best')

            if epoch % 30 == 1:
                save_model(
                    {
                        'epoch': epoch,
                        'bh': args.bh,
                        'state_dict': shared_model.state_dict(),
                        'optimizer': opt_ac.state_dict(),
                        'obs': shared_obs_stats,
                    }, PATH_TO_MODEL, epoch)
        # wait for a new signal to continue
        while traffic_light.get() == signal_init:
            pass
예제 #23
0

# create env
history, abbreviation = create_env_input()
env = PortfolioEnv(history, abbreviation)
#env = ch.envs.Logger(env, interval=20)
env = ch.envs.Torch(env)

# create net
action_size = env.action_space.shape[0]
number_asset, seq_window, features_all = env.observation_space.shape
assert action_size == number_asset + 1
input_size = features_all - 1

net = ActorCritic(input_size=input_size,
                  hidden_size=50,
                  action_size=action_size)
net_tgt = ActorCritic(input_size=input_size,
                      hidden_size=50,
                      action_size=action_size)
net_tgt.eval()
print(net_tgt)
net_tgt.load_state_dict(net.state_dict())

# create replay
replay = ch.ExperienceReplay()

# create loss function
criterion_mse = nn.MSELoss()

# create optimizer
예제 #24
0
    return args


if __name__=='__main__':
    args=get_args()
    LEVEL=str(args.world)+'-'+str(args.stage)

    folder='./model/{}'.format(LEVEL)

    if(not  os.path.exists(folder)):
        os.mkdir(folder)
    
    if(args.model_type == "LSTM"):
        global_model=ActorCritic_LSTM()
    else:
        global_model=ActorCritic()
    
    global_model.to(device)
    optimizer=SharedAdam(global_model.parameters(),lr=1e-4)

    PATH='./model/{}/A3C_{}_{}.pkl'.format(LEVEL,LEVEL,args.model_type) 
    epoch=1
    if(args.load_model):
        if(os.path.exists(PATH)):
            print('Loaded Model')
            check_point=torch.load(PATH)
            global_model.load_state_dict(check_point['model_state_dict'])
            optimizer.load_state_dict(check_point['optimizer_state_dict'])
            epoch=check_point['epoch']
    
    global_model.share_memory()
예제 #25
0
from train import train_on_env
from models import ActorCritic
from constants import *

model = ActorCritic(use_conv=False, input_size=4)
if USE_CUDA: model.cuda()

train_on_env("CartPole-v0", model, 5000, 128)
예제 #26
0
import torch
from torch import optim
from tqdm import tqdm
from env import Env
from models import ActorCritic
from utils import plot


max_steps, batch_size, discount, trace_decay = 100000, 16, 0.99, 0.97
env = Env()
agent = ActorCritic()
actor_optimiser = optim.Adam(list(agent.actor.parameters()) + [agent.policy_log_std], lr=3e-4)
critic_optimiser = optim.Adam(agent.critic.parameters(), lr=1e-3)


step, pbar = 0, tqdm(total=max_steps, smoothing=0)
while step < max_steps:
  # Collect set of trajectories D by running policy π in the environment
  D = [[]] * batch_size
  for idx in range(batch_size):
    state, done, total_reward = env.reset(), False, 0
    while not done:
      policy, value = agent(state)
      action = policy.sample()
      log_prob_action = policy.log_prob(action)
      next_state, reward, done = env.step(action)
      step += 1
      pbar.update(1)
      total_reward += reward
      D[idx].append({'state': state, 'action': action, 'reward': torch.tensor([reward]), 'log_prob_action': log_prob_action, 'value': value})
      state = next_state
예제 #27
0
#lr               = 3e-4
#num_steps        = 20
#mini_batch_size  = 5
#ppo_epochs       = 4

for c in range(num_classes):

    print("Learning Policy for class:", c)

    envs = [
        make_env(num_features, blackbox_model, c, max_nodes, min_nodes)
        for i in range(num_envs)
    ]
    envs = SubprocVecEnv(envs)

    model = ActorCritic(num_features, embedding_size)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    max_frames = 5000
    frame_idx = 0
    test_rewards = []

    state = envs.reset()

    early_stop = False

    #Save mean rewards per episode
    env_0_mean_rewards = []
    env_0_rewards = []

    while frame_idx < max_frames and not early_stop:
예제 #28
0
def test(rank, args, shared_model, opt_ac):
    best_result = -1000
    torch.manual_seed(args.seed + rank)
    torch.set_default_tensor_type('torch.DoubleTensor')
    num_inputs = args.feature
    num_actions = 9
    last_state = numpy.zeros(41)

    if args.render:
        env = RunEnv(visualize=True)
    else:
        env = RunEnv(visualize=False)

    running_state = ZFilter((num_inputs, ), clip=5)
    running_reward = ZFilter((1, ), demean=False, clip=10)
    episode_lengths = []

    PATH_TO_MODEL = '../models/' + str(args.bh)

    ac_net = ActorCritic(num_inputs, num_actions)

    start_time = time.time()

    for i_episode in count(1):
        memory = Memory()
        ac_net.load_state_dict(shared_model.state_dict())

        num_steps = 0
        reward_batch = 0
        num_episodes = 0
        while num_steps < args.batch_size:
            #state = env.reset()
            #print(num_steps)
            state = env.reset(difficulty=0)
            state = numpy.array(state)
            #global last_state
            #last_state = state
            #last_state,_ = update_observation(last_state,state)
            #last_state,state = update_observation(last_state,state)
            #print(state.shape[0])
            #print(state[41])
            state = running_state(state)

            reward_sum = 0
            for t in range(10000):  # Don't infinite loop while learning
                #print(t)
                #timer = time.time()
                if args.use_sep_pol_val:
                    action = select_action(state)
                else:
                    action = select_action_actor_critic(state, ac_net)

                #print(action)
                action = action.data[0].numpy()
                if numpy.any(numpy.isnan(action)):
                    print(action)
                    puts('ERROR')
                    return
                #print('NN take:')
                #print(time.time()-timer)
                #print(action)
                #print("------------------------")

                #timer = time.time()
                if args.skip:
                    #env.step(action)
                    _, reward, _, _ = env.step(action)
                    reward_sum += reward
                next_state, reward, done, _ = env.step(action)
                next_state = numpy.array(next_state)
                reward_sum += reward

                #print('env take:')
                #print(time.time()-timer)

                #timer = time.time()

                #last_state ,next_state = update_observation(last_state,next_state)
                next_state = running_state(next_state)
                #print(next_state[41:82])

                mask = 1
                if done:
                    mask = 0

                #print('update take:')
                #print(time.time()-timer)

                #timer = time.time()

                memory.push(state, np.array([action]), mask, next_state,
                            reward)

                #print('memory take:')
                #print(time.time()-timer)

                #if args.render:
                #    env.render()
                if done:
                    break

                state = next_state

            num_steps += (t - 1)
            num_episodes += 1
            #print(num_episodes)
            reward_batch += reward_sum

        #print(num_episodes)
        reward_batch /= num_episodes
        batch = memory.sample()

        #update_params_actor_critic(batch,args,shared_model,ac_net,opt_ac)
        time.sleep(60)

        if i_episode % args.log_interval == 0:
            File = open(PATH_TO_MODEL + '/record.txt', 'a+')
            File.write("Time {}, episode reward {}, Average reward {}".format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - start_time)),
                reward_sum, reward_batch))
            File.close()
            #print('TestEpisode {}\tLast reward: {}\tAverage reward {:.2f}'.format(
            #    i_episode, reward_sum, reward_batch))
            print("Time {}, episode reward {}, Average reward {}".format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - start_time)),
                reward_sum, reward_batch))
            #print('!!!!')

        epoch = i_episode
        if reward_batch > best_result:
            best_result = reward_batch
            save_model(
                {
                    'epoch': epoch,
                    'bh': args.bh,
                    'state_dict': shared_model.state_dict(),
                    'optimizer': opt_ac.state_dict(),
                }, PATH_TO_MODEL, 'best')

        if epoch % 30 == 1:
            save_model(
                {
                    'epoch': epoch,
                    'bh': args.bh,
                    'state_dict': shared_model.state_dict(),
                    'optimizer': opt_ac.state_dict(),
                }, PATH_TO_MODEL, epoch)
예제 #29
0
def main(dataset,
         pretrain_max_epoch,
         max_epoch,
         learning_rate,
         weight_decay,
         max_pretrain_grad_norm,
         max_grad_norm,
         batch_size,
         embedding_size,
         rnn_input_size,
         rnn_hidden_size,
         hidden_size,
         bottleneck_size,
         entropy_penalty,
         gamma,
         alpha,
         nonlinear_func='tanh',
         value_weight=0.5,
         reward_function='exf1',
         label_order='freq2rare',
         input_dropout_prob=0.2,
         dropout_prob=0.5,
         num_layers=1,
         cv_fold=0,
         seed=None,
         fixed_label_seq_pretrain=False):

    data_loaders, configs = prepare_exp(dataset,
                                        max_epoch,
                                        learning_rate,
                                        weight_decay,
                                        batch_size,
                                        embedding_size,
                                        rnn_input_size,
                                        rnn_hidden_size,
                                        hidden_size,
                                        bottleneck_size,
                                        nonlinear_func=nonlinear_func,
                                        dropout_prob=dropout_prob,
                                        num_layers=num_layers,
                                        label_order=label_order,
                                        entropy_penalty=entropy_penalty,
                                        value_weight=value_weight,
                                        reward_function=reward_function,
                                        gamma=gamma,
                                        alpha=alpha,
                                        cv_fold=cv_fold,
                                        seed=seed)

    train_loader, sub_train_loader, valid_loader, test_loader = data_loaders
    opt_config, data_config, model_config = configs

    BOS_ID = train_loader.dataset.get_start_label_id()
    EOS_ID = train_loader.dataset.get_stop_label_id()
    is_sparse_data = train_loader.dataset.is_sparse_dataset()

    criterion = nn.NLLLoss(ignore_index=0, reduction='none')
    model = ActorCritic(model_config)
    if device.type == 'cuda':
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(),
                           lr=opt_config['learning_rate'],
                           weight_decay=weight_decay)
    env = Environment(model_config)
    bipartition_eval_functions, ranking_evaluation_functions = load_evaluation_functions(
    )

    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    model_arch_info = 'emb_{}_rnn_{}_hid_{}_bot_{}_inpdp_{}_dp_{}_{}'.format(
        embedding_size, rnn_hidden_size, hidden_size, bottleneck_size,
        input_dropout_prob, dropout_prob, nonlinear_func)
    rl_info = 'alpha_{}_gamma_{}_vw_{}_reward_{}_ent_{}'.format(
        alpha, gamma, value_weight, reward_function, entropy_penalty)
    optim_info = 'lr_{}_decay_{}_norm_{}-{}_bs_{}_epoch_{}-{}_fold_{}'.format(
        learning_rate, weight_decay, max_pretrain_grad_norm, max_grad_norm,
        batch_size, pretrain_max_epoch, max_epoch, cv_fold)

    if fixed_label_seq_pretrain and max_epoch == 0:
        # baseline models
        summary_comment = '_'.join([
            current_time, 'baseline', label_order, model_arch_info, optim_info
        ])
    else:
        summary_comment = '_'.join(
            [current_time, 'proposed', model_arch_info, rl_info, optim_info])

    summary_log_dir = os.path.join('runs', dataset, summary_comment)
    bipartition_model_save_path = os.path.join(
        'models', dataset, summary_comment + '_bipartition.pth')
    ranking_model_save_path = os.path.join('models', dataset,
                                           summary_comment + '_ranking.pth')

    writer = SummaryWriter(log_dir=summary_log_dir)

    n_batches = 0
    for batch_idx, (data, targets) in enumerate(train_loader):
        n_batches += 1

    num_param_updates = 0
    best_bipartition_valid_score = -np.inf
    best_ranking_valid_score = -np.inf
    max_epoch = opt_config['max_epoch']

    input_dropout = nn.Dropout(p=input_dropout_prob)

    # pretrain or only supervised learning with a fixed label ordering
    for epoch in range(pretrain_max_epoch):

        print('==== {} ===='.format(epoch))
        avg_rewards = []
        for batch_idx, (data, targets) in enumerate(train_loader):

            data, targets = prepare_minibatch(
                data,
                targets,
                train_loader.dataset.get_feature_dim(),
                is_sparse_data,
                drop_EOS=label_order == 'mblp')
            batch_size = len(targets)
            assert data.shape[0] == batch_size, '{}\t{}'.format(
                data.shape[0], batch_size)

            data = input_dropout(data)

            if label_order != 'mblp':
                target_length = np.array(list(map(len, targets)))
                max_length = int(np.max(target_length))
                targets_ = np.zeros((max_length, batch_size), dtype=np.int64)

                for i in range(batch_size):
                    targets_[:len(targets[i]), i] = targets[i]

                targets = torch.tensor(targets_,
                                       dtype=torch.int64,
                                       device=device,
                                       requires_grad=False)
            else:
                max_target_length = np.max(np.array(list(map(len, targets))))
                max_sampling_steps = int(max_target_length * 1.5)

                env.clear_episode_temp_data()
                gen_actions_per_episode = []
                rewards_per_episode = []

                model = model.eval()
                prev_states = model.init_hidden(data, device)
                prev_actions = torch.tensor([BOS_ID] * batch_size,
                                            dtype=torch.int64,
                                            device=device)

                for t in range(
                        max_sampling_steps):  # no infinite loop while learning

                    model_outputs, states = model(data,
                                                  prev_actions,
                                                  prev_states,
                                                  state_value_grad=False)
                    gen_actions, _, done = env.step(model_outputs)

                    gen_actions_per_episode.append(
                        gen_actions.data.cpu().numpy())

                    if done:
                        break

                    prev_actions = gen_actions
                    prev_states = states

                # gen_actions_per_episode: (batch_size, max_trials) # cols can be smaller.
                gen_actions_per_episode = np.array(gen_actions_per_episode).T

                # sort labels according to model predictions
                targets_ = convert_labelset2seq(targets,
                                                gen_actions_per_episode,
                                                EOS_ID)
                targets = torch.tensor(targets_,
                                       dtype=torch.int64,
                                       device=device,
                                       requires_grad=False)

                del gen_actions_per_episode

            model = model.train()
            prev_states = model.init_hidden(data, device)
            prev_actions = torch.tensor([BOS_ID] * batch_size,
                                        dtype=torch.int64,
                                        device=device,
                                        requires_grad=False)
            dropout_masks = create_dropout_mask(
                model_config.dropout_prob, batch_size,
                model_config.embedding_size * 2, model_config.rnn_hidden_size)

            losses = []
            for t in range(targets.size(0)):  # no infinite loop while learning
                model_outputs, states = model(data,
                                              prev_actions,
                                              prev_states,
                                              dropout_masks=dropout_masks,
                                              state_value_grad=False)

                logits = model_outputs[0]
                log_probs = F.log_softmax(logits, dim=-1)
                target_t = targets[t]

                losses.append(criterion(log_probs, target_t))

                prev_actions = target_t
                prev_states = states

            # loss: (seq_len, batch_size)
            loss = torch.stack(losses, dim=0)
            loss = torch.sum(loss, dim=0).mean()

            optimizer.zero_grad()
            loss.backward()

            output_str = '{}/Before gradient norm'.format(dataset)
            print_param_norm(model.parameters(), writer, output_str,
                             num_param_updates)

            # torch.nn.utils.clip_grad_value_(model.parameters(), 100)
            if max_pretrain_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_pretrain_grad_norm)

            output_str = '{}/After gradient norm'.format(dataset)
            print_param_norm(model.parameters(), writer, output_str,
                             num_param_updates)

            optimizer.step()

            num_param_updates += 1

        results = evaluation(
            OrderedDict([('sub_train', sub_train_loader),
                         ('valid', valid_loader), ('test', test_loader)]),
            model, env, bipartition_eval_functions,
            ranking_evaluation_functions, model_config.max_trials)

        print_result_summary(results, writer, dataset, epoch)

        for split_name, scores in results.items():
            if split_name is 'valid':
                if scores['example f1 score'] > best_bipartition_valid_score:
                    best_bipartition_valid_scores = scores['example f1 score']
                    save_model(epoch, model, optimizer,
                               bipartition_model_save_path)

                if scores['nDCG_k'][-1] > best_ranking_valid_score:
                    best_ranking_valid_scores = scores['nDCG_k'][-1]
                    save_model(epoch, model, optimizer,
                               ranking_model_save_path)

    def update_alpha(epoch, xlimit=6, alpha_max=1):
        updated_alpha = 1 / (
            1 + float(np.exp(xlimit - 2 * xlimit / float(max_epoch) * epoch)))
        updated_alpha = min(updated_alpha, alpha_max)
        return updated_alpha

    del optimizer

    # joint learning
    rl_optimizer = optim.Adam(model.parameters(),
                              lr=opt_config['learning_rate'],
                              weight_decay=weight_decay)
    for epoch in range(max_epoch):
        if alpha == 'auto':
            alpha_e = update_alpha(epoch)
        else:
            assert float(alpha) >= 0 and float(alpha) <= 1
            alpha_e = float(alpha)

        print('==== {} ===='.format(epoch + pretrain_max_epoch))
        avg_rewards = []
        for batch_idx, (data, targets) in enumerate(train_loader):

            model = model.train()
            data, targets = prepare_minibatch(
                data, targets, train_loader.dataset.get_feature_dim(),
                is_sparse_data)
            batch_size = len(targets)
            assert data.shape[0] == batch_size, '{}\t{}'.format(
                data.shape[0], batch_size)

            data = input_dropout(data)

            dropout_masks = create_dropout_mask(
                model_config.dropout_prob, batch_size,
                model_config.embedding_size * 2, model_config.rnn_hidden_size)
            prev_states = model.init_hidden(data, device)
            prev_actions = torch.tensor([BOS_ID] * batch_size,
                                        dtype=torch.int64,
                                        device=device,
                                        requires_grad=False)

            max_target_length = np.max(np.array(list(map(len, targets))))
            max_sampling_steps = int(max_target_length * 1.5)

            env.clear_episode_temp_data()
            gen_actions_per_episode = []
            rewards_per_episode = []

            for t in range(
                    max_sampling_steps):  # no infinite loop while learning

                model_outputs, states = model(data,
                                              prev_actions,
                                              prev_states,
                                              dropout_masks=dropout_masks)
                gen_actions, rewards, done = env.step(model_outputs, targets)

                gen_actions_per_episode.append(gen_actions.data.cpu().numpy())
                rewards_per_episode.append(rewards)

                if done:
                    break

                prev_actions = gen_actions
                prev_states = states

            num_non_empty = np.array([len(t) > 0 for t in targets]).sum()
            r = np.stack(rewards_per_episode,
                         axis=1).sum(1).sum() / num_non_empty
            avg_rewards.append(r)

            ps_loss, adv_collection = calculate_loss(env, model_config)
            writer.add_scalar('{}/avg_advantages'.format(dataset),
                              adv_collection.mean().data.cpu().numpy(),
                              num_param_updates)

            # gen_actions_per_episode: (batch_size, max_trials) # cols can be smaller.
            gen_actions_per_episode = np.array(gen_actions_per_episode).T

            # sort labels according to model predictions
            targets_ = convert_labelset2seq(targets, gen_actions_per_episode,
                                            EOS_ID)
            targets = torch.tensor(targets_,
                                   dtype=torch.int64,
                                   device=device,
                                   requires_grad=False)

            del gen_actions_per_episode

            prev_states = model.init_hidden(data, device)
            prev_actions = torch.tensor([BOS_ID] * batch_size,
                                        dtype=torch.int64,
                                        device=device,
                                        requires_grad=False)
            dropout_masks = create_dropout_mask(
                model_config.dropout_prob, batch_size,
                model_config.embedding_size * 2, model_config.rnn_hidden_size)

            losses = []
            for t in range(targets.size(0)):  # no infinite loop while learning
                model_outputs, states = model(data,
                                              prev_actions,
                                              prev_states,
                                              dropout_masks=dropout_masks,
                                              state_value_grad=False)
                logits = model_outputs[0]
                log_probs = F.log_softmax(logits, dim=-1)
                target_t = targets[t]

                losses.append(criterion(log_probs, target_t))

                prev_actions = target_t
                prev_states = states

            # loss: (seq_len, batch_size)
            sup_loss = torch.stack(losses, dim=0).sum(0)

            loss = alpha_e * ps_loss + (1 - alpha_e) * sup_loss
            loss = loss.mean()

            rl_optimizer.zero_grad()
            loss.backward()

            output_str = '{}/Before gradient norm'.format(dataset)
            print_param_norm(model.parameters(), writer, output_str,
                             num_param_updates)

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            output_str = '{}/After gradient norm'.format(dataset)
            print_param_norm(model.parameters(), writer, output_str,
                             num_param_updates)

            rl_optimizer.step()

            num_param_updates += 1

        results = evaluation(
            OrderedDict([('sub_train', sub_train_loader),
                         ('valid', valid_loader), ('test', test_loader)]),
            model, env, bipartition_eval_functions,
            ranking_evaluation_functions, model_config.max_trials)

        print_result_summary(results, writer, dataset,
                             epoch + pretrain_max_epoch)

        for split_name, scores in results.items():
            if split_name is 'valid':
                if scores['example f1 score'] > best_bipartition_valid_score:
                    best_bipartition_valid_scores = scores['example f1 score']
                    save_model(epoch + pretrain_max_epoch, model, rl_optimizer,
                               bipartition_model_save_path)

                if scores['nDCG_k'][-1] > best_ranking_valid_score:
                    best_ranking_valid_scores = scores['nDCG_k'][-1]
                    save_model(epoch + pretrain_max_epoch, model, rl_optimizer,
                               ranking_model_save_path)

    writer.close()