Exemplo n.º 1
0
                player.env.render()
        if player.starter and player.flag:
            player = player_start(player)
        else:
            player.flag = False
        if player.done and not player.flag:
            player.model.load_state_dict(saved_state)
            player.cx = Variable(torch.zeros(1, 512), volatile=True)
            player.hx = Variable(torch.zeros(1, 512), volatile=True)
            player.flag = False
        elif not player.flag:
            player.cx = Variable(player.cx.data, volatile=True)
            player.hx = Variable(player.hx.data, volatile=True)
            player.flag = False
        if not player.flag:
            player, reward = player_act(player, train=False)
            reward_sum += reward

        if not player.done:
            if player.current_life > player.info['ale.lives']:
                player.flag = True
                player.current_life = player.info['ale.lives']
            else:
                player.current_life = player.info['ale.lives']
                player.flag = False

        if player.done:
            player.flag = True
            player.current_life = 0
            num_tests += 1
            reward_total_sum += reward_sum
Exemplo n.º 2
0
def train(rank, args, shared_model, optimizer, env_conf):
    torch.manual_seed(args.seed + rank)

    env = atari_env(args.env, env_conf)
    model = A3Clstm(env.observation_space.shape[0], env.action_space)

    if optimizer is None:
        if args.optimizer == 'RMSprop':
            optimizer = optim.RMSprop(shared_model.parameters(), lr=args.lr)
        if args.optimizer == 'Adam':
            optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)

    env.seed(args.seed + rank)
    state = env.reset()
    player = Agent(model, env, args, state)
    player.state = torch.from_numpy(state).float()
    player.model.train()
    epoch = 0
    while True:

        player.model.load_state_dict(shared_model.state_dict())
        if player.done:
            player.cx = Variable(torch.zeros(1, 512))
            player.hx = Variable(torch.zeros(1, 512))
            if player.starter:
                player = player_start(player, train=True)
        else:
            player.cx = Variable(player.cx.data)
            player.hx = Variable(player.hx.data)

        for step in range(args.num_steps):

            player = player_act(player, train=True)

            if player.done:
                break

            if player.current_life > player.info['ale.lives']:
                player.flag = True
                player.current_life = player.info['ale.lives']
            else:
                player.current_life = player.info['ale.lives']
                player.flag = False
            if args.count_lives:
                if player.flag:
                    player.done = True
                    break

            if player.starter and player.flag:
                player = player_start(player, train=True)
            if player.done:
                break

        if player.done:
            player.eps_len = 0
            player.current_life = 0
            state = player.env.reset()
            player.state = torch.from_numpy(state).float()
            player.flag = False

        R = torch.zeros(1, 1)
        if not player.done:
            value, _, _ = player.model(
                (Variable(player.state.unsqueeze(0)), (player.hx, player.cx)))
            R = value.data

        player.values.append(Variable(R))
        policy_loss = 0
        value_loss = 0
        R = Variable(R)
        gae = torch.zeros(1, 1)
        for i in reversed(range(len(player.rewards))):
            R = args.gamma * R + player.rewards[i]
            advantage = R - player.values[i]
            value_loss += 0.5 * advantage.pow(2)

            # Generalized Advantage Estimataion
            delta_t = player.rewards[i] + args.gamma * player.values[i + 1].data - player.values[i].data
            gae = gae * args.gamma * args.tau + delta_t

            policy_loss = policy_loss - player.log_probs[i] * Variable(gae) - 0.01 * player.entropies[i]

        optimizer.zero_grad()

        (policy_loss + value_loss).backward()

        ensure_shared_grads(player.model, shared_model)
        optimizer.step()
        player.values = []
        player.log_probs = []
        player.rewards = []
        player.entropies = []
Exemplo n.º 3
0
def test(args, shared_model, env_conf):
    log = {}
    setup_logger('{}_log'.format(args.env),
                 r'{0}{1}_log'.format(args.log_dir, args.env))
    log['{}_log'.format(args.env)] = logging.getLogger('{}_log'.format(
        args.env))
    d_args = vars(args)
    for k in d_args.keys():
        log['{}_log'.format(args.env)].info('{0}: {1}'.format(k, d_args[k]))

    torch.manual_seed(args.seed)
    env = atari_env(args.env, env_conf)
    model = A3Clstm(env.observation_space.shape[0], env.action_space)

    state = env.reset()
    reward_sum = 0
    start_time = time.time()
    num_tests = 0
    reward_total_sum = 0
    player = Agent(model, env, args, state)
    player.state = torch.from_numpy(state).float()
    player.model.eval()
    while True:
        if player.starter and player.flag:
            player = player_start(player)
        else:
            player.flag = False
        if player.done and not player.flag:
            player.model.load_state_dict(shared_model.state_dict())
            player.cx = Variable(torch.zeros(1, 512), volatile=True)
            player.hx = Variable(torch.zeros(1, 512), volatile=True)
            player.flag = False
        elif not player.flag:
            player.cx = Variable(player.cx.data, volatile=True)
            player.hx = Variable(player.hx.data, volatile=True)
            player.flag = False
        if not player.flag:
            player, reward = player_act(player, train=False)
            reward_sum += reward

        if not player.done:
            if player.current_life > player.info['ale.lives']:
                player.flag = True
                player.current_life = player.info['ale.lives']
            else:
                player.current_life = player.info['ale.lives']
                player.flag = False

        if player.done:
            num_tests += 1
            player.current_life = 0
            player.flag = True
            reward_total_sum += reward_sum
            reward_mean = reward_total_sum / num_tests
            log['{}_log'.format(args.env)].info(
                "Time {0}, episode reward {1}, episode length {2}, reward mean {3:.4f}"
                .format(
                    time.strftime("%Hh %Mm %Ss",
                                  time.gmtime(time.time() - start_time)),
                    reward_sum, player.eps_len, reward_mean))

            if reward_sum > args.save_score_level:
                player.model.load_state_dict(shared_model.state_dict())
                state_to_save = player.model.state_dict()
                torch.save(state_to_save,
                           '{0}{1}.dat'.format(args.save_model_dir, args.env))

            reward_sum = 0
            player.eps_len = 0
            state = player.env.reset()
            time.sleep(60)
            player.state = torch.from_numpy(state).float()