Example #1
0
    def forward(self,
                x,
                params=None,
                buffers=None,
                bn_training=None,
                do_training=None,
                no_grad_hint_mask=None):
        if params is None:
            params = [p for p in self.parameters()]
        if buffers is None:
            buffers = [p for p in utils.get_buffers(self)]
        if no_grad_hint_mask is None:
            no_grad_hint_mask = '0' * len([m for m in self.modules()])

        if bn_training is None:
            bn_training = [self.training] * self.get_num_batchnorms()
        if do_training is None:
            do_training = self.training

        out = self.feat_extractor(x,
                                  params=params,
                                  buffers=buffers,
                                  bn_training=bn_training,
                                  do_training=do_training,
                                  no_grad_hint_mask=no_grad_hint_mask[1:])
        return out.view(out.size(0), -1)
Example #2
0
    def forward(self,
                x_input,
                params=None,
                buffers=None,
                bn_training=None,
                do_training=None,
                no_grad_hint_mask=None):
        if params is None:
            params = [p for p in self.parameters()]
        if buffers is None:
            buffers = [p for p in get_buffers(self)]
        if no_grad_hint_mask is None:
            no_grad_hint_mask = '0' * len([m for m in self.modules()])

        if bn_training is None:
            bn_training = [self.training] * self.get_num_batchnorms()
        if do_training is None:
            do_training = self.training

        zipped = list(
            zip(self.children(), self.unmerge_parameters(params),
                self.unmerge_buffers(buffers),
                self.unmerge_bn_training(bn_training),
                self.unmerge_no_grad_hint_mask(no_grad_hint_mask)))

        x = x_input
        for b, p, pb, bt, nghm in zipped:
            x = b(x,
                  params=p,
                  buffers=pb,
                  bn_training=bt,
                  do_training=do_training,
                  no_grad_hint_mask=nghm)

        return x
Example #3
0
    def forward(self,
                x_input,
                params=None,
                buffers=None,
                bn_training=None,
                do_training=None,
                no_grad_hint_mask=None):
        if params is None:
            params = [p for p in self.parameters()]
        if buffers is None:
            buffers = [p for p in get_buffers(self)]
        if no_grad_hint_mask is None:
            no_grad_hint_mask = '0' * len([m for m in self.modules()])

        if bn_training is None:
            bn_training = [self.training] * self.get_num_batchnorms()
        if do_training is None:
            do_training = self.training

        assert len(no_grad_hint_mask) == 1
        with (NoopContext if no_grad_hint_mask == '0' else torch.no_grad)():
            x = F.conv2d(x_input,
                         self.get_hybrid_param(params, 'weight'),
                         bias=(self.get_hybrid_param(params, 'bias')
                               if self.has_hybrid_param('bias') else None),
                         stride=self.stride,
                         padding=self.padding,
                         dilation=self.dilation,
                         groups=self.groups)

        assert len(bn_training) == 0
        return x
Example #4
0
    def forward(self,
                x_input,
                params=None,
                buffers=None,
                bn_training=None,
                do_training=None,
                no_grad_hint_mask=None):
        if params is None:
            params = [p for p in self.parameters()]
        if buffers is None:
            buffers = [p for p in get_buffers(self)]
        if no_grad_hint_mask is None:
            no_grad_hint_mask = '0' * len([m for m in self.modules()])

        if bn_training is None:
            bn_training = [self.training] * self.get_num_batchnorms()
        if do_training is None:
            do_training = self.training

        assert len(no_grad_hint_mask) == 1
        with (NoopContext if no_grad_hint_mask == '0' else torch.no_grad)():
            x = F.batch_norm(x_input,
                             self.get_hybrid_buffer(buffers, 'running_mean'),
                             self.get_hybrid_buffer(buffers, 'running_var'),
                             weight=self.get_hybrid_param(params, 'weight'),
                             bias=self.get_hybrid_param(params, 'bias'),
                             training=bn_training.pop(0),
                             momentum=self.momentum,
                             eps=self.eps)

        assert len(bn_training) == 0
        return x
Example #5
0
 def unmerge_buffers(self, buffers):
     children = self.children()
     if buffers is None:
         return [None] * len([b for b in children])
     buffers = [p for p in buffers]
     nums_buffers = [len([p for p in get_buffers(b)]) for b in children]
     assert (sum(nums_buffers) == len(buffers))
     unmerged_buffers = []
     for n in nums_buffers:
         unmerged_buffers.append(buffers[:n])
         buffers = buffers[n:]
     return unmerged_buffers
Example #6
0
    def forward(self,
                x,
                params=None,
                buffers=None,
                bn_training=None,
                do_training=None,
                no_grad_hint_mask=None):
        if params is None:
            params = [p for p in self.parameters()]
        if buffers is None:
            buffers = [p for p in utils.get_buffers(self)]
        if no_grad_hint_mask is None:
            no_grad_hint_mask = '0' * len([m for m in self.modules()])

        if bn_training is None:
            bn_training = [self.training] * self.get_num_batchnorms()
        if do_training is None:
            do_training = self.training

        zipped = list(
            zip(self.children(), self.unmerge_parameters(params),
                self.unmerge_buffers(buffers),
                self.unmerge_bn_training(bn_training),
                self.unmerge_no_grad_hint_mask(no_grad_hint_mask)))
        assert len(zipped) == 2

        b, p, pb, bt, nghm = zipped[0]
        assert b is self.conv_block
        conv_block_result = self.conv_block(x,
                                            params=p,
                                            buffers=pb,
                                            bn_training=bt,
                                            do_training=do_training,
                                            no_grad_hint_mask=nghm)
        b, p, pb, bt, nghm = zipped[1]
        assert b is self.skip_layer
        skip_layer_result = self.skip_layer(x,
                                            params=p,
                                            buffers=pb,
                                            bn_training=bt,
                                            do_training=do_training,
                                            no_grad_hint_mask=nghm)
        return skip_layer_result + conv_block_result