Beispiel #1
0
    def __init__(self,
                 num_classes,
                 image_dim,
                 transforms,
                 patch_size,
                 model=None,
                 apply_transforms=True,
                 init='rand'):
        super().__init__()
        if init == 'const':
            initialization = ch.zeros(num_classes, 3, image_dim,
                                      image_dim) + 0.5
        elif init == 'rand':
            initialization = ch.normal(mean=0.5,
                                       std=0.1,
                                       size=(num_classes, 3, image_dim,
                                             image_dim))
        else:
            raise ValueError('Unknown initialization for the booster.')

        self.patches = nn.Parameter(initialization, requires_grad=True)
        self.patch_size = patch_size
        self.image_dim = image_dim
        self.aff_transformer = RandomAffine(**transforms,
                                            return_transform=True).cuda()
        self.apply_transforms = apply_transforms
def get_transform():

    transform = nn.Sequential(
        ZeroPad2d(150),
        RandomAffine(degrees=(-20, 20),
                     translate=(0.25, 0.25),
                     scale=(1.1, 1.5)),
    )

    def transform_fn(image, batch_size):
        b_image = image.repeat(batch_size, 1, 1, 1)
        return transform(b_image)

    return transform_fn
Beispiel #3
0
    def __init__(self, num_classes, tex_size, image_size, 
                 batch_size, *, num_gpus, num_texcoords,
                 render_options, forward_render=False, init='const',
                 debug=False, custom_file=None, corruptions=None):
        super().__init__()
        if init == 'const':
           initialization = ch.zeros(num_classes, 3, tex_size, tex_size) + 0.5
        elif init == 'rand':
            initialization = ch.normal(mean=0.5, std=0.1, size=(num_classes, 3, tex_size, tex_size))
        else:
            raise ValueError('Unknown initialization for the booster.') 

        # ctx = get_context('spawn')
        ctx = get_context('fork')
        self.in_q, self.out_q = ctx.Queue(), ctx.Queue()

        effective_batch = (batch_size // num_gpus + 1) * num_gpus 
        self.dones_sh_mem = ch.zeros(effective_batch).bool().share_memory_()
        self.texture_sh_mem = ch.zeros(effective_batch, tex_size*tex_size*3).share_memory_()
        self.render_sh_mem = ch.zeros(effective_batch, image_size, image_size, 4).share_memory_()
        self.uv_map_sh_mem = ch.zeros(effective_batch, image_size, image_size, 4).share_memory_()
        render_info = {
            "image_size": image_size,
            "samples": render_options['samples'],
            "scale_range": (render_options['min_zoom'], render_options['max_zoom']),
            "light_range": (render_options['min_light'], render_options['max_light'])
        }
        args = (SCENE_DICT(custom_file), render_info, self.in_q, self.out_q, 
            self.dones_sh_mem, self.texture_sh_mem, self.render_sh_mem, self.uv_map_sh_mem)
        gpus_to_use = cycle(['0', '1', '2', '3', '4']) # TODO: fix this
        Image.fromarray(np.zeros((tex_size, tex_size, 3)).astype(np.uint8)).save('/base_texture.jpg')
        for _ in range(num_texcoords):
            os.environ['CUDA_VISIBLE_DEVICES'] = next(gpus_to_use)
            p = ctx.Process(target=render.render, args=args)
            p.start()

        self.num_texcoords = num_texcoords
        self.textures = nn.Parameter(initialization, requires_grad=True)
        self.translate_tx = RandomAffine(degrees=0., p=1.0,
            translate=(0.4, 0.4), return_transform=True)
        self.tex_size = tex_size
        self.image_size = image_size
        self.batch_size = batch_size
        self.num_gpus = num_gpus
        self.debug = debug
        self.render_forward = forward_render
        self.corruptions = corruptions
Beispiel #4
0
    def __init__(self,
                 num_classes,
                 image_dim,
                 transforms,
                 patch_size,
                 apply_transforms=True,
                 detector='pyzbar'):
        super().__init__()
        to_tensor = tvt.ToTensor()
        self.detector = detector
        if detector == 'cv2':
            self.qrCodeDetector = cv2.QRCodeDetector()

        self.patches = [to_tensor(qrcode.make(str(i), border=1).convert('RGB').resize((patch_size, patch_size), Image.NEAREST)) \
                            for i in range(num_classes)]
        self.patches = ch.stack(self.patches, 0)
        self.patches = pad_tensor(self.patches, image_dim).cuda()

        for i, p in enumerate(self.patches):
            qrcode_image = 255 * np.uint8(p.permute(1, 2, 0).cpu().numpy())
            if detector == 'cv2':
                ## opencv qr detector is really crappy
                decoded_class, _, _ = self.qrCodeDetector.detectAndDecode(
                    qrcode_image)
            elif detector == 'pyzbar':
                decoded_class = pyzbar.decode(qrcode_image)[0].data.decode(
                    "utf-8")
            else:
                raise Exception('Unknown QRCode detector')
            assert decoded_class == str(i), f'The decoded QR code for class {i} is wrong.' \
                                        'Probably something weird happend during processing.'

        self.patch_size = patch_size
        self.image_dim = image_dim
        self.aff_transformer = RandomAffine(**transforms,
                                            return_transform=True,
                                            resample='nearest').cuda()
        self.apply_transforms = apply_transforms
 def rotate(self, degree: float, p: float = 1.0) -> TransformType:
     return RandomAffine(degrees=(degree, degree), p=p)
Beispiel #6
0
    def test_param(self, degrees, translate, scale, shear, resample,
                   align_corners, return_transform, same_on_batch, device,
                   dtype):

        _degrees = degrees if isinstance(degrees, (int, float, list, tuple)) else \
            nn.Parameter(degrees.clone().to(device=device, dtype=dtype))
        _translate = translate if isinstance(translate, (int, float, list, tuple)) else \
            nn.Parameter(translate.clone().to(device=device, dtype=dtype))
        _scale = scale if isinstance(scale, (int, float, list, tuple)) else \
            nn.Parameter(scale.clone().to(device=device, dtype=dtype))
        _shear = shear if isinstance(shear, (int, float, list, tuple)) else \
            nn.Parameter(shear.clone().to(device=device, dtype=dtype))

        torch.manual_seed(0)
        input = torch.randint(255, (2, 3, 10, 10), device=device,
                              dtype=dtype) / 255.
        aug = RandomAffine(_degrees,
                           _translate,
                           _scale,
                           _shear,
                           resample,
                           align_corners=align_corners,
                           return_transform=return_transform,
                           same_on_batch=same_on_batch,
                           p=1.)

        if return_transform:
            output, _ = aug(input)
        else:
            output = aug(input)

        if len(list(aug.parameters())) != 0:
            mse = nn.MSELoss()
            opt = torch.optim.SGD(aug.parameters(), lr=10)
            loss = mse(output, torch.ones_like(output) * 2)
            loss.backward()
            opt.step()

        if not isinstance(degrees, (int, float, list, tuple)):
            assert isinstance(aug.degrees, torch.Tensor)
            # Assert if param not updated
            if resample == 'nearest' and aug.degrees.is_cuda:
                # grid_sample in nearest mode and cuda device returns nan than 0
                pass
            elif resample == 'nearest' or torch.all(aug.degrees._grad == 0.):
                # grid_sample will return grad = 0 for resample nearest
                # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894
                assert (degrees.to(device=device, dtype=dtype) -
                        aug.degrees.data).sum() == 0
            else:
                assert (degrees.to(device=device, dtype=dtype) -
                        aug.degrees.data).sum() != 0
        if not isinstance(translate, (int, float, list, tuple)):
            assert isinstance(aug.translate, torch.Tensor)
            # Assert if param not updated
            if resample == 'nearest' and aug.translate.is_cuda:
                # grid_sample in nearest mode and cuda device returns nan than 0
                pass
            elif resample == 'nearest' or torch.all(aug.translate._grad == 0.):
                # grid_sample will return grad = 0 for resample nearest
                # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894
                assert (translate.to(device=device, dtype=dtype) -
                        aug.translate.data).sum() == 0
            else:
                assert (translate.to(device=device, dtype=dtype) -
                        aug.translate.data).sum() != 0
        if not isinstance(scale, (int, float, list, tuple)):
            assert isinstance(aug.scale, torch.Tensor)
            # Assert if param not updated
            if resample == 'nearest' and aug.scale.is_cuda:
                # grid_sample in nearest mode and cuda device returns nan than 0
                pass
            elif resample == 'nearest' or torch.all(aug.scale._grad == 0.):
                # grid_sample will return grad = 0 for resample nearest
                # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894
                assert (scale.to(device=device, dtype=dtype) -
                        aug.scale.data).sum() == 0
            else:
                assert (scale.to(device=device, dtype=dtype) -
                        aug.scale.data).sum() != 0
        if not isinstance(shear, (int, float, list, tuple)):
            assert isinstance(aug.shear, torch.Tensor)
            # Assert if param not updated
            if resample == 'nearest' and aug.shear.is_cuda:
                # grid_sample in nearest mode and cuda device returns nan than 0
                pass
            elif resample == 'nearest' or torch.all(aug.shear._grad == 0.):
                # grid_sample will return grad = 0 for resample nearest
                # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894
                assert (shear.to(device=device, dtype=dtype) -
                        aug.shear.data).sum() == 0
            else:
                assert (shear.to(device=device, dtype=dtype) -
                        aug.shear.data).sum() != 0
Beispiel #7
0
    def compute_point_loss(self, loss_name, images, logits, points):
        if loss_name == 'toponet':
            if self.first_time:
                self.first_time = False
                self.vgg = nn.DataParallel(lanenet.VGG().cuda(1), list(range(1,4)))
                self.vgg.train()
            points = points[:,None]
            images_flip = flips.Hflip()(images)
            logits_flip = self.model_base(images_flip)
            loss = torch.mean(torch.abs(flips.Hflip()(logits_flip)-logits))
            
            logits_flip_vgg = self.vgg(F.sigmoid(logits_flip.cuda(1)))
            logits_vgg = self.vgg(F.sigmoid(logits.cuda(1))) 
            loss += self.exp_dict["model"]["loss_weight"] * torch.mean(torch.abs(flips.Hflip()(logits_flip_vgg)-logits_vgg)).cuda(0)

            ind = points!=255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        elif loss_name == 'multiscale_cons_point_loss':
            logits, features = logits
            points = points[:,None]
            ind = points!=255
            # self.vis_on_batch(batch, savedir_image='tmp.png')

            # POINT LOSS
            # loss = ut.joint_loss(logits, points[:,None].float().cuda(), ignore_index=255)
            # print(points[ind].sum())
            logits_flip, features_flip = self.model_base(flips.Hflip()(images), return_features=True)

            loss = torch.mean(torch.abs(flips.Hflip()(logits_flip)-logits))
            for f, f_flip in zip(features, features_flip):
                loss += torch.mean(torch.abs(flips.Hflip()(f_flip)-f)) * self.exp_dict["model"]["loss_weight"]
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                # logits_flip = self.model_base(flips.Hflip()(images))
                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                # if 1:
                #     pf = points_flip.clone()
                #     pf[pf==1] = 2
                #     pf[pf==0] = 1
                #     pf[pf==255] = 0
                #     lcfcn_loss.save_tmp('tmp.png', flips.Hflip()(images[[0]]), logits_flip[[0]], 3, pf[[0]])
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        elif loss_name == "affine_cons_point_loss":
            points = points[:,None]
            if np.random.randint(2) == 0:
                images_flip = flips.Hflip()(images)
                flipped = True
            else:
                images_flip = images
                flipped = False
            batch_size, C, height, width = logits.shape
            random_affine = RandomAffine(degrees=2, translate=None, scale=(0.85, 1), shear=[-2, 2], return_transform=True)
            images_aff, transform = random_affine(images_flip)
            logits_aff = self.model_base(images_aff)
            
            # hu.save_image('tmp1.png', images_aff[0])
            itransform = transform.inverse()
            logits_aff = kornia.geometry.transform.warp_affine(logits_aff, itransform[:,:2, :], dsize=(height, width))
            if flipped:
                logits_aff = flips.Hflip()(logits_aff)
            # hu.save_image('tmp2.png', images_recovered[0])

            
            # logits_flip = self.model_base(flips.Hflip()(images))

            loss = torch.mean(torch.abs(logits_aff-logits))
            points_aff = kornia.geometry.transform.warp_affine(points.float(), itransform[:,:2, :], dsize=(height, width), flags="nearest").long()
            # points_aff = points
            ind = points!=255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')
                if flipped:
                    points_aff = flips.Hflip()(points_aff)
                # logits_flip =  self.model_base(flips.Hflip()(images))
                ind = points_aff!=255
                loss += F.binary_cross_entropy_with_logits(logits_aff[ind], 
                                        points_aff[ind].float().cuda(), 
                                        reduction='mean')
        elif loss_name == "elastic_cons_point_loss":
            points = points[:,None]
            ind = points!=255
            # self.vis_on_batch(batch, savedir_image='tmp.png')

            # POINT LOSS
            # loss = ut.joint_loss(logits, points[:,None].float().cuda(), ignore_index=255)
            # print(points[ind].sum())
            B, C, H, W = images.shape
            # ELASTIC TRANSFORM
            def norm_grid(grid):
                grid -= grid.min()
                grid /= grid.max()
                grid = (grid - 0.5) * 2
                return grid
            grid_x, grid_y = torch.meshgrid(torch.arange(H), torch.arange(W))
            grid_x = grid_x.float().cuda()
            grid_y = grid_y.float().cuda()
            sigma=4
            alpha=34
            indices = torch.stack([grid_y, grid_x], -1).view(1, H, W, 2).expand(B, H, W, 2).contiguous()
            indices = norm_grid(indices)
            dx = gaussian_filter((np.random.rand(H, W) * 2 - 1), sigma, mode="constant", cval=0) * alpha
            dy = gaussian_filter((np.random.rand(H, W) * 2 - 1), sigma, mode="constant", cval=0) * alpha
            dx = torch.from_numpy(dx).cuda().float()
            dy = torch.from_numpy(dy).cuda().float()
            dgrid_x = grid_x + dx
            dgrid_y = grid_y + dy
            dgrid_y = norm_grid(dgrid_y)
            dgrid_x = norm_grid(dgrid_x)
            dindices = torch.stack([dgrid_y, dgrid_x], -1).view(1, H, W, 2).expand(B, H, W, 2).contiguous()
            dindices0 = dindices.permute(0, 3, 1, 2).contiguous().view(B*2, H, W)
            indices0 = indices.permute(0, 3, 1, 2).contiguous().view(B*2, H, W)
            iindices = torch.bmm(indices0, dindices0.pinverse()).view(B, 2, H, W).permute(0, 2, 3, 1)
            # indices_im = indices.permute(0, 3, 1, 2)
            # iindices = F.grid_sample(indices_im, dindices).permute(0, 2, 3, 1)
            aug = F.grid_sample(images, dindices)
            iaug = F.grid_sample(aug,iindices)


            # logits_aff = self.model_base(images_aff)
            # inv_transform = transform.inverse()
            
            import pylab
            def save_im(image, name):
                _images_aff = image.data.cpu().numpy()
                _images_aff -= _images_aff.min()
                _images_aff /= _images_aff.max()
                _images_aff *= 255
                _images_aff = _images_aff.transpose((1,2,0))
                pylab.imsave(name, _images_aff.astype('uint8'))
            save_im(aug[0], 'tmp1.png')
            save_im(iaug[0], 'tmp2.png')
            pass


        elif loss_name == 'lcfcn_loss':
            loss = 0.
  
            for lg, pt in zip(logits, points):
                loss += lcfcn_loss.compute_loss((pt==1).long(), lg.sigmoid())

                # loss += lcfcn_loss.compute_binary_lcfcn_loss(l[None], 
                #         p[None].long().cuda())

        elif loss_name == 'point_loss':
            points = points[:,None]
            ind = points!=255
            # self.vis_on_batch(batch, savedir_image='tmp.png')

            # POINT LOSS
            # loss = ut.joint_loss(logits, points[:,None].float().cuda(), ignore_index=255)
            # print(points[ind].sum())
            if ind.sum() == 0:
                loss = 0.
            else:
                loss = F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')
                                        
            # print(points[ind].sum().item(), float(loss))
        elif loss_name == 'att_point_loss':
            points = points[:,None]
            ind = points!=255

            loss = 0.
            if ind.sum() != 0:
                loss = F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                logits_flip = self.model_base(flips.Hflip()(images))
                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        elif loss_name == 'cons_point_loss':
            points = points[:,None]
            
            logits_flip = self.model_base(flips.Hflip()(images))
            loss = torch.mean(torch.abs(flips.Hflip()(logits_flip)-logits))
            
            ind = points!=255
            if ind.sum() != 0:
                loss += F.binary_cross_entropy_with_logits(logits[ind], 
                                        points[ind].float().cuda(), 
                                        reduction='mean')

                points_flip = flips.Hflip()(points)
                ind = points_flip!=255
                loss += F.binary_cross_entropy_with_logits(logits_flip[ind], 
                                        points_flip[ind].float().cuda(), 
                                        reduction='mean')

        return loss 
Beispiel #8
0
    def __init__(
            self,
            image_shape,
            output_size,
            n_atoms,
            dueling,
            jumps,
            spr,
            augmentation,
            target_augmentation,
            eval_augmentation,
            dynamics_blocks,
            norm_type,
            noisy_nets,
            aug_prob,
            classifier,
            imagesize,
            time_offset,
            local_spr,
            global_spr,
            momentum_encoder,
            shared_encoder,
            distributional,
            dqn_hidden_size,
            momentum_tau,
            renormalize,
            renormalize_type,
            q_l1_type,
            dropout,
            final_classifier,
            model_rl,
            noisy_nets_std,
            residual_tm,
            pred_hidden_ratio,
            encoder_type,
            transition_type,
            conv_proj_channel,
            proj_hidden_size,
            gru_input_size,
            gru_proj_size,
            ln_ratio,
            use_maxpool=False,
            channels=None,  # None uses default.
            kernel_sizes=None,
            strides=None,
            paddings=None,
            framestack=4,
    ):
        """Instantiates the neural network according to arguments; network defaults
        stored within this method."""
        super().__init__()
        self.noisy = noisy_nets
        self.time_offset = time_offset
        self.aug_prob = aug_prob
        self.classifier_type = classifier

        self.distributional = distributional
        n_atoms = 1 if not self.distributional else n_atoms
        self.dqn_hidden_size = dqn_hidden_size

        self.transforms = []
        self.eval_transforms = []

        self.uses_augmentation = False
        for aug in augmentation:
            if aug == "affine":
                transformation = RandomAffine(5, (.14, .14), (.9, 1.1), (-5, 5))
                eval_transformation = nn.Identity()
                self.uses_augmentation = True
            elif aug == "crop":
                transformation = RandomCrop((84, 84))
                # Crashes if aug-prob not 1: use CenterCrop((84, 84)) or Resize((84, 84)) in that case.
                eval_transformation = CenterCrop((84, 84))
                self.uses_augmentation = True
                imagesize = 84
            elif aug == "rrc":
                transformation = RandomResizedCrop((100, 100), (0.8, 1))
                eval_transformation = nn.Identity()
                self.uses_augmentation = True
            elif aug == "blur":
                transformation = GaussianBlur2d((5, 5), (1.5, 1.5))
                eval_transformation = nn.Identity()
                self.uses_augmentation = True
            elif aug == "shift":
                transformation = nn.Sequential(nn.ReplicationPad2d(4), RandomCrop((84, 84)))
                eval_transformation = nn.Identity()
            elif aug == "intensity":
                transformation = Intensity(scale=0.05)
                eval_transformation = nn.Identity()
            elif aug == "none":
                transformation = eval_transformation = nn.Identity()
            else:
                raise NotImplementedError()
            self.transforms.append(transformation)
            self.eval_transforms.append(eval_transformation)

        self.dueling = dueling
        f, c = image_shape[:2]
        in_channels = np.prod(image_shape[:2])

        if encoder_type == 'conv2d':
            self.conv = Conv2dModel(
                in_channels=in_channels,
                channels=[32, 64, 64],
                kernel_sizes=[8, 4, 3],
                strides=[4, 2, 1],
                paddings=[0, 0, 0],
                use_maxpool=False,
                dropout=dropout,
                conv_proj_channel=conv_proj_channel,
            )
        elif encoder_type == 'resnet18':
            self.conv = resnet18()
        else:
            raise NotImplementedError

        fake_input = torch.zeros(1, f*c, imagesize, imagesize)
        fake_output = self.conv(fake_input)


        self.hidden_size = fake_output.shape[1]
        self.pixels = fake_output.shape[-1]*fake_output.shape[-2]
        print("Spatial latent size is {}".format(fake_output.shape[1:]))

        if proj_hidden_size:
            self.conv_proj = nn.Sequential(
                nn.Flatten(1, -1),
                nn.Linear(self.hidden_size * self.pixels, proj_hidden_size),
                nn.LayerNorm(proj_hidden_size),
                nn.ReLU(),
                nn.Dropout(dropout),
            )
        else:
            self.conv_proj = nn.Identity()

        self.jumps = jumps
        self.model_rl = model_rl
        self.use_spr = spr
        self.target_augmentation = target_augmentation
        self.eval_augmentation = eval_augmentation
        self.num_actions = output_size
        self.transition_type = transition_type

        if dueling:
            self.head = DQNDistributionalDuelingHeadModel(self.hidden_size,
                                                          output_size,
                                                          hidden_size=self.dqn_hidden_size,
                                                          pixels=self.pixels,
                                                          noisy=self.noisy,
                                                          n_atoms=n_atoms,
                                                          std_init=noisy_nets_std,
                                                          proj_hidden_size=proj_hidden_size)
        else:
            self.head = DQNDistributionalHeadModel(self.hidden_size,
                                                   output_size,
                                                   hidden_size=self.dqn_hidden_size,
                                                   pixels=self.pixels,
                                                   noisy=self.noisy,
                                                   n_atoms=n_atoms,
                                                   std_init=noisy_nets_std)

        if self.jumps > 0:
            repr_size = proj_hidden_size if proj_hidden_size else (self.pixels * self.hidden_size)

            if transition_type == 'gru':
                self.dynamics_model = GRUModel(
                    input_size = gru_input_size,
                    repr_size = repr_size,
                    proj_size = gru_proj_size,
                    num_layers = 1,
                    num_actions = self.num_actions,
                    renormalize=renormalize,
                    renormalize_type=renormalize_type,
                    dropout=dropout
                )
            else:
                self.dynamics_model = TransitionModel(channels=self.hidden_size,
                                                      num_actions=output_size,
                                                      pixels=self.pixels,
                                                      hidden_size=self.hidden_size,
                                                      limit=1,
                                                      blocks=dynamics_blocks,
                                                      norm_type=norm_type,
                                                      renormalize=renormalize,
                                                      residual=residual_tm)
        else:
            self.dynamics_model = nn.Identity()

        self.renormalize = renormalize
        self.renormalize_type = renormalize_type
        self.ln_ratio = ln_ratio

        if renormalize_type == 'train_ln':
            self.renormalize_ln = nn.LayerNorm(repr_size)
        else:
            self.renormalize_ln = nn.Identity()

        if self.use_spr:
            self.local_spr = local_spr
            self.global_spr = global_spr
            self.momentum_encoder = momentum_encoder
            self.momentum_tau = momentum_tau
            self.shared_encoder = shared_encoder
            assert not (self.shared_encoder and self.momentum_encoder)

            # in case someone tries something silly like --local-spr 2
            self.num_sprs = int(bool(self.local_spr)) + \
                            int(bool(self.global_spr))

            if self.local_spr:
                self.local_final_classifier = nn.Identity()
                if self.classifier_type == "mlp":
                    self.local_classifier = nn.Sequential(nn.Linear(self.hidden_size,
                                                                    self.hidden_size),
                                                          nn.BatchNorm1d(self.hidden_size),
                                                          nn.ReLU(),
                                                          nn.Linear(self.hidden_size,
                                                                    self.hidden_size))
                elif self.classifier_type == "bilinear":
                    self.local_classifier = nn.Linear(self.hidden_size, self.hidden_size)
                elif self.classifier_type == "none":
                    self.local_classifier = nn.Identity()
                if final_classifier == "mlp":
                    self.local_final_classifier = nn.Sequential(nn.Linear(self.hidden_size, 2*self.hidden_size),
                                                                nn.BatchNorm1d(2*self.hidden_size),
                                                                nn.ReLU(),
                                                                nn.Linear(2*self.hidden_size,
                                                                    self.hidden_size))
                elif final_classifier == "linear":
                    self.local_final_classifier = nn.Linear(self.hidden_size, self.hidden_size)
                else:
                    self.local_final_classifier = nn.Identity()

                self.local_target_classifier = self.local_classifier
            else:
                self.local_classifier = self.local_target_classifier = nn.Identity()
            if self.global_spr:
                self.global_final_classifier = nn.Identity()
                if self.classifier_type == "mlp":
                    self.global_classifier = nn.Sequential(
                                                nn.Flatten(-3, -1),
                                                nn.Linear(self.pixels*self.hidden_size, 512),
                                                nn.BatchNorm1d(512),
                                                nn.ReLU(),
                                                nn.Linear(512, 256)
                                                )
                    self.global_target_classifier = self.global_classifier
                    global_spr_size = 256
                elif self.classifier_type == "q_l1":
                    self.global_classifier = QL1Head(self.head, dueling=dueling, type=q_l1_type)
                    global_spr_size = self.global_classifier.out_features
                    self.global_target_classifier = self.global_classifier
                elif self.classifier_type == "q_l2":
                    self.global_classifier = nn.Sequential(self.head, nn.Flatten(-2, -1))
                    self.global_target_classifier = self.global_classifier
                    global_spr_size = 256
                elif self.classifier_type == "bilinear":
                    self.global_classifier = nn.Sequential(nn.Flatten(-3, -1),
                                                           nn.Linear(self.hidden_size*self.pixels,
                                                                     self.hidden_size*self.pixels))
                    self.global_target_classifier = nn.Flatten(-3, -1)
                elif self.classifier_type == "none":
                    self.global_classifier = nn.Flatten(-3, -1)
                    self.global_target_classifier = nn.Flatten(-3, -1)

                    global_spr_size = self.hidden_size*self.pixels
                if final_classifier == "mlp":
                    global_final_hidden_size = int(global_spr_size * pred_hidden_ratio)
                    self.global_final_classifier = nn.Sequential(
                        nn.Linear(global_spr_size, global_final_hidden_size),
                        nn.BatchNorm1d(global_final_hidden_size),
                        nn.ReLU(),
                        nn.Linear(global_final_hidden_size, global_spr_size)
                    )
                elif final_classifier == "linear":
                    self.global_final_classifier = nn.Sequential(
                        nn.Linear(global_spr_size, global_spr_size),
                    )
                elif final_classifier == "none":
                    self.global_final_classifier = nn.Identity()
            else:
                self.global_classifier = self.global_target_classifier = nn.Identity()

            if self.momentum_encoder:
                self.target_encoder = copy.deepcopy(self.conv)
                self.target_encoder_proj = copy.deepcopy(self.conv_proj)
                self.target_renormalize_ln = copy.deepcopy(self.renormalize_ln)
                self.global_target_classifier = copy.deepcopy(self.global_target_classifier)
                self.local_target_classifier = copy.deepcopy(self.local_target_classifier)
                for param in (list(self.target_encoder.parameters())
                            + list(self.target_encoder_proj.parameters())
                            + list(self.target_renormalize_ln.parameters())
                            + list(self.global_target_classifier.parameters())
                            + list(self.local_target_classifier.parameters())):
                    param.requires_grad = False

            elif not self.shared_encoder:
                # Use a separate target encoder on the last frame only.
                self.global_target_classifier = copy.deepcopy(self.global_target_classifier)
                self.local_target_classifier = copy.deepcopy(self.local_target_classifier)
                if self.stack_actions:
                    input_size = c - 1
                else:
                    input_size = c
                self.target_encoder = Conv2dModel(in_channels=input_size,
                                                  channels=[32, 64, 64],
                                                  kernel_sizes=[8, 4, 3],
                                                  strides=[4, 2, 1],
                                                  paddings=[0, 0, 0],
                                                  use_maxpool=False,
                                                  )

            elif self.shared_encoder:
                self.target_encoder = self.conv

        print("Initialized model with {} parameters".format(count_parameters(self)))