def __init__(self, planes, norm_layer, **kwargs): """Initialize head layer.""" super(HeadLayer, self).__init__() self.cascaded = kwargs['cascaded'] self.time_bn = kwargs.get('time_bn', kwargs['cascaded']) inplanes = 3 if kwargs.get('imagenet', False): self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=7, stride=2, padding=3, bias=False) else: self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if self.cascaded: tdl_mode = kwargs.get('tdl_mode', 'OSD') self.tdline = tdl.setup_tdl_kernel(tdl_mode, kwargs)
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, **kwargs): """Initialize basic block.""" super(BasicBlock, self).__init__() self.cascaded = kwargs['cascaded'] self.time_bn = kwargs.get('time_bn', kwargs['cascaded']) self.downsample = downsample self.stride = stride # Setup ops self.conv1 = custom_ops.conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = custom_ops.conv3x3(planes, planes) self.bn2 = norm_layer(planes) # TDL if self.cascaded: tdl_mode = kwargs.get('tdl_mode', 'OSD') self.tdline = tdl.setup_tdl_kernel(tdl_mode, kwargs)
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, **kwargs): """Initialize bottleneck block.""" super(Bottleneck, self).__init__() base_width = 64 width = int(planes * (base_width / 64.)) self.downsample = downsample self.stride = stride self.cascaded = kwargs['cascaded'] self.time_bn = kwargs.get('time_bn', kwargs['cascaded']) self.conv1 = custom_ops.conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = custom_ops.conv3x3(width, width, stride) self.bn2 = norm_layer(width) self.conv3 = custom_ops.conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) if self.cascaded: tdl_mode = kwargs.get('tdl_mode', 'OSD') self.tdline = tdl.setup_tdl_kernel(tdl_mode, kwargs)