def forward(self, x, params=None, buffers=None, bn_training=None, do_training=None, no_grad_hint_mask=None): if params is None: params = [p for p in self.parameters()] if buffers is None: buffers = [p for p in utils.get_buffers(self)] if no_grad_hint_mask is None: no_grad_hint_mask = '0' * len([m for m in self.modules()]) if bn_training is None: bn_training = [self.training] * self.get_num_batchnorms() if do_training is None: do_training = self.training out = self.feat_extractor(x, params=params, buffers=buffers, bn_training=bn_training, do_training=do_training, no_grad_hint_mask=no_grad_hint_mask[1:]) return out.view(out.size(0), -1)
def forward(self, x_input, params=None, buffers=None, bn_training=None, do_training=None, no_grad_hint_mask=None): if params is None: params = [p for p in self.parameters()] if buffers is None: buffers = [p for p in get_buffers(self)] if no_grad_hint_mask is None: no_grad_hint_mask = '0' * len([m for m in self.modules()]) if bn_training is None: bn_training = [self.training] * self.get_num_batchnorms() if do_training is None: do_training = self.training zipped = list( zip(self.children(), self.unmerge_parameters(params), self.unmerge_buffers(buffers), self.unmerge_bn_training(bn_training), self.unmerge_no_grad_hint_mask(no_grad_hint_mask))) x = x_input for b, p, pb, bt, nghm in zipped: x = b(x, params=p, buffers=pb, bn_training=bt, do_training=do_training, no_grad_hint_mask=nghm) return x
def forward(self, x_input, params=None, buffers=None, bn_training=None, do_training=None, no_grad_hint_mask=None): if params is None: params = [p for p in self.parameters()] if buffers is None: buffers = [p for p in get_buffers(self)] if no_grad_hint_mask is None: no_grad_hint_mask = '0' * len([m for m in self.modules()]) if bn_training is None: bn_training = [self.training] * self.get_num_batchnorms() if do_training is None: do_training = self.training assert len(no_grad_hint_mask) == 1 with (NoopContext if no_grad_hint_mask == '0' else torch.no_grad)(): x = F.conv2d(x_input, self.get_hybrid_param(params, 'weight'), bias=(self.get_hybrid_param(params, 'bias') if self.has_hybrid_param('bias') else None), stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) assert len(bn_training) == 0 return x
def forward(self, x_input, params=None, buffers=None, bn_training=None, do_training=None, no_grad_hint_mask=None): if params is None: params = [p for p in self.parameters()] if buffers is None: buffers = [p for p in get_buffers(self)] if no_grad_hint_mask is None: no_grad_hint_mask = '0' * len([m for m in self.modules()]) if bn_training is None: bn_training = [self.training] * self.get_num_batchnorms() if do_training is None: do_training = self.training assert len(no_grad_hint_mask) == 1 with (NoopContext if no_grad_hint_mask == '0' else torch.no_grad)(): x = F.batch_norm(x_input, self.get_hybrid_buffer(buffers, 'running_mean'), self.get_hybrid_buffer(buffers, 'running_var'), weight=self.get_hybrid_param(params, 'weight'), bias=self.get_hybrid_param(params, 'bias'), training=bn_training.pop(0), momentum=self.momentum, eps=self.eps) assert len(bn_training) == 0 return x
def unmerge_buffers(self, buffers): children = self.children() if buffers is None: return [None] * len([b for b in children]) buffers = [p for p in buffers] nums_buffers = [len([p for p in get_buffers(b)]) for b in children] assert (sum(nums_buffers) == len(buffers)) unmerged_buffers = [] for n in nums_buffers: unmerged_buffers.append(buffers[:n]) buffers = buffers[n:] return unmerged_buffers
def forward(self, x, params=None, buffers=None, bn_training=None, do_training=None, no_grad_hint_mask=None): if params is None: params = [p for p in self.parameters()] if buffers is None: buffers = [p for p in utils.get_buffers(self)] if no_grad_hint_mask is None: no_grad_hint_mask = '0' * len([m for m in self.modules()]) if bn_training is None: bn_training = [self.training] * self.get_num_batchnorms() if do_training is None: do_training = self.training zipped = list( zip(self.children(), self.unmerge_parameters(params), self.unmerge_buffers(buffers), self.unmerge_bn_training(bn_training), self.unmerge_no_grad_hint_mask(no_grad_hint_mask))) assert len(zipped) == 2 b, p, pb, bt, nghm = zipped[0] assert b is self.conv_block conv_block_result = self.conv_block(x, params=p, buffers=pb, bn_training=bt, do_training=do_training, no_grad_hint_mask=nghm) b, p, pb, bt, nghm = zipped[1] assert b is self.skip_layer skip_layer_result = self.skip_layer(x, params=p, buffers=pb, bn_training=bt, do_training=do_training, no_grad_hint_mask=nghm) return skip_layer_result + conv_block_result