Пример #1
0
    def __init__(self, **kwargs):
        self.activations_factory = object_factory.ObjectFactory()
        self.activations_factory.register_builder('relu', lambda **kwargs: tf.nn.relu)
        self.activations_factory.register_builder('tanh', lambda **kwargs: tf.nn.tanh)
        self.activations_factory.register_builder('sigmoid', lambda **kwargs: tf.nn.sigmoid)
        self.activations_factory.register_builder('elu', lambda **kwargs: tf.nn.elu)
        self.activations_factory.register_builder('selu', lambda **kwargs: tf.nn.selu)
        self.activations_factory.register_builder('softplus', lambda **kwargs: tf.nn.softplus)
        self.activations_factory.register_builder('None', lambda **kwargs: None)

        self.init_factory = object_factory.ObjectFactory()
        self.init_factory.register_builder('normc_initializer', lambda **kwargs: normc_initializer(**kwargs))
        self.init_factory.register_builder('const_initializer', lambda **kwargs: tf.constant_initializer(**kwargs))
        self.init_factory.register_builder('orthogonal_initializer',
                                           lambda **kwargs: tf.orthogonal_initializer(**kwargs))
        self.init_factory.register_builder('glorot_normal_initializer',
                                           lambda **kwargs: tf.glorot_normal_initializer(**kwargs))
        self.init_factory.register_builder('glorot_uniform_initializer',
                                           lambda **kwargs: tf.glorot_uniform_initializer(**kwargs))
        self.init_factory.register_builder('variance_scaling_initializer',
                                           lambda **kwargs: tf.variance_scaling_initializer(**kwargs))
        self.init_factory.register_builder('random_uniform_initializer',
                                           lambda **kwargs: tf.random_uniform_initializer(**kwargs))

        self.init_factory.register_builder('None', lambda **kwargs: None)

        self.regularizer_factory = object_factory.ObjectFactory()
        self.regularizer_factory.register_builder('l1_regularizer',
                                                  lambda **kwargs: tf.contrib.layers.l1_regularizer(**kwargs))
        self.regularizer_factory.register_builder('l2_regularizer',
                                                  lambda **kwargs: tf.contrib.layers.l2_regularizer(**kwargs))
        self.regularizer_factory.register_builder('l1l2_regularizer',
                                                  lambda **kwargs: tf.contrib.layers.l1l2_regularizer(**kwargs))
        self.regularizer_factory.register_builder('None', lambda **kwargs: None)
Пример #2
0
    def __init__(self):

        self.model_factory = object_factory.ObjectFactory()
        self.model_factory.register_builder(
            'discrete_a2c', lambda network, **kwargs: models.ModelA2C(network))
        self.model_factory.register_builder(
            'discrete_a2c_lstm',
            lambda network, **kwargs: models.LSTMModelA2C(network))
        self.model_factory.register_builder(
            'continuous_a2c',
            lambda network, **kwargs: models.ModelA2CContinuous(network))
        self.model_factory.register_builder(
            'continuous_a2c_logstd',
            lambda network, **kwargs: models.ModelA2CContinuousLogStd(network))
        self.model_factory.register_builder(
            'continuous_a2c_lstm',
            lambda network, **kwargs: models.LSTMModelA2CContinuous(network))
        self.model_factory.register_builder(
            'continuous_a2c_lstm_logstd', lambda network, **kwargs: models.
            LSTMModelA2CContinuousLogStd(network))
        self.model_factory.register_builder(
            'dqn', lambda network, **kwargs: models.AtariDQN(network))

        self.network_factory = object_factory.ObjectFactory()
        self.network_factory.register_builder(
            'actor_critic', lambda **kwargs: network_builder.A2CBuilder())
        self.network_factory.register_builder(
            'dqn', lambda **kwargs: network_builder.DQNBuilder())
Пример #3
0
        def __init__(self, **kwargs):
            nn.Module.__init__(self, **kwargs)

            self.activations_factory = object_factory.ObjectFactory()
            self.activations_factory.register_builder(
                'relu', lambda **kwargs: nn.ReLU(**kwargs))
            self.activations_factory.register_builder(
                'tanh', lambda **kwargs: nn.Tanh(**kwargs))
            self.activations_factory.register_builder(
                'sigmoid', lambda **kwargs: nn.Sigmoid(**kwargs))
            self.activations_factory.register_builder(
                'elu', lambda **kwargs: nn.ELU(**kwargs))
            self.activations_factory.register_builder(
                'selu', lambda **kwargs: nn.SELU(**kwargs))
            self.activations_factory.register_builder(
                'softplus', lambda **kwargs: nn.Softplus(**kwargs))
            self.activations_factory.register_builder(
                'None', lambda **kwargs: nn.Identity())

            self.init_factory = object_factory.ObjectFactory()
            #self.init_factory.register_builder('normc_initializer', lambda **kwargs : normc_initializer(**kwargs))
            self.init_factory.register_builder(
                'const_initializer', lambda **kwargs: _create_initializer(
                    nn.init.constant_, **kwargs))
            self.init_factory.register_builder(
                'orthogonal_initializer', lambda **kwargs: _create_initializer(
                    nn.init.orthogonal_, **kwargs))
            self.init_factory.register_builder(
                'glorot_normal_initializer', lambda **kwargs:
                _create_initializer(nn.init.xavier_normal_, **kwargs))
            self.init_factory.register_builder(
                'glorot_uniform_initializer', lambda **kwargs:
                _create_initializer(nn.init.xavier_uniform_, **kwargs))
            self.init_factory.register_builder(
                'variance_scaling_initializer',
                lambda **kwargs: _create_initializer(
                    torch_ext.variance_scaling_initializer, **kwargs))
            self.init_factory.register_builder(
                'random_uniform_initializer', lambda **kwargs:
                _create_initializer(nn.init.uniform_, **kwargs))
            self.init_factory.register_builder(
                'kaiming_normal', lambda **kwargs: _create_initializer(
                    nn.init.kaiming_normal_, **kwargs))
            self.init_factory.register_builder('None', lambda **kwargs: None)