コード例 #1
0
class SimpleStochasticPolicy(Module):
    def __init__(self, **kwargs):
        super(SimpleStochasticPolicy, self).__init__()
        hidden_size = kwargs['linear_layers_size']
        # actor
        self.bn = BatchNorm1d(kwargs['input_dim'])
        self.linears = ModuleList(
            [Linear(kwargs['input_dim'], hidden_size[0])])
        self.linears.extend([
            Linear(hidden_size[i - 1], hidden_size[i])
            for i in range(1, len(hidden_size))
        ])
        self.mu = Linear(hidden_size[-1], kwargs['action_dim'])
        self.log_var = Linear(hidden_size[-1], kwargs['action_dim'])
        # self.log_var = torch.nn.Parameter(torch.zeros(kwargs['action_dim']))

        self.relu = ReLU()
        self.tanh = Tanh()

        self.apply(init_weights)  # xavier uniform init

    def forward(self, input, action=None):
        x = input
        x = self.bn(x)
        for l in self.linears:
            x = l(x)
            x = self.relu(x)
        mu = self.tanh(self.mu(x))
        log_var = -self.relu(self.log_var(x))
        sigmas = log_var.exp().sqrt()
        dists = Normal(mu, sigmas + 1.0e-4)
        if action is None:
            action = dists.rsample()
        log_prob = dists.log_prob(action).sum(dim=-1, keepdim=True)
        return action, log_prob, dists.entropy()
コード例 #2
0
class SimpleDeterministicPolicy(Module):
    def __init__(self, **kwargs):
        super(SimpleDeterministicPolicy, self).__init__()
        hidden_size = kwargs['linear_layers_size']
        # actor
        self.bn = BatchNorm1d(kwargs['input_dim'])
        self.linears = ModuleList(
            [Linear(kwargs['input_dim'], hidden_size[0])])
        self.linears.extend([
            Linear(hidden_size[i - 1], hidden_size[i])
            for i in range(1, len(hidden_size))
        ])
        self.out = Linear(hidden_size[-1], kwargs['action_dim'])

        self.relu = ReLU()
        self.tanh = Tanh()

        self.apply(init_weights)  # xavier uniform init

    def forward(self, input, action=None):
        x = input
        x = self.bn(x)
        for l in self.linears:
            x = l(x)
            x = self.relu(x)
        action = self.tanh(self.out(x))
        return action
コード例 #3
0
class DeterministicLstmModelDynamics(Module):
    def __init__(self, **kwargs):
        super(DeterministicLstmModelDynamics, self).__init__()
        self.__acoustic_state_dim = kwargs['goal_dim']
        self.__action_dim = kwargs['action_dim']
        self.__state_dim = kwargs['state_dim']
        self.__lstm_sizes = kwargs['lstm_layers_size']
        self.__linears_size = kwargs['linear_layers_size']

        input_size = self.__acoustic_state_dim + self.__state_dim + self.__action_dim
        self.__bn1 = torch.nn.BatchNorm1d(input_size)

        self.lstms = ModuleList(
            [LSTM(input_size, self.__lstm_sizes[0], batch_first=True)])
        self.lstms.extend([
            LSTM(self.__lstm_sizes[i - 1],
                 self.__lstm_sizes[i],
                 batch_first=True) for i in range(1, len(self.__lstm_sizes))
        ])
        self.hiddens = [None] * len(self.__lstm_sizes)

        self.linears = ModuleList(
            [Linear(self.__lstm_sizes[-1], self.__linears_size[0])])
        self.linears.extend([
            Linear(self.__linears_size[i - 1], self.__linears_size[i])
            for i in range(1, len(self.__linears_size))
        ])

        self.goal = Linear(self.__linears_size[-1], kwargs['goal_dim'])

        self.state = Linear(self.__linears_size[-1], kwargs['state_dim'])

        self.relu = ReLU()
        self.tanh = Tanh()

        self.apply(init_weights)  # xavier uniform init

    def forward(self, states, actions, hidden=None):
        x = torch.cat((states, actions), -1)
        original_dim = x.shape
        x = self.__bn1(x.view(-1, original_dim[-1]))
        x = x.view(original_dim)

        for i, lstm in enumerate(self.lstms):
            x, self.hiddens[i] = lstm(x, self.hiddens[i])

        for linear in self.linears:
            x = self.relu(linear(x))

        # predict state
        states_out = self.state(x) + states[:, :self.__state_dim]

        # predict goal
        goals_out = self.goal(x) + states[:, self.__state_dim:]

        return states_out, goals_out

    def reset_hidden_state(self):
        self.hiddens = [None] * len(self.__lstm_sizes)
コード例 #4
0
class SimpleStochasticModelDynamics(Module):
    def __init__(self, **kwargs):
        super(SimpleStochasticModelDynamics, self).__init__()
        self.__acoustic_state_dim = kwargs['goal_dim']
        self.__action_dim = kwargs['action_dim']
        self.__state_dim = kwargs['state_dim']
        self.__linears_size = kwargs['linear_layers_size']

        input_size = self.__acoustic_state_dim + self.__state_dim + self.__action_dim
        self.__bn1 = torch.nn.BatchNorm1d(input_size)

        self.linears = ModuleList([Linear(input_size, self.__linears_size[0])])
        self.linears.extend([
            Linear(self.__linears_size[i - 1], self.__linears_size[i])
            for i in range(1, len(self.__linears_size))
        ])

        self.goal_mu = Linear(self.__linears_size[-1], kwargs['goal_dim'])
        self.goal_log_var = Linear(self.__linears_size[-1], kwargs['goal_dim'])

        self.state_mu = Linear(self.__linears_size[-1], kwargs['state_dim'])
        self.state_log_var = Linear(self.__linears_size[-1],
                                    kwargs['state_dim'])
        self.state_log_var_const = torch.zeros(kwargs['state_dim'])

        self.relu = ReLU()
        self.tanh = Tanh()

        self.apply(init_weights)  # xavier uniform init

    def forward(self, states, actions):
        x = torch.cat((states, actions), -1)
        original_dim = x.shape
        x = self.__bn1(x.view(-1, original_dim[-1]))
        x = x.view(original_dim)

        for linear in self.linears:
            x = self.relu(linear(x))

        # predict state
        state_mu = self.tanh(self.state_mu(x))
        # state_log_var = -self.relu(self.state_log_var(x))
        state_log_var = -5. - self.relu(self.state_log_var_const)
        state_sigmas = state_log_var.exp().sqrt()
        state_dists = Normal(state_mu, state_sigmas + 1.0e-4)
        states = state_dists.rsample()
        state_log_prob = state_dists.log_prob(states).sum(dim=-1, keepdim=True)

        # predict goal
        goal_mu = self.tanh(self.goal_mu(x))
        goal_log_var = -self.relu(self.goal_log_var(x))
        goal_sigmas = goal_log_var.exp().sqrt()
        goal_dists = Normal(goal_mu, goal_sigmas + 1.0e-4)
        goals = goal_dists.rsample()
        goal_log_prob = goal_dists.log_prob(goals).sum(dim=-1, keepdim=True)

        return states, goals, state_log_prob, goal_log_prob, state_dists, goal_dists
コード例 #5
0
class SimpleDDPGAgent(Module):
    def __init__(self, **kwargs):
        super(SimpleDDPGAgent, self).__init__()

        hidden_size = kwargs['hidden_size']
        # actor
        self.actor_linears = ModuleList(
            [Linear(kwargs['state_dim'], hidden_size[0])])
        self.actor_linears.extend([
            Linear(hidden_size[i - 1], hidden_size[i])
            for i in range(1, len(hidden_size))
        ])
        self.action = Linear(hidden_size[-1], kwargs['action_dim'])

        # critic
        self.critic_linears = ModuleList([
            Linear(kwargs['state_dim'] + kwargs['action_dim'], hidden_size[0])
        ])
        self.critic_linears.extend([
            Linear(hidden_size[i - 1], hidden_size[i])
            for i in range(1, len(hidden_size))
        ])
        self.q = Linear(hidden_size[-1], 1)

        self.relu = ReLU()
        self.sigmoid = Sigmoid()
        self.tanh = Tanh()

        self.apply(init_weights)  # xavier uniform init

    def act(self, state):
        x = state
        for l in self.actor_linears:
            x = l(x)
            x = self.relu(x)
        action = self.tanh(self.action(x))
        return action

    def Q(self, state, action):
        x = torch.cat([state, action], dim=1)
        for l in self.critic_linears:
            x = l(x)
            x = self.relu(x)
        q = self.q(x)
        return q

    def get_actor_parameters(self):
        return list(self.actor_linears.parameters()) + list(
            self.action.parameters())

    def get_critic_parameters(self):
        return list(self.critic_linears.parameters()) + list(
            self.q.parameters())
コード例 #6
0
class SimpleDeterministicModelDynamics(Module):
    def __init__(self, **kwargs):
        super(SimpleDeterministicModelDynamics, self).__init__()
        self.__acoustic_state_dim = kwargs['goal_dim']
        self.__action_dim = kwargs['action_dim']
        self.__state_dim = kwargs['state_dim']
        self.__linears_size = kwargs['linear_layers_size']

        input_size = self.__acoustic_state_dim + self.__state_dim + self.__action_dim
        self.__bn1 = torch.nn.BatchNorm1d(input_size)

        self.linears = ModuleList([Linear(input_size, self.__linears_size[0])])
        self.linears.extend([
            Linear(self.__linears_size[i - 1], self.__linears_size[i])
            for i in range(1, len(self.__linears_size))
        ])

        self.goal = Linear(self.__linears_size[-1], kwargs['goal_dim'])
        self.state = Linear(self.__linears_size[-1], kwargs['state_dim'])

        self.relu = ReLU()
        self.tanh = Tanh()

        self.apply(init_weights)  # xavier uniform init

    def forward(self, states, actions):
        x = torch.cat((states, actions), -1)
        original_dim = x.shape
        x = self.__bn1(x.view(-1, original_dim[-1]))
        x = x.view(original_dim)

        for linear in self.linears:
            x = self.relu(linear(x))

        # predict state
        states = self.tanh(self.state(x))
        # predict goal
        goals = self.tanh(self.goal(x))

        return states, goals
コード例 #7
0
class MLPFeatureExtractor(FeatureExtractor):

    def __init__(self, num_inputs: int, layers: Sequence[LinearSpec], dropout=0.0):

        super(MLPFeatureExtractor, self).__init__()
        self._num_inputs = num_inputs

        self._extractors = ModuleList()
        last_units = num_inputs

        for units, activation in layers:
            self._extractors.extend([
                Linear(last_units, units),
                activations[activation] if isinstance(activation, str) else activation,
                Dropout(dropout)
            ])
            last_units = units

        self._num_features = last_units

    def forward(self, X):
        for extractor in self._extractors:
            X = extractor(X)
        return X

    def features(self, X):
        with torch.no_grad():
            features = [extractor(X) for extractor in self._extractors if isinstance(extractor, Linear)]
        return features

    @property
    def num_features(self):
        return self._num_features

    @property
    def input_shape(self):
        return (self._num_inputs, )
コード例 #8
0
ファイル: model.py プロジェクト: vellamike/bonito
class Block(Module):
    """
    TCSConv, Batch Normalisation, Activation, Dropout
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 activation,
                 repeat=5,
                 kernel_size=1,
                 stride=1,
                 dilation=1,
                 dropout=0.0,
                 residual=False,
                 separable=False):

        super(Block, self).__init__()

        self.use_res = residual
        self.conv = ModuleList()

        _in_channels = in_channels
        padding = self.get_padding(kernel_size[0], stride[0], dilation[0])

        # add the first n - 1 convolutions + activation
        for _ in range(repeat - 1):
            self.conv.extend(
                self.get_tcs(_in_channels,
                             out_channels,
                             kernel_size=kernel_size,
                             stride=stride,
                             dilation=dilation,
                             padding=padding,
                             separable=separable))

            self.conv.extend(self.get_activation(activation, dropout))
            _in_channels = out_channels

        # add the last conv and batch norm
        self.conv.extend(
            self.get_tcs(_in_channels,
                         out_channels,
                         kernel_size=kernel_size,
                         stride=stride,
                         dilation=dilation,
                         padding=padding,
                         separable=separable))

        # add the residual connection
        if self.use_res:
            self.residual = Sequential(
                *self.get_tcs(in_channels, out_channels))

        # add the activation and dropout
        self.activation = Sequential(*self.get_activation(activation, dropout))

    def get_activation(self, activation, dropout):
        return activation, Dropout(p=dropout)

    def get_padding(self, kernel_size, stride, dilation):
        if stride > 1 and dilation > 1:
            raise ValueError(
                "Dilation and stride can not both be greater than 1")
        return (kernel_size // 2) * dilation

    def get_tcs(self,
                in_channels,
                out_channels,
                kernel_size=1,
                stride=1,
                dilation=1,
                padding=0,
                bias=False,
                separable=False):
        return [
            TCSConv1d(in_channels,
                      out_channels,
                      kernel_size,
                      stride=stride,
                      dilation=dilation,
                      padding=padding,
                      bias=bias,
                      separable=separable),
            BatchNorm1d(out_channels, eps=1e-3, momentum=0.1)
        ]

    def forward(self, x):
        _x = x
        for layer in self.conv:
            _x = layer(_x)
        if self.use_res:
            _x = _x + self.residual(x)
        return self.activation(_x)
コード例 #9
0
class SimplePPOAgent(Module):
    def __init__(self, **kwargs):
        super(SimplePPOAgent, self).__init__()
        hidden_size = kwargs['hidden_size']
        # actor
        self.actor_bn = BatchNorm1d(kwargs['state_dim'])
        self.actor_linears = ModuleList(
            [Linear(kwargs['state_dim'], hidden_size[0])])
        self.actor_linears.extend([
            Linear(hidden_size[i - 1], hidden_size[i])
            for i in range(1, len(hidden_size))
        ])
        self.mu = Linear(hidden_size[-1], kwargs['action_dim'])
        self.log_var = Linear(hidden_size[-1], kwargs['action_dim'])
        # self.log_var = torch.nn.Parameter(torch.zeros(kwargs['action_dim']))

        # critic
        self.critic_bn = BatchNorm1d(kwargs['state_dim'])
        self.critic_linears = ModuleList(
            [Linear(kwargs['state_dim'], hidden_size[0])])
        self.critic_linears.extend([
            Linear(hidden_size[i - 1], hidden_size[i])
            for i in range(1, len(hidden_size))
        ])
        self.v = Linear(hidden_size[-1], 1)

        self.relu = ReLU()
        self.tanh = Tanh()

        self.apply(init_weights)  # xavier uniform init

    def forward(self, state, action=None):
        x = state
        x = self.actor_bn(x)
        for l in self.actor_linears:
            x = l(x)
            x = self.relu(x)
        mu = self.tanh(self.mu(x))
        log_var = -self.relu(self.log_var(x))
        sigmas = log_var.exp().sqrt()
        dists = Normal(mu, sigmas)
        if action is None:
            action = dists.sample()
        log_prob = dists.log_prob(action).sum(dim=-1, keepdim=True)

        x = state
        x = self.critic_bn(x)
        for l in self.critic_linears:
            x = l(x)
            x = self.relu(x)
        v = self.v(x)
        return action, log_prob, dists.entropy(), v

    def get_actor_parameters(self):
        return [
            *self.actor_bn.parameters(), *self.actor_linears.parameters(),
            *self.mu.parameters(), *self.log_var.parameters()
        ]

    def get_critic_parameters(self):
        return [
            *self.critic_bn.parameters(), *self.critic_linears.parameters(),
            *self.v.parameters()
        ]
コード例 #10
0
class BlockDP(Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 activation,
                 kernel_size,
                 norm=None,
                 prediction_size=32,
                 dropout=0.05):
        super(BlockDP, self).__init__()

        self.conv = ModuleList()
        self.conv.extend(
            self.get_tcs(in_channels,
                         out_channels,
                         kernel_size=kernel_size,
                         padding=kernel_size[0] // 2,
                         separable=False))

        self.predictor = torch.nn.Sequential(
            torch.nn.Conv1d(2, prediction_size, 31, stride=1, padding=15),
            torch.nn.BatchNorm1d(prediction_size), torch.nn.SiLU(),
            torch.nn.Conv1d(prediction_size,
                            prediction_size,
                            15,
                            stride=1,
                            padding=7), torch.nn.BatchNorm1d(prediction_size),
            torch.nn.SiLU(),
            torch.nn.Conv1d(prediction_size, 2, 15, stride=1, padding=7))

        self.predictor[-1].weight.data *= 0.01
        self.predictor[-1].bias.data *= 0
        self.predictor[-1].bias.data[0] -= np.log(2)
        self.predictor[-1].bias.data[1] -= np.log(2)
        self.activation = Sequential(*self.get_activation(activation, dropout))
        self.norm_target = norm

        self.register_buffer('norm_mean', torch.ones(1))

    def get_activation(self, activation, dropout):
        return nn.GLU(dim=1), Dropout(p=dropout)

    def row_pool(self, features, moves, weights):
        fw = features * weights.to(features.dtype).unsqueeze(1)

        poses = torch.cumsum(moves.detach(), 0)

        poses = poses.unsqueeze(1)

        floors = torch.floor(poses)
        ceils = floors + 1

        w1 = (1 - (poses - floors)).to(features.dtype)
        w2 = (1 - (ceils - poses)).to(features.dtype)

        out = torch.zeros((int(ceils[-1].item()) + 1, features.shape[1]),
                          device=features.device,
                          dtype=features.dtype)

        out.index_add_(0, floors.to(torch.long).squeeze(1), w1 * fw)
        out.index_add_(0, ceils.to(torch.long).squeeze(1), w2 * fw)

        return out

    def get_tcs(self,
                in_channels,
                out_channels,
                kernel_size=1,
                stride=1,
                dilation=1,
                padding=0,
                bias=False,
                separable=False):
        return [
            TCSConv1d(in_channels,
                      out_channels * 2,
                      kernel_size,
                      stride=stride,
                      dilation=dilation,
                      padding=padding,
                      bias=bias,
                      separable=separable),
            BatchNorm1d(out_channels * 2, eps=1e-3, momentum=0.1)
        ]

    def forward(self, x):
        _x = x
        for layer in self.conv:
            _x = layer(_x)
        _x = self.activation(_x)
        features = _x
        jumps_mat = self.predictor(torch.cat([x, x * x], dim=1))
        weights = torch.sigmoid(jumps_mat[:, 0, :])
        moves = torch.sigmoid(jumps_mat[:, 1, :])
        bmoves = moves
        if self.training:
            renorm = (1 / self.norm_target / moves.mean().detach()).detach()
            self.norm_mean.copy_(0.99 * self.norm_mean + 0.01 * renorm)
        else:
            renorm = self.norm_mean

        moves = moves * renorm

        features = features.permute((0, 2, 1))
        lens = []
        x_evs = []
        for f, m, w in zip(features.unbind(0), moves.unbind(0),
                           weights.unbind(0)):
            pooled = self.row_pool(f, m, w)
            x_evs.append(pooled)
            lens.append(pooled.shape[0])

        x_evs = torch.nn.utils.rnn.pad_sequence(x_evs, True)
        x_evs = x_evs.permute(0, 2, 1)
        x_evs = F.pad(x_evs, (0, 3 - (x_evs.shape[2] % 3)))
        #x_evs = self.activation(x_evs)

        return x_evs, lens, bmoves, weights
コード例 #11
0
class Merger(Component):
    def protocol_average(self, x) -> torch.Tensor:
        for m in self.merged_modules:
            if self.output is None:
                self.output = m(x)
            else:
                self.output.add_(m(x))
        self.output.div_(len(self.merged_modules))
        return self.output

    def protocol_voting(self, x) -> torch.Tensor:
        warn(
            "Some tensor operations in the voting merging do not support CUDA")

        n_classes = len(self.merged_modules[0](x[0]))
        tensors = tuple(
            m(x).argmax(dim=1)[:, None] for m in
            self.merged_modules)  # Assuming m(x) predictions 2-D tensor
        class_predictions = torch.cat(tensors, dim=1)
        voting = class_predictions.mode().values
        self.output = one_hot(voting, num_classes=n_classes).float(
        )  # Returns a probability distribution
        return self.output

    def protocol_max(self, x) -> torch.Tensor:
        warn("Max protocol is uneficient, and does not support CUDA")

        tuple_pred = tuple(m(x) for m in self.merged_modules)
        stack_red = torch.stack(tuple_pred, dim=2)
        max_prob_pred = torch.max(stack_red, dim=1).values
        valid_pd = max_prob_pred.argmax(dim=1)
        # Fill the output with the valid distribution
        self.output = torch.empty_like(tuple_pred[0])
        for i in torch.unique(valid_pd):
            idx = torch.where(valid_pd == i)[0]
            self.output[idx, :] = tuple_pred[i][idx, :]
        return self.output

    def __init__(self,
                 merged_modules: List[Component],
                 protocol=MergeProtocol.AVERAGE):
        if len(merged_modules) == 0:
            raise ValueError("Number of merged components has to be > 0")

        super().__init__(p=1, t=0)
        self.protocol = protocol
        self.merged_modules = ModuleList(merged_modules)
        self.register_buffer("output", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.protocol == MergeProtocol.AVERAGE:
            return self.protocol_average(x)
        elif self.protocol == MergeProtocol.VOTING:
            return self.protocol_voting(x)
        elif self.protocol == MergeProtocol.MAX:
            return self.protocol_max(x)
        else:
            raise ValueError("Merging protocol not supported")

    def get_merge_protocol(self):
        return self.protocol

    def get_merged_modules(self):
        return self.merged_modules

    def update_merge_protocol(self, p: MergeProtocol):
        self.protocol = p

    def add_classifier(self, c: Component):
        self.merged_modules.extend([c])
コード例 #12
0
class SimpleDeterministicModelDynamicsDeltaPredict(Module):
    def __init__(self, **kwargs):
        super(SimpleDeterministicModelDynamicsDeltaPredict, self).__init__()
        self.__acoustic_state_dim = kwargs['goal_dim']
        self.__action_dim = kwargs['action_dim']
        self.__state_dim = kwargs['state_dim']
        self.__acoustic_dim = 26
        self.__linears_size = kwargs['linear_layers_size']

        # input_size = self.__acoustic_state_dim + self.__state_dim + self.__action_dim
        input_size = self.__state_dim + self.__action_dim
        self.__bn1 = torch.nn.BatchNorm1d(input_size)

        self.drop = torch.nn.modules.Dropout(p=0.1)

        # self.artic_state_0 = Linear(self.__state_dim + self.__action_dim - self.__acoustic_dim, 64)
        # self.artic_state_1 = Linear(64, self.__state_dim - self.__acoustic_dim)
        self.artic_state_0 = Linear(
            self.__state_dim + self.__action_dim - self.__acoustic_dim,
            self.__state_dim - self.__acoustic_dim)
        # self.artic_state_1 = Linear(64, )

        self.linears = ModuleList([Linear(input_size, self.__linears_size[0])])
        self.linears.extend([
            Linear(self.__linears_size[i - 1], self.__linears_size[i])
            for i in range(1, len(self.__linears_size))
        ])

        self.goal = Linear(self.__linears_size[-1], kwargs['goal_dim'])
        self.acoustic_state = Linear(self.__linears_size[-1],
                                     self.__acoustic_dim)
        self.state = Linear(self.__state_dim, self.__state_dim)

        self.relu = ReLU()
        self.tanh = Tanh()

        self.apply(init_weights)  # xavier uniform init

    def forward(self, states, actions):
        x = torch.cat((states[:, :self.__state_dim], actions), -1)
        original_dim = x.shape
        x = self.__bn1(x.view(-1, original_dim[-1]))
        x = x.view(original_dim)
        x = self.drop(x)

        # artic
        artic_x = x[:, :self.__state_dim - self.__acoustic_dim]
        actions_x = x[:, -self.__action_dim:]
        # artic_state_delta = self.artic_state_1(self.relu(self.artic_state_0(torch.cat((artic_x, actions_x), -1))))
        artic_state_delta = self.artic_state_0(
            torch.cat((artic_x, actions_x), -1))

        # acoustic
        for linear in self.linears:
            x = self.relu(linear(x))

        # predict state
        acoustic_state_delta = self.acoustic_state(x)

        states_delta = torch.cat((artic_state_delta, acoustic_state_delta), -1)
        # states_delta = self.tanh(torch.cat((artic_state_delta, acoustic_state_delta), -1))
        out_states = self.tanh(states[:, :self.__state_dim] + states_delta)

        # predict goal
        goals_delta = self.tanh(self.goal(x))
        out_goals = self.tanh(states[:, self.__state_dim:] + goals_delta)
        return out_states, out_goals
コード例 #13
0
class Conv2dFeatureExtractor(FeatureExtractor):

    def __init__(self, num_channels: int, width: int, height: int, layers: Sequence[Conv2dSpec], dropout=0.0):

        super(Conv2dFeatureExtractor, self).__init__()

        self._num_channels = num_channels
        self._width = width
        self._height = height

        self._extractors = ModuleList()
        last_kernels = num_channels

        for num_kernels, kernel_size, stride, activation in layers:
            self._extractors.extend([
                Conv2d(last_kernels, num_kernels, kernel_size, stride),
                activations[activation] if isinstance(activation, str) else activation,
                Dropout2d(dropout)
            ])
            last_kernels = num_kernels
            width, height = compute_output((width, height), kernel_size, stride)

        self._num_features = last_kernels * width * height

    def forward(self, X):
        for extractor in self._extractors:
            X = extractor(X)
        F = flatten(X)
        return F

    @property
    def num_features(self):
        return self._num_features

    @property
    def input_shape(self):
        return self._num_channels, self._width, self._height

    def feature_maps(self, X):
        feature_maps = []
        with torch.no_grad():
            for extractor in self._extractors:
                X = extractor(X)
                if isinstance(extractor, Conv2d):
                    w, h = X.shape[-2], X.shape[-1]
                    fmaps = X.view(-1, w, h).cpu()
                    [feature_maps.append(fmap) for fmap in fmaps]
        return feature_maps

    def plot_feature_maps(self, X):
        for feature_map in self.feature_maps(X):
            plt.imshow(feature_map)
            plt.show()

    @property
    def filters(self):
        filters = []
        for extractor in self._extractors:
            if isinstance(extractor, Conv2d):
                weights = extractor.weight.cpu().detach()
                w, h = weights.shape[-2], weights.shape[-1]
                weights = weights.view(-1, w, h).cpu()
                [filters.append(weight) for weight in weights]
        return filters

    def plot_filters(self):
        for filter in self.filters:
            plt.imshow(filter)
            plt.show()