Пример #1
0
class CNPBasic(nn.Module):
    """
    The Neural Process + FiLM: a model for chemical data imputation.
    """
    def __init__(self, in_dim, out_dim, z_dim, n_properties, encoder_dims,
                 decoder_dims):
        """

        :param in_dim: (int) dimensionality of the input x
        :param out_dim: (int) dimensionality of the target variable y
        :param z_dim: (int) dimensionality of the embedding / context vector r
        :param n_properties: (int) the number of unknown properties. Adrenergic = 5; Kinase = 159.
        :param d_encoder_dims: (list of ints) architecture of the descriptor encoder NN.
        :param p_encoder_dims: (list of ints) architecture of the property encoder NN.
        :param decoder_hidden_dims: (list of ints) architecture of the decoder NN.
        """

        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.z_dim = z_dim
        self.d_dim = in_dim - n_properties
        self.n_properties = n_properties
        self.encoder = VanillaNN(in_dim=self.d_dim + self.n_properties,
                                 out_dim=self.z_dim,
                                 hidden_dims=encoder_dims)
        self.decoder = MultiProbabilisticVanillaNN(
            in_dim=self.z_dim,
            out_dim=1,
            n_properties=self.n_properties,
            hidden_dims=decoder_dims,
            restrict_var=False)

    def train_model(self,
                    x,
                    epochs,
                    batch_size,
                    file,
                    print_freq=50,
                    x_test=None,
                    means=None,
                    stds=None,
                    lr=0.001):
        """
        :param x:
        :param epochs:
        :param batch_size:
        :param file:
        :param print_freq:
        :param x_test:
        :param means:
        :param stds:
        :param lr:
        :return:
        """
        self.means = means
        self.stds = stds
        self.dir_name = os.path.dirname(file.name)
        self.file_start = file.name[len(self.dir_name) + 1:-4]

        optimiser = optim.Adam(
            list(self.encoder.parameters()) + list(self.decoder.parameters()),
            lr)

        self.epoch = 0
        for epoch in range(epochs):
            self.epoch = epoch
            optimiser.zero_grad()

            # Select a batch
            batch_idx = torch.randperm(x.shape[0])[:batch_size]
            x_batch = x[batch_idx, ...]  # [batch_size, x.shape[1]]
            target_batch = x_batch[:, -self.n_properties:]

            # Mask of the properties that are missing
            mask_batch = torch.isnan(x_batch[:, -self.n_properties:])

            # To form the context mask we will add properties to the missing values
            mask_context = copy.deepcopy(mask_batch)
            batch_properties = [
                torch.where(~mask_batch[i, ...])[0]
                for i in range(mask_batch.shape[0])
            ]

            for i, properties in enumerate(batch_properties):
                ps = np.random.choice(properties.numpy(),
                                      size=np.random.randint(
                                          low=0, high=properties.shape[0] + 1),
                                      replace=False)

                # add property to those being masked
                mask_context[i, ps] = True
            input_batch = copy.deepcopy(x_batch)
            input_batch[:, -self.n_properties:][mask_context] = 0.0

            z = self.encoder(input_batch)

            mus_y, vars_y = self.decoder.forward(z, mask_batch)

            likelihood_term = 0

            for p in range(self.n_properties):
                target = target_batch[:, p][~mask_batch[:, p]]
                mu_y = mus_y[p].squeeze(1)
                var_y = vars_y[p].squeeze(1)

                ll = (-0.5 * np.log(2 * np.pi) - 0.5 * torch.log(var_y) - 0.5 *
                      ((target - mu_y)**2 / var_y))

                likelihood_term += torch.sum(ll)

            likelihood_term /= torch.sum(~mask_batch)

            loss = -likelihood_term

            if (epoch % print_freq == 0) and (epoch > 0):
                file.write('\n Epoch {} Loss: {:4.4f} LL: {:4.4f}'.format(
                    epoch, loss.item(), likelihood_term.item()))

                r2_scores, mlls, rmses = self.metrics_calculator(x, test=False)
                r2_scores = np.array(r2_scores)
                mlls = np.array(mlls)
                rmses = np.array(rmses)

                file.write('\n R^2 score (train): {:.3f}+- {:.3f}'.format(
                    np.mean(r2_scores), np.std(r2_scores)))
                file.write('\n MLL (train): {:.3f}+- {:.3f} \n'.format(
                    np.mean(mlls), np.std(mlls)))
                file.write('\n RMSE (train): {:.3f}+- {:.3f} \n'.format(
                    np.mean(rmses), np.std(rmses)))
                file.flush()

                if x_test is not None:
                    r2_scores, mlls, rmses = self.metrics_calculator(x_test,
                                                                     test=True)
                    r2_scores = np.array(r2_scores)
                    mlls = np.array(mlls)
                    rmses = np.array(rmses)

                    file.write('\n R^2 score (test): {:.3f}+- {:.3f}'.format(
                        np.mean(r2_scores), np.std(r2_scores)))
                    file.write('\n MLL (test): {:.3f}+- {:.3f} \n'.format(
                        np.mean(mlls), np.std(mlls)))
                    file.write('\n RMSE (test): {:.3f}+- {:.3f} \n'.format(
                        np.mean(rmses), np.std(rmses)))
                    file.flush()

                    if (self.epoch % 500) == 0 and (self.epoch > 0):
                        path_to_save = self.dir_name + '/' + self.file_start + '_' + str(
                            self.epoch)
                        np.save(path_to_save + 'r2_scores.npy', r2_scores)
                        np.save(path_to_save + 'mll_scores.npy', mlls)
                        np.save(path_to_save + 'rmse_scores.npy', rmses)

            loss.backward()

            optimiser.step()

    def metrics_calculator(self, x, test=True):
        mask = torch.isnan(x[:, -self.n_properties:])
        r2_scores = []
        mlls = []
        rmses = []

        for p in range(0, self.n_properties, 1):
            p_idx = torch.where(~mask[:, p])[0]
            x_p = x[p_idx]

            input_p = copy.deepcopy(x_p)
            input_p[:, -self.n_properties:][mask[p_idx]] = 0.0
            input_p[:, (-self.n_properties + p)] = 0.0

            mask_p = torch.zeros_like(mask[p_idx, :]).fill_(True)
            mask_p[:, p] = False
            z = self.encoder(input_p)
            predict_mean, predict_var = self.decoder.forward(z, mask_p)

            predict_mean = predict_mean[p].reshape(-1).detach()
            predict_std = (predict_var[p]**0.5).reshape(-1).detach()

            target = x_p[:, (-self.n_properties + p)]

            if (self.means is not None) and (self.stds is not None):
                predict_mean = (
                    predict_mean.numpy() * self.stds[-self.n_properties + p] +
                    self.means[-self.n_properties + p])
                predict_std = predict_std.numpy() * self.stds[
                    -self.n_properties + p]
                target = (target.numpy() * self.stds[-self.n_properties + p] +
                          self.means[-self.n_properties + p])
                r2_scores.append(r2_score(target, predict_mean))
                mlls.append(mll(predict_mean, predict_std**2, target))
                rmses.append(np.sqrt(mean_squared_error(target, predict_mean)))

                path_to_save = self.dir_name + '/' + self.file_start + str(p)

                if (self.epoch % 500) == 0 and (self.epoch > 0):
                    if test:
                        np.save(path_to_save + '_mean.npy', predict_mean)
                        np.save(path_to_save + '_std.npy', predict_std)
                        np.save(path_to_save + '_target.npy', target)

            else:
                r2_scores.append(r2_score(target.numpy(),
                                          predict_mean.numpy()))
                mlls.append(mll(predict_mean, predict_std**2, target))
                rmses.append(
                    np.sqrt(
                        mean_squared_error(target.numpy(),
                                           predict_mean.numpy())))
        return r2_scores, mlls, rmses
class CNP():
    """
    The Conditional Neural Process model.
    """
    def __init__(self,
                 x_dim,
                 y_dim,
                 r_dim,
                 encoder_dims,
                 decoder_dims,
                 encoder_non_linearity=F.relu,
                 decoder_non_linearity=F.relu):
        """

        :param x_dim: (int) Dimensionality of x, the input to the CNP
        :param y_dim: (int) Dimensionality of y, the target.
        :param r_dim: (int) Dimensionality of the deterministic embedding, r.
        :param encoder_dims: (list of ints) Architecture of the encoder network.
        :param decoder_dims: (list of ints) Architecture of the decoder network.
        :param encoder_non_linearity: Non-linear activation function to apply after each linear transformation,
                                in the encoder network e.g. relu or tanh.
        :param decoder_non_linearity: Non-linear activation function to apply after each linear transformation,
                                in the decoder network e.g. relu or tanh.
        :param lr: (float) Optimiser learning rate.
        """
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.r_dim = r_dim

        self.encoder = VanillaNN((x_dim + y_dim), r_dim, encoder_dims,
                                 encoder_non_linearity)
        self.decoder = ProbabilisticVanillaNN(
            (x_dim + r_dim), y_dim, decoder_dims, decoder_non_linearity)

    def forward(self, x_context, y_context, x_target, batch_size):
        """

        :param x_context: (torch tensor of dimensions [batch_size*n_context, x_dim])
        :param y_context: (torch tensor of dimensions [batch_size*n_context, y_dim])
        :param x_target: (torch tensor of dimensions [batch_size*n_target, x_dim])
        :return: mu_y, sigma_y: (both torch tensors of dimensions [batch_size*n_target, y_dim])
        """
        assert x_target.shape[0] % batch_size == 0
        assert len(
            x_context.shape
        ) == 2, 'Input must be of shape [batch_size*n_context, x_dim].'
        assert len(
            y_context.shape
        ) == 2, 'Input must be of shape [batch_size*n_context, y_dim].'
        assert len(
            x_target.shape
        ) == 2, 'Input must be of shape [batch_size*n_target, x_dim].'

        n_target = int(x_target.shape[0] / batch_size)

        r = self.encoder.forward(
            torch.cat((x_context, y_context),
                      dim=-1).float())  # [batch_size*n_context, r_dim]

        r = r.view(batch_size, -1,
                   self.r_dim)  # [batch_size, n_context, r_dim]
        r = torch.mean(r, dim=1).reshape(-1, self.r_dim)  # [batch_size, r_dim]

        r = torch.repeat_interleave(r, n_target,
                                    dim=0)  # [batch_size*n_target, r_dim]

        mu_y, var_y = self.decoder.forward(
            torch.cat((x_target.float(), r),
                      dim=-1))  # [batch_size*n_target, y_dim] x2

        return mu_y, var_y

    def train(self,
              x,
              y,
              x_test=None,
              y_test=None,
              x_scaler=None,
              y_scaler=None,
              batch_size=10,
              lr=0.001,
              epochs=3000,
              print_freq=100,
              VERBOSE=True,
              dataname=None):
        """

        :param x: [n_functions, [n_train, x_dim]]
        :param y: [n_functions, [n_train, y_dim]]
        :param lr:
        :param iterations:
        :return:
        """
        self.optimiser = optim.Adam(
            list(self.encoder.parameters()) + list(self.decoder.parameters()),
            lr)

        for epoch in range(epochs):
            self.optimiser.zero_grad()

            # Sample the function from the set of functions
            x_context, y_context, x_target, y_target = batch_sampler(
                x, y, batch_size)

            # Make a forward pass through the CNP to obtain a distribution over the target set.
            mu_y, var_y = self.forward(
                x_context, y_context, x_target,
                batch_size)  #[batch_size*n_target, y_dim] x2

            log_ps = MultivariateNormal(
                mu_y, torch.diag_embed(var_y)).log_prob(y_target.float())

            # Calculate the loss function.
            loss = -torch.mean(log_ps)
            self.losslogger = loss

            if epoch % print_freq == 0:
                print('Epoch {:.0f}: Loss = {:.5f}'.format(epoch, loss))

                if VERBOSE:
                    metrics_calculator(self, 'cnp', x, y, x_test, y_test,
                                       dataname, epoch, x_scaler, y_scaler)

            loss.backward()
            self.optimiser.step()