예제 #1
0
def perturb_model(args, model, random_seed):
    """
    Modifies the given model with a pertubation of its parameters,
    as well as the negative perturbation, and returns both perturbed
    models.
    """
    new_model = ES(args.small_net)
    anti_model = ES(args.small_net)
    new_model.load_state_dict(model.state_dict())
    anti_model.load_state_dict(model.state_dict())
    np.random.seed(random_seed)
    for (k, v), (anti_k, anti_v) in zip(new_model.es_params(),
                                        anti_model.es_params()):
        eps = np.random.normal(0, 1, v.size())
        v += torch.from_numpy(args.sigma * eps).float()
        anti_v += torch.from_numpy(args.sigma * -eps).float()
    return [new_model, anti_model]
예제 #2
0
                    help='Silence print statements during training')
parser.add_argument('--test',
                    action='store_true',
                    help='Just render the env, no training')

if __name__ == '__main__':
    args = parser.parse_args()
    assert args.n % 2 == 0
    if args.small_net and args.env_name not in [
            'CartPole-v0', 'CartPole-v1', 'MountainCar-v0'
    ]:
        args.env_name = 'CartPole-v1'
        print('Switching env to CartPole')

    env = create_atari_env(args.env_name)
    chkpt_dir = 'checkpoints/%s/' % args.env_name
    if not os.path.exists(chkpt_dir):
        os.makedirs(chkpt_dir)
    synced_model = ES(env.observation_space.shape[0], env.action_space,
                      args.small_net)
    for param in synced_model.parameters():
        param.requires_grad = False
    if args.restore:
        state_dict = torch.load(args.restore)
        synced_model.load_state_dict(state_dict)

    if args.test:
        render_env(args, synced_model, env)
    else:
        train_loop(args, synced_model, env, chkpt_dir)
예제 #3
0
파일: main.py 프로젝트: lisun97/pytorch-es
                    action='store_true',
                    help='Use simple MLP on CartPole')
parser.add_argument('--variable-ep-len',
                    action='store_true',
                    help="Change max episode length during training")
parser.add_argument('--silent',
                    action='store_true',
                    help='Silence print statements during training')
parser.add_argument('--test',
                    action='store_true',
                    help='Just render the env, no training')

if __name__ == '__main__':
    args = parser.parse_args()
    assert args.n % 2 == 0

    chkpt_dir = 'checkpoints/%s/' % args.env_name
    if not os.path.exists(chkpt_dir):
        os.makedirs(chkpt_dir)
    synced_model = ES(args.small_net)
    for param in synced_model.parameters():
        param.requires_grad = False
    if args.restore:
        state_dict = torch.load(args.restore)
        synced_model.load_state_dict(state_dict)

    if args.test:
        render_env(args, synced_model)
    else:
        train_loop(args, synced_model, chkpt_dir)
예제 #4
0
                    help='Silence print statements during training')
parser.add_argument('--test',
                    action='store_true',
                    help='Just render the env, no training')
parser.add_argument('--max-gradient-updates',
                    type=int,
                    default=100000,
                    metavar='MGU',
                    help='maximum number of updates')

if __name__ == '__main__':
    args = parser.parse_args()
    assert args.n % 2 == 0

    chkpt_dir = 'checkpoints/'
    if not os.path.exists(chkpt_dir):
        os.makedirs(chkpt_dir)

    env = TicTacToeEnv()
    synced_model = ES(env.observation_space, env.action_space)
    for param in synced_model.parameters():
        param.requires_grad = False
    if args.restore:
        state_dict = torch.load(args.restore)
        synced_model.load_state_dict(state_dict)

    if args.test:
        render_env(synced_model)
    else:
        train_loop(args, synced_model, chkpt_dir)
예제 #5
0
def perturb_model(args, model, random_seed, env):
    """
    Modifies the given model with a perturbation of its parameters,
    as well as the negative perturbation, and returns both perturbed
    models.
    """
    new_model = ES(env.observation_space,env.action_space,
                    use_a3c_net=args.a3c_net, use_virtual_batch_norm=args.virtual_batch_norm)
    anti_model = ES(env.observation_space,env.action_space,
                    use_a3c_net=args.a3c_net, use_virtual_batch_norm=args.virtual_batch_norm)
    new_model.load_state_dict(model.state_dict())
    anti_model.load_state_dict(model.state_dict())
    np.random.seed(random_seed)
    eps = args.sigma * np.random.normal(0.0, 1.0, size=model.count_parameters())
    new_model.adjust_es_params(add=eps)
    anti_model.adjust_es_params(add=-eps)
    # for (k, v), (anti_k, anti_v) in zip(new_model.get_es_params(),
    #                                     anti_model.get_es_params()):
    #     eps = np.random.normal(0, 1, v.size())
    #     v += torch.from_numpy(args.sigma*eps).float()
    #     anti_v += torch.from_numpy(args.sigma*-eps).float()
    return [new_model, anti_model]
예제 #6
0
    env = create_atari_env(args.env_name,
                           frame_stack_size=args.stack_images,
                           noop_init=args.noop_init,
                           image_dim=args.image_dim)

    # set checkpoint directory
    if args.checkpoint_dir:
        chkpt_dir = args.checkpoint_dir
    else:
        chkpt_dir = 'checkpoints/%s/' % args.env_name
    if not os.path.exists(chkpt_dir):
        os.makedirs(chkpt_dir)

    # instantiate model (and restore if needed)
    synced_model = ES(env.observation_space,
                      env.action_space,
                      use_a3c_net=args.a3c_net,
                      use_virtual_batch_norm=args.virtual_batch_norm)
    for param in synced_model.parameters():
        param.requires_grad = False
    if args.restore:
        state_dict = torch.load(args.restore)
        synced_model.load_state_dict(state_dict)

    # compute batch for virtual batch normalization
    if args.virtual_batch_norm and not args.test:
        # print('Computing batch for virtual batch normalization')
        virtual_batch = gather_for_virtual_batch_norm(
            env, batch_size=args.virtual_batch_norm)
        virtual_batch = torchify(virtual_batch, unsqueeze=False)
    else:
        virtual_batch = None