Exemplo n.º 1
0
    def transform(self, X, y):
        """Additional transformations on X and y.

        By default, they are cast to torch tensors. Override this if
        you want a different behavior.

        Note: If you use this in conjuction with pytorch's DataLoader,
        the latter will call the dataset for each row separately,
        which means that the incoming X and y each are single rows.

        """
        # pytorch DataLoader cannot deal with None so we use 0 as a
        # placeholder value. We only return a Tensor with one value
        # (as opposed to ``batchsz`` values) since the pytorch
        # DataLoader calls __getitem__ for each row in the batch
        # anyway, which results in a dummy ``y`` value for each row in
        # the batch.
        # FIXME:  BatchScoring and EpochScoring with caching will get
        # FIXME:: this placeholder value as `y` since they are operating
        # FIXME:: on batch level. This might be easy to trip over, esp.
        # FIXME:: since this value may look meaningful.
        y = torch.Tensor([0]) if y is None else y

        return (
            to_tensor(X, device=self.device),
            to_tensor(y, device=self.device),
        )
Exemplo n.º 2
0
def compute_amplitude_gradients_for_X(model, X):
    device = next(model.parameters()).device
    ffted = np.fft.rfft(X, axis=2)
    amps = np.abs(ffted)
    phases = np.angle(ffted)
    amps_th = to_tensor(amps.astype(np.float32),
                        device=device).requires_grad_(True)
    phases_th = to_tensor(phases.astype(np.float32),
                          device=device).requires_grad_(True)

    fft_coefs = amps_th.unsqueeze(-1) * torch.stack(
        (torch.cos(phases_th), torch.sin(phases_th)), dim=-1)
    fft_coefs = fft_coefs.squeeze(3)

    iffted = torch.irfft(fft_coefs, signal_ndim=1, signal_sizes=(X.shape[2], ))

    outs = model(iffted)

    n_filters = outs.shape[1]
    amp_grads_per_filter = np.full((n_filters, ) + ffted.shape,
                                   np.nan,
                                   dtype=np.float32)
    for i_filter in range(n_filters):
        mean_out = torch.mean(outs[:, i_filter])
        mean_out.backward(retain_graph=True)
        amp_grads = to_numpy(amps_th.grad.clone())
        amp_grads_per_filter[i_filter] = amp_grads
        amps_th.grad.zero_()
    assert not np.any(np.isnan(amp_grads_per_filter))
    return amp_grads_per_filter
Exemplo n.º 3
0
    def transform(self, X, y):
        """Additional transformations on X and y.

        By default, they are cast to torch tensors. Override this if
        you want a different behavior.

        Note: If you use this in conjuction with pytorch's DataLoader,
        the latter will call the dataset for each row separately,
        which means that the incoming X and y each are single rows.

        """
        # pytorch DataLoader cannot deal with None so we use 0 as a
        # placeholder value. We only return a Tensor with one value
        # (as opposed to ``batchsz`` values) since the pytorch
        # DataLoader calls __getitem__ for each row in the batch
        # anyway, which results in a dummy ``y`` value for each row in
        # the batch.
        # FIXME:  BatchScoring and EpochScoring with caching will get
        # FIXME:: this placeholder value as `y` since they are operating
        # FIXME:: on batch level. This might be easy to trip over, esp.
        # FIXME:: since this value may look meaningful.
        y = torch.Tensor([0]) if y is None else y

        return (
            to_tensor(X, device=self.device),
            to_tensor(y, device=self.device),
        )
Exemplo n.º 4
0
    def test_with_torch_tensors(self, cv_split_cls, data):
        data.X = to_tensor(data.X, device='cpu')
        data.y = to_tensor(data.y, device='cpu')
        m = self.num_samples // 5
        n = self.num_samples - m
        dataset_train, dataset_valid = cv_split_cls(5)(data)

        assert len(dataset_valid) == m
        assert len(dataset_train) == n
Exemplo n.º 5
0
 def test_with_torch_tensors_and_stratified(self, cv_split_cls, data):
     X = to_tensor(data[0], use_cuda=False)
     num_expected = self.num_samples // 4
     y = np.hstack([np.repeat([0, 0, 0], num_expected),
                    np.repeat([1], num_expected)])
     y = to_tensor(y, use_cuda=False)
     _, _, y_train, y_valid = cv_split_cls(5, stratified=True)(X, y)
     assert y_train.sum() == 0.8 * num_expected
     assert y_valid.sum() == 0.2 * num_expected
Exemplo n.º 6
0
    def test_with_torch_tensors(self, cv_split_cls, data):
        data.X = to_tensor(data.X, device='cpu')
        data.y = to_tensor(data.y, device='cpu')
        m = self.num_samples // 5
        n = self.num_samples - m
        dataset_train, dataset_valid = cv_split_cls(5)(data)

        assert len(dataset_valid) == m
        assert len(dataset_train) == n
Exemplo n.º 7
0
    def test_with_torch_tensors(self, cv_split_cls, data):
        data.X = to_tensor(data.X, use_cuda=False)
        data.y = to_tensor(data.y, use_cuda=False)
        m = self.num_samples // 5
        n = self.num_samples - m
        dataset_train, dataset_valid = cv_split_cls(5)(data)

        assert len(dataset_valid) == m
        assert len(dataset_train) == n
Exemplo n.º 8
0
    def test_with_torch_tensors(self, cv_split_cls, data):
        X, y = data
        X = to_tensor(X, use_cuda=False)
        y = to_tensor(y, use_cuda=False)
        m = self.num_samples // 5
        n = self.num_samples - m

        X_train, X_valid, y_train, y_valid = cv_split_cls(5)(X, y)
        assert len(X_train) == len(y_train) == n
        assert len(X_valid) == len(y_valid) == m
Exemplo n.º 9
0
    def test_sparse_tensor_not_accepted_raises(self, to_tensor, device):
        if device == 'cuda' and not torch.cuda.is_available():
            pytest.skip()

        inp = sparse.csr_matrix(np.zeros((5, 3)).astype(np.float32))
        with pytest.raises(TypeError) as exc:
            to_tensor(inp, device=device)

        msg = ("Sparse matrices are not supported. Set "
               "accept_sparse=True to allow sparse matrices.")
        assert exc.value.args[0] == msg
Exemplo n.º 10
0
    def test_with_torch_tensors_and_stratified(self, cv_split_cls, data):
        num_expected = self.num_samples // 4
        data.X = to_tensor(data.X, device='cpu')
        y = np.hstack([np.repeat([0, 0, 0], num_expected),
                       np.repeat([1], num_expected)])
        data.y = to_tensor(y, device='cpu')

        dataset_train, dataset_valid = cv_split_cls(5, stratified=True)(data, y)
        y_train = data_from_dataset(dataset_train)[1]
        y_valid = data_from_dataset(dataset_valid)[1]

        assert y_train.sum() == 0.8 * num_expected
        assert y_valid.sum() == 0.2 * num_expected
Exemplo n.º 11
0
    def test_device_setting_cuda(self, to_tensor):
        x = np.ones((2, 3, 4))
        t = to_tensor(x, device='cpu')
        assert t.device.type == 'cpu'

        t = to_tensor(x, device='cuda')
        assert t.device.type.startswith('cuda')

        t = to_tensor(t, device='cuda')
        assert t.device.type.startswith('cuda')

        t = to_tensor(t, device='cpu')
        assert t.device.type == 'cpu'
Exemplo n.º 12
0
    def test_device_setting_cuda(self, to_tensor):
        x = np.ones((2, 3, 4))
        t = to_tensor(x, device='cpu')
        assert t.device.type == 'cpu'

        t = to_tensor(x, device='cuda')
        assert t.device.type.startswith('cuda')

        t = to_tensor(t, device='cuda')
        assert t.device.type.startswith('cuda')

        t = to_tensor(t, device='cpu')
        assert t.device.type == 'cpu'
Exemplo n.º 13
0
    def choose_r(self, X, y):
        if self.seed is not None:
            np.random.seed(self.seed)

        idx = np.arange(X.shape[0])
        idx_r = np.random.choice(idx,
                                 size=(int(self.reference_set_size_ratio *
                                           X.shape[0]), ),
                                 replace=False)
        self.xR, self.yR = to_tensor(X[idx_r], device=self.device), to_tensor(
            y[idx_r], device=self.device)

        return idx_r
Exemplo n.º 14
0
    def test_with_torch_tensors_and_stratified(self, cv_split_cls, data):
        num_expected = self.num_samples // 4
        data.X = to_tensor(data.X, device='cpu')
        y = np.hstack([np.repeat([0, 0, 0], num_expected),
                       np.repeat([1], num_expected)])
        data.y = to_tensor(y, device='cpu')

        dataset_train, dataset_valid = cv_split_cls(5, stratified=True)(data, y)
        y_train = data_from_dataset(dataset_train)[1]
        y_valid = data_from_dataset(dataset_valid)[1]

        assert y_train.sum() == 0.8 * num_expected
        assert y_valid.sum() == 0.2 * num_expected
Exemplo n.º 15
0
    def predict(self, X):
        self.module_.eval()

        output_size = self.module__output_size
        assert output_size % 2 == 0
        preds = torch.zeros((X.shape[0], self.sample_count, output_size))

        with torch.no_grad():
            for j in range(self.sample_count):
                preds[:, j] = to_tensor(
                    self.predict_proba(to_tensor(X, device=self.device)),
                    self.device)

        return combine_uncertainties(preds, output_size)
    def predict(self, X):
        self.module_.eval()

        output_size = self.module__output_size
        assert output_size % 2 == 0
        outut_dim = int(output_size / 2)

        pred = to_tensor(self.predict_proba(to_tensor(X, device=self.device)), self.device)

        mean = pred[..., :outut_dim]
        # use softplus to be numerically stable and not depend on activation functions of nn
        softplus = torch.nn.Softplus()
        std = softplus(pred[..., outut_dim:])

        return np.stack([to_numpy(mean), to_numpy(std**2)], -1)
    def predict(self, X):
        self.module_.eval()

        # batched predictions, because ram!
        pred_means = []
        pred_vars = []
        for idx in np.arange(0, X.shape[0], self.batch_size):
            predictive_means, predictive_variances = \
                self.module_.predict(to_tensor(X[idx: min(idx + self.batch_size, X.shape[0])], self.device))

            pred_means.append(predictive_means)
            pred_vars.append(predictive_variances)

        predictive_means = torch.cat(pred_means, dim=1)
        predictive_variances = torch.cat(pred_vars, dim=1)

        outut_dim = int(self.module__output_size / 2)

        mean = predictive_means.mean(0)[..., :outut_dim]
        epistemic_var = predictive_variances.mean(0)[..., :outut_dim]
        softplus = torch.nn.Softplus()
        aleotoric_var = softplus(predictive_means.mean(0)[..., outut_dim:])**2
        var = epistemic_var + aleotoric_var

        return np.stack(
            [to_numpy(mean),
             to_numpy(var),
             to_numpy(epistemic_var),
             to_numpy(aleotoric_var)],
            -1)
Exemplo n.º 18
0
    def train_step(self, Xi, yi, **fit_params):
        train_generator = fit_params.pop('train_generator', True)

        self.module_.train()
        self.critic_.train()

        self.optimizer_.zero_grad()
        self.critic_optimizer_.zero_grad()

        b = Xi.shape[0]
        real = to_tensor(Xi, self.device)
        generated = self.module_.generate(b)

        critic_loss = self.critic_.loss(real, generated.detach())
        critic_loss.backward()
        self.critic_optimizer_.step()

        distance = self.critic_.distance(real, generated)

        if train_generator:
            distance.backward()
            self.optimizer_.step()

        self.history.record_batch('critic_loss', critic_loss.item())

        return {
            'critic_loss': critic_loss,
            'distance': distance,
        }
Exemplo n.º 19
0
 def get_loss(self, y_pred, y_true, X=None, training=False):
     if isinstance(X, dict):
         X_X = X['X']
     else:
         X_X = X
     loss = super().get_loss(y_pred, y_true, X_X, training)
     if not training:
         return loss
     if X is not None and self.alpha is not None and self.lambda_bar != 0:
         ksi = self.infer(X, name=self.layer_name)
         alpha = to_tensor(X['alpha'], device=self.device).to(X_X.dtype)
         Z = to_tensor(X['Z'], device=self.device).to(X_X.dtype)
         reg_loss = self.mu * 0.5 * torch.norm(alpha - ksi + Z)
         reg_loss = reg_loss.mean(dim=0)
         loss += reg_loss
     return loss
Exemplo n.º 20
0
    def get_loss(self, y_pred, y_true, X=None, training=False):
        """Return the loss for this batch.

        Parameters
        ----------
        y_pred : torch tensor
          Predicted target values

        y_true : torch tensor
          True target values.

        X : input data, compatible with skorch.dataset.Dataset
          By default, you should be able to pass:

            * numpy arrays
            * torch tensors
            * pandas DataFrame or Series
            * a dictionary of the former three
            * a list/tuple of the former three
            * a Dataset

          If this doesn't work with your data, you have to pass a
          ``Dataset`` that can deal with the data.

        training : bool (default=False)
          Whether train mode should be used or not.

        """
        y_true = to_tensor(y_true, device=self.device)
        return self.criterion_(y_pred, y_true)
Exemplo n.º 21
0
    def train_step(self, Xi, yi=None, **fit_params):

        # turn input data into tensor
        Xi = to_tensor(Xi, device=self.device)

        # create local variables for the generator and discriminator
        discriminator = self.module_.discriminator
        generator = self.module_.generator

        # forward the generator and obtain its data
        fake, latent_Xi, latent_fake = generator(Xi)

        # evaluate real and fake data with the discriminator
        prediction_real, features_real = discriminator(Xi)
        prediction_fake, features_fake = discriminator(fake.detach())

        # create a tensor of ones
        # this is used for the discriminator
        labels_real = ones_like(
            prediction_real, dtype=torch.float32, device=self.device)
        labels_fake = zeros_like(
            prediction_fake, dtype=torch.float32, device=self.device)

        # calculate generator loss
        adversarial_loss = self.module_.adversarial_loss(
            features_real, features_fake)
        contextual_loss = self.module_.contextual_loss(Xi, fake)
        encoder_loss = self.module_.encoder_loss(latent_Xi, latent_fake)
        generator_loss = self.module_.adversarial_weight * adversarial_loss + \
            self.module_.contextual_weight * contextual_loss + \
            self.module_.encoder_weight * encoder_loss

        # calculate discriminator loss
        discriminator_loss_real = self.module_.discriminator_loss(
            prediction_real, labels_real)
        discriminator_loss_fake = self.module_.discriminator_loss(
            prediction_fake, labels_fake)
        discriminator_loss = (discriminator_loss_real +
                              discriminator_loss_fake) * 0.5

        # set gradient of generator optimizer to zero and update generator weights
        self.generator_optimizer_.zero_grad()
        generator_loss.backward(retain_graph=True)
        self.generator_optimizer_.step()

        # set gradient of discriminator optimizer to zero and update discriminator weights
        self.discriminator_optimizer_.zero_grad()
        discriminator_loss.backward()
        self.discriminator_optimizer_.step()

        # record the different loss values in the models training history
        self.history.record_batch('generator_loss', generator_loss.item())
        self.history.record_batch('adversarial_loss', adversarial_loss.item())
        self.history.record_batch('contextual_loss', contextual_loss.item())
        self.history.record_batch('encoder_loss', encoder_loss.item())
        self.history.record_batch(
            'discriminator_loss', discriminator_loss.item())

        # return train loss for skorch
        return {'loss': generator_loss + discriminator_loss}
Exemplo n.º 22
0
 def distance(self, real):
     real = to_tensor(real, self.device)
     self.module_.eval()
     self.critic_.eval()
     with torch.no_grad():
         fake = self.module_.generate(real.shape[0])
         return self.critic_.distance(real, fake).item()
Exemplo n.º 23
0
    def describe_signature(self, df):
        """Describe the signature required for the given data.

        Pass the DataFrame to receive a description of the signature
        required for the module's forward method. The description
        consists of three parts:

        1. The names of the arguments that the forward method
        needs.
        2. The dtypes of the torch tensors passed to forward.
        3. The number of input units that are required for the
        corresponding argument. For the float parameter, this is just
        the number of dimensions of the tensor. For categorical
        parameters, it is the number of unique elements.

        Returns
        -------
        signature : dict
          Returns a dict with each key corresponding to one key
          required for the forward method. The values are dictionaries
          of two elements. The key "dtype" describes the torch dtype
          of the resulting tensor, the key "input_units" describes the
          required number of input units.

        """
        X_dict = self.fit_transform(df)
        signature = {}

        X = X_dict.get('X')
        if X is not None:
            signature['X'] = dict(
                dtype=to_tensor(X, device='cpu').dtype,
                input_units=X.shape[1],
            )

        for key, val in X_dict.items():
            if key == 'X':
                continue

            tensor = to_tensor(val, device='cpu')
            nunique = len(torch.unique(tensor))
            signature[key] = dict(
                dtype=tensor.dtype,
                input_units=nunique,
            )

        return signature
Exemplo n.º 24
0
 def get_loss(self, y_pred, y_true, X=None, training=False):
     y_true = to_tensor(y_true, device='cpu')
     loss_a = torch.abs(y_true.float() - y_pred[:, 1]).mean()
     loss_b = ((y_true.float() - y_pred[:, 1]) ** 2).mean()
     if training:
         self.history.record_batch('loss_a', to_numpy(loss_a))
         self.history.record_batch('loss_b', to_numpy(loss_b))
     return loss_a + loss_b
Exemplo n.º 25
0
    def score(self, X, y=None):
        X = to_tensor(X, device=self.device)

        # create local variables for the generator and discriminator
        discriminator = self.module_.discriminator
        generator = self.module_.generator

        # forward the generator and obtain it's data
        fake, latent_X, latent_fake = generator(X)

        # evaluate real and fake data with the discriminator
        prediction_real, features_real = discriminator(X)
        prediction_fake, features_fake = discriminator(fake.detach())

        # create a tensor of ones
        # this is used for the discriminator
        labels_real = ones_like(
            prediction_real, dtype=torch.float32, device=self.device)
        labels_fake = zeros_like(
            prediction_fake, dtype=torch.float32, device=self.device)

        # calculate generator loss
        adversarial_loss = self.module_.adversarial_loss(
            features_real, features_fake)
        contextual_loss = self.module_.contextual_loss(X, fake)
        encoder_loss = self.module_.encoder_loss(latent_X, latent_fake)

        generator_loss = self.module_.adversarial_weight * adversarial_loss + \
            self.module_.contextual_weight * contextual_loss + \
            self.module_.encoder_weight * encoder_loss

        generator_loss = generator_loss / \
            (self.module_.adversarial_weight +
             self.module_.contextual_weight + self.module_.encoder_weight)

        # calculate discriminator loss
        discriminator_loss_real = self.module_.discriminator_loss(
            prediction_real, labels_real)

        discriminator_loss_fake = self.module_.discriminator_loss(
            prediction_fake, labels_fake)

        discriminator_loss = (discriminator_loss_real +
                              discriminator_loss_fake) * 0.5

        # calculate train loss
        train_loss = generator_loss + discriminator_loss

        # make scores negative
        # GridSearchCV then takes the lowest value
        generator_loss = -1 * generator_loss.item()
        train_loss = -1 * train_loss.item()

        if discriminator_loss.item() < 1e-5:
            discriminator.apply(nets.weights_init)

        # return scores as dictionary
        return {'generator_loss': generator_loss, 'train_loss': train_loss}
Exemplo n.º 26
0
 def infer(self, x, **fit_params):
     x = to_tensor(x, device=self.device)
     if isinstance(x, dict):
         x_dict = self._merge_x_and_fit_params(x, fit_params)
         # set reference sest
         x_dict['XR'] = self.xR
         x_dict['yR'] = self.yR
         return self.module_(**x_dict)
     return self.module_(x, XR=self.xR, yR=self.yR, **fit_params)
Exemplo n.º 27
0
    def test_sparse_tensor(self, to_tensor, device):
        if device == 'cuda' and not torch.cuda.is_available():
            pytest.skip()

        inp = sparse.csr_matrix(np.zeros((5, 3)).astype(np.float32))
        expected = torch.sparse_coo_tensor(size=(5, 3)).to(device)

        result = to_tensor(inp, device=device, accept_sparse=True)
        assert self.tensors_equal(result, expected)
Exemplo n.º 28
0
    def fit(self, X, y=None, **fit_params):
        """Initialize and fit the module.

        If the module was already initialized, by calling fit, the
        module will be re-initialized (unless ``warm_start`` is True).

        Parameters
        ----------
        X : input data, compatible with skorch.dataset.Dataset
          By default, you should be able to pass:

            * numpy arrays
            * torch tensors
            * pandas DataFrame or Series
            * scipy sparse CSR matrices
            * a dictionary of the former three
            * a list/tuple of the former three
            * a Dataset

          If this doesn't work with your data, you have to pass a
          ``Dataset`` that can deal with the data.

        y : target data, compatible with skorch.dataset.Dataset
          The same data types as for ``X`` are supported. If your X is
          a Dataset that contains the target, ``y`` may be set to
          None.

        **fit_params : dict
          Additional parameters passed to the ``forward`` method of
          the module and to the ``self.train_split`` call.

        """
        if not self.warm_start or not self.initialized_:
            self.initialize()

        # set training data of the ExactGP module
        self.module_.set_train_data(
            inputs=to_tensor(X, device=self.device),
            targets=to_tensor(y, device=self.device),
            strict=False,
        )

        self.partial_fit(X, y, **fit_params)
        return self
Exemplo n.º 29
0
    def triplet_infer(self, x):
        """Perform a single inference step on a batch of data.

        Parameters
        ----------
        x : input data
          A batch of the input data.

        """
        x = to_tensor(x, device=self.device)
        return self.module_(x[0], x[1], x[2])
Exemplo n.º 30
0
def test_cropped_trial_epoch_scoring_none_x_test():
    dataset_train = None
    dataset_valid = None
    predictions = np.array(
        [
            [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]],
            [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]],
            [[1.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]],
            [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]],
        ]
    )
    y_true = [torch.tensor([0, 0]), torch.tensor([1, 1])]
    window_inds = [(
        torch.tensor([0, 0]),  # i_window_in_trials
        [None],  # won't be used
        torch.tensor([4, 4]),  # i_window_stops
    ),
        (
        torch.tensor([0, 0]),  # i_window_in_trials
        [None],  # won't be used
        torch.tensor([4, 4]),  # i_window_stops
    )]
    cropped_trial_epoch_scoring = CroppedTrialEpochScoring("accuracy")
    cropped_trial_epoch_scoring.initialize()
    cropped_trial_epoch_scoring.y_preds_ = [
        to_tensor(predictions[:2], device="cpu"),
        to_tensor(predictions[2:], device="cpu"),
    ]
    cropped_trial_epoch_scoring.y_trues_ = y_true
    cropped_trial_epoch_scoring.window_inds_ = window_inds

    mock_skorch_net = MockSkorchNet()
    mock_skorch_net.callbacks_ = [(
        "", cropped_trial_epoch_scoring)]
    output = cropped_trial_epoch_scoring.on_epoch_end(
        mock_skorch_net, dataset_train, dataset_valid
    )
    assert output is None
Exemplo n.º 31
0
    def transform(self, X, y):
        """Additional transformations on X and y.

        By default, they are cast to torch tensors. Override this if
        you want a different behavior.

        Note: If you use this in conjuction with pytorch's DataLoader,
        the latter will call the dataset for each row separately,
        which means that the incoming X and y each are single rows.

        """
        # pytorch DataLoader cannot deal with None so we use 0 as a
        # placeholder value. We only return a Tensor with one value
        # (as opposed to ``batchsz`` values) since the pytorch
        # DataLoader calls __getitem__ for each row in the batch
        # anyway, which results in a dummy ``y`` value for each row in
        # the batch.
        y = torch.Tensor([0]) if y is None else y

        return (
            to_tensor(X, use_cuda=self.use_cuda),
            to_tensor(y, use_cuda=self.use_cuda),
        )
Exemplo n.º 32
0
    def predict(self, X, samples=30):
        self.module_.eval()

        output_size = self.module__dim_y
        samples = 30
        dxy = np.zeros((X.shape[0], samples, output_size))

        with torch.no_grad():
            for j in range(samples):
                dxy[:, j] = to_numpy(
                    self.module_.predict(to_tensor(X, device=self.device),
                                         self.xR, self.yR))

        mean, var = dxy.mean(axis=1), dxy.var(axis=1)

        r = np.stack([mean, var], -1)

        return r
Exemplo n.º 33
0
    def infer(self, x, **fit_params):
        """Perform a single inference step on a batch of data.

        Parameters
        ----------
        x : input data
          A batch of the input data.

        **fit_params : dict
          Additional parameters passed to the ``forward`` method of
          the module and to the ``self.train_split`` call.

        """
        x = to_tensor(x, device=self.device)
        if isinstance(x, dict):
            x_dict = self._merge_x_and_fit_params(x, fit_params)
            return self.module_(**x_dict)
        return self.module_(x, **fit_params)
Exemplo n.º 34
0
    def train_step(self, Xi, yi, **fit_params):
        train_generator = fit_params.pop('train_generator', True)

        self.module_.train()
        self.critic_.train()

        self.optimizer_.zero_grad()
        self.critic_optimizer_.zero_grad()

        b = Xi.shape[0]
        real = to_tensor(Xi, self.device)
        generated = self.module_.generate(b)

        y_real = self.critic_(real)
        y_generated = self.critic_(generated.detach())

        critic_loss = self.get_loss(y_real, torch.ones_like(y_real))
        critic_loss = critic_loss + self.get_loss(
            y_generated, torch.zeros_like(y_generated))
        critic_loss.backward()

        self.critic_optimizer_.step()

        if train_generator:
            y_generated = self.critic_(generated)
            generator_loss = self.get_loss(y_generated,
                                           torch.ones_like(y_generated))
            generator_loss.backward()

            self.optimizer_.step()

        distance = y_real.log().mean() + (1 - y_generated).log().mean()
        distance = distance / 2

        return {
            'critic_loss': critic_loss,
            'distance': distance,
        }