Ejemplo n.º 1
0
 def forward(self, inputs, params=None):
     features = self.features(inputs,
                              params=get_subdict(params, 'features'))
     features = features.view((features.size(0), -1))
     logits = self.classifier(features,
                              params=get_subdict(params, 'classifier'))
     return logits
Ejemplo n.º 2
0
 def forward(self, inputs, pos_edge_index, neg_edge_index, params=None):
     total_edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
     node_a = torch.index_select(inputs, 0, total_edge_index[0])
     node_a = self.linear_a(node_a, params=get_subdict(params, 'linear_a'))
     node_b = torch.index_select(inputs, 0, total_edge_index[1])
     node_b = self.linear_a(node_b, params=get_subdict(params, 'linear_a'))
     x = torch.cat((node_a, node_b), 1)
     x = self.linear(x, params=get_subdict(params, 'linear'))
     x = torch.clamp(torch.sigmoid(x), min=1e-8, max=1 - 1e-8)
     return x
Ejemplo n.º 3
0
 def forward(self, x, params=None):
     x = self.layer1(x, params=get_subdict(params, 'layer1'))
     x = self.layer2(x, params=get_subdict(params, 'layer2'))
     x = self.layer3(x, params=get_subdict(params, 'layer3'))
     x = self.layer4(x, params=get_subdict(params, 'layer4'))
     if self.keep_avg_pool:
         x = self.avgpool(x)
     features = x.view((x.size(0), -1))
     logits = self.classifier(self.dropout(features),
                              params=get_subdict(params, 'classifier'))
     return features, logits
Ejemplo n.º 4
0
    def forward(self, model_input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())
        if self.use_meta:
            # Enables us to compute gradients w.r.t. coordinates
            coords_org = model_input['coords'].clone().detach().requires_grad_(
                True)
            coords = coords_org
        else:
            coords_org = model_input
            coords = coords_org

        # various input processing methods for different applications
        encoding = torch.nn.Identity()
        if self.mode == 'rbf':
            encoding1x = self.rbf_layer
            encoding2x = self.rbf_layer
            encoding4x = self.rbf_layer
        elif self.mode in ['nerf', 'positional', 'gauss']:
            if self.ff_dims is None:
                encoding1x = self.positional_encoding
                encoding2x = self.positional_encoding
                encoding4x = self.positional_encoding
            else:
                encoding1x = self.positional_encoding1x
                encoding2x = self.positional_encoding2x
                encoding4x = self.positional_encoding4x
        input4x = encoding4x(self.image_downsampling4x(coords))
        if self.use_meta:
            output4x = self.net4x(input4x, get_subdict(params, 'net4x'))
        else:
            output4x = self.net4x(input4x)

        input2x = torch.cat(
            [encoding2x(self.image_downsampling2x(coords)), output4x], axis=-1)

        if self.use_meta:
            output2x = self.net2x(input2x, get_subdict(params, 'net2x'))
        else:
            output2x = self.net2x(input2x)

        input1x = torch.cat([encoding1x(coords), output2x], axis=-1)
        if self.use_meta:
            output = self.net1x(input1x, get_subdict(params, 'net1x'))
            res = {'model_in': coords_org, 'model_out': output}
        else:
            output = self.net1x(input1x)
            res = output

        return res
Ejemplo n.º 5
0
 def forward(self, coords, params=None, **kwargs):
     if params is None:
         params = OrderedDict(self.named_parameters())
     coords = coords.squeeze()
     output = self.net(coords, params=get_subdict(params, 'net'))
     output = output.unsqueeze(0)
     return output
Ejemplo n.º 6
0
    def forward(self, model_input, params=None):

        # Enables us to compute gradients w.r.t. coordinates
        coords_org = model_input['coords'].requires_grad_(True)
        coords = coords_org

        # various input processing methods for different applications
        output = self.net(coords_org, get_subdict(params, 'net'))
        return {'model_in': coords_org, 'model_out': output}
Ejemplo n.º 7
0
    def forward_with_activations(self, coords, params=None, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.'''
        if params is None:
            params = OrderedDict(self.named_parameters())

        activations = OrderedDict()

        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            subdict = get_subdict(params, 'net.%d' % i)
            for j, sublayer in enumerate(layer):
                if isinstance(sublayer, BatchLinear):
                    x = sublayer(x, params=get_subdict(subdict, '%d' % j))
                else:
                    x = sublayer(x)

                if retain_grad:
                    x.retain_grad()
                activations['_'.join((str(sublayer.__class__), "%d" % i))] = x
        return activations
Ejemplo n.º 8
0
 def forward(self, input, params=None):
     for name, module in self._modules.items():
         if isinstance(module, MetaModule):
             input = module(input, params=get_subdict(params, name))
         elif isinstance(module, nn.Module):
             input = module(input)
         else:
             raise TypeError(
                 'The module must be either a torch module '
                 '(inheriting from `nn.Module`), or a `MetaModule`. '
                 'Got type: `{0}`'.format(type(module)))
     return input
Ejemplo n.º 9
0
    def forward(self, x, params=None):
        self.num_batches_tracked += 1

        out = self.conv1(x, params=get_subdict(params, 'conv1'))
        out = self.bn1(out, params=get_subdict(params, 'bn1'))
        out = self.relu1(out)

        out = self.conv2(out, params=get_subdict(params, 'conv2'))
        out = self.bn2(out, params=get_subdict(params, 'bn2'))
        out = self.relu2(out)

        out = self.conv3(out, params=get_subdict(params, 'conv3'))
        out = self.bn3(out, params=get_subdict(params, 'bn3'))
        out = self.relu3(out)

        out = self.maxpool(out)

        if self.drop_rate > 0:
            if self.drop_block == True:
                feat_size = out.size()[2]
                keep_rate = max(
                    1.0 - self.drop_rate / (20 * 2000) *
                    (self.num_batches_tracked), 1.0 - self.drop_rate)
                gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (
                    feat_size - self.block_size + 1)**2
                out = self.DropBlock(out, gamma=gamma)
            else:
                out = F.dropout(out,
                                p=self.drop_rate,
                                training=self.training,
                                inplace=True)

        return out
    def _forward(self, x, params=None):
        if params is None:
            p = OrderedDict(self.named_parameters())
        else:
            p = get_subdict(params, self.w.id)
            if len(p) == 0:
                p = OrderedDict(self.named_parameters())

        output = torch.matmul(x[0], p['w'].transpose(-1, -2))
        if self.b is not None:
            output += p['b'].unsqueeze(0).expand_as(output)

        return output
Ejemplo n.º 11
0
    def forward(self, model_input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        # Enables us to compute gradients w.r.t. coordinates
        coords_org = model_input['coords'].clone().detach().requires_grad_(True)
        coords = coords_org

        # various input processing methods for different applications
        if self.image_downsampling.downsample:
            coords = self.image_downsampling(coords)
        if self.mode == 'rbf':
            coords = self.rbf_layer(coords)
        elif self.mode == 'nerf':
            coords = self.positional_encoding(coords)

        output = self.net(coords, get_subdict(params, 'net'))
        return {'model_in': coords_org, 'model_out': output}
Ejemplo n.º 12
0
    def forward(self, coords, params=None, **kwargs):
        if params is not None:
            params = get_subdict(params, 'net')

        output = self.net(coords, params=params)
        return output
Ejemplo n.º 13
0
 def forward(self, inputs, batch, params=None):
     x = self.linear1(inputs, params=get_subdict(params, 'linear1'))
     x = F.relu(x)
     x = global_add_pool(x, batch)
     x = self.linear2(x, params=get_subdict(params, 'linear2'))
     return x
Ejemplo n.º 14
0
 def forward(self, inputs, params=None):
     x = self.linear(inputs, params=get_subdict(params, 'linear'))
     return x