Exemple #1
0
    def __init__(self, input_shape, architecture_args, device='cuda',
                 revert_normalization=None, **kwargs):
        super(VAE, self).__init__()

        self.args = {
            'input_shape': input_shape,
            'architecture_args': architecture_args,
            'device': device,
            'class': 'VAE'
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.device = device
        self.revert_normalization = revert_normalization

        # used later
        self._vis_iters = 0

        # initialize the network
        self.hidden_shape = [None, self.architecture_args['hidden_dim']]

        self.decoder, _ = nn_utils.parse_feed_forward(args=self.architecture_args['decoder'],
                                                      input_shape=self.hidden_shape)
        self.decoder = self.decoder.to(self.device)

        self.encoder, _ = nn_utils.parse_feed_forward(args=self.architecture_args['encoder'],
                                                      input_shape=self.input_shape)

        self.encoder = self.encoder.to(self.device)
Exemple #2
0
    def __init__(self,
                 input_shape,
                 architecture_args,
                 pretrained_arg=None,
                 device="cuda",
                 grad_weight_decay=0.0,
                 grad_l1_penalty=0.0,
                 lamb=1.0,
                 **kwargs):
        super(PredictGradOutputGeneralFormUseLabel, self).__init__(**kwargs)

        self.args = {
            "input_shape": input_shape,
            "architecture_args": architecture_args,
            "device": device,
            "pretrained_arg": pretrained_arg,
            "grad_weight_decay": grad_weight_decay,
            "grad_l1_penalty": grad_l1_penalty,
            "lamb": lamb,
            "class": "PredictGradOutputGeneralFormUseLabel",
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.pretrained_arg = pretrained_arg
        self.device = device
        self.grad_weight_decay = grad_weight_decay
        self.grad_l1_penalty = grad_l1_penalty
        self.lamb = lamb

        # initialize the network
        self.classifier, _ = nn_utils.parse_feed_forward(
            args=self.architecture_args["classifier"],
            input_shape=self.input_shape)
        self.classifier = self.classifier.to(self.device)
        self.num_classes = self.architecture_args["classifier"][-1]["dim"]

        if self.pretrained_arg is not None:
            self.q_base = pretrained_models.get_pretrained_model(
                self.pretrained_arg, self.input_shape, self.device)
            q_base_shape = self.q_base.output_shape
        else:
            self.q_base, q_base_shape = nn_utils.parse_feed_forward(
                args=self.architecture_args["q-base"],
                input_shape=self.input_shape)
            self.q_base = self.q_base.to(self.device)

        # NOTE: we want to use classifier parameters too
        # TODO: find a good parametrization
        self.q_top = torch.nn.Sequential(
            torch.nn.Linear(q_base_shape[-1] + 2 * self.num_classes, 128),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(128, self.num_classes),
        ).to(self.device)
    def __init__(self,
                 input_shape,
                 architecture_args,
                 pretrained_arg=None,
                 device='cuda',
                 grad_weight_decay=0.0,
                 grad_l1_penalty=0.0,
                 lamb=1.0,
                 **kwargs):
        super(PredictGradOutputGeneralFormUseLabel, self).__init__(**kwargs)

        self.args = {
            'input_shape': input_shape,
            'architecture_args': architecture_args,
            'device': device,
            'pretrained_arg': pretrained_arg,
            'grad_weight_decay': grad_weight_decay,
            'grad_l1_penalty': grad_l1_penalty,
            'lamb': lamb,
            'class': 'PredictGradOutputGeneralFormUseLabel'
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.pretrained_arg = pretrained_arg
        self.device = device
        self.grad_weight_decay = grad_weight_decay
        self.grad_l1_penalty = grad_l1_penalty
        self.lamb = lamb

        # initialize the network
        self.classifier, _ = nn_utils.parse_feed_forward(
            args=self.architecture_args['classifier'],
            input_shape=self.input_shape)
        self.classifier = self.classifier.to(self.device)
        self.num_classes = self.architecture_args['classifier'][-1]['dim']

        if self.pretrained_arg is not None:
            self.q_base = pretrained_models.get_pretrained_model(
                self.pretrained_arg, self.input_shape, self.device)
            q_base_shape = self.q_base.output_shape
        else:
            self.q_base, q_base_shape = nn_utils.parse_feed_forward(
                args=self.architecture_args['q-base'],
                input_shape=self.input_shape)
            self.q_base = self.q_base.to(self.device)

        # NOTE: we want to use classifier parameters too
        # TODO: find a good parametrization
        self.q_top = torch.nn.Sequential(
            torch.nn.Linear(q_base_shape[-1] + 2 * self.num_classes, 128),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(128, self.num_classes)).to(self.device)
    def __init__(self, input_shape, architecture_args, pretrained_arg=None, device='cuda',
                 grad_weight_decay=0.0, grad_l1_penalty=0.0, lamb=1.0, **kwargs):
        super(PenalizeLastLayerFixedForm, self).__init__(**kwargs)

        self.args = {
            'input_shape': input_shape,
            'architecture_args': architecture_args,
            'pretrained_arg': pretrained_arg,
            'device': device,
            'grad_weight_decay': grad_weight_decay,
            'grad_l1_penalty': grad_l1_penalty,
            'lamb': lamb,
            'class': 'PenalizeLastLayerFixedForm'
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.pretrained_arg = pretrained_arg
        self.device = device
        self.grad_weight_decay = grad_weight_decay
        self.grad_l1_penalty = grad_l1_penalty
        self.lamb = lamb

        # initialize the network
        self.num_classes = self.architecture_args['classifier'][-1]['dim']
        self.last_layer_dim = self.architecture_args['classifier'][-2]['dim']

        self.classifier_base, _ = nn_utils.parse_feed_forward(args=self.architecture_args['classifier'][:-1],
                                                              input_shape=self.input_shape)
        self.classifier_base = self.classifier_base.to(self.device)
        self.classifier_last_layer = torch.nn.Linear(self.last_layer_dim,
                                                     self.num_classes,
                                                     bias=False).to(self.device)

        if self.pretrained_arg is not None:
            q_base = pretrained_models.get_pretrained_model(self.pretrained_arg, self.input_shape, self.device)

            # create the trainable part of the q_network
            q_top = torch.nn.Sequential(
                torch.nn.Linear(q_base.output_shape[-1], 128),
                torch.nn.ReLU(inplace=True),
                torch.nn.Linear(128, self.num_classes)).to(self.device)

            self.q_network = torch.nn.Sequential(q_base, q_top)
        else:
            self.q_network, _ = nn_utils.parse_feed_forward(args=self.architecture_args['q-network'],
                                                            input_shape=self.input_shape)
            self.q_network = self.q_network.to(self.device)
Exemple #5
0
    def __init__(self,
                 input_shape,
                 architecture_args,
                 pretrained_arg=None,
                 device='cuda',
                 loss_function='ce',
                 add_noise=False,
                 noise_type='Gaussian',
                 noise_std=0.0,
                 loss_function_param=None,
                 load_from=None,
                 **kwargs):
        super(StandardClassifier, self).__init__(**kwargs)

        self.args = {
            'input_shape': input_shape,
            'architecture_args': architecture_args,
            'pretrained_arg': pretrained_arg,
            'device': device,
            'loss_function': loss_function,
            'add_noise': add_noise,
            'noise_type': noise_type,
            'noise_std': noise_std,
            'loss_function_param': loss_function_param,
            'load_from': load_from,
            'class': 'StandardClassifier'
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.pretrained_arg = pretrained_arg
        self.device = device
        self.loss_function = loss_function
        self.add_noise = add_noise
        self.noise_type = noise_type
        self.noise_std = noise_std
        self.loss_function_param = loss_function_param
        self.load_from = load_from

        # initialize the network
        self.repr_net = pretrained_models.get_pretrained_model(
            self.pretrained_arg, self.input_shape, self.device)
        self.repr_shape = self.repr_net.output_shape
        self.classifier, output_shape = nn_utils.parse_feed_forward(
            args=self.architecture_args['classifier'],
            input_shape=self.repr_shape)
        self.num_classes = output_shape[-1]
        self.classifier = self.classifier.to(self.device)
        self.grad_noise_class = nn_utils.get_grad_noise_class(
            standard_dev=noise_std, q_dist=noise_type)

        if self.load_from is not None:
            print("Loading the classifier model from {}".format(load_from))
            stored_net = utils.load(load_from, device='cpu')
            stored_net_params = dict(stored_net.classifier.named_parameters())
            for key, param in self.classifier.named_parameters():
                param.data = stored_net_params[key].data.to(self.device)
    def __init__(self,
                 input_shape,
                 architecture_args,
                 device="cuda",
                 revert_normalization=None,
                 **kwargs):
        super(VAE, self).__init__()

        self.args = {
            "input_shape": input_shape,
            "architecture_args": architecture_args,
            "device": device,
            "class": "VAE",
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.device = device
        self.revert_normalization = revert_normalization

        # used later
        self._vis_iters = 0

        # initialize the network
        self.hidden_shape = [None, self.architecture_args["hidden_dim"]]

        self.decoder, _ = nn_utils.parse_feed_forward(
            args=self.architecture_args["decoder"],
            input_shape=self.hidden_shape)
        self.decoder = self.decoder.to(self.device)

        self.encoder, _ = nn_utils.parse_feed_forward(
            args=self.architecture_args["encoder"],
            input_shape=self.input_shape)

        self.encoder = self.encoder.to(self.device)
    def __init__(
        self,
        input_shape,
        architecture_args,
        pretrained_arg=None,
        device="cuda",
        loss_function="ce",
        add_noise=False,
        noise_type="Gaussian",
        noise_std=0.0,
        loss_function_param=None,
        **kwargs
    ):
        super(StandardClassifierWithNoise, self).__init__(**kwargs)

        self.args = {
            "input_shape": input_shape,
            "architecture_args": architecture_args,
            "pretrained_arg": pretrained_arg,
            "device": device,
            "loss_function": loss_function,
            "add_noise": add_noise,
            "noise_type": noise_type,
            "noise_std": noise_std,
            "loss_function_param": loss_function_param,
            "class": "StandardClassifierWithNoise",
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.pretrained_arg = pretrained_arg
        self.device = device
        self.loss_function = loss_function
        self.add_noise = add_noise
        self.noise_type = noise_type
        self.noise_std = noise_std
        self.loss_function_param = loss_function_param

        # initialize the network
        self.repr_net = pretrained_models.get_pretrained_model(
            self.pretrained_arg, self.input_shape, self.device
        )
        self.repr_shape = self.repr_net.output_shape
        self.classifier, output_shape = nn_utils.parse_feed_forward(
            args=self.architecture_args["classifier"], input_shape=self.repr_shape
        )
        self.num_classes = output_shape[-1]
        self.classifier = self.classifier.to(self.device)
Exemple #8
0
    def __init__(self,
                 input_shape,
                 architecture_args,
                 pretrained_arg=None,
                 device='cuda',
                 loss_function='ce',
                 add_noise=False,
                 noise_type='Gaussian',
                 noise_std=0.0,
                 loss_function_param=None,
                 **kwargs):
        super(StandardClassifierWithNoise, self).__init__(**kwargs)

        self.args = {
            'input_shape': input_shape,
            'architecture_args': architecture_args,
            'pretrained_arg': pretrained_arg,
            'device': device,
            'loss_function': loss_function,
            'add_noise': add_noise,
            'noise_type': noise_type,
            'noise_std': noise_std,
            'loss_function_param': loss_function_param,
            'class': 'StandardClassifierWithNoise'
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.pretrained_arg = pretrained_arg
        self.device = device
        self.loss_function = loss_function
        self.add_noise = add_noise
        self.noise_type = noise_type
        self.noise_std = noise_std
        self.loss_function_param = loss_function_param

        # initialize the network
        self.repr_net = pretrained_models.get_pretrained_model(
            self.pretrained_arg, self.input_shape, self.device)
        self.repr_shape = self.repr_net.output_shape
        self.classifier, output_shape = nn_utils.parse_feed_forward(
            args=self.architecture_args['classifier'],
            input_shape=self.repr_shape)
        self.num_classes = output_shape[-1]
        self.classifier = self.classifier.to(self.device)
    def __init__(
        self,
        input_shape,
        architecture_args,
        pretrained_arg=None,
        device="cuda",
        loss_function="ce",
        add_noise=False,
        noise_type="Gaussian",
        noise_std=0.0,
        loss_function_param=None,
        load_from=None,
        **kwargs
    ):
        super(StandardClassifier, self).__init__(**kwargs)

        self.args = {
            "input_shape": input_shape,
            "architecture_args": architecture_args,
            "pretrained_arg": pretrained_arg,
            "device": device,
            "loss_function": loss_function,
            "add_noise": add_noise,
            "noise_type": noise_type,
            "noise_std": noise_std,
            "loss_function_param": loss_function_param,
            "load_from": load_from,
            "class": "StandardClassifier",
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.pretrained_arg = pretrained_arg
        self.device = device
        self.loss_function = loss_function
        self.add_noise = add_noise
        self.noise_type = noise_type
        self.noise_std = noise_std
        self.loss_function_param = loss_function_param
        self.load_from = load_from

        # initialize the network
        self.repr_net = pretrained_models.get_pretrained_model(
            self.pretrained_arg, self.input_shape, self.device
        )
        self.repr_shape = self.repr_net.output_shape
        self.classifier, output_shape = nn_utils.parse_feed_forward(
            args=self.architecture_args["classifier"], input_shape=self.repr_shape
        )
        self.num_classes = output_shape[-1]
        self.classifier = self.classifier.to(self.device)
        self.grad_noise_class = nn_utils.get_grad_noise_class(
            standard_dev=noise_std, q_dist=noise_type
        )

        if self.load_from is not None:
            print("Loading the classifier model from {}".format(load_from))
            stored_net = utils.load(load_from, device="cpu")
            stored_net_params = dict(stored_net.classifier.named_parameters())
            for key, param in self.classifier.named_parameters():
                param.data = stored_net_params[key].data.to(self.device)
    def __init__(self,
                 input_shape,
                 architecture_args,
                 pretrained_arg=None,
                 device="cuda",
                 grad_weight_decay=0.0,
                 grad_l1_penalty=0.0,
                 lamb=1.0,
                 **kwargs):
        super(PenalizeLastLayerFixedForm, self).__init__(**kwargs)

        self.args = {
            "input_shape": input_shape,
            "architecture_args": architecture_args,
            "pretrained_arg": pretrained_arg,
            "device": device,
            "grad_weight_decay": grad_weight_decay,
            "grad_l1_penalty": grad_l1_penalty,
            "lamb": lamb,
            "class": "PenalizeLastLayerFixedForm",
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.pretrained_arg = pretrained_arg
        self.device = device
        self.grad_weight_decay = grad_weight_decay
        self.grad_l1_penalty = grad_l1_penalty
        self.lamb = lamb

        # initialize the network
        self.num_classes = self.architecture_args["classifier"][-1]["dim"]
        self.last_layer_dim = self.architecture_args["classifier"][-2]["dim"]

        self.classifier_base, _ = nn_utils.parse_feed_forward(
            args=self.architecture_args["classifier"][:-1],
            input_shape=self.input_shape)
        self.classifier_base = self.classifier_base.to(self.device)
        self.classifier_last_layer = torch.nn.Linear(self.last_layer_dim,
                                                     self.num_classes,
                                                     bias=False).to(
                                                         self.device)

        if self.pretrained_arg is not None:
            q_base = pretrained_models.get_pretrained_model(
                self.pretrained_arg, self.input_shape, self.device)

            # create the trainable part of the q_network
            q_top = torch.nn.Sequential(
                torch.nn.Linear(q_base.output_shape[-1], 128),
                torch.nn.ReLU(inplace=True),
                torch.nn.Linear(128, self.num_classes),
            ).to(self.device)

            self.q_network = torch.nn.Sequential(q_base, q_top)
        else:
            self.q_network, _ = nn_utils.parse_feed_forward(
                args=self.architecture_args["q-network"],
                input_shape=self.input_shape)
            self.q_network = self.q_network.to(self.device)
    def __init__(self,
                 input_shape,
                 architecture_args,
                 pretrained_arg=None,
                 device='cuda',
                 grad_weight_decay=0.0,
                 grad_l1_penalty=0.0,
                 lamb=1.0,
                 sample_from_q=False,
                 q_dist='Gaussian',
                 loss_function='ce',
                 detach=True,
                 load_from=None,
                 warm_up=0,
                 **kwargs):
        super(PredictGradOutput, self).__init__(**kwargs)

        self.args = {
            'input_shape': input_shape,
            'architecture_args': architecture_args,
            'pretrained_arg': pretrained_arg,
            'device': device,
            'grad_weight_decay': grad_weight_decay,
            'grad_l1_penalty': grad_l1_penalty,
            'lamb': lamb,
            'sample_from_q': sample_from_q,
            'q_dist': q_dist,
            'loss_function': loss_function,
            'detach': detach,
            'load_from': load_from,
            'warm_up': warm_up,
            'class': 'PredictGradOutput'
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.pretrained_arg = pretrained_arg
        self.device = device
        self.grad_weight_decay = grad_weight_decay
        self.grad_l1_penalty = grad_l1_penalty
        self.lamb = lamb
        self.sample_from_q = sample_from_q
        self.q_dist = q_dist
        self.detach = detach
        self.loss_function = loss_function
        self.load_from = load_from
        self.warm_up = warm_up

        # lamb is the coefficient in front of the H(p,q) term. It controls the variance of predicted gradients.
        if self.q_dist == 'Gaussian':
            self.grad_replacement_class = nn_utils.get_grad_replacement_class(
                sample=self.sample_from_q,
                standard_dev=np.sqrt(1.0 / 2.0 / (self.lamb + 1e-12)),
                q_dist=self.q_dist)
        elif self.q_dist == 'Laplace':
            self.grad_replacement_class = nn_utils.get_grad_replacement_class(
                sample=self.sample_from_q,
                standard_dev=np.sqrt(2.0) / (self.lamb + 1e-6),
                q_dist=self.q_dist)
        elif self.q_dist == 'dot':
            assert not self.sample_from_q
            self.grad_replacement_class = nn_utils.get_grad_replacement_class(
                sample=False)
        else:
            raise NotImplementedError()

        # initialize the network
        self.classifier, output_shape = nn_utils.parse_feed_forward(
            args=self.architecture_args['classifier'],
            input_shape=self.input_shape)
        self.classifier = self.classifier.to(self.device)
        self.num_classes = output_shape[-1]

        if self.pretrained_arg is not None:
            q_base = pretrained_models.get_pretrained_model(
                self.pretrained_arg, self.input_shape, self.device)

            # create the trainable part of the q_network
            q_top = torch.nn.Sequential(
                torch.nn.Linear(q_base.output_shape[-1], 128),
                torch.nn.ReLU(inplace=True),
                torch.nn.Linear(128, self.num_classes)).to(self.device)

            self.q_network = torch.nn.Sequential(q_base, q_top)
        else:
            self.q_network, _ = nn_utils.parse_feed_forward(
                args=self.architecture_args['q-network'],
                input_shape=self.input_shape)
            self.q_network = self.q_network.to(self.device)

            if self.load_from is not None:
                print("Loading the gradient predictor model from {}".format(
                    load_from))
                stored_net = utils.load(load_from, device='cpu')
                stored_net_params = dict(
                    stored_net.classifier.named_parameters())
                for key, param in self.q_network.named_parameters():
                    param.data = stored_net_params[key].data.to(self.device)

        self.q_loss = None
        if self.loss_function == 'none':  # predicted gradient has general form
            self.q_loss = torch.nn.Sequential(
                torch.nn.Linear(2 * self.num_classes, 128),
                torch.nn.ReLU(inplace=True),
                torch.nn.Linear(128, self.num_classes)).to(self.device)
    def __init__(self,
                 input_shape,
                 architecture_args,
                 pretrained_arg=None,
                 device='cuda',
                 grad_weight_decay=0.0,
                 grad_l1_penalty=0.0,
                 lamb=1.0,
                 small_qtop=False,
                 sample_from_q=False,
                 **kwargs):
        super(PredictGradOutputFixedFormWithConfusion, self).__init__(**kwargs)

        self.args = {
            'input_shape': input_shape,
            'architecture_args': architecture_args,
            'pretrained_arg': pretrained_arg,
            'device': device,
            'grad_weight_decay': grad_weight_decay,
            'grad_l1_penalty': grad_l1_penalty,
            'lamb': lamb,
            'small_qtop': small_qtop,
            'sample_from_q': sample_from_q,
            'class': 'PredictGradOutputFixedFormWithConfusion'
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.pretrained_arg = pretrained_arg
        self.device = device
        self.grad_weight_decay = grad_weight_decay
        self.grad_l1_penalty = grad_l1_penalty
        self.lamb = lamb
        self.small_qtop = small_qtop
        self.sample_from_q = sample_from_q
        self.grad_replacement_class = nn_utils.get_grad_replacement_class(
            sample=self.sample_from_q,
            standard_dev=np.sqrt(1.0 / 2.0 / (self.lamb + 1e-12)))

        # initialize the network
        self.classifier, output_shape = nn_utils.parse_feed_forward(
            args=self.architecture_args['classifier'],
            input_shape=self.input_shape)
        self.classifier = self.classifier.to(self.device)
        self.num_classes = output_shape[-1]

        if self.pretrained_arg is not None:
            self.q_base = pretrained_models.get_pretrained_model(
                self.pretrained_arg, self.input_shape, self.device)
            q_base_shape = self.q_base.output_shape
        else:
            self.q_base, q_base_shape = nn_utils.parse_feed_forward(
                args=self.architecture_args['q-base'],
                input_shape=self.input_shape)
            self.q_base = self.q_base.to(self.device)

        if small_qtop:
            self.q_top = torch.nn.Linear(q_base_shape[-1],
                                         self.num_classes).to(self.device)
        else:
            self.q_top = torch.nn.Sequential(
                torch.nn.Linear(q_base_shape[-1], 128),
                torch.nn.ReLU(inplace=True),
                torch.nn.Linear(128, self.num_classes)).to(self.device)

        # the confusion matrix trainable logits, (true, observed)
        Q_init = torch.zeros(size=(self.num_classes, self.num_classes),
                             device=self.device,
                             dtype=torch.float)
        Q_init += 1.0 * torch.eye(self.num_classes,
                                  device=self.device,
                                  dtype=torch.float)  # TODO: tune the constant
        self.Q_logits = torch.nn.Parameter(Q_init, requires_grad=True)
Exemple #13
0
    def __init__(self,
                 input_shape,
                 architecture_args,
                 pretrained_arg=None,
                 device="cuda",
                 grad_weight_decay=0.0,
                 grad_l1_penalty=0.0,
                 lamb=1.0,
                 sample_from_q=False,
                 q_dist="Gaussian",
                 loss_function="ce",
                 detach=True,
                 load_from=None,
                 warm_up=0,
                 **kwargs):
        super(PredictGradOutput, self).__init__(**kwargs)

        self.args = {
            "input_shape": input_shape,
            "architecture_args": architecture_args,
            "pretrained_arg": pretrained_arg,
            "device": device,
            "grad_weight_decay": grad_weight_decay,
            "grad_l1_penalty": grad_l1_penalty,
            "lamb": lamb,
            "sample_from_q": sample_from_q,
            "q_dist": q_dist,
            "loss_function": loss_function,
            "detach": detach,
            "load_from": load_from,
            "warm_up": warm_up,
            "class": "PredictGradOutput",
        }

        assert len(input_shape) == 3
        self.input_shape = [None] + list(input_shape)
        self.architecture_args = architecture_args
        self.pretrained_arg = pretrained_arg
        self.device = device
        self.grad_weight_decay = grad_weight_decay
        self.grad_l1_penalty = grad_l1_penalty
        self.lamb = lamb
        self.sample_from_q = sample_from_q
        self.q_dist = q_dist
        self.detach = detach
        self.loss_function = loss_function
        self.load_from = load_from
        self.warm_up = warm_up

        # lamb is the coefficient in front of the H(p,q) term. It controls the variance of predicted gradients.
        if self.q_dist == "Gaussian":
            self.grad_replacement_class = nn_utils.get_grad_replacement_class(
                sample=self.sample_from_q,
                standard_dev=np.sqrt(1.0 / 2.0 / (self.lamb + 1e-12)),
                q_dist=self.q_dist,
            )
        elif self.q_dist == "Laplace":
            self.grad_replacement_class = nn_utils.get_grad_replacement_class(
                sample=self.sample_from_q,
                standard_dev=np.sqrt(2.0) / (self.lamb + 1e-6),
                q_dist=self.q_dist,
            )
        elif self.q_dist == "dot":
            assert not self.sample_from_q
            self.grad_replacement_class = nn_utils.get_grad_replacement_class(
                sample=False)
        else:
            raise NotImplementedError()

        # initialize the network
        self.classifier, output_shape = nn_utils.parse_feed_forward(
            args=self.architecture_args["classifier"],
            input_shape=self.input_shape)
        self.classifier = self.classifier.to(self.device)
        self.num_classes = output_shape[-1]

        if self.pretrained_arg is not None:
            q_base = pretrained_models.get_pretrained_model(
                self.pretrained_arg, self.input_shape, self.device)

            # create the trainable part of the q_network
            q_top = torch.nn.Sequential(
                torch.nn.Linear(q_base.output_shape[-1], 128),
                torch.nn.ReLU(inplace=True),
                torch.nn.Linear(128, self.num_classes),
            ).to(self.device)

            self.q_network = torch.nn.Sequential(q_base, q_top)
        else:
            self.q_network, _ = nn_utils.parse_feed_forward(
                args=self.architecture_args["q-network"],
                input_shape=self.input_shape)
            self.q_network = self.q_network.to(self.device)

            if self.load_from is not None:
                print("Loading the gradient predictor model from {}".format(
                    load_from))
                stored_net = utils.load(load_from, device="cpu")
                stored_net_params = dict(
                    stored_net.classifier.named_parameters())
                for key, param in self.q_network.named_parameters():
                    param.data = stored_net_params[key].data.to(self.device)

        self.q_loss = None
        if self.loss_function == "none":  # predicted gradient has general form
            self.q_loss = torch.nn.Sequential(
                torch.nn.Linear(2 * self.num_classes, 128),
                torch.nn.ReLU(inplace=True),
                torch.nn.Linear(128, self.num_classes),
            ).to(self.device)