예제 #1
0
파일: dga_detector.py 프로젝트: zeblok/clx
    def load_model(self, file_path):
        """ This function load already saved model and sets cuda parameters.

        :param file_path: File path of a model to loaded.
        :type file_path: string
        """
        model_dict = torch.load(file_path)
        model = RNNClassifier(
            model_dict["input_size"],
            model_dict["hidden_size"],
            model_dict["output_size"],
            model_dict["n_layers"],
        )
        model.load_state_dict(model_dict["state_dict"])
        super()._load_model(model)
예제 #2
0
    def load_checkpoint(self, file_path):
        """ This function load already saved model checkpoint and sets cuda parameters.

        :param file_path: File path of a model checkpoint to be loaded.
        :type file_path: string
        """
        checkpoint = torch.load(file_path)
        model = RNNClassifier(
            checkpoint["input_size"],
            checkpoint["hidden_size"],
            checkpoint["output_size"],
            checkpoint["n_layers"],
        )
        model.load_state_dict(checkpoint["state_dict"])
        super().leverage_model(model)
예제 #3
0
 def init_model(self, char_vocab=128, hidden_size=100, n_domain_type=2, n_layers=3):
     """This function instantiates RNNClassifier model to train. And also optimizes to scale it and keep running on parallelism. 
     
     :param char_vocab: Vocabulary size is set to 128 ASCII characters.
     :type char_vocab: int
     :param hidden_size: Hidden size of the network.
     :type hidden_size: int
     :param n_domain_type: Number of domain types.
     :type n_domain_type: int
     :param n_layers: Number of network layers.
     :type n_layers: int
     """
     if self.model is None:
         model = RNNClassifier(char_vocab, hidden_size, n_domain_type, n_layers)
         self.leverage_model(model)
예제 #4
0
 def init_model(self, char_vocab=128, hidden_size=100, n_domain_type=2, n_layers=3):
     if self.model is None:
         model = RNNClassifier(char_vocab, hidden_size, n_domain_type, n_layers)
         self.leverage_model(model)