def __init__( self, in_features: int, out_features: int, policy_type: str = None, out_activation: nn.Module = None ): super().__init__() assert policy_type in [ "categorical", "gauss", "real_nvp", "logits", None ] # @TODO: refactor layer_fn = nn.Linear activation_fn = nn.ReLU squashing_fn = nn.Tanh bias = True if policy_type == "categorical": head_size = out_features policy_net = CategoricalPolicy() elif policy_type == "gauss": head_size = out_features * 2 policy_net = GaussPolicy(squashing_fn) elif policy_type == "real_nvp": head_size = out_features * 2 policy_net = RealNVPPolicy( action_size=out_features, layer_fn=layer_fn, activation_fn=activation_fn, squashing_fn=squashing_fn, bias=bias ) else: head_size = out_features policy_net = None policy_type = "logits" self.policy_type = policy_type head_net = SequentialNet( hiddens=[in_features, head_size], layer_fn=nn.Linear, activation_fn=out_activation, norm_fn=None, bias=True ) head_net.apply(outer_init) self.head_net = head_net self.policy_net = policy_net self._policy_fn = None if policy_net is None: self._policy_fn = lambda *args: args[0] elif isinstance( policy_net, (CategoricalPolicy, GaussPolicy, RealNVPPolicy) ): self._policy_fn = policy_net.forward else: raise NotImplementedError
def get_from_params( cls, observation_net_params=None, aggregation_net_params=None, main_net_params=None, ) -> "StateNet": assert observation_net_params is not None assert aggregation_net_params is None, "Lama is not implemented yet" observation_net = SequentialNet(**observation_net_params) main_net = SequentialNet(**main_net_params) net = cls(main_net=main_net, observation_net=observation_net) return net
def __init__(self, action_size, layer_fn, activation_fn=nn.ReLU, bias=True, parity="odd"): """ Conditional affine coupling layer used in Real NVP Bijector. Original paper: https://arxiv.org/abs/1605.08803 Adaptation to RL: https://arxiv.org/abs/1804.02808 Important notes --------------- 1. State embeddings are supposed to have size (action_size * 2). 2. Scale and translation networks used in the Real NVP Bijector both have one hidden layer of (action_size) (activation_fn) units. 3. Parity ("odd" or "even") determines which part of the input is being copied and which is being transformed. """ super().__init__() layer_fn = MODULES.get_if_str(layer_fn) activation_fn = MODULES.get_if_str(activation_fn) self.parity = parity if self.parity == "odd": self.copy_size = action_size // 2 else: self.copy_size = action_size - action_size // 2 self.scale_prenet = SequentialNet( hiddens=[action_size * 2 + self.copy_size, action_size], layer_fn=layer_fn, activation_fn=activation_fn, norm_fn=None, bias=bias) self.scale_net = SequentialNet( hiddens=[action_size, action_size - self.copy_size], layer_fn=layer_fn, activation_fn=None, norm_fn=None, bias=True) self.translation_prenet = SequentialNet( hiddens=[action_size * 2 + self.copy_size, action_size], layer_fn=layer_fn, activation_fn=activation_fn, norm_fn=None, bias=bias) self.translation_net = SequentialNet( hiddens=[action_size, action_size - self.copy_size], layer_fn=layer_fn, activation_fn=None, norm_fn=None, bias=True) inner_init = create_optimal_inner_init(nonlinearity=activation_fn) self.scale_prenet.apply(inner_init) self.scale_net.apply(outer_init) self.translation_prenet.apply(inner_init) self.translation_net.apply(outer_init)
def get_from_params( cls, image_size: int = None, encoder_params: Dict = None, embedding_net_params: Dict = None, heads_params: Dict = None, ) -> "MultiHeadNet": encoder_params_ = deepcopy(encoder_params) embedding_net_params_ = deepcopy(embedding_net_params) heads_params_ = deepcopy(heads_params) model_name = encoder_params_.pop('model') encoder_net = registry.MODELS.get_instance(model_name, **encoder_params_) enc_size = embedding_net_params_.pop('input_channels') embedding_net_params_["hiddens"].insert(0, enc_size) embedding_net = SequentialNet(**embedding_net_params_) emb_size = embedding_net_params_["hiddens"][-1] head_kwargs_ = {} for key, value in heads_params_.items(): head_kwargs_[key] = nn.Linear(emb_size, value, bias=True) head_nets = nn.ModuleDict(head_kwargs_) net = cls( encoder_net=encoder_net, embedding_net=embedding_net, head_nets=head_nets, ) return net
def get_from_params( cls, image_size: int = None, encoder_params: Dict = None, embedding_net_params: Dict = None, heads_params: Dict = None, ) -> "MultiHeadNet": encoder_params_ = deepcopy(encoder_params) embedding_net_params_ = deepcopy(embedding_net_params) heads_params_ = deepcopy(heads_params) encoder_net = ResnetEncoder(**encoder_params_) encoder_input_shape = (3, image_size, image_size) encoder_output = utils.get_network_output(encoder_net, encoder_input_shape) enc_size = encoder_output.nelement() embedding_net_params_["hiddens"].insert(0, enc_size) embedding_net = SequentialNet(**embedding_net_params_) emb_size = embedding_net_params_["hiddens"][-1] head_kwargs_ = {} for key, value in heads_params_.items(): head_kwargs_[key] = nn.Linear(emb_size, value, bias=True) head_nets = nn.ModuleDict(head_kwargs_) net = cls( encoder_net=encoder_net, embedding_net=embedding_net, head_nets=head_nets, ) return net
def __init__(self, enc, n_cls, hiddens, emb_size, activation_fn=torch.nn.ReLU, norm_fn=None, bias=True, dropout=None): super().__init__() self.encoder = enc self.emb_net = SequentialNet(hiddens=hiddens + [emb_size], activation_fn=activation_fn, norm_fn=norm_fn, bias=bias, dropout=dropout) self.head = nn.Linear(emb_size, n_cls, bias=True)
def __init__(self, encoder_params, embedding_net_params, heads_params): super().__init__() encoder_params_ = deepcopy(encoder_params) embedding_net_params_ = deepcopy(embedding_net_params) heads_params_ = deepcopy(heads_params) self.encoder_net = encoder = ResnetEncoder(**encoder_params_) self.enc_size = encoder.out_features if self.enc_size is not None: embedding_net_params_["hiddens"].insert(0, self.enc_size) self.embedding_net = SequentialNet(**embedding_net_params_) self.emb_size = embedding_net_params_["hiddens"][-1] head_kwargs_ = {} for key, value in heads_params_.items(): head_kwargs_[key] = nn.Linear(self.emb_size, value, bias=True) self.heads = nn.ModuleDict(head_kwargs_)
def create_from_params(cls, state_shape, observation_hiddens=None, head_hiddens=None, layer_fn=nn.Linear, activation_fn=nn.ReLU, dropout=None, norm_fn=None, bias=True, layer_order=None, residual=False, out_activation=None, history_aggregation_type=None, lama_poolings=None, **kwargs): assert len(kwargs) == 0 # hack to prevent cycle imports from catalyst.contrib.registry import Registry observation_hiddens = observation_hiddens or [] head_hiddens = head_hiddens or [] layer_fn = Registry.name2nn(layer_fn) activation_fn = Registry.name2nn(activation_fn) norm_fn = Registry.name2nn(norm_fn) out_activation = Registry.name2nn(out_activation) inner_init = create_optimal_inner_init(nonlinearity=activation_fn) if isinstance(state_shape, int): state_shape = (state_shape, ) if len(state_shape) in [1, 2]: # linear case: one observation or several one # state_shape like [history_len, obs_shape] # @TODO: handle lama/rnn correctly if not history_aggregation_type: state_size = reduce(lambda x, y: x * y, state_shape) else: state_size = reduce(lambda x, y: x * y, state_shape[1:]) if len(observation_hiddens) > 0: observation_net = SequentialNet(hiddens=[state_size] + observation_hiddens, layer_fn=layer_fn, dropout=dropout, activation_fn=activation_fn, norm_fn=norm_fn, bias=bias, layer_order=layer_order, residual=residual) observation_net.apply(inner_init) obs_out = observation_hiddens[-1] else: observation_net = None obs_out = state_size elif len(state_shape) in [3, 4]: # cnn case: one image or several one @TODO raise NotImplementedError else: raise NotImplementedError assert obs_out if history_aggregation_type == "lama_obs": aggregation_net = LamaPooling(features_in=obs_out, poolings=lama_poolings) aggregation_out = aggregation_net.features_out else: aggregation_net = None aggregation_out = obs_out main_net = SequentialNet(hiddens=[aggregation_out] + head_hiddens[:-1], layer_fn=layer_fn, dropout=dropout, activation_fn=activation_fn, norm_fn=norm_fn, bias=bias, layer_order=layer_order, residual=residual) main_net.apply(inner_init) # @TODO: place for memory network head_net = SequentialNet(hiddens=[head_hiddens[-2], head_hiddens[-1]], layer_fn=nn.Linear, activation_fn=out_activation, norm_fn=None, bias=True) head_net.apply(outer_init) critic_net = cls(observation_net=observation_net, aggregation_net=aggregation_net, main_net=main_net, head_net=head_net, policy_net=None) return critic_net
def __init__(self, in_features: int, out_features: int, policy_type: str = None, out_activation: nn.Module = None): super().__init__() assert policy_type in [ "categorical", "bernoulli", "diagonal-gauss", "squashing-gauss", "real-nvp", "logits", None ] # @TODO: refactor layer_fn = nn.Linear activation_fn = nn.ReLU squashing_fn = out_activation bias = True if policy_type == "categorical": assert out_activation is None head_size = out_features policy_net = CategoricalPolicy() elif policy_type == "bernoulli": assert out_activation is None head_size = out_features policy_net = BernoulliPolicy() elif policy_type == "diagonal-gauss": head_size = out_features * 2 policy_net = DiagonalGaussPolicy() elif policy_type == "squashing-gauss": out_activation = None head_size = out_features * 2 policy_net = SquashingGaussPolicy(squashing_fn) elif policy_type == "real-nvp": out_activation = None head_size = out_features * 2 policy_net = RealNVPPolicy(action_size=out_features, layer_fn=layer_fn, activation_fn=activation_fn, squashing_fn=squashing_fn, bias=bias) else: head_size = out_features policy_net = None policy_type = "logits" self.policy_type = policy_type head_net = SequentialNet( hiddens=[in_features, head_size], layer_fn={ "module": layer_fn, "bias": True }, activation_fn=out_activation, norm_fn=None, ) head_net.apply(outer_init) self.head_net = head_net self.policy_net = policy_net self._policy_fn = None if policy_net is not None: self._policy_fn = policy_net.forward else: self._policy_fn = lambda *args: args[0]
def create_from_params( cls, state_shape, action_size, observation_hiddens=None, head_hiddens=None, layer_fn=nn.Linear, activation_fn=nn.ReLU, dropout=None, norm_fn=None, bias=True, layer_order=None, residual=False, out_activation=None, observation_aggregation=None, lama_poolings=None, policy_type=None, squashing_fn=nn.Tanh, **kwargs ): assert len(kwargs) == 0 observation_hiddens = observation_hiddens or [] head_hiddens = head_hiddens or [] layer_fn = MODULES.get_if_str(layer_fn) activation_fn = MODULES.get_if_str(activation_fn) norm_fn = MODULES.get_if_str(norm_fn) out_activation = MODULES.get_if_str(out_activation) inner_init = create_optimal_inner_init(nonlinearity=activation_fn) if isinstance(state_shape, int): state_shape = (state_shape,) if len(state_shape) in [1, 2]: # linear case: one observation or several one # state_shape like [history_len, obs_shape] # @TODO: handle lama/rnn correctly if not observation_aggregation: observation_size = reduce(lambda x, y: x * y, state_shape) else: observation_size = reduce(lambda x, y: x * y, state_shape[1:]) if len(observation_hiddens) > 0: observation_net = SequentialNet( hiddens=[observation_size] + observation_hiddens, layer_fn=layer_fn, dropout=dropout, activation_fn=activation_fn, norm_fn=norm_fn, bias=bias, layer_order=layer_order, residual=residual ) observation_net.apply(inner_init) obs_out = observation_hiddens[-1] else: observation_net = None obs_out = observation_size elif len(state_shape) in [3, 4]: # cnn case: one image or several one @TODO raise NotImplementedError else: raise NotImplementedError assert obs_out if observation_aggregation == "lama_obs": aggregation_net = LamaPooling( features_in=obs_out, poolings=lama_poolings ) aggregation_out = aggregation_net.features_out else: aggregation_net = None aggregation_out = obs_out main_net = SequentialNet( hiddens=[aggregation_out] + head_hiddens, layer_fn=layer_fn, dropout=dropout, activation_fn=activation_fn, norm_fn=norm_fn, bias=bias, layer_order=layer_order, residual=residual ) main_net.apply(inner_init) # @TODO: place for memory network if policy_type == "gauss": head_size = action_size * 2 policy_net = GaussPolicy(squashing_fn) elif policy_type == "real_nvp": head_size = action_size * 2 policy_net = RealNVPPolicy( action_size=action_size, layer_fn=layer_fn, activation_fn=activation_fn, squashing_fn=squashing_fn, norm_fn=None, bias=bias ) else: head_size = action_size policy_net = None head_net = SequentialNet( hiddens=[head_hiddens[-1], head_size], layer_fn=nn.Linear, activation_fn=out_activation, norm_fn=None, bias=True ) head_net.apply(outer_init) actor_net = cls( observation_net=observation_net, aggregation_net=aggregation_net, main_net=main_net, head_net=head_net, policy_net=policy_net ) return actor_net
def __init__(self, encoder, num_classes, feature_net_hiddens=None, emb_net_hiddens=None, activation_fn=torch.nn.ReLU, norm_fn=None, bias=True, dropout=None, consensus=None, kernel_size=1, feature_net_skip_connection=False, early_consensus=True): super().__init__() assert consensus is not None assert kernel_size in [1, 3, 5] consensus = consensus if isinstance(consensus, list) else [consensus] self.consensus = consensus self.encoder = encoder self.dropout = nn.Dropout(dropout) self.feature_net_skip_connection = feature_net_skip_connection self.early_consensus = early_consensus nonlinearity = registry.MODULES.get_if_str(activation_fn) inner_init = create_optimal_inner_init(nonlinearity=nonlinearity) kernel2pad = {1: 0, 3: 1, 5: 2} def layer_fn(in_features, out_features, bias=True): return nn.Conv1d(in_features, out_features, bias=bias, kernel_size=kernel_size, padding=kernel2pad[kernel_size]) if feature_net_hiddens is not None: self.feature_net = SequentialNet( hiddens=[encoder.out_features] + [feature_net_hiddens], layer_fn=layer_fn, norm_fn=norm_fn, activation_fn=activation_fn, ) self.feature_net.apply(inner_init) out_features = feature_net_hiddens else: # if no feature net, then no need of skip connection # (nothing to skip) assert not self.feature_net_skip_connection self.feature_net = lambda x: x out_features = encoder.out_features # Differences are starting here # Input channels to consensus function # (also to embedding net multiplied by len(consensus)) if self.feature_net_skip_connection: in_channels = out_features + encoder.out_features else: in_channels = out_features consensus_fn = OrderedDict() for key in sorted(consensus): if key == "attention": self.attn = nn.Sequential( nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=kernel_size, padding=kernel2pad[kernel_size], bias=True), nn.Softmax(dim=1)) def self_attn_fn(x): x_a = x.transpose(1, 2) x_attn = (self.attn(x_a) * x_a) x_attn = x_attn.transpose(1, 2) x_attn = x_attn.mean(1, keepdim=True) return x_attn consensus_fn["attention"] = self_attn_fn elif key == "avg": consensus_fn[key] = lambda x: x.mean(1, keepdim=True) elif key == "max": consensus_fn[key] = lambda x: x.max(1, keepdim=True)[0] # Not optimized if too more understandable logic if self.early_consensus: out_features = emb_net_hiddens self.emb_net = SequentialNet( hiddens=[in_channels * len(consensus_fn), emb_net_hiddens], layer_fn=nn.Linear, norm_fn=norm_fn, activation_fn=activation_fn, ) self.emb_net.apply(inner_init) else: if self.feature_net_skip_connection: out_features = out_features + self.encoder.out_features else: out_features = out_features self.head = nn.Linear(out_features, num_classes, bias=True) if 'attention' in consensus: self.attn.apply(outer_init) self.head.apply(outer_init) self.consensus_fn = consensus_fn
class CouplingLayer(nn.Module): def __init__(self, action_size, layer_fn, activation_fn=nn.ReLU, bias=True, parity="odd"): """ Conditional affine coupling layer used in Real NVP Bijector. Original paper: https://arxiv.org/abs/1605.08803 Adaptation to RL: https://arxiv.org/abs/1804.02808 Important notes --------------- 1. State embeddings are supposed to have size (action_size * 2). 2. Scale and translation networks used in the Real NVP Bijector both have one hidden layer of (action_size) (activation_fn) units. 3. Parity ("odd" or "even") determines which part of the input is being copied and which is being transformed. """ super().__init__() layer_fn = MODULES.get_if_str(layer_fn) activation_fn = MODULES.get_if_str(activation_fn) self.parity = parity if self.parity == "odd": self.copy_size = action_size // 2 else: self.copy_size = action_size - action_size // 2 self.scale_prenet = SequentialNet( hiddens=[action_size * 2 + self.copy_size, action_size], layer_fn=layer_fn, activation_fn=activation_fn, norm_fn=None, bias=bias) self.scale_net = SequentialNet( hiddens=[action_size, action_size - self.copy_size], layer_fn=layer_fn, activation_fn=None, norm_fn=None, bias=True) self.translation_prenet = SequentialNet( hiddens=[action_size * 2 + self.copy_size, action_size], layer_fn=layer_fn, activation_fn=activation_fn, norm_fn=None, bias=bias) self.translation_net = SequentialNet( hiddens=[action_size, action_size - self.copy_size], layer_fn=layer_fn, activation_fn=None, norm_fn=None, bias=True) inner_init = create_optimal_inner_init(nonlinearity=activation_fn) self.scale_prenet.apply(inner_init) self.scale_net.apply(outer_init) self.translation_prenet.apply(inner_init) self.translation_net.apply(outer_init) def forward(self, action, state_embedding, log_pi): if self.parity == "odd": action_copy = action[:, :self.copy_size] action_transform = action[:, self.copy_size:] else: action_copy = action[:, -self.copy_size:] action_transform = action[:, :-self.copy_size] x = torch.cat((state_embedding, action_copy), dim=1) t = self.translation_prenet(x) t = self.translation_net(t) s = self.scale_prenet(x) s = self.scale_net(s) out_transform = t + action_transform * torch.exp(s) if self.parity == "odd": action = torch.cat((action_copy, out_transform), dim=1) else: action = torch.cat((out_transform, action_copy), dim=1) log_det_jacobian = s.sum(dim=1) log_pi = log_pi - log_det_jacobian return action, log_pi
def get_from_params( cls, backbone_params: Dict = None, neck_params: Dict = None, heads_params: Dict = None, ) -> "GenericModel": backbone_params_ = deepcopy(backbone_params) neck_params_ = deepcopy(neck_params) heads_params_ = deepcopy(heads_params) if "requires_grad" in backbone_params_: requires_grad = backbone_params_.pop("requires_grad") else: requires_grad = False if "pretrained" in backbone_params_: pretrained = backbone_params_.pop("pretrained") else: pretrained = True if backbone_params_["model_name"] in pretrainedmodels.__dict__: model_name = backbone_params_.pop("model_name") backbone = pretrainedmodels.__dict__[model_name]( num_classes=1000, pretrained="imagenet" if pretrained else None) enc_size = backbone.last_linear.in_features # elif backbone_params_["model_name"].startswith("efficientnet"): # if pretrained is not None: # backbone = EfficientNet.from_pretrained(**backbone_params_) # else: # backbone = EfficientNet.from_name(**backbone_params_) # # backbone.set_swish(memory_efficient=True) # # if in_channels != 3: # Conv2d = get_same_padding_conv2d( # image_size=backbone._global_params.image_size) # out_channels = round_filters(32, backbone._global_params) # backbone._conv_stem = Conv2d(in_channels, out_channels, # kernel_size=3, # stride=2, bias=False) # # enc_size = backbone._conv_head.out_channels else: raise NotImplementedError("This model not yet implemented") del backbone.last_linear # backbone._adapt_avg_pooling = nn.AdaptiveAvgPool2d(1) # backbone._dropout = nn.Dropout(p=0.2) neck = None if neck_params_: neck_params_["hiddens"].insert(0, enc_size) emb_size = neck_params_["hiddens"][-1] if neck_params_ is not None: neck = SequentialNet(**neck_params_) # neck.requires_grad = requires_grad else: emb_size = enc_size if heads_params_ is not None: head_kwargs_ = {} for head, params in heads_params_.items(): if isinstance(heads_params_, int): head_kwargs_[head] = nn.Linear(emb_size, params, bias=True) elif isinstance(heads_params_, dict): params["hiddens"].insert(0, emb_size) head_kwargs_[head] = SequentialNet(**params) # head_kwargs_[head].requires_grad = requires_grad heads = nn.ModuleDict(head_kwargs_) else: heads = None model = cls(backbone=backbone, neck=neck, heads=heads) utils.set_requires_grad(model, requires_grad) print(model) return model
def test_config2(): config2 = { "in_features": 16, "heads_params": { "head1": { "hiddens": [2], "layer_fn": { "module": "Linear", "bias": True }, }, "_head2": { "_hidden": { "hiddens": [16], "layer_fn": { "module": "Linear", "bias": False }, }, "head2_1": { "hiddens": [32], "layer_fn": { "module": "Linear", "bias": True }, "normalize_output": True }, "_head2_2": { "_hidden": { "hiddens": [16, 16, 16], "layer_fn": { "module": "Linear", "bias": False }, }, "head2_2_1": { "hiddens": [32], "layer_fn": { "module": "Linear", "bias": True }, "normalize_output": False, }, }, }, }, } hydra = Hydra.get_from_params(**config2) config2_ = copy.deepcopy(config2) _pop_normalization(config2_) heads_params = config2_["heads_params"] heads_params["head1"]["hiddens"].insert(0, 16) heads_params["_head2"]["_hidden"]["hiddens"].insert(0, 16) heads_params["_head2"]["head2_1"]["hiddens"].insert(0, 16) heads_params["_head2"]["_head2_2"]["_hidden"]["hiddens"].insert(0, 16) heads_params["_head2"]["_head2_2"]["head2_2_1"]["hiddens"].insert(0, 16) net = nn.ModuleDict({ "encoder": nn.Sequential(), "heads": nn.ModuleDict({ "head1": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["head1"])), ])), "_head2": nn.ModuleDict({ "_hidden": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["_head2"]["_hidden"])) ])), "head2_1": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["_head2"]["head2_1"])), ("normalize", Normalize()), ])), "_head2_2": nn.ModuleDict({ "_hidden": nn.Sequential( OrderedDict([("net", SequentialNet(**heads_params["_head2"] ["_head2_2"]["_hidden"])) ])), "head2_2_1": nn.Sequential( OrderedDict([("net", SequentialNet(**heads_params["_head2"] ["_head2_2"]["head2_2_1"])) ])), }) }) }) }) _check_named_parameters(hydra.encoder, net["encoder"]) _check_named_parameters(hydra.heads, net["heads"]) assert hydra.embedders == {} input_ = torch.rand(1, 16) output_kv = hydra(input_) assert (input_ == output_kv["features"]).sum().item() == 16 assert (input_ == output_kv["embeddings"]).sum().item() == 16 kv_keys = [ "features", "embeddings", "head1", "_head2/", "_head2/head2_1", "_head2/_head2_2/", "_head2/_head2_2/head2_2_1", ] _check_lists(output_kv.keys(), kv_keys) with pytest.raises(KeyError): output_kv = hydra(input_, target1=torch.ones(1, 2).long()) with pytest.raises(KeyError): output_kv = hydra(input_, target2=torch.ones(1, 2).long()) with pytest.raises(KeyError): output_kv = hydra(input_, target1=torch.ones(1, 2).long(), target2=torch.ones(1, 2).long()) output_tuple = hydra.forward_tuple(input_) assert len(output_tuple) == 5 assert (output_tuple[0] == output_kv["features"]).sum().item() == 16 assert (output_tuple[1] == output_kv["embeddings"]).sum().item() == 16
class TSN(nn.Module): def __init__(self, encoder, num_classes, feature_net_hiddens=None, emb_net_hiddens=None, activation_fn=torch.nn.ReLU, norm_fn=None, bias=True, dropout=None, consensus=None, kernel_size=1, feature_net_skip_connection=False, early_consensus=True): super().__init__() assert consensus is not None assert kernel_size in [1, 3, 5] consensus = consensus if isinstance(consensus, list) else [consensus] self.consensus = consensus self.encoder = encoder self.dropout = nn.Dropout(dropout) self.feature_net_skip_connection = feature_net_skip_connection self.early_consensus = early_consensus nonlinearity = registry.MODULES.get_if_str(activation_fn) inner_init = create_optimal_inner_init(nonlinearity=nonlinearity) kernel2pad = {1: 0, 3: 1, 5: 2} def layer_fn(in_features, out_features, bias=True): return nn.Conv1d(in_features, out_features, bias=bias, kernel_size=kernel_size, padding=kernel2pad[kernel_size]) if feature_net_hiddens is not None: self.feature_net = SequentialNet( hiddens=[encoder.out_features] + [feature_net_hiddens], layer_fn=layer_fn, norm_fn=norm_fn, activation_fn=activation_fn, ) self.feature_net.apply(inner_init) out_features = feature_net_hiddens else: # if no feature net, then no need of skip connection # (nothing to skip) assert not self.feature_net_skip_connection self.feature_net = lambda x: x out_features = encoder.out_features # Differences are starting here # Input channels to consensus function # (also to embedding net multiplied by len(consensus)) if self.feature_net_skip_connection: in_channels = out_features + encoder.out_features else: in_channels = out_features consensus_fn = OrderedDict() for key in sorted(consensus): if key == "attention": self.attn = nn.Sequential( nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=kernel_size, padding=kernel2pad[kernel_size], bias=True), nn.Softmax(dim=1)) def self_attn_fn(x): x_a = x.transpose(1, 2) x_attn = (self.attn(x_a) * x_a) x_attn = x_attn.transpose(1, 2) x_attn = x_attn.mean(1, keepdim=True) return x_attn consensus_fn["attention"] = self_attn_fn elif key == "avg": consensus_fn[key] = lambda x: x.mean(1, keepdim=True) elif key == "max": consensus_fn[key] = lambda x: x.max(1, keepdim=True)[0] # Not optimized if too more understandable logic if self.early_consensus: out_features = emb_net_hiddens self.emb_net = SequentialNet( hiddens=[in_channels * len(consensus_fn), emb_net_hiddens], layer_fn=nn.Linear, norm_fn=norm_fn, activation_fn=activation_fn, ) self.emb_net.apply(inner_init) else: if self.feature_net_skip_connection: out_features = out_features + self.encoder.out_features else: out_features = out_features self.head = nn.Linear(out_features, num_classes, bias=True) if 'attention' in consensus: self.attn.apply(outer_init) self.head.apply(outer_init) self.consensus_fn = consensus_fn def forward(self, input): if len(input.shape) < 5: input = input.unsqueeze(1) bs, fl, ch, h, w = input.shape x = input.view(-1, ch, h, w) x = self.encoder(x) x = self.dropout(x) identity = x # in simple case feature_net is identity mapping x = x.view(bs, fl, -1) x = x.transpose(1, 2) x = self.feature_net(x) x = x.transpose(1, 2).contiguous() # because conv1d x = x.view(bs * fl, -1) if self.feature_net_skip_connection: x = torch.cat([identity, x], dim=-1) else: x = x if self.early_consensus: x = x.view(bs, fl, -1) c_list = [] for c_fn in self.consensus_fn.values(): c_res = c_fn(x) c_list.append(c_res) x = torch.cat(c_list, dim=1) x = x.view(bs, -1) x = self.emb_net(x) x = self.head(x) if not self.early_consensus: x = x.view(bs, fl, -1) if self.consensus[0] == "avg": x = x.mean(1, keepdim=False) elif self.consensus[0] == "attention": identity = identity.view(bs, fl, -1) x_a = identity.transpose(1, 2) x_ = x.transpose(1, 2) x_attn = (self.attn(x_a) * x_) x_attn = x_attn.transpose(1, 2) x = x_attn.sum(1, keepdim=False) x = torch.sigmoid(x) # with bce loss return x
def test_config4(): config_path = Path(__file__).absolute().parent / "config4.yml" config4 = utils.load_config(config_path)["model_params"] with pytest.raises(AssertionError): hydra = Hydra.get_from_params(**config4) config4["in_features"] = 16 hydra = Hydra.get_from_params(**config4) config4_ = copy.deepcopy(config4) _pop_normalization(config4_) heads_params = config4_["heads_params"] heads_params["head1"]["hiddens"].insert(0, 16) heads_params["_head2"]["_hidden"]["hiddens"].insert(0, 16) heads_params["_head2"]["head2_1"]["hiddens"].insert(0, 16) heads_params["_head2"]["_head2_2"]["_hidden"]["hiddens"].insert(0, 16) heads_params["_head2"]["_head2_2"]["head2_2_1"]["hiddens"].insert(0, 16) net = nn.ModuleDict({ "encoder": nn.Sequential(), "heads": nn.ModuleDict({ "head1": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["head1"])), ])), "_head2": nn.ModuleDict({ "_hidden": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["_head2"]["_hidden"])) ])), "head2_1": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["_head2"]["head2_1"])), ("normalize", Normalize()), ])), "_head2_2": nn.ModuleDict({ "_hidden": nn.Sequential( OrderedDict([("net", SequentialNet(**heads_params["_head2"] ["_head2_2"]["_hidden"])) ])), "head2_2_1": nn.Sequential( OrderedDict([("net", SequentialNet(**heads_params["_head2"] ["_head2_2"]["head2_2_1"])) ])), }) }) }) }) _check_named_parameters(hydra.encoder, net["encoder"]) _check_named_parameters(hydra.heads, net["heads"]) assert hydra.embedders == {} input_ = torch.rand(1, 16) output_kv = hydra(input_) assert (input_ == output_kv["features"]).sum().item() == 16 assert (input_ == output_kv["embeddings"]).sum().item() == 16 kv_keys = [ "features", "embeddings", "head1", "_head2/", "_head2/head2_1", "_head2/_head2_2/", "_head2/_head2_2/head2_2_1", ] _check_lists(output_kv.keys(), kv_keys) with pytest.raises(KeyError): output_kv = hydra(input_, target1=torch.ones(1, 2).long()) with pytest.raises(KeyError): output_kv = hydra(input_, target2=torch.ones(1, 2).long()) with pytest.raises(KeyError): output_kv = hydra(input_, target1=torch.ones(1, 2).long(), target2=torch.ones(1, 2).long()) output_tuple = hydra.forward_tuple(input_) assert len(output_tuple) == 5 assert (output_tuple[0] == output_kv["features"]).sum().item() == 16 assert (output_tuple[1] == output_kv["embeddings"]).sum().item() == 16
def test_config3(): config_path = Path(__file__).absolute().parent / "config3.yml" config3 = utils.load_config(config_path)["model_params"] hydra = Hydra.get_from_params(**config3) config3_ = copy.deepcopy(config3) _pop_normalization(config3_) encoder_params = config3_["encoder_params"] heads_params = config3_["heads_params"] heads_params["head1"]["hiddens"].insert(0, 16) heads_params["_head2"]["_hidden"]["hiddens"].insert(0, 16) heads_params["_head2"]["head2_1"]["hiddens"].insert(0, 16) heads_params["_head2"]["_head2_2"]["_hidden"]["hiddens"].insert(0, 16) heads_params["_head2"]["_head2_2"]["head2_2_1"]["hiddens"].insert(0, 16) net = nn.ModuleDict({ "encoder": SequentialNet(**encoder_params), "embedders": nn.ModuleDict({ "target1": nn.Sequential( OrderedDict([ ("embedding", nn.Embedding(embedding_dim=16, num_embeddings=2)), ("normalize", Normalize()), ])), "target2": nn.Sequential( OrderedDict([ ("embedding", nn.Embedding(embedding_dim=16, num_embeddings=2)), ])), }), "heads": nn.ModuleDict({ "head1": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["head1"])), ])), "_head2": nn.ModuleDict({ "_hidden": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["_head2"]["_hidden"])) ])), "head2_1": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["_head2"]["head2_1"])), ("normalize", Normalize()), ])), "_head2_2": nn.ModuleDict({ "_hidden": nn.Sequential( OrderedDict([("net", SequentialNet(**heads_params["_head2"] ["_head2_2"]["_hidden"])) ])), "head2_2_1": nn.Sequential( OrderedDict([("net", SequentialNet(**heads_params["_head2"] ["_head2_2"]["head2_2_1"])) ])), }) }) }) }) _check_named_parameters(hydra.encoder, net["encoder"]) _check_named_parameters(hydra.heads, net["heads"]) _check_named_parameters(hydra.embedders, net["embedders"]) input_ = torch.rand(1, 16) output_kv = hydra(input_) assert (input_ == output_kv["features"]).sum().item() == 16 kv_keys = [ "features", "embeddings", "head1", "_head2/", "_head2/head2_1", "_head2/_head2_2/", "_head2/_head2_2/head2_2_1", ] _check_lists(output_kv.keys(), kv_keys) output_kv = hydra(input_, target1=torch.ones(1, 2).long()) kv_keys = [ "features", "embeddings", "head1", "_head2/", "_head2/head2_1", "_head2/_head2_2/", "_head2/_head2_2/head2_2_1", "target1_embeddings", ] _check_lists(output_kv.keys(), kv_keys) output_kv = hydra(input_, target2=torch.ones(1, 2).long()) kv_keys = [ "features", "embeddings", "head1", "_head2/", "_head2/head2_1", "_head2/_head2_2/", "_head2/_head2_2/head2_2_1", "target2_embeddings", ] _check_lists(output_kv.keys(), kv_keys) output_kv = hydra(input_, target1=torch.ones(1, 2).long(), target2=torch.ones(1, 2).long()) kv_keys = [ "features", "embeddings", "head1", "_head2/", "_head2/head2_1", "_head2/_head2_2/", "_head2/_head2_2/head2_2_1", "target1_embeddings", "target2_embeddings", ] _check_lists(output_kv.keys(), kv_keys) output_tuple = hydra.forward_tuple(input_) assert len(output_tuple) == 5 assert (output_tuple[0] == output_kv["features"]).sum().item() == 16 assert (output_tuple[1] == output_kv["embeddings"]).sum().item() == 16
def test_config1(): config1 = { "encoder_params": { "hiddens": [16, 16], "layer_fn": { "module": "Linear", "bias": False }, "norm_fn": "LayerNorm", }, "heads_params": { "head1": { "hiddens": [2], "layer_fn": { "module": "Linear", "bias": True }, }, "_head2": { "_hidden": { "hiddens": [16], "layer_fn": { "module": "Linear", "bias": False }, }, "head2_1": { "hiddens": [32], "layer_fn": { "module": "Linear", "bias": True }, "normalize_output": True }, "_head2_2": { "_hidden": { "hiddens": [16, 16, 16], "layer_fn": { "module": "Linear", "bias": False }, }, "head2_2_1": { "hiddens": [32], "layer_fn": { "module": "Linear", "bias": True }, "normalize_output": False, }, }, }, }, "embedders_params": { "target1": { "num_embeddings": 2, "normalize_output": True, }, "target2": { "num_embeddings": 2, "normalize_output": False, }, } } hydra = Hydra.get_from_params(**config1) config1_ = copy.deepcopy(config1) _pop_normalization(config1_) encoder_params = config1_["encoder_params"] heads_params = config1_["heads_params"] heads_params["head1"]["hiddens"].insert(0, 16) heads_params["_head2"]["_hidden"]["hiddens"].insert(0, 16) heads_params["_head2"]["head2_1"]["hiddens"].insert(0, 16) heads_params["_head2"]["_head2_2"]["_hidden"]["hiddens"].insert(0, 16) heads_params["_head2"]["_head2_2"]["head2_2_1"]["hiddens"].insert(0, 16) net = nn.ModuleDict({ "encoder": SequentialNet(**encoder_params), "embedders": nn.ModuleDict({ "target1": nn.Sequential( OrderedDict([ ("embedding", nn.Embedding(embedding_dim=16, num_embeddings=2)), ("normalize", Normalize()), ])), "target2": nn.Sequential( OrderedDict([ ("embedding", nn.Embedding(embedding_dim=16, num_embeddings=2)), ])), }), "heads": nn.ModuleDict({ "head1": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["head1"])), ])), "_head2": nn.ModuleDict({ "_hidden": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["_head2"]["_hidden"])) ])), "head2_1": nn.Sequential( OrderedDict([ ("net", SequentialNet(**heads_params["_head2"]["head2_1"])), ("normalize", Normalize()), ])), "_head2_2": nn.ModuleDict({ "_hidden": nn.Sequential( OrderedDict([("net", SequentialNet(**heads_params["_head2"] ["_head2_2"]["_hidden"])) ])), "head2_2_1": nn.Sequential( OrderedDict([("net", SequentialNet(**heads_params["_head2"] ["_head2_2"]["head2_2_1"])) ])), }) }) }) }) _check_named_parameters(hydra.encoder, net["encoder"]) _check_named_parameters(hydra.heads, net["heads"]) _check_named_parameters(hydra.embedders, net["embedders"]) input_ = torch.rand(1, 16) output_kv = hydra(input_) assert (input_ == output_kv["features"]).sum().item() == 16 kv_keys = [ "features", "embeddings", "head1", "_head2/", "_head2/head2_1", "_head2/_head2_2/", "_head2/_head2_2/head2_2_1", ] _check_lists(output_kv.keys(), kv_keys) output_kv = hydra(input_, target1=torch.ones(1, 2).long()) kv_keys = [ "features", "embeddings", "head1", "_head2/", "_head2/head2_1", "_head2/_head2_2/", "_head2/_head2_2/head2_2_1", "target1_embeddings", ] _check_lists(output_kv.keys(), kv_keys) output_kv = hydra(input_, target2=torch.ones(1, 2).long()) kv_keys = [ "features", "embeddings", "head1", "_head2/", "_head2/head2_1", "_head2/_head2_2/", "_head2/_head2_2/head2_2_1", "target2_embeddings", ] _check_lists(output_kv.keys(), kv_keys) output_kv = hydra(input_, target1=torch.ones(1, 2).long(), target2=torch.ones(1, 2).long()) kv_keys = [ "features", "embeddings", "head1", "_head2/", "_head2/head2_1", "_head2/_head2_2/", "_head2/_head2_2/head2_2_1", "target1_embeddings", "target2_embeddings", ] _check_lists(output_kv.keys(), kv_keys) output_tuple = hydra.forward_tuple(input_) assert len(output_tuple) == 5 assert (output_tuple[0] == output_kv["features"]).sum().item() == 16 assert (output_tuple[1] == output_kv["embeddings"]).sum().item() == 16