Example #1
0
def get_functions(args):
    ''' based on alg type return tuple of train/test functionsm model and env factory '''
    #TODO add make_model 
    import envs
    from test import test
    algo_module = importlib.import_module('algorithms.{}'.format(args.algo))
    model_module = importlib.import_module('models.{}'.format(args.arch))
    if args.env_name == 'CartPole-v0':
        make_env = lambda: envs.basic_env(args.env_name, stacked=1)
    elif args.arch == 'lstm_universe':
        make_env = lambda: envs.atari_env(args.env_name, side=42, stacked=1)
    elif args.arch == 'lstm_nature':
        make_env = lambda: envs.atari_env(args.env_name, side=84, stacked=1)
    else:
        raise argparse.ArgumentError('net architeture not known')
    
    return (algo_module.train, test, model_module.Net, make_env)
Example #2
0
def train(rank, args, shared_model, optimizer, env_conf):
    torch.manual_seed(args.seed + rank)

    env = atari_env(args.env_name, env_conf)
    model = A3Clstm(env.observation_space.shape[0], env.action_space)
    _ = env.reset()
    action = env.action_space.sample()
    _, _, _, info = env.step(action)
    start_lives = info['ale.lives']

    if optimizer is None:
        if args.optimizer == 'RMSprop':
            optimizer = optim.RMSprop(shared_model.parameters(), lr=args.lr)
        if args.optimizer == 'Adam':
            optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)

    model.train()
    env.seed(args.seed + rank)
    state = env.reset()
    state = torch.from_numpy(state).float()
    done = True
    episode_length = 0
    while True:
        episode_length += 1
        # Sync with the shared model
        model.load_state_dict(shared_model.state_dict())
        if done:
            cx = Variable(torch.zeros(1, 512))
            hx = Variable(torch.zeros(1, 512))
        else:
            cx = Variable(cx.data)
            hx = Variable(hx.data)

        values = []
        log_probs = []
        rewards = []
        entropies = []

        for step in range(args.num_steps):

            value, logit, (hx, cx) = model(
                (Variable(state.unsqueeze(0)), (hx, cx)))
            prob = F.softmax(logit)
            log_prob = F.log_softmax(logit)
            entropy = -(log_prob * prob).sum(1)
            entropies.append(entropy)

            action = prob.multinomial().data
            log_prob = log_prob.gather(1, Variable(action))

            state, reward, done, info = env.step(action.numpy())
            done = done or episode_length >= args.max_episode_length
            if args.count_lives:
                if start_lives > info['ale.lives']:
                    done = True
            reward = max(min(reward, 1), -1)

            if done:
                episode_length = 0
                state = env.reset()

            state = torch.from_numpy(state).float()
            values.append(value)
            log_probs.append(log_prob)
            rewards.append(reward)

            if done:
                break

        R = torch.zeros(1, 1)
        if not done:

            value, _, _ = model((Variable(state.unsqueeze(0)), (hx, cx)))
            R = value.data

        values.append(Variable(R))
        policy_loss = 0
        value_loss = 0
        R = Variable(R)
        gae = torch.zeros(1, 1)
        for i in reversed(range(len(rewards))):
            R = args.gamma * R + rewards[i]
            advantage = R - values[i]
            value_loss = value_loss + 0.5 * advantage.pow(2)

            # Generalized Advantage Estimataion
            delta_t = rewards[i] + args.gamma * \
                values[i + 1].data - values[i].data
            gae = gae * args.gamma * args.tau + delta_t

            policy_loss = policy_loss - \
                log_probs[i] * Variable(gae) - 0.01 * entropies[i]

        optimizer.zero_grad()

        (policy_loss + 0.5 * value_loss).backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 40)

        ensure_shared_grads(model, shared_model)
        optimizer.step()
Example #3
0
# https://github.com/pytorch/examples/tree/master/mnist_hogwild
# Training settings
# Implemented multiprocessing using locks but was not beneficial. Hogwild
# training was far superior

if __name__ == '__main__':
    args = parser.parse_args()
    torch.set_default_tensor_type('torch.FloatTensor')
    torch.manual_seed(args.seed)

    setup_json = read_config(args.env_config)
    env_conf = setup_json["Default"]
    for i in setup_json.keys():
        if i in args.env_name:
            env_conf = setup_json[i]
    env = atari_env(args.env_name, env_conf)
    shared_model = A3Clstm(env.observation_space.shape[0], env.action_space)
    if args.load:
        saved_state = torch.load('{0}{1}.dat'.format(args.load_model_dir,
                                                     args.env_name))
        shared_model.load_state_dict(saved_state)
    shared_model.share_memory()

    if args.shared_optimizer:
        if args.optimizer == 'RMSprop':
            optimizer = shared_optim.SharedRMSprop(shared_model.parameters(),
                                                   lr=args.lr)
        if args.optimizer == 'Adam':
            optimizer = shared_optim.SharedAdam(shared_model.parameters(),
                                                lr=args.lr)
        optimizer.share_memory()
Example #4
0
        env_conf = setup_json[i]
torch.set_default_tensor_type('torch.FloatTensor')

saved_state = torch.load('{0}{1}.dat'.format(args.load_model_dir,
                                             args.env_name),
                         map_location=lambda storage, loc: storage)

done = True

log = {}
setup_logger('{}_mon_log'.format(args.env_name),
             r'{0}{1}_mon_log'.format(args.log_dir, args.env_name))
log['{}_mon_log'.format(args.env_name)] = logging.getLogger(
    '{}_mon_log'.format(args.env_name))

env = atari_env("{}".format(args.env_name), env_conf)
model = A3Clstm(env.observation_space.shape[0], env.action_space)
model.eval()

env = gym.wrappers.Monitor(env, "{}_monitor".format(args.env_name), force=True)
num_tests = 0
reward_total_sum = 0
for i_episode in range(args.num_episodes):
    state = env.reset()
    episode_length = 0
    reward_sum = 0
    while True:
        if args.render:
            if i_episode % args.render_freq == 0:
                env.render()
        if done:
Example #5
0
def test(rank, args, shared_model, env_conf):
    log = {}
    setup_logger('{}_log'.format(args.env_name),
                 r'{0}{1}_log'.format(args.log_dir, args.env_name))
    log['{}_log'.format(args.env_name)] = logging.getLogger('{}_log'.format(
        args.env_name))
    d_args = vars(args)
    for k in d_args.keys():
        log['{}_log'.format(args.env_name)].info('{0}: {1}'.format(
            k, d_args[k]))

    torch.manual_seed(args.seed)
    env = atari_env(args.env_name, env_conf)
    model = A3Clstm(env.observation_space.shape[0], env.action_space)
    model.eval()

    state = env.reset()
    state = torch.from_numpy(state).float()
    reward_sum = 0
    done = True
    start_time = time.time()
    episode_length = 0
    num_tests = 0
    reward_total_sum = 0
    while True:
        episode_length += 1
        # Sync with the shared model
        if done:
            model.load_state_dict(shared_model.state_dict())
            cx = Variable(torch.zeros(1, 512), volatile=True)
            hx = Variable(torch.zeros(1, 512), volatile=True)
        else:
            cx = Variable(cx.data, volatile=True)
            hx = Variable(hx.data, volatile=True)

        value, logit, (hx, cx) = model((Variable(state.unsqueeze(0),
                                                 volatile=True), (hx, cx)))
        prob = F.softmax(logit)
        action = prob.max(1)[1].data.numpy()
        state, reward, done, _ = env.step(action[0, 0])
        done = done or episode_length >= args.max_episode_length
        reward_sum += reward

        if done:
            num_tests += 1
            reward_total_sum += reward_sum
            reward_mean = reward_total_sum / num_tests
            log['{}_log'.format(args.env_name)].info(
                "Time {0}, episode reward {1}, episode length {2}, reward mean {3:.4f}"
                .format(
                    time.strftime("%Hh %Mm %Ss",
                                  time.gmtime(time.time() - start_time)),
                    reward_sum, episode_length, reward_mean))
            if reward_sum > args.save_score_level:
                model.load_state_dict(shared_model.state_dict())
                state_to_save = model.state_dict()
                torch.save(
                    state_to_save, '{0}{1}.dat'.format(args.save_model_dir,
                                                       args.env_name))

            reward_sum = 0
            episode_length = 0
            state = env.reset()
            time.sleep(60)

        state = torch.from_numpy(state).float()