def __init__( self, in_features: int, out_features: int, policy_type: str = None, out_activation: nn.Module = None ): super().__init__() assert policy_type in [ "categorical", "gauss", "real_nvp", "logits", None ] # @TODO: refactor layer_fn = nn.Linear activation_fn = nn.ReLU squashing_fn = nn.Tanh bias = True if policy_type == "categorical": head_size = out_features policy_net = CategoricalPolicy() elif policy_type == "gauss": head_size = out_features * 2 policy_net = GaussPolicy(squashing_fn) elif policy_type == "real_nvp": head_size = out_features * 2 policy_net = RealNVPPolicy( action_size=out_features, layer_fn=layer_fn, activation_fn=activation_fn, squashing_fn=squashing_fn, bias=bias ) else: head_size = out_features policy_net = None policy_type = "logits" self.policy_type = policy_type head_net = SequentialNet( hiddens=[in_features, head_size], layer_fn=nn.Linear, activation_fn=out_activation, norm_fn=None, bias=True ) head_net.apply(outer_init) self.head_net = head_net self.policy_net = policy_net self._policy_fn = None if policy_net is None: self._policy_fn = lambda *args: args[0] elif isinstance( policy_net, (CategoricalPolicy, GaussPolicy, RealNVPPolicy) ): self._policy_fn = policy_net.forward else: raise NotImplementedError
def __init__(self, in_features: int, out_features: int, policy_type: str = None, out_activation: nn.Module = None): super().__init__() assert policy_type in [ "categorical", "bernoulli", "diagonal-gauss", "squashing-gauss", "real-nvp", "logits", None ] # @TODO: refactor layer_fn = nn.Linear activation_fn = nn.ReLU squashing_fn = out_activation bias = True if policy_type == "categorical": assert out_activation is None head_size = out_features policy_net = CategoricalPolicy() elif policy_type == "bernoulli": assert out_activation is None head_size = out_features policy_net = BernoulliPolicy() elif policy_type == "diagonal-gauss": head_size = out_features * 2 policy_net = DiagonalGaussPolicy() elif policy_type == "squashing-gauss": out_activation = None head_size = out_features * 2 policy_net = SquashingGaussPolicy(squashing_fn) elif policy_type == "real-nvp": out_activation = None head_size = out_features * 2 policy_net = RealNVPPolicy(action_size=out_features, layer_fn=layer_fn, activation_fn=activation_fn, squashing_fn=squashing_fn, bias=bias) else: head_size = out_features policy_net = None policy_type = "logits" self.policy_type = policy_type head_net = SequentialNet( hiddens=[in_features, head_size], layer_fn={ "module": layer_fn, "bias": True }, activation_fn=out_activation, norm_fn=None, ) head_net.apply(outer_init) self.head_net = head_net self.policy_net = policy_net self._policy_fn = None if policy_net is not None: self._policy_fn = policy_net.forward else: self._policy_fn = lambda *args: args[0]
def create_from_params( cls, state_shape, action_size, observation_hiddens=None, head_hiddens=None, layer_fn=nn.Linear, activation_fn=nn.ReLU, dropout=None, norm_fn=None, bias=True, layer_order=None, residual=False, out_activation=None, observation_aggregation=None, lama_poolings=None, policy_type=None, squashing_fn=nn.Tanh, **kwargs ): assert len(kwargs) == 0 observation_hiddens = observation_hiddens or [] head_hiddens = head_hiddens or [] layer_fn = MODULES.get_if_str(layer_fn) activation_fn = MODULES.get_if_str(activation_fn) norm_fn = MODULES.get_if_str(norm_fn) out_activation = MODULES.get_if_str(out_activation) inner_init = create_optimal_inner_init(nonlinearity=activation_fn) if isinstance(state_shape, int): state_shape = (state_shape,) if len(state_shape) in [1, 2]: # linear case: one observation or several one # state_shape like [history_len, obs_shape] # @TODO: handle lama/rnn correctly if not observation_aggregation: observation_size = reduce(lambda x, y: x * y, state_shape) else: observation_size = reduce(lambda x, y: x * y, state_shape[1:]) if len(observation_hiddens) > 0: observation_net = SequentialNet( hiddens=[observation_size] + observation_hiddens, layer_fn=layer_fn, dropout=dropout, activation_fn=activation_fn, norm_fn=norm_fn, bias=bias, layer_order=layer_order, residual=residual ) observation_net.apply(inner_init) obs_out = observation_hiddens[-1] else: observation_net = None obs_out = observation_size elif len(state_shape) in [3, 4]: # cnn case: one image or several one @TODO raise NotImplementedError else: raise NotImplementedError assert obs_out if observation_aggregation == "lama_obs": aggregation_net = LamaPooling( features_in=obs_out, poolings=lama_poolings ) aggregation_out = aggregation_net.features_out else: aggregation_net = None aggregation_out = obs_out main_net = SequentialNet( hiddens=[aggregation_out] + head_hiddens, layer_fn=layer_fn, dropout=dropout, activation_fn=activation_fn, norm_fn=norm_fn, bias=bias, layer_order=layer_order, residual=residual ) main_net.apply(inner_init) # @TODO: place for memory network if policy_type == "gauss": head_size = action_size * 2 policy_net = GaussPolicy(squashing_fn) elif policy_type == "real_nvp": head_size = action_size * 2 policy_net = RealNVPPolicy( action_size=action_size, layer_fn=layer_fn, activation_fn=activation_fn, squashing_fn=squashing_fn, norm_fn=None, bias=bias ) else: head_size = action_size policy_net = None head_net = SequentialNet( hiddens=[head_hiddens[-1], head_size], layer_fn=nn.Linear, activation_fn=out_activation, norm_fn=None, bias=True ) head_net.apply(outer_init) actor_net = cls( observation_net=observation_net, aggregation_net=aggregation_net, main_net=main_net, head_net=head_net, policy_net=policy_net ) return actor_net
class CouplingLayer(nn.Module): def __init__(self, action_size, layer_fn, activation_fn=nn.ReLU, bias=True, parity="odd"): """ Conditional affine coupling layer used in Real NVP Bijector. Original paper: https://arxiv.org/abs/1605.08803 Adaptation to RL: https://arxiv.org/abs/1804.02808 Important notes --------------- 1. State embeddings are supposed to have size (action_size * 2). 2. Scale and translation networks used in the Real NVP Bijector both have one hidden layer of (action_size) (activation_fn) units. 3. Parity ("odd" or "even") determines which part of the input is being copied and which is being transformed. """ super().__init__() layer_fn = MODULES.get_if_str(layer_fn) activation_fn = MODULES.get_if_str(activation_fn) self.parity = parity if self.parity == "odd": self.copy_size = action_size // 2 else: self.copy_size = action_size - action_size // 2 self.scale_prenet = SequentialNet( hiddens=[action_size * 2 + self.copy_size, action_size], layer_fn=layer_fn, activation_fn=activation_fn, norm_fn=None, bias=bias) self.scale_net = SequentialNet( hiddens=[action_size, action_size - self.copy_size], layer_fn=layer_fn, activation_fn=None, norm_fn=None, bias=True) self.translation_prenet = SequentialNet( hiddens=[action_size * 2 + self.copy_size, action_size], layer_fn=layer_fn, activation_fn=activation_fn, norm_fn=None, bias=bias) self.translation_net = SequentialNet( hiddens=[action_size, action_size - self.copy_size], layer_fn=layer_fn, activation_fn=None, norm_fn=None, bias=True) inner_init = create_optimal_inner_init(nonlinearity=activation_fn) self.scale_prenet.apply(inner_init) self.scale_net.apply(outer_init) self.translation_prenet.apply(inner_init) self.translation_net.apply(outer_init) def forward(self, action, state_embedding, log_pi): if self.parity == "odd": action_copy = action[:, :self.copy_size] action_transform = action[:, self.copy_size:] else: action_copy = action[:, -self.copy_size:] action_transform = action[:, :-self.copy_size] x = torch.cat((state_embedding, action_copy), dim=1) t = self.translation_prenet(x) t = self.translation_net(t) s = self.scale_prenet(x) s = self.scale_net(s) out_transform = t + action_transform * torch.exp(s) if self.parity == "odd": action = torch.cat((action_copy, out_transform), dim=1) else: action = torch.cat((out_transform, action_copy), dim=1) log_det_jacobian = s.sum(dim=1) log_pi = log_pi - log_det_jacobian return action, log_pi
def create_from_params(cls, state_shape, observation_hiddens=None, head_hiddens=None, layer_fn=nn.Linear, activation_fn=nn.ReLU, dropout=None, norm_fn=None, bias=True, layer_order=None, residual=False, out_activation=None, history_aggregation_type=None, lama_poolings=None, **kwargs): assert len(kwargs) == 0 # hack to prevent cycle imports from catalyst.contrib.registry import Registry observation_hiddens = observation_hiddens or [] head_hiddens = head_hiddens or [] layer_fn = Registry.name2nn(layer_fn) activation_fn = Registry.name2nn(activation_fn) norm_fn = Registry.name2nn(norm_fn) out_activation = Registry.name2nn(out_activation) inner_init = create_optimal_inner_init(nonlinearity=activation_fn) if isinstance(state_shape, int): state_shape = (state_shape, ) if len(state_shape) in [1, 2]: # linear case: one observation or several one # state_shape like [history_len, obs_shape] # @TODO: handle lama/rnn correctly if not history_aggregation_type: state_size = reduce(lambda x, y: x * y, state_shape) else: state_size = reduce(lambda x, y: x * y, state_shape[1:]) if len(observation_hiddens) > 0: observation_net = SequentialNet(hiddens=[state_size] + observation_hiddens, layer_fn=layer_fn, dropout=dropout, activation_fn=activation_fn, norm_fn=norm_fn, bias=bias, layer_order=layer_order, residual=residual) observation_net.apply(inner_init) obs_out = observation_hiddens[-1] else: observation_net = None obs_out = state_size elif len(state_shape) in [3, 4]: # cnn case: one image or several one @TODO raise NotImplementedError else: raise NotImplementedError assert obs_out if history_aggregation_type == "lama_obs": aggregation_net = LamaPooling(features_in=obs_out, poolings=lama_poolings) aggregation_out = aggregation_net.features_out else: aggregation_net = None aggregation_out = obs_out main_net = SequentialNet(hiddens=[aggregation_out] + head_hiddens[:-1], layer_fn=layer_fn, dropout=dropout, activation_fn=activation_fn, norm_fn=norm_fn, bias=bias, layer_order=layer_order, residual=residual) main_net.apply(inner_init) # @TODO: place for memory network head_net = SequentialNet(hiddens=[head_hiddens[-2], head_hiddens[-1]], layer_fn=nn.Linear, activation_fn=out_activation, norm_fn=None, bias=True) head_net.apply(outer_init) critic_net = cls(observation_net=observation_net, aggregation_net=aggregation_net, main_net=main_net, head_net=head_net, policy_net=None) return critic_net
class TSN(nn.Module): def __init__(self, encoder, num_classes, feature_net_hiddens=None, emb_net_hiddens=None, activation_fn=torch.nn.ReLU, norm_fn=None, bias=True, dropout=None, consensus=None, kernel_size=1, feature_net_skip_connection=False, early_consensus=True): super().__init__() assert consensus is not None assert kernel_size in [1, 3, 5] consensus = consensus if isinstance(consensus, list) else [consensus] self.consensus = consensus self.encoder = encoder self.dropout = nn.Dropout(dropout) self.feature_net_skip_connection = feature_net_skip_connection self.early_consensus = early_consensus nonlinearity = registry.MODULES.get_if_str(activation_fn) inner_init = create_optimal_inner_init(nonlinearity=nonlinearity) kernel2pad = {1: 0, 3: 1, 5: 2} def layer_fn(in_features, out_features, bias=True): return nn.Conv1d(in_features, out_features, bias=bias, kernel_size=kernel_size, padding=kernel2pad[kernel_size]) if feature_net_hiddens is not None: self.feature_net = SequentialNet( hiddens=[encoder.out_features] + [feature_net_hiddens], layer_fn=layer_fn, norm_fn=norm_fn, activation_fn=activation_fn, ) self.feature_net.apply(inner_init) out_features = feature_net_hiddens else: # if no feature net, then no need of skip connection # (nothing to skip) assert not self.feature_net_skip_connection self.feature_net = lambda x: x out_features = encoder.out_features # Differences are starting here # Input channels to consensus function # (also to embedding net multiplied by len(consensus)) if self.feature_net_skip_connection: in_channels = out_features + encoder.out_features else: in_channels = out_features consensus_fn = OrderedDict() for key in sorted(consensus): if key == "attention": self.attn = nn.Sequential( nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=kernel_size, padding=kernel2pad[kernel_size], bias=True), nn.Softmax(dim=1)) def self_attn_fn(x): x_a = x.transpose(1, 2) x_attn = (self.attn(x_a) * x_a) x_attn = x_attn.transpose(1, 2) x_attn = x_attn.mean(1, keepdim=True) return x_attn consensus_fn["attention"] = self_attn_fn elif key == "avg": consensus_fn[key] = lambda x: x.mean(1, keepdim=True) elif key == "max": consensus_fn[key] = lambda x: x.max(1, keepdim=True)[0] # Not optimized if too more understandable logic if self.early_consensus: out_features = emb_net_hiddens self.emb_net = SequentialNet( hiddens=[in_channels * len(consensus_fn), emb_net_hiddens], layer_fn=nn.Linear, norm_fn=norm_fn, activation_fn=activation_fn, ) self.emb_net.apply(inner_init) else: if self.feature_net_skip_connection: out_features = out_features + self.encoder.out_features else: out_features = out_features self.head = nn.Linear(out_features, num_classes, bias=True) if 'attention' in consensus: self.attn.apply(outer_init) self.head.apply(outer_init) self.consensus_fn = consensus_fn def forward(self, input): if len(input.shape) < 5: input = input.unsqueeze(1) bs, fl, ch, h, w = input.shape x = input.view(-1, ch, h, w) x = self.encoder(x) x = self.dropout(x) identity = x # in simple case feature_net is identity mapping x = x.view(bs, fl, -1) x = x.transpose(1, 2) x = self.feature_net(x) x = x.transpose(1, 2).contiguous() # because conv1d x = x.view(bs * fl, -1) if self.feature_net_skip_connection: x = torch.cat([identity, x], dim=-1) else: x = x if self.early_consensus: x = x.view(bs, fl, -1) c_list = [] for c_fn in self.consensus_fn.values(): c_res = c_fn(x) c_list.append(c_res) x = torch.cat(c_list, dim=1) x = x.view(bs, -1) x = self.emb_net(x) x = self.head(x) if not self.early_consensus: x = x.view(bs, fl, -1) if self.consensus[0] == "avg": x = x.mean(1, keepdim=False) elif self.consensus[0] == "attention": identity = identity.view(bs, fl, -1) x_a = identity.transpose(1, 2) x_ = x.transpose(1, 2) x_attn = (self.attn(x_a) * x_) x_attn = x_attn.transpose(1, 2) x = x_attn.sum(1, keepdim=False) x = torch.sigmoid(x) # with bce loss return x