Exemplo n.º 1
0
    def __init__(self, parser_params, device):
        """
        :param parser_params: parser args
        """

        num_hidden = [int(x) for x in parser_params.num_hidden.split(',')]

        self.seq_length = parser_params.seq_length
        self.batch_size = parser_params.batch_size
        self.patch_size = parser_params.patch_size
        self.num_layers = len(num_hidden)
        networks_map = {'BiLSTM': BiLSTM.RNN, 'Bi-LSTM3': BiLSTM3.RNN}

        if parser_params.model_name in networks_map:
            if parser_params.LSTM_pretrained:
                print("Loading LSTM pretrained model...")
                Network = networks_map[parser_params.model_name]
                self.network = Network(
                    self.num_layers, num_hidden, parser_params.seq_length,
                    parser_params.patch_size, parser_params.batch_size,
                    parser_params.img_size, parser_params.img_channel,
                    parser_params.filter_size, parser_params.stride)
                self.network.load_state_dict(
                    torch.load(parser_params.LSTM_pretrained))

                self.network = DataParallel(self.network, device_ids=[0, 1, 2])
                self.network.to(device)

                ### freeze weight
                for param in self.network.parameters():
                    param.requires_grad = False
            else:
                Network = networks_map[parser_params.model_name]
                self.network = Network(
                    self.num_layers, num_hidden, parser_params.seq_length,
                    parser_params.patch_size, parser_params.batch_size,
                    parser_params.img_size, parser_params.img_channel,
                    parser_params.filter_size, parser_params.stride)
                self.network = DataParallel(self.network, device_ids=[0, 1, 2])
                self.network.to(device)

            # Fine tune network
            Network2 = networks_map[parser_params.model_name]
            self.network2 = Network2(
                self.num_layers, num_hidden, parser_params.seq_length,
                parser_params.patch_size, parser_params.batch_size,
                parser_params.img_size, parser_params.img_channel,
                parser_params.filter_size, parser_params.stride)
            self.network2 = DataParallel(self.network2, device_ids=[0, 1, 2])
            self.network2.to(device)

        else:
            raise ValueError('Name of network unknown {}'.format(
                parser_params.model_name))

        self.criterion = Criterion.Loss()
        self.optimizer2 = Adam(self.network2.parameters(), lr=parser_params.lr)
Exemplo n.º 2
0
    def __init__(self, parser_params, device):
        """
        :param parser_params: parser args
        """

        num_hidden = [int(x) for x in parser_params.num_hidden.split(',')]

        self.seq_length = parser_params.seq_length
        self.batch_size = parser_params.batch_size
        self.patch_size = parser_params.patch_size
        self.num_layers = len(num_hidden)
        self.CA_patch_size = parser_params.CA_patch_size
        networks_map = {'BiLSTM': BiLSTM.RNN, 'Bi-LSTM3': BiLSTM3.RNN}

        if parser_params.model_name in networks_map:
            if parser_params.LSTM_pretrained:
                print("Loading LSTM pretrained model...")
                self.network = torch.load(parser_params.LSTM_pretrained)
                #### freeze weight
                for param in self.network.parameters():
                    param.requires_grad = False
            else:
                Network = networks_map[parser_params.model_name]
                self.network = Network(
                    self.num_layers, num_hidden, parser_params.seq_length,
                    parser_params.patch_size, parser_params.batch_size,
                    parser_params.img_size, parser_params.img_channel,
                    parser_params.filter_size, parser_params.stride)
                self.network = DataParallel(self.network, device_ids=[0, 1])

            self.network.to(device)

            self.CA = ResModule.RES(n_resgroups=parser_params.n_resgroups,
                                    n_resblocks=parser_params.n_resblocks,
                                    n_channel=parser_params.img_channel *
                                    (parser_params.CA_patch_size**2))
            self.CA = DataParallel(self.CA, device_ids=[0, 1])
            self.CA.to(device)
        else:
            raise ValueError('Name of network unknown {}'.format(
                parser_params.model_name))

#         self.optimizer = Adam(self.network.parameters(), lr=parser_params.lr)
        self.criterion = Criterion.Loss()
        self.optimizer_CA = Adam(self.CA.parameters(), lr=parser_params.lr)