Exemplo n.º 1
0
    def __init__(self):

        self.model_factory = object_factory.ObjectFactory()
        self.model_factory.register_builder(
            'discrete_a2c', lambda network, **kwargs: models.ModelA2C(network))
        self.model_factory.register_builder(
            'multi_discrete_a2c',
            lambda network, **kwargs: models.ModelA2CMultiDiscrete(network))
        self.model_factory.register_builder(
            'continuous_a2c',
            lambda network, **kwargs: models.ModelA2CContinuous(network))
        self.model_factory.register_builder(
            'continuous_a2c_logstd',
            lambda network, **kwargs: models.ModelA2CContinuousLogStd(network))
        self.model_factory.register_builder(
            'soft_actor_critic',
            lambda network, **kwargs: models.ModelSACContinuous(network))
        #self.model_factory.register_builder('dqn', lambda network, **kwargs : models.AtariDQN(network))

        self.network_factory = object_factory.ObjectFactory()
        self.network_factory.set_builders(NETWORK_REGISTRY)
        self.network_factory.register_builder(
            'actor_critic', lambda **kwargs: network_builder.A2CBuilder())
        self.network_factory.register_builder(
            'resnet_actor_critic',
            lambda **kwargs: network_builder.A2CResnetBuilder())
        self.network_factory.register_builder(
            'rnd_curiosity',
            lambda **kwargs: network_builder.RNDCuriosityBuilder())
        self.network_factory.register_builder(
            'soft_actor_critic', lambda **kwargs: network_builder.SACBuilder())
Exemplo n.º 2
0
    def __init__(self):

        self.model_factory = object_factory.ObjectFactory()
        self.model_factory.register_builder(
            'discrete_a2c', lambda network, **kwargs: models.ModelA2C(network))
        self.model_factory.register_builder(
            'discrete_a2c_lstm',
            lambda network, **kwargs: models.LSTMModelA2C(network))
        self.model_factory.register_builder(
            'continuous_a2c',
            lambda network, **kwargs: models.ModelA2CContinuous(network))
        self.model_factory.register_builder(
            'continuous_a2c_logstd',
            lambda network, **kwargs: models.ModelA2CContinuousLogStd(network))
        self.model_factory.register_builder(
            'continuous_a2c_lstm',
            lambda network, **kwargs: models.LSTMModelA2CContinuous(network))
        self.model_factory.register_builder(
            'continuous_a2c_lstm_logstd', lambda network, **kwargs: models.
            LSTMModelA2CContinuousLogStd(network))
        self.model_factory.register_builder(
            'dqn', lambda network, **kwargs: models.AtariDQN(network))

        self.network_factory = object_factory.ObjectFactory()
        self.network_factory.register_builder(
            'actor_critic', lambda **kwargs: network_builder.A2CBuilder())
        self.network_factory.register_builder(
            'dqn', lambda **kwargs: network_builder.DQNBuilder())
Exemplo n.º 3
0
        def __init__(self, **kwargs):
            nn.Module.__init__(self, **kwargs)

            self.activations_factory = object_factory.ObjectFactory()
            self.activations_factory.register_builder('relu', lambda **kwargs : nn.ReLU(**kwargs))
            self.activations_factory.register_builder('tanh', lambda **kwargs : nn.Tanh(**kwargs))
            self.activations_factory.register_builder('sigmoid', lambda **kwargs : nn.Sigmoid(**kwargs))
            self.activations_factory.register_builder('elu', lambda  **kwargs : nn.ELU(**kwargs))
            self.activations_factory.register_builder('selu', lambda **kwargs : nn.SELU(**kwargs))
            self.activations_factory.register_builder('swish', lambda **kwargs : nn.SiLU(**kwargs))
            self.activations_factory.register_builder('gelu', lambda **kwargs: nn.GELU(**kwargs))
            self.activations_factory.register_builder('softplus', lambda **kwargs : nn.Softplus(**kwargs))
            self.activations_factory.register_builder('None', lambda **kwargs : nn.Identity())

            self.init_factory = object_factory.ObjectFactory()
            #self.init_factory.register_builder('normc_initializer', lambda **kwargs : normc_initializer(**kwargs))
            self.init_factory.register_builder('const_initializer', lambda **kwargs : _create_initializer(nn.init.constant_,**kwargs))
            self.init_factory.register_builder('orthogonal_initializer', lambda **kwargs : _create_initializer(nn.init.orthogonal_,**kwargs))
            self.init_factory.register_builder('glorot_normal_initializer', lambda **kwargs : _create_initializer(nn.init.xavier_normal_,**kwargs))
            self.init_factory.register_builder('glorot_uniform_initializer', lambda **kwargs : _create_initializer(nn.init.xavier_uniform_,**kwargs))
            self.init_factory.register_builder('variance_scaling_initializer', lambda **kwargs : _create_initializer(torch_ext.variance_scaling_initializer,**kwargs))
            self.init_factory.register_builder('random_uniform_initializer', lambda **kwargs : _create_initializer(nn.init.uniform_,**kwargs))
            self.init_factory.register_builder('kaiming_normal', lambda **kwargs : _create_initializer(nn.init.kaiming_normal_,**kwargs))
            self.init_factory.register_builder('orthogonal', lambda **kwargs : _create_initializer(nn.init.orthogonal_,**kwargs))
            self.init_factory.register_builder('default', lambda **kwargs : nn.Identity() )
Exemplo n.º 4
0
    def __init__(self, algo_observer=None):
        self.algo_factory = object_factory.ObjectFactory()
        self.algo_factory.register_builder(
            'a2c_continuous',
            lambda **kwargs: a2c_continuous.A2CAgent(**kwargs))
        self.algo_factory.register_builder(
            'a2c_discrete',
            lambda **kwargs: a2c_discrete.DiscreteA2CAgent(**kwargs))
        self.algo_factory.register_builder(
            'sac', lambda **kwargs: sac_agent.SACAgent(**kwargs))
        #self.algo_factory.register_builder('dqn', lambda **kwargs : dqnagent.DQNAgent(**kwargs))

        self.player_factory = object_factory.ObjectFactory()
        self.player_factory.register_builder(
            'a2c_continuous',
            lambda **kwargs: players.PpoPlayerContinuous(**kwargs))
        self.player_factory.register_builder(
            'a2c_discrete',
            lambda **kwargs: players.PpoPlayerDiscrete(**kwargs))
        self.player_factory.register_builder(
            'sac', lambda **kwargs: players.SACPlayer(**kwargs))
        #self.player_factory.register_builder('dqn', lambda **kwargs : players.DQNPlayer(**kwargs))

        self.algo_observer = algo_observer if algo_observer else DefaultAlgoObserver(
        )
        torch.backends.cudnn.benchmark = True
Exemplo n.º 5
0
    def __init__(self, algo_observer=None):
        self.algo_factory = object_factory.ObjectFactory()
        self.algo_factory.register_builder(
            'a2c_continuous',
            lambda **kwargs: a2c_continuous.A2CAgent(**kwargs))
        self.algo_factory.register_builder(
            'a2c_discrete',
            lambda **kwargs: a2c_discrete.DiscreteA2CAgent(**kwargs))
        #self.algo_factory.register_builder('dqn', lambda **kwargs : dqnagent.DQNAgent(**kwargs))

        self.player_factory = object_factory.ObjectFactory()
        self.player_factory.register_builder(
            'a2c_continuous',
            lambda **kwargs: players.PpoPlayerContinuous(**kwargs))
        self.player_factory.register_builder(
            'a2c_discrete',
            lambda **kwargs: players.PpoPlayerDiscrete(**kwargs))
        #self.player_factory.register_builder('dqn', lambda **kwargs : players.DQNPlayer(**kwargs))

        self.model_builder = model_builder.ModelBuilder()
        self.network_builder = network_builder.NetworkBuilder()

        self.algo_observer = algo_observer

        torch.backends.cudnn.benchmark = True
Exemplo n.º 6
0
 def __init__(self):
     self.network_factory = object_factory.ObjectFactory()
     self.network_factory.set_builders(NETWORK_REGISTRY)
     self.network_factory.register_builder(
         'actor_critic', lambda **kwargs: network_builder.A2CBuilder())
     self.network_factory.register_builder(
         'resnet_actor_critic',
         lambda **kwargs: network_builder.A2CResnetBuilder())
     self.network_factory.register_builder(
         'rnd_curiosity',
         lambda **kwargs: network_builder.RNDCuriosityBuilder())
     self.network_factory.register_builder(
         'soft_actor_critic', lambda **kwargs: network_builder.SACBuilder())
Exemplo n.º 7
0
    def __init__(self):
        self.algo_factory = object_factory.ObjectFactory()
        self.algo_factory.register_builder(
            'a2c_continuous',
            lambda **kwargs: a2c_continuous.A2CAgent(**kwargs))
        self.algo_factory.register_builder(
            'a2c_discrete', lambda **kwargs: a2c_discrete.A2CAgent(**kwargs))
        self.algo_factory.register_builder(
            'dqn', lambda **kwargs: dqnagent.DQNAgent(**kwargs))

        self.player_factory = object_factory.ObjectFactory()
        self.player_factory.register_builder(
            'a2c_continuous',
            lambda **kwargs: players.PpoPlayerContinuous(**kwargs))
        self.player_factory.register_builder(
            'a2c_discrete',
            lambda **kwargs: players.PpoPlayerDiscrete(**kwargs))
        self.player_factory.register_builder(
            'dqn', lambda **kwargs: players.DQNPlayer(**kwargs))

        self.model_builder = model_builder.ModelBuilder()
        self.network_builder = network_builder.NetworkBuilder()
        self.sess = None
Exemplo n.º 8
0
 def __init__(self):
     self.model_factory = object_factory.ObjectFactory()
     self.model_factory.set_builders(MODEL_REGISTRY)
     self.model_factory.register_builder(
         'discrete_a2c', lambda network, **kwargs: models.ModelA2C(network))
     self.model_factory.register_builder(
         'multi_discrete_a2c',
         lambda network, **kwargs: models.ModelA2CMultiDiscrete(network))
     self.model_factory.register_builder(
         'continuous_a2c',
         lambda network, **kwargs: models.ModelA2CContinuous(network))
     self.model_factory.register_builder(
         'continuous_a2c_logstd',
         lambda network, **kwargs: models.ModelA2CContinuousLogStd(network))
     self.model_factory.register_builder(
         'soft_actor_critic',
         lambda network, **kwargs: models.ModelSACContinuous(network))
     self.model_factory.register_builder(
         'central_value',
         lambda network, **kwargs: models.ModelCentralValue(network))
     self.network_builder = NetworkBuilder()