Exemple #1
0
    def embed(self, loader, df):
        # Gets latent embeddings of molecules in df.
        # Inputs :
        # 0. loader object to convert smiles into batches of inputs
        # 1. dataframe with 'can' column containing smiles to embed
        # Outputs :
        # 0. np array of embeddings, (N_molecules , latent_size)

        loader.dataset.pass_dataset(df)
        _, _, test_loader = loader.get_data()
        batch_size = loader.batch_size

        # Latent embeddings
        z_all = []

        with torch.no_grad():
            for batch_idx, (graph, smiles, p_target, a_target) in enumerate(test_loader):
                # batch_size = graph.batch_size
                graph = send_graph_to_device(graph, self.device)

                z = self.encode(graph, mean_only=True)  # z_shape = N * l_size
                z = z.cpu()
                z_all.append(z)

        z_all = torch.cat(z_all, dim=0).numpy()
        return z_all
Exemple #2
0
        total_steps = 0
    beta = args.beta
    tf_proba = args.tf_init

    for epoch in range(1, args.epochs + 1):
        print(f'Starting epoch {epoch}')
        model.train()
        epoch_train_rec, epoch_train_kl, epoch_train_pmse, epoch_train_amse = 0, 0, 0, 0

        for batch_idx, (graph, smiles, p_target,
                        a_target) in enumerate(train_loader):

            total_steps += 1  # count training steps

            smiles = smiles.to(device)
            graph = send_graph_to_device(graph, device)
            if use_props:
                p_target = p_target.to(device).view(-1, len(properties))
            if use_affs:
                a_target = a_target.to(device)

            # Forward passs
            mu, logv, _, out_smi, out_p, out_a = model(graph,
                                                       smiles,
                                                       tf=tf_proba)

            # Compute loss terms : change according to multitask setting
            rec, kl = VAELoss(out_smi, smiles, mu, logv)

            if not use_affs and not use_props:  # VAE only
                pmse, amse = torch.tensor(0), torch.tensor(0)
Exemple #3
0
    def step(self, input_type, x, w):
        """ 
        Trains the model for n_epochs on samples x, weighted by w 
        input type : 'selfies' or 'smiles', for dataloader (validity checks and format conversions are different)
        """

        if input_type == 'smiles':
            self.dataset.pass_smiles_list(x, w)
        elif input_type == 'selfies':
            self.dataset.pass_selfies_list(x, w)

        train_loader = DataLoader(dataset=self.dataset,
                                  shuffle=True,
                                  batch_size=32,
                                  num_workers=self.processes,
                                  collate_fn=collate_block,
                                  drop_last=True)
        # Training loop
        total_steps = 0
        for epoch in range(self.n_epochs):
            print(f'Starting epoch {epoch}')
            self.model.train()

            for batch_idx, (graph, smiles, w_i) in enumerate(train_loader):
                total_steps += 1  # count training steps

                smiles = smiles.to(self.device)
                graph = send_graph_to_device(graph, self.device)
                w_i = w_i.to(self.device)

                # Forward pass
                mu, logv, z, out_smi, out_p, _ = self.model(
                    graph, smiles, tf=self.teacher_forcing)  # no tf
                # plot_kde(z.cpu().detach().numpy())

                # Compute CbAS loss with samples weights
                loss = CbASLoss(out_smi, smiles, mu, logv, w_i, self.beta)
                if batch_idx == 0:
                    _, out_chars = torch.max(out_smi.detach(), dim=1)
                    print(f'CbAS Loss at batch 0 : {loss.item()}')

                    differences = 1. - torch.abs(out_chars - smiles)
                    differences = torch.clamp(differences, min=0.,
                                              max=1.).double()
                    quality = 100. * torch.mean(differences)
                    quality = quality.detach().cpu()
                    print(
                        'fraction of correct characters at reconstruction : ',
                        quality.item())

                self.optimizer.zero_grad()
                loss.backward()
                clip.clip_grad_norm_(self.model.parameters(), self.clip_grads)
                del loss
                self.optimizer.step()

                # Annealing KL and LR
                if total_steps % self.anneal_iter == 0:
                    self.scheduler.step()
                    print("learning rate: %.6f" % self.scheduler.get_lr()[0])

                if batch_idx == 0 and self.debug:
                    smiles = self.model.probas_to_smiles(out_smi)
                    smiles = [decoder(s) for s in smiles]
                    print(smiles[:5])

        # Update weights at 'save_model_weights' :
        print(
            f'Finished training after {total_steps} optimizer steps. Saving search model weights'
        )
        self.model.cpu()
        self.dump()
        self.model.to(self.device)