Exemplo n.º 1
0
def algo_init(args):
    logger.info('Run algo_init function')

    setup_seed(args['seed'])
    
    if args["obs_shape"] and args["action_shape"]:
        obs_shape, action_shape = args["obs_shape"], args["action_shape"]
        max_action = args["max_action"]
    elif "task" in args.keys():
        from offlinerl.utils.env import get_env_shape, get_env_action_range
        obs_shape, action_shape = get_env_shape(args['task'])
        max_action, _ = get_env_action_range(args["task"])
        args["obs_shape"], args["action_shape"] = obs_shape, action_shape
    else:
        raise NotImplementedError
    
    net_a = Net(layer_num=args['actor_layers'], 
                state_shape=obs_shape, 
                hidden_layer_size=args['actor_features'])

    actor = TanhGaussianPolicy(preprocess_net=net_a,
                               action_shape=action_shape,
                               hidden_layer_size=args['actor_features'],
                               conditioned_sigma=True).to(args['device'])

    actor_optim = torch.optim.Adam(actor.parameters(), lr=args['actor_lr'])

    return {
        "actor" : {"net" : actor, "opt" : actor_optim},
    }
Exemplo n.º 2
0
def algo_init(args):
    logger.info('Run algo_init function')

    setup_seed(args['seed'])

    if args["obs_shape"] and args["action_shape"]:
        obs_shape, action_shape = args["obs_shape"], args["action_shape"]
    elif "task" in args.keys():
        from offlinerl.utils.env import get_env_shape
        obs_shape, action_shape = get_env_shape(args['task'])
        args["obs_shape"], args["action_shape"] = obs_shape, action_shape
    else:
        raise NotImplementedError

    if isinstance(args["obs_shape"], int):
        state_dim = (4, 84, 84)

        critic = Conv_Q(state_dim[0], args["action_shape"]).to(args['device'])
    else:
        critic = FC_Q(np.prod(args["obs_shape"]),
                      args["action_shape"]).to(args['device'])

    critic_opt = optim.Adam(critic.parameters(),
                            **args["optimizer_parameters"])

    nets = {
        "critic": {
            "net": critic,
            "opt": critic_opt
        },
    }

    return nets
Exemplo n.º 3
0
def algo_init(args):
    logger.info('Run algo_init function')

    setup_seed(args['seed'])

    if args["obs_shape"] and args["action_shape"]:
        obs_shape, action_shape = args["obs_shape"], args["action_shape"]
        max_action = args["max_action"]
    elif "task" in args.keys():
        from offlinerl.utils.env import get_env_shape, get_env_action_range
        obs_shape, action_shape = get_env_shape(args['task'])
        max_action, _ = get_env_action_range(args["task"])
        args["obs_shape"], args["action_shape"] = obs_shape, action_shape
    else:
        raise NotImplementedError

    vae = VAE(obs_shape, action_shape, args['vae_features'],
              args['vae_layers'], max_action).to(args['device'])
    vae_optim = torch.optim.Adam(vae.parameters(), lr=args['vae_lr'])

    jitter = Jitter(obs_shape, action_shape, args['jitter_features'],
                    args['jitter_layers'], max_action,
                    args['phi']).to(args['device'])
    jitter_optim = torch.optim.Adam(jitter.parameters(), lr=args['jitter_lr'])

    q1 = MLP(obs_shape + action_shape,
             1,
             args['value_features'],
             args['value_layers'],
             hidden_activation='relu').to(args['device'])
    q2 = MLP(obs_shape + action_shape,
             1,
             args['value_features'],
             args['value_layers'],
             hidden_activation='relu').to(args['device'])
    critic_optim = torch.optim.Adam([*q1.parameters(), *q2.parameters()],
                                    lr=args['critic_lr'])

    return {
        "vae": {
            "net": vae,
            "opt": vae_optim
        },
        "jitter": {
            "net": jitter,
            "opt": jitter_optim
        },
        "critic": {
            "net": [q1, q2],
            "opt": critic_optim
        },
    }
Exemplo n.º 4
0
def algo_init(args):
    logger.info('Run algo_init function')

    setup_seed(args['seed'])
    
    if args["obs_shape"] and args["action_shape"]:
        obs_shape, action_shape = args["obs_shape"], args["action_shape"]
    elif "task" in args.keys():
        from offlinerl.utils.env import get_env_shape
        obs_shape, action_shape = get_env_shape(args['task'])
        args["obs_shape"], args["action_shape"] = obs_shape, action_shape
    else:
        raise NotImplementedError
    
    transition = EnsembleTransition(obs_shape, action_shape, args['hidden_layer_size'], args['transition_layers'], args['transition_init_num']).to(args['device'])
    transition_optim = torch.optim.Adam(transition.parameters(), lr=args['transition_lr'], weight_decay=0.000075)

    net_a = Net(layer_num=args['hidden_layers'], 
                state_shape=obs_shape, 
                hidden_layer_size=args['hidden_layer_size'])

    actor = TanhGaussianPolicy(preprocess_net=net_a,
                               action_shape=action_shape,
                               hidden_layer_size=args['hidden_layer_size'],
                               conditioned_sigma=True).to(args['device'])

    actor_optim = torch.optim.Adam(actor.parameters(), lr=args['actor_lr'])

    log_alpha = torch.zeros(1, requires_grad=True, device=args['device'])
    alpha_optimizer = torch.optim.Adam([log_alpha], lr=args["actor_lr"])

    q1 = MLP(obs_shape + action_shape, 1, args['hidden_layer_size'], args['hidden_layers'], norm=None, hidden_activation='swish').to(args['device'])
    q2 = MLP(obs_shape + action_shape, 1, args['hidden_layer_size'], args['hidden_layers'], norm=None, hidden_activation='swish').to(args['device'])
    critic_optim = torch.optim.Adam([*q1.parameters(), *q2.parameters()], lr=args['actor_lr'])

    return {
        "transition" : {"net" : transition, "opt" : transition_optim},
        "actor" : {"net" : actor, "opt" : actor_optim},
        "log_alpha" : {"net" : log_alpha, "opt" : alpha_optimizer},
        "critic" : {"net" : [q1, q2], "opt" : critic_optim},
    }
Exemplo n.º 5
0
def algo_init(args):
    logger.info('Run algo_init function')

    setup_seed(args['seed'])

    if args["obs_shape"] and args["action_shape"]:
        obs_shape, action_shape = args["obs_shape"], args["action_shape"]
    elif "task" in args.keys():
        from offlinerl.utils.env import get_env_shape
        obs_shape, action_shape = get_env_shape(args['task'])
        args["obs_shape"], args["action_shape"] = obs_shape, action_shape
    else:
        raise NotImplementedError

    net_a = Net(layer_num=args['layer_num'],
                state_shape=obs_shape,
                hidden_layer_size=args['hidden_layer_size'])

    actor = TanhGaussianPolicy(
        preprocess_net=net_a,
        action_shape=action_shape,
        hidden_layer_size=args['hidden_layer_size'],
        conditioned_sigma=True,
    ).to(args['device'])

    actor_optim = optim.Adam(actor.parameters(), lr=args['actor_lr'])

    net_c1 = Net(layer_num=args['layer_num'],
                 state_shape=obs_shape,
                 action_shape=action_shape,
                 concat=True,
                 hidden_layer_size=args['hidden_layer_size'])
    critic1 = Critic(
        preprocess_net=net_c1,
        hidden_layer_size=args['hidden_layer_size'],
    ).to(args['device'])
    critic1_optim = optim.Adam(critic1.parameters(), lr=args['critic_lr'])

    net_c2 = Net(layer_num=args['layer_num'],
                 state_shape=obs_shape,
                 action_shape=action_shape,
                 concat=True,
                 hidden_layer_size=args['hidden_layer_size'])
    critic2 = Critic(
        preprocess_net=net_c2,
        hidden_layer_size=args['hidden_layer_size'],
    ).to(args['device'])
    critic2_optim = optim.Adam(critic2.parameters(), lr=args['critic_lr'])

    if args["use_automatic_entropy_tuning"]:
        if args["target_entropy"]:
            target_entropy = args["target_entropy"]
        else:
            target_entropy = -np.prod(args["action_shape"]).item()
        log_alpha = torch.zeros(1, requires_grad=True, device=args['device'])
        alpha_optimizer = optim.Adam(
            [log_alpha],
            lr=args["actor_lr"],
        )

    nets = {
        "actor": {
            "net": actor,
            "opt": actor_optim
        },
        "critic1": {
            "net": critic1,
            "opt": critic1_optim
        },
        "critic2": {
            "net": critic2,
            "opt": critic2_optim
        },
        "log_alpha": {
            "net": log_alpha,
            "opt": alpha_optimizer,
            "target_entropy": target_entropy
        },
    }

    if args["lagrange_thresh"] >= 0:
        target_action_gap = args["lagrange_thresh"]
        log_alpha_prime = torch.zeros(1,
                                      requires_grad=True,
                                      device=args['device'])
        alpha_prime_optimizer = optim.Adam(
            [log_alpha_prime],
            lr=args["critic_lr"],
        )

        nets.update({
            "log_alpha_prime": {
                "net": log_alpha_prime,
                "opt": alpha_prime_optimizer
            }
        })

    return nets
Exemplo n.º 6
0
def algo_init(args):
    logger.info('Run algo_init function')
    
    setup_seed(args['seed'])

    if args["obs_shape"] and args["action_shape"]:
        obs_shape, action_shape = args["obs_shape"], args["action_shape"]
        max_action = args["max_action"]
    elif "task" in args.keys():
        from offlinerl.utils.env import get_env_shape, get_env_action_range
        obs_shape, action_shape = get_env_shape(args['task'])
        max_action, _ = get_env_action_range(args["task"])
        args["obs_shape"], args["action_shape"] = obs_shape, action_shape
    else:
        raise NotImplementedError
        
    latent_dim = action_shape *2
    vae = VAE(state_dim = obs_shape, 
              action_dim = action_shape, 
              latent_dim = latent_dim, 
              max_action = max_action,
              hidden_size=args["vae_hidden_size"]).to(args['device'])
    
    vae_opt = optim.Adam(vae.parameters(), lr=args["vae_lr"])
    


    if args["latent"]:
        actor = ActorPerturbation(obs_shape, 
                                  action_shape, 
                                  latent_dim, 
                                  max_action,
                                  max_latent_action=2, 
                                  phi=args['phi']).to(args['device'])
        
    else:
        net_a = Net(layer_num = args["layer_num"], 
                    state_shape = obs_shape, 
                    hidden_layer_size = args["hidden_layer_size"])
        actor = Actor(preprocess_net = net_a,
                     action_shape = latent_dim,
                     max_action = max_action,
                     hidden_layer_size = args["hidden_layer_size"]).to(args['device'])

    
    actor_opt = optim.Adam(actor.parameters(), lr=args["actor_lr"])
    
    net_c1 = Net(layer_num = args['layer_num'],
                  state_shape = obs_shape,  
                  action_shape = action_shape,
                  concat = True, 
                  hidden_layer_size = args['hidden_layer_size'])
    critic1 = Critic(preprocess_net = net_c1, 
                     hidden_layer_size = args['hidden_layer_size'],
                    ).to(args['device'])
    critic1_opt = optim.Adam(critic1.parameters(), lr=args['critic_lr'])
    
    net_c2 = Net(layer_num = args['layer_num'],
                  state_shape = obs_shape,  
                  action_shape = action_shape,
                  concat = True, 
                  hidden_layer_size = args['hidden_layer_size'])
    critic2 = Critic(preprocess_net = net_c2, 
                     hidden_layer_size = args['hidden_layer_size'],
                    ).to(args['device'])
    critic2_opt = optim.Adam(critic2.parameters(), lr=args['critic_lr'])
    
    return {
        "vae" : {"net" : vae, "opt" : vae_opt},
        "actor" : {"net" : actor, "opt" : actor_opt},
        "critic1" : {"net" : critic1, "opt" : critic1_opt},
        "critic2" : {"net" : critic2, "opt" : critic2_opt},
    }