Ejemplo n.º 1
0
import sys
sys.path.insert(0, './')
from rlf import PPO
from rlf import run_policy
from tests.test_run_settings import TestRunSettings
from rlf.policies.actor_critic.dist_actor_critic import DistActorCritic


class PPORunSettings(TestRunSettings):
    def get_policy(self):
        return DistActorCritic()

    def get_algo(self):
        return PPO()


if __name__ == "__main__":
    run_policy(PPORunSettings())
Ejemplo n.º 2
0
            save_prefix = args.load_embeddings_file if args.load_embeddings_file is not None else ''
        if args.gt_embs:
            save_prefix = save_prefix.split('_')[0] + '_engineered'

        vis_embs(args.dist_mem,
                 args.emb_mem,
                 args.num_distributions,
                 args.exp_type,
                 True,
                 save_prefix=save_prefix,
                 args=args,
                 use_idx=None)
        vis_embs(args.dist_mem,
                 args.emb_mem,
                 args.num_distributions,
                 args.exp_type,
                 True,
                 save_prefix=save_prefix + '_train',
                 args=args,
                 use_idx=args.env_interface.train_action_set)
        vis_embs(args.dist_mem,
                 args.emb_mem,
                 args.num_distributions,
                 args.exp_type,
                 True,
                 save_prefix=save_prefix + '_test',
                 args=args,
                 use_idx=args.env_interface.test_action_set)
    else:
        run_policy(run_settings)
Ejemplo n.º 3
0
                get_actor_fn=lambda _, i_shape: MLPBase(
                    i_shape[0],
                    False, (hidden_size, hidden_size),
                    weight_init=reg_init,
                    get_activation=lambda: nn.ReLU()),
                get_actor_head_fn=get_actor_head,
                get_critic_fn=lambda _, i_shape, a_space:
                TwoLayerMlpWithAction(i_shape[0], (hidden_size, hidden_size),
                                      a_space.shape[0],
                                      weight_init=reg_init,
                                      get_activation=lambda: nn.ReLU()),
                get_critic_head_fn=lambda hidden_dim: nn.Linear(hidden_dim, 1),
                use_goal=True)

    def get_algo(self):
        pass_kwargs = {}
        if self.base_args.use_her:
            pass_kwargs['create_storage_buff_fn'] = create_her_storage_buff
        if 'BitFlip' in self.base_args.env_name:
            return QLearning(**pass_kwargs)
        else:
            return DDPG(**pass_kwargs)

    def get_add_args(self, parser):
        super().get_add_args(parser)
        parser.add_argument('--use-her', default=True, type=str2bool)


if __name__ == "__main__":
    run_policy(HerRunSettings())
Ejemplo n.º 4
0
import sys
sys.path.insert(0, './')

from rlf.algos.hier.option_critic import OptionCritic
from rlf import run_policy
from tests.test_run_settings import TestRunSettings
from rlf.policies.options_policy import OptionsPolicy


class OptionCriticRunSettings(TestRunSettings):
    def get_policy(self):
        return OptionsPolicy()

    def get_algo(self):
        return OptionCritic()


if __name__ == "__main__":
    run_policy(OptionCriticRunSettings())
Ejemplo n.º 5
0
import sys
sys.path.insert(0, './')
from rlf import QLearning
from rlf import DQN
from tests.test_run_settings import TestRunSettings
from rlf import run_policy
import rlf.envs.neuron_poker


class DqnRunSettings(TestRunSettings):
    def get_policy(self):
        return DQN()

    def get_algo(self):
        return QLearning()

    def get_add_args(self, parser):
        super().get_add_args(parser)


if __name__ == "__main__":
    run_policy(DqnRunSettings())
Ejemplo n.º 6
0
from rlf import run_policy
from tests.test_run_settings import TestRunSettings
from rlf.policies.actor_critic.dist_actor_critic import DistActorCritic
import torch.nn as nn
import torch.nn.functional as F
from rlf.rl.model import BaseNet, IdentityBase, MLPBase
from rlf.policies.actor_critic.dist_actor_q import DistActorQ, get_sac_actor, get_sac_critic
import torch
import math
from functools import partial


class SACRunSettings(TestRunSettings):
    def get_policy(self):
        return DistActorQ(
            get_critic_fn=partial(get_sac_critic,
                                  hidden_dim=self.base_args.hidden_dim),
            get_actor_fn=partial(get_sac_actor,
                                 hidden_dim=self.base_args.hidden_dim))

    def get_algo(self):
        return SAC()

    def get_add_args(self, parser):
        super().get_add_args(parser)
        parser.add_argument('--hidden-dim', type=int, default=1024)


if __name__ == "__main__":
    run_policy(SACRunSettings())
Ejemplo n.º 7
0
import sys
sys.path.insert(0, './')
from rlf import run_policy
from tests.test_run_settings import TestRunSettings
from rlf.policies.actor_critic.dist_actor_q import DistActorQ, get_sac_actor, get_sac_critic
from rlf.algos.il.sqil import SQIL
from functools import partial


class SqilRunSettings(TestRunSettings):
    def get_policy(self):
        return DistActorQ(get_critic_fn=partial(get_sac_critic,
                                                hidden_dim=256),
                          get_actor_fn=partial(get_sac_actor, hidden_dim=256))

    def get_algo(self):
        return SQIL()


if __name__ == "__main__":
    run_policy(SqilRunSettings())
Ejemplo n.º 8
0
def get_actor_head(hidden_dim, action_dim):
    return nn.Sequential(
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh())

class DDPGRunSettings(TestRunSettings):
    def get_policy(self):
        if 'Pendulum' in self.base_args.env_name:
            hidden_size = 128
        else:
            hidden_size = 256
        return RegActorCritic(
                get_actor_fn=lambda _, i_shape: MLPBase(
                    i_shape[0], False, (hidden_size, hidden_size),
                    weight_init=reg_init,
                    get_activation=lambda: nn.ReLU()),
                get_actor_head_fn=get_actor_head,
                get_critic_fn=lambda _, i_shape, a_space: TwoLayerMlpWithAction(
                    i_shape[0], (hidden_size, hidden_size), a_space.shape[0],
                    weight_init=reg_init,
                    get_activation=lambda: nn.ReLU()),
                get_critic_head_fn = lambda hidden_dim: nn.Linear(hidden_dim, 1)
                )

    def get_algo(self):
        return DDPG()

if __name__ == "__main__":
    run_policy(DDPGRunSettings())
Ejemplo n.º 9
0
import sys
sys.path.insert(0, './')
from rlf import run_policy
from rlf.policies.tabular.action_value_policy import ActionValuePolicy
from rlf.algos.tabular.bandit_algos import SimpleBanditAlgo
from tests.test_run_settings import TestRunSettings
import gym_bandits


class BanditRunSettings(TestRunSettings):
    def get_policy(self):
        return ActionValuePolicy()

    def get_algo(self):
        return SimpleBanditAlgo()


run_policy(BanditRunSettings())
Ejemplo n.º 10
0
from rlf.rl.loggers.wb_logger import WbLogger


class ClassicAlgRunSettings(TestRunSettings):
    def get_policy(self):
        return QTable()

    def get_algo(self):
        if self.base_args.alg_type == 'td':
            return TabularTdMethods()
        elif self.base_args.alg_type == 'mc':
            return TabularMcMethods()
        else:
            raise ValueError(f"Unrecognized option {self.base_args.alg_type}")

    def get_add_args(self, parser):
        super().get_add_args(parser)
        parser.add_argument('--alg-type', type=str, default='td')
        parser.add_argument('--wb', action='store_true', default=False)

    def get_logger(self):
        if self.base_args.wb:
            return WbLogger()
        else:
            return PltLogger(['avg_r'], '# Updates', ['Reward'],
                             ['Cliff Walker'])


if __name__ == "__main__":
    run_policy(ClassicAlgRunSettings())
Ejemplo n.º 11
0
import sys
sys.path.insert(0, './')
from rlf import run_policy
from tests.test_run_settings import TestRunSettings
from rlf.policies.actor_critic.dist_actor_critic import DistActorCritic
from rlf.rl.model import MLPBase
import torch.nn as nn
from rlf import GAIL


def get_discrim():
    return nn.Sequential(nn.Linear(400, 300), nn.Tanh(), nn.Linear(300,
                                                                   1)), 400


class GaifoSRunSettings(TestRunSettings):
    def get_policy(self):
        return DistActorCritic(
            get_actor_fn=lambda _, i_shape: MLPBase(i_shape[0], False,
                                                    (400, 300)),
            get_critic_fn=lambda _, i_shape, a_shape: MLPBase(
                i_shape[0], False, (400, 300)))

    def get_algo(self):
        return GAIL(get_discrim=get_discrim)


if __name__ == "__main__":
    run_policy(GaifoSRunSettings())
Ejemplo n.º 12
0
sys.path.insert(0, './')
from rlf import run_policy
from rlf.policies.tabular.tabular_policy import TabularPolicy
from rlf.algos.tabular.policy_iteration import PolicyIteration
from rlf.algos.tabular.value_iteration import ValueIteration
from tests.test_run_settings import TestRunSettings
from rlf.args import str2bool
from rlf.rl.loggers.plt_logger import PltLogger


class PolicyIterRunSettings(TestRunSettings):
    def get_policy(self):
        return TabularPolicy()

    def get_algo(self):
        if self.base_args.value_iter:
            return ValueIteration()
        else:
            return PolicyIteration()

    def get_add_args(self, parser):
        super().get_add_args(parser)
        parser.add_argument('--value-iter', default=False, type=str2bool)

    def get_logger(self):
        return PltLogger(['eval_train_r'], '# Updates', ['Reward'],
                         ['Frozen Lake'])


run_policy(PolicyIterRunSettings())
Ejemplo n.º 13
0
            return DDPG()
        elif alg == 'bc':
            return BehavioralCloning()
        elif alg == 'gail':
            return GAIL()
        elif alg == 'gaifo':
            return GAIFO()
        elif alg == 'rnd':
            return BaseAlgo()
        elif alg == 'bco':
            return BehavioralCloningFromObs()
        else:
            raise ValueError('Unrecognized alg for optimizer')

    def get_logger(self):
        if self.base_args.no_wb:
            return BaseLogger()
        else:
            return WbLogger()

    def get_config_file(self):
        return './tests/config.yaml'

    def get_add_args(self, parser):
        parser.add_argument('--alg')
        parser.add_argument('--no-wb', default=False, action='store_true')
        parser.add_argument('--env-name')


run_policy(DefaultRunSettings())
Ejemplo n.º 14
0
import sys
sys.path.insert(0, './')
from rlf import BehavioralCloningFromObs
from rlf import BasicPolicy
from rlf import run_policy
from tests.test_run_settings import TestRunSettings
from rlf.policies.actor_critic.dist_actor_critic import DistActorCritic
from rlf.rl.model import MLPBase


class BcoRunSettings(TestRunSettings):
    def get_policy(self):
        return BasicPolicy(is_stoch=self.base_args.stoch_policy,
                           get_base_net_fn=lambda i_shape, recurrent: MLPBase(
                               i_shape[0], False, (400, 300)))

    def get_algo(self):
        return BehavioralCloningFromObs()

    def get_add_args(self, parser):
        super().get_add_args(parser)
        parser.add_argument('--stoch-policy',
                            default=False,
                            action='store_true')


if __name__ == "__main__":
    run_policy(BcoRunSettings())
Ejemplo n.º 15
0
import sys
sys.path.insert(0, './')
from rlf import run_policy
from rlf.algos.off_policy.soft_qlearning import SoftQLearning
from rlf.policies.svgd_policy import SVGDPolicy
from tests.test_run_settings import TestRunSettings
from rlf.args import str2bool
from rlf.rl.loggers.plt_logger import PltLogger


class SoftQLearningRunSettings(TestRunSettings):
    def get_policy(self):
        return SVGDPolicy()

    def get_algo(self):
        return SoftQLearning()


run_policy(SoftQLearningRunSettings())