コード例 #1
0
ファイル: adam.py プロジェクト: yulkang/vel
    def instantiate(self, model: Model) -> torch.optim.Adam:
        if self.layer_groups:
            parameters = mu.to_parameter_groups(model.get_layer_groups())

            if isinstance(self.lr, collections.Sequence):
                for idx, lr in enumerate(self.lr):
                    parameters[idx]['lr'] = lr

                default_lr = self.lr[0]
            else:
                default_lr = float(self.lr)

            if isinstance(self.weight_decay, collections.Sequence):
                for idx, weight_decay in enumerate(self.weight_decay):
                    parameters[idx]['weight_decay'] = weight_decay

                default_weight_decay = self.weight_decay[0]
            else:
                default_weight_decay = self.weight_decay

            return torch.optim.Adam(parameters,
                                    lr=default_lr,
                                    betas=self.betas,
                                    eps=self.eps,
                                    weight_decay=default_weight_decay,
                                    amsgrad=self.amsgrad)
        else:
            parameters = filter(lambda p: p.requires_grad, model.parameters())

            return torch.optim.Adam(parameters,
                                    lr=self.lr,
                                    betas=self.betas,
                                    eps=self.eps,
                                    weight_decay=self.weight_decay,
                                    amsgrad=self.amsgrad)
コード例 #2
0
    def instantiate(self, model: Model) -> torch.optim.SGD:
        if self.layer_groups:
            parameters = mu.to_parameter_groups(model.get_layer_groups())
        else:
            parameters = filter(lambda p: p.requires_grad, model.parameters())

        return torch.optim.SGD(
            parameters,
            lr=self.lr, momentum=self.momentum, dampening=self.dampening, weight_decay=self.weight_decay,
            nesterov=self.nesterov
        )