def _optimizer(c: Configs):
    optimizer = OptimizerConfigs()
    optimizer.parameters = c.model.parameters()
    optimizer.optimizer = 'Adam'
    optimizer.d_model = c.d_model

    return optimizer
예제 #2
0
def _generator_optimizer(c: Configs):
    opt_conf = OptimizerConfigs()
    opt_conf.optimizer = 'Adam'
    opt_conf.parameters = c.generator.parameters()
    opt_conf.learning_rate = 2.5e-4
    # Setting exponent decay rate for first moment of gradient,
    # $\beta_`$ to `0.5` is important.
    # Default of `0.9` fails.
    opt_conf.betas = (0.5, 0.999)
    return opt_conf
예제 #3
0
파일: train_model.py 프로젝트: codeaudit/nn
def transformer_optimizer(c: Configs):
    """
    Create a configurable optimizer.

    Parameters like learning rate can be changed by passing a dictionary when starting the experiment.
    """
    optimizer = OptimizerConfigs()
    optimizer.parameters = c.model.parameters()
    optimizer.d_model = c.transformer.d_model
    optimizer.optimizer = 'Noam'

    return optimizer
예제 #4
0
    def init(self):
        # Create a configurable optimizer.
        # Parameters like learning rate can be changed by passing a dictionary when starting the experiment.
        optimizer = OptimizerConfigs()
        optimizer.parameters = self.model.parameters()
        optimizer.d_model = self.transformer.d_model
        optimizer.optimizer = 'Noam'
        self.optimizer = optimizer

        # Create a sequential data loader for training
        self.train_loader = SequentialDataLoader(text=self.text.train,
                                                 dataset=self.text,
                                                 batch_size=self.batch_size,
                                                 seq_len=self.seq_len)

        # Create a sequential data loader for validation
        self.valid_loader = SequentialDataLoader(text=self.text.valid,
                                                 dataset=self.text,
                                                 batch_size=self.batch_size,
                                                 seq_len=self.seq_len)

        self.state_modules = [self.accuracy]