def __init__(self, config): super().__init__() self.detach_head = False self.summary_type = config.summary_type if hasattr(config, "summary_type") else "last" if self.summary_type == "attn": raise NotImplementedError self.summary = Identity() if hasattr(config, "summary_use_proj") and config.summary_use_proj: if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = Identity() if hasattr(config, "summary_activation") and config.summary_activation == "tanh": self.activation = nn.Tanh() self.first_dropout = Identity() if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout) self.flatten = nn.Flatten()
def build_model(self): blocks = self.model_config['block'] seq_model = [] for block in blocks: block_type = block.pop('block_type') block_obj = self._build_block(block_type, **block) if isinstance(block_obj, PoolingFilter): if len(seq_model) > 0: self.read_level_encoder = Sequential(*seq_model) else: self.read_level_encoder = None self.pooling_filter = block_obj seq_model = [] else: seq_model.append(block_obj) if (self.read_level_encoder is None) and (self.pooling_filter is None): self.read_level_encoder = Sequential(*seq_model) self.pooling_filter = Identity() self.decoder = Identity() else: if len(seq_model) == 0: self.decoder = Identity() else: self.decoder = Sequential(*seq_model)
def __init__(self, config): super(SequenceSummary, self).__init__() self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' if self.summary_type == 'attn': # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = Identity() if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = Identity() if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': self.activation = nn.Tanh() self.first_dropout = Identity() if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
def __init__(self, config): super().__init__() self.summary_type = config.summary_type if hasattr(config, "summary_type") else "last" if self.summary_type == "attn": # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.summary = Identity() if hasattr(config, "summary_use_proj") and config.summary_use_proj: if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: print(f"num_class: {config.num_labels}") num_classes = config.num_labels else: print(f"num_class here: {config.hidden_size}") num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = Identity() if hasattr(config, "summary_activation") and config.summary_activation == "tanh": self.activation = nn.Tanh() self.first_dropout = Identity() if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, rotary_emb = True, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128): super().__init__() emb_dim = default(emb_dim, dim) self.max_seq_len = max_seq_len self.token_emb = nn.Embedding(num_tokens, emb_dim) self.to_model_dim = Identity() if emb_dim == dim else nn.Linear(emb_dim, dim) self.pos_emb = Always(0) self.layer_pos_emb = Always(None) if rotary_emb: self.layer_pos_emb = FixedPositionalEmbedding(dim_head) elif absolute_position_emb: self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) elif fixed_position_emb: self.pos_emb = FixedPositionalEmbedding(emb_dim) else: axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / bucket_size), bucket_size)) self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_position_shape) self.reformer = Reformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, bucket_size = bucket_size, n_hashes = n_hashes, ff_chunks = ff_chunks, attn_chunks = attn_chunks, causal = causal, weight_tie = weight_tie, lsh_dropout = lsh_dropout, ff_mult = ff_mult, ff_activation = ff_activation, ff_glu = ff_glu, ff_dropout = ff_dropout, post_attn_dropout = 0., layer_dropout = layer_dropout, random_rotations_per_head = random_rotations_per_head, use_scale_norm = use_scale_norm, use_rezero = use_rezero, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, reverse_thres = reverse_thres, num_mem_kv = num_mem_kv, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys) self.norm = nn.LayerNorm(dim) if return_embeddings: self.out = Identity() return self.out = nn.Sequential( nn.Linear(dim, emb_dim) if emb_dim != dim else Identity(), nn.Linear(emb_dim, num_tokens) if not weight_tie_embedding else MatrixMultiply(self.token_emb.weight, transpose=True, normalize=True) )
def get_layer2layer_bridges(cls, config): num_layers = config.num_hidden_layers if not hasattr(config, 'layer2layer_bridges') or not config.layer2layer_bridges \ or config.layer2layer_bridges not in ['adapters', 'model_parallelism']: return nn.ModuleList([Identity() for i in range(num_layers)]) if config.layer2layer_bridges == 'adapters': hidden_size = config.hidden_size if hasattr(config, 'layer2layer_bridges_config'): adapter_config = AdapterConfig(**config.layer2layer_bridges_config) else: adapter_config = AdapterConfig() return nn.ModuleList([Adapter(hidden_size, adapter_config) for i in range(num_layers)]) elif config.layer2layer_bridges == 'model_parallelism': if hasattr(config, layer2layer_bridges_config) devices = config.layer2layer_bridges_config['devices'] else: devices = list(range(torch.cuda.device_count())) if if torch.cuda.is_available() else [] if not devices: return nn.ModuleList([Identity() for i in range(num_layers)]) layer_groups = num_layers // len(devices) modules = [] current_device_num = 0 for i in range(num_layers-1): if (i + 1) % layer_groups == 0: current_device_num += 1 if current_device_num > len(device): current_device_num = 0 modules.append(SwitchDevice(device[current_device_num])) else: modules.append(Identity()) modules.append(SwitchDevice(device[0])) # Back to first device at the end return nn.ModuleList(modules)
def __init__(self, config): super(SequenceSummary, self).__init__() self.summary_type = config.summary_type if hasattr( config, 'summary_use_proj') else 'last' if self.summary_type == 'attn': raise NotImplementedError self.summary = Identity() if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if hasattr( config, 'summary_proj_to_labels' ) and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) self.activation = Identity() if hasattr( config, 'summary_activation') and config.summary_activation == 'tanh': self.activation = nn.Tanh() self.first_dropout = Identity() if hasattr( config, 'summary_first_dropout') and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout)
def __new__(cls, n_outputs, *, batch_norm=True): layers = [ ('conv_0', Conv2d(4, 32, 8, 4, bias=not batch_norm)), ('bnorm0', BatchNorm2d(32, affine=True)) if batch_norm else None, ('relu_0', ReLU()), ('tap_0', Identity()), # nop tap for viewer # ('drop_1', Dropout(p=0.5)), ('conv_1', Conv2d(32, 64, 4, 2, bias=not batch_norm)), ('bnorm1', BatchNorm2d(64, affine=True)) if batch_norm else None, ('relu_1', ReLU()), ('tap_1', Identity()), # nop tap for viewer # ('drop_2', Dropout(p=0.2)), ('conv_2', Conv2d(64, 64, 3, 1, bias=not batch_norm)), ('bnorm2', BatchNorm2d(64, affine=True)) if batch_norm else None, ('relu_2', ReLU()), ('tap_2', Identity()), # nop tap for viewer ('flat_3', Flatten(1, -1)), ('drop_3', Dropout(p=0.5)), ('dense3', Linear(64 * 7 * 7, 512, bias=True)), ('relu_3', ReLU()), ('drop_4', Dropout(p=0.2)), ('dense4', Linear(512, n_outputs, bias=True)), ] # filter out `None`s and build a sequential network return Sequential(OrderedDict(list(filter(None, layers))))
def create_cutmix_net(): model = ResNet('imagenet', 101, 1000) model = torch.nn.DataParallel(model) checkpoint = torch.load(pretrained_path) model.load_state_dict(checkpoint['state_dict']) model.module.avgpool = Identity() model.module.fc = Identity() return model
def __init__(self, layers, activation, activation_last=None, batch_norm=False, initialize=True, input_shape=None, output_shape=None, *args, **kwargs): """ Args: layers (list): list of hidden layers activation (str): activation function for MLP activation_last (str): activation function for the MLP last layer batch_norm (bool): use batch normalization *args: Variable length argument list **kwargs: Arbitrary keyword arguments """ super(MLPBlock, self).__init__(*args, **kwargs) if input_shape is not None: layers = [input_shape] + layers if output_shape is not None: layers = layers + [output_layers] _layers = [] for i, node in enumerate(layers): if i == len(layers) - 1: break else: _layers.append(Linear(layers[i], layers[i + 1])) if batch_norm: _layers.append(BatchNorm1d(layers[i + 1])) if i == len(layers) - 2: if activation_last is None or activation_last == 'Identity': _layers.append(Identity()) else: _layers.append(getattr(act, activation_last)()) else: if activation == 'Identity': _layers.append(Identity()) else: _layers.append(getattr(act, activation)()) self._layers = Sequential(*_layers) if initialize: self.apply(self._init_weights)
def __init__(self, n_actions=4, n_channels=4): super().__init__() self.phi = Sequential( # f_32 k_3 s_2 p_1 Conv2d(n_channels, 32, 3, stride=2, padding=1, bias=True), ELU(), Conv2d(32, 32, 3, stride=2, padding=1, bias=True), ELU(), Conv2d(32, 32, 3, stride=2, padding=1, bias=True), ELU(), Conv2d(32, 32, 3, stride=2, padding=1, bias=True), ELU(), Identity(), # tap Flatten(1, -1), ) self.gee = torch.nn.Sequential( Linear(2 * 32 * 3 * 3, 256, bias=True), ReLU(), Linear(256, n_actions, bias=True), ) self.eff = torch.nn.Sequential( Linear(32 * 3 * 3 + n_actions, 256, bias=True), ReLU(), Linear(256, 32 * 3 * 3, bias=True), ) self.n_actions, self.n_emb_dim = n_actions, 32 * 3 * 3
def __init__(self, shape, list_zernike_ft, list_zernike_direct, padding_coeff = 0., deformation = 'single', features = None): # Here we define the type of Model we want to be using, the number of polynoms and if we want to implement a deformation. super(Aberration, self).__init__() #Check whether the model is given the lists of zernike polynoms to use or simply the total number to use if type(list_zernike_direct) not in [list, np.ndarray]: list_zernike_direct = range(0,list_zernike_direct) if type(list_zernike_ft) not in [list, np.ndarray]: list_zernike_ft = range(0,list_zernike_ft) self.nxy = shape # padding layer, to have a good FFT resolution # (requires to crop after IFFT) padding = int(padding_coeff*self.nxy) self.pad = ZeroPad2d(padding) # scaling x, y if deformation == 'single': self.deformation = ComplexDeformation() elif deformation == 'scaling': self.deformation = ComplexScaling() else: self.deformation = Identity() self.zernike_ft = Sequential(*(ComplexZernike(j=j + 1) for j in list_zernike_ft)) self.zernike_direct = Sequential(*(ComplexZernike(j=j + 1) for j in list_zernike_direct))
def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, dropout: float = 0.0, batch_norm: bool = True, relu_last: bool = False, ): super(MLP, self).__init__() self.lins = ModuleList() self.lins.append(Linear(in_channels, hidden_channels)) for _ in range(num_layers - 2): self.lins.append(Linear(hidden_channels, hidden_channels)) self.lins.append(Linear(hidden_channels, out_channels)) self.batch_norms = ModuleList() for _ in range(num_layers - 1): norm = BatchNorm1d(hidden_channels) if batch_norm else Identity() self.batch_norms.append(norm) self.dropout = dropout self.relu_last = relu_last
def __init__(self, layers_conv2d=None, initialize=True, *args, **kwargs): """ Args: layers_conv2d (list(tuple(str, dict))): configs of conv2d layer. list of tuple(op_name, op_args). *args: Variable length argument list **kwargs: Arbitrary keyword arguments """ super(Conv2DBlock, self).__init__(*args, **kwargs) from copy import copy _layers = [] conv2d_args = {"stride": 1, "padding": 0, "activation": 'ReLU'} maxpooling2d_args = {"kernel_size": 2, "stride": 2} for layer, args in layers_conv2d: if layer == 'conv2d': layer_args = copy(conv2d_args) layer_args.update(args) activation = layer_args.pop('activation') _layers.append(Conv2d(**layer_args)) if activation == 'Identity': _layers.append(Identity()) else: _layers.append(getattr(act, activation)()) elif layer == 'maxpooling2d': layer_args = copy(maxpooling2d_args) layer_args.update(args) _layers.append(MaxPool2d(**layer_args)) else: raise ValueError(f"{layer} is not implemented") self._layers = Sequential(*_layers) if initialize: self.apply(self._init_weights)
def test_amp_and_parallel_for_scalar_models( test_output_dirs: TestOutputDirectories, execution_mode: ModelExecutionMode, use_mixed_precision: bool) -> None: """ Tests the mix precision flag and data parallel for scalar models. """ assert machine_has_gpu, "This test must be executed on a GPU machine." assert torch.cuda.device_count( ) > 1, "This test must be executed on a multi-GPU machine" config = ClassificationModelForTesting() config.use_mixed_precision = use_mixed_precision model = DummyScalarModel( expected_image_size_zyx=config.expected_image_size_zyx, activation=Identity()) model.use_mixed_precision = use_mixed_precision model_and_info = ModelAndInfo(model=model, model_execution_mode=execution_mode) # This is the same logic spelt out in update_model_for_multiple_gpu # execution_mode == ModelExecutionMode.TRAIN or (not use_model_parallel), which is always True in our case use_data_parallel = True model_and_info = model_util.update_model_for_multiple_gpus( model_and_info, config) if use_data_parallel: assert isinstance(model_and_info.model, DataParallelModel) data_loaders = config.create_data_loaders() gradient_scaler = GradScaler() if use_mixed_precision else None train_val_parameters: TrainValidateParameters = TrainValidateParameters( model=model_and_info.model, data_loader=data_loaders[execution_mode], in_training_mode=execution_mode == ModelExecutionMode.TRAIN, gradient_scaler=gradient_scaler, dataframe_loggers=MetricsDataframeLoggers( Path(test_output_dirs.root_dir)), summary_writers=SummaryWriters(train=None, val=None) # type: ignore ) training_steps = ModelTrainingStepsForScalarModel(config, train_val_parameters) sample = list(data_loaders[execution_mode])[0] model_input = get_scalar_model_inputs_and_labels(config, model, sample) logits, posteriors, loss = training_steps._compute_model_output_and_loss( model_input) # When using DataParallel, we expect to get a list of tensors back, one per GPU. if use_data_parallel: assert isinstance(logits, list) first_logit = logits[0] else: first_logit = logits if use_mixed_precision: assert first_logit.dtype == torch.float16 assert posteriors.dtype == torch.float16 # BCEWithLogitsLoss outputs float32, even with float16 args assert loss.dtype == torch.float32 else: assert first_logit.dtype == torch.float32 assert posteriors.dtype == torch.float32 assert loss.dtype == torch.float32 # Verify that forward pass does not throw. It would for example if it fails to gather tensors or not convert # float16 to float32 _, _, _ = training_steps._compute_model_output_and_loss(model_input)
def __init__( self, input_num_filters: int, increase_dim: bool = False, projection: bool = False, last: bool = False, ): super().__init__() self.last: bool = last if increase_dim: first_stride = (2, 2) out_num_filters = input_num_filters * 2 else: first_stride = (1, 1) out_num_filters = input_num_filters self.direct = Sequential( conv3x3(input_num_filters, out_num_filters, stride=first_stride), batch_norm(out_num_filters), ReLU(True), conv3x3(out_num_filters, out_num_filters, stride=(1, 1)), batch_norm(out_num_filters), ) self.shortcut: Module # add shortcut connections if increase_dim: if projection: # projection shortcut, as option B in paper self.shortcut = Sequential( Conv2d( input_num_filters, out_num_filters, kernel_size=(1, 1), stride=(2, 2), bias=False, ), batch_norm(out_num_filters), ) else: # identity shortcut, as option A in paper self.shortcut = Sequential( IdentityShortcut(lambda x: x[:, :, ::2, ::2]), ConstantPad3d( ( 0, 0, 0, 0, out_num_filters // 4, out_num_filters // 4, ), 0.0, ), ) else: self.shortcut = Identity()
def get_activation(name): """Get activation function by name.""" return { 'relu': ReLU(), 'sigmoid': Sigmoid(), 'tanh': Tanh(), 'identity': Identity() }[name]
def get_activation(name): """Get activation function by name.""" return { "relu": ReLU(), "sigmoid": Sigmoid(), "tanh": Tanh(), "identity": Identity(), }[name]
def _init_model(self, in_features, num_classes: int) -> None: self.model = Sequential( SNNDense(in_features, 1024, dropout=False), SNNDense(1024, 256), SNNDense(256, 64), SNNDense(64, 16), SNNDense(16, num_classes, activation=Identity()), ).to(self.device)
def _get_features_net(self, args): if args.features_net == 'resnet': net = torchvision.models.resnet18(pretrained=True) net.fc = Identity() elif args.features_net == 'adm': net = FeatureExtractor(args.in_shape, args.feature_depth, args.dropout, False) return net
def __init__(self, num_classes, model='resnet101', pretrained=False): super().__init__() assert model in [ 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'densenet121', 'densenet169' ] if model == 'resnet18': self._model = models.resnet18(pretrained=pretrained) fc_in_features = 512 elif model == 'resnet34': self._model = models.resnet34(pretrained=pretrained) fc_in_features = 512 elif model == 'resnet50': self._model = models.resnet50(pretrained=pretrained) fc_in_features = 2048 elif model == 'resnet101': self._model = models.resnet101(pretrained=pretrained) fc_in_features = 2048 elif model == 'resnet152': self._model = models.resnet152(pretrained=pretrained) fc_in_features = 2048 elif model == 'densenet121': self._model = models.densenet121(pretrained=pretrained, drop_rate=0) fc_in_features = 1024 elif model == 'densenet169': self._model = models.densenet169(pretrained=pretrained, drop_rate=0) fc_in_features = 1664 else: assert False if 'resnet' in model: self._model.fc = Identity() elif 'densenet' in model: # densenet self._model.classifier = self._model.fc = Identity() self._fc = torch.nn.Linear(in_features=fc_in_features, out_features=num_classes) self.T = nn.Parameter(torch.tensor(1.0)) self.N = 25 self.p = 0.5
def __init__(self, models: list, last_layer: Module = Identity(), final_layer: Module = Sequential(AdaptiveAvgPool2d((1, 1)), Flatten())): super().__init__(len(models)) self.core = Sequential(*models) self.final = final_layer self.last = last_layer
def __init__(self, in_features: int = 1, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.l_rate = 1e-1 activation = Identity() layers = [ torch.nn.Linear(in_features=in_features, out_features=1, bias=True), activation ] self.model = torch.nn.Sequential(*layers) # type: ignore
def __init__(self, in_planes, out_planes, kernel_size, props, stride=None): """ This is norm nonlin conv norm nonlin conv :param in_planes: :param out_planes: :param props: :param override_stride: """ super().__init__() self.kernel_size = kernel_size props['conv_op_kwargs']['stride'] = 1 self.stride = stride self.props = props self.out_planes = out_planes self.in_planes = in_planes if stride is not None: kwargs_conv1 = deepcopy(props['conv_op_kwargs']) kwargs_conv1['stride'] = stride else: kwargs_conv1 = props['conv_op_kwargs'] self.norm1 = props['norm_op'](in_planes, **props['norm_op_kwargs']) self.nonlin1 = props['nonlin'](**props['nonlin_kwargs']) self.conv1 = props['conv_op'](in_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], **kwargs_conv1) if props['dropout_op_kwargs']['p'] != 0: self.dropout = props['dropout_op'](**props['dropout_op_kwargs']) else: self.dropout = Identity() self.norm2 = props['norm_op'](out_planes, **props['norm_op_kwargs']) self.nonlin2 = props['nonlin'](**props['nonlin_kwargs']) self.conv2 = props['conv_op'](out_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], **props['conv_op_kwargs']) if (self.stride is not None and any( (i != 1 for i in self.stride))) or (in_planes != out_planes): stride_here = stride if stride is not None else 1 self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False)) else: self.downsample_skip = None
def genActivation(actType: Union[str, None], params: dict = None): if actType is None: layer = Identity() else: if params is None: layer = getattr(Activations, actType)() else: layer = getattr(Activations, actType)(**params) return layer
def __init__(self, in_features: int = 1, *args, **kwargs) -> None: # type: ignore super().__init__(in_features=in_features, *args, **kwargs) # type: ignore self.l_rate = 1e-1 self.dataset_split = ModelExecutionMode.TRAIN activation = Identity() layers = [ torch.nn.Linear(in_features=in_features, out_features=1, bias=True), activation ] self.model = torch.nn.Sequential(*layers) # type: ignore
def __init__(self, dim_features, no_experts, dim_target, config): super().__init__(dim_features, 0, dim_target, config) self.no_experts = no_experts self.output_type = config['output_type'] self.hidden_units = config['expert_hidden_units'] self.dim_target = dim_target if config['aggregation'] == 'sum': self.aggregate = global_add_pool elif config['aggregation'] == 'mean': self.aggregate = global_mean_pool elif config['aggregation'] == 'max': self.aggregate = global_max_pool elif config['aggregation'] is None: # for node classification self.aggregate = None if 'binomial' in self.output_type: assert dim_target == 1, "Implementation works with a single dim regression problem for now" self.output_activation = Identity() # emulate bernoulli self.node_transform = Identity() self.final_transform = Linear(dim_features, self.no_experts * 2, bias=False) elif 'gaussian' in self.output_type: self.output_activation = Identity( ) # emulate gaussian (needs variance as well # Need independent parameters for the variance if self.hidden_units > 0: self.node_transform = Sequential( Linear(dim_features, self.hidden_units * self.no_experts * 2), ReLU()) self.final_transform = Linear( self.hidden_units * self.no_experts * 2, self.no_experts * 2 * dim_target) else: self.node_transform = Identity() self.final_transform = Linear(dim_features, self.no_experts * 2 * dim_target) else: raise NotImplementedError( f'Activation {self.output_type} unrecognized, use binomal, gaussian.' )
def __init__( self, channel_list: Optional[Union[List[int], int]] = None, *, in_channels: Optional[int] = None, hidden_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: Optional[int] = None, dropout: float = 0., act: str = "relu", batch_norm: bool = True, act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, batch_norm_kwargs: Optional[Dict[str, Any]] = None, plain_last: bool = True, bias: bool = True, relu_first: bool = False, ): super().__init__() act_first = act_first or relu_first # Backward compatibility. batch_norm_kwargs = batch_norm_kwargs or {} if isinstance(channel_list, int): in_channels = channel_list if in_channels is not None: assert num_layers >= 1 channel_list = [hidden_channels] * (num_layers - 1) channel_list = [in_channels] + channel_list + [out_channels] assert isinstance(channel_list, (tuple, list)) assert len(channel_list) >= 2 self.channel_list = channel_list self.dropout = dropout self.act = activation_resolver(act, **(act_kwargs or {})) self.act_first = act_first self.plain_last = plain_last self.lins = torch.nn.ModuleList() iterator = zip(channel_list[:-1], channel_list[1:]) for in_channels, out_channels in iterator: self.lins.append(Linear(in_channels, out_channels, bias=bias)) self.norms = torch.nn.ModuleList() iterator = channel_list[1:-1] if plain_last else channel_list[1:] for hidden_channels in iterator: if batch_norm: norm = BatchNorm1d(hidden_channels, **batch_norm_kwargs) else: norm = Identity() self.norms.append(norm) self.reset_parameters()
def __init__(self, C_in=1, C_hid=5, input_dim=(28, 28), output_dim=10): """Instantiate submodules that are used in the forward pass.""" super().__init__() self.conv1 = Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1) self.conv2 = Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1) self.linear1 = Linear(input_dim[0] * input_dim[1] * C_hid, output_dim) if C_in == C_hid: self.shortcut = Identity() else: self.shortcut = Conv2d(C_in, C_hid, kernel_size=1, stride=1)
def __init__( self, in_edge_feats: int, in_node_feats: int, in_global_feats: int, out_edge_feats: int, out_node_feats: int, out_global_feats: int, batch_norm: bool, dropout: float, ): super().__init__() in_feats = in_node_feats + in_edge_feats + in_global_feats self.edge_fn = Sequential( Linear(in_feats, out_edge_feats), Dropout(p=dropout) if dropout > 0 else Identity(), BatchNorm1d(out_edge_feats) if batch_norm else Identity(), ReLU(), ) in_feats = in_node_feats + out_edge_feats + in_global_feats self.node_fn = Sequential( Linear(in_feats, out_node_feats), Dropout(p=dropout) if dropout > 0 else Identity(), BatchNorm1d(out_node_feats) if batch_norm else Identity(), ReLU(), ) in_feats = out_node_feats + out_edge_feats + in_global_feats self.global_fn = Sequential( Linear(in_feats, out_global_feats), Dropout(p=dropout) if dropout > 0 else Identity(), BatchNorm1d(out_global_feats) if batch_norm else Identity(), ReLU(), )