예제 #1
0
파일: mpn.py 프로젝트: yxgu2353/chemprop
    def forward(
            self,
            batch: Union[List[str], List[Chem.Mol], BatchMolGraph],
            features_batch: List[np.ndarray] = None,
            atom_descriptors_batch: List[np.ndarray] = None
    ) -> torch.FloatTensor:
        """
        Encodes a batch of molecules.

        :param batch: A list of SMILES, a list of RDKit molecules, or a
                      :class:`~chemprop.features.featurization.BatchMolGraph`.
        :param features_batch: A list of numpy arrays containing additional features.
        :param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.
        :return: A PyTorch tensor of shape :code:`(num_molecules, hidden_size)` containing the encoding of each molecule.
        """
        if type(batch) != BatchMolGraph:
            if self.atom_descriptors == 'feature':
                batch = mol2graph(batch, atom_descriptors_batch)
            else:
                batch = mol2graph(batch)

        if self.atom_descriptors == 'descriptor':
            output = self.encoder.forward(batch, features_batch,
                                          atom_descriptors_batch)
        else:
            output = self.encoder.forward(batch, features_batch)

        return output
예제 #2
0
파일: mpn.py 프로젝트: z-linlinlin/chemprop
    def forward(
            self,
            batch: Union[List[List[str]], List[List[Chem.Mol]], BatchMolGraph],
            features_batch: List[np.ndarray] = None,
            atom_descriptors_batch: List[np.ndarray] = None
    ) -> torch.FloatTensor:
        """
        Encodes a batch of molecules.

        :param batch: A list of list of SMILES, a list of list of RDKit molecules, or a
                      :class:`~chemprop.features.featurization.BatchMolGraph`.
        :param features_batch: A list of numpy arrays containing additional features.
        :param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.
        :return: A PyTorch tensor of shape :code:`(num_molecules, hidden_size)` containing the encoding of each molecule.
        """
        if type(batch[0]) != BatchMolGraph:
            # TODO: handle atom_descriptors_batch with multiple molecules per input
            if self.atom_descriptors == 'feature':
                if len(batch[0]) > 1:
                    raise NotImplementedError(
                        'Atom descriptors are currently only supported with one molecule '
                        'per input (i.e., number_of_molecules = 1).')

                batch = [mol2graph(b, atom_descriptors_batch) for b in batch]
            else:
                batch = [mol2graph(b) for b in batch]

        if self.use_input_features:
            features_batch = torch.from_numpy(
                np.stack(features_batch)).float().to(self.device)

            if self.features_only:
                return features_batch

        if self.atom_descriptors == 'descriptor':
            if len(batch) > 1:
                raise NotImplementedError(
                    'Atom descriptors are currently only supported with one molecule '
                    'per input (i.e., number_of_molecules = 1).')

            encodings = [
                enc(ba, atom_descriptors_batch)
                for enc, ba in zip(self.encoder, batch)
            ]
        else:
            encodings = [enc(ba) for enc, ba in zip(self.encoder, batch)]

        output = reduce(lambda x, y: torch.cat((x, y), dim=1), encodings)

        if self.use_input_features:
            if len(features_batch.shape) == 1:
                features_batch = features_batch.view(1, -1)

            output = torch.cat([output, features_batch], dim=1)

        return output
예제 #3
0
    def forward(self,
                batch: Union[List[str], BatchMolGraph],
                features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
        """
        Encodes a batch of molecular SMILES strings.

        :param batch: A list of SMILES strings or a BatchMolGraph (if self.graph_input is True).
        :param features_batch: A list of ndarrays containing additional features.
        :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.
        """
        print(len(batch))
        if not self.graph_input:  # if features only, batch won't even be used
            g_batch = mol2graph(batch, self.args)

        output = self.encoder.forward(g_batch, features_batch)
        if True:
            batch_len = len(batch)
            padding = torch.zeros((batch_len, 60))
            for i in range(batch_len):
                mol = Chem.MolFromSmiles(batch[i])
                for atom_id in range(0, 60):
                    if delaney_atoms_info[atom_id] == 0:
                        continue
                    at = Chem.MolFromSmarts('[#{}]'.format(atom_id + 1))
                    padding[i, atom_id] = 2 * float(len(mol.GetSubstructMatches(at))) / \
                                              float(delaney_atoms_info[atom_id]) - 1.0

            output = torch.cat((output, padding), dim=1)

        return output
예제 #4
0
    def forward(self, batch: Union[List[str], BatchMolGraph],
                features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
        if not self.graph_input:  # if features only, batch won't even be used
            batch = mol2graph(batch, self.args)

        output = self.encoder.forward(batch, features_batch)

        return output
예제 #5
0
    def viz_attention(self,
                      viz_dir: str,
                      batch: Union[List[str], BatchMolGraph],
                      features_batch: List[np.ndarray] = None):
        """
        Visualizes attention weights for a batch of molecular SMILES strings

        :param viz_dir: Directory in which to save visualized attention weights.
        :param batch: A list of SMILES strings or a BatchMolGraph (if self.graph_input).
        :param features_batch: A list of ndarrays containing additional features.
        """
        if not self.graph_input:
            batch = mol2graph(batch, self.args)

        self.encoder.forward(batch, features_batch, viz_dir=viz_dir)
예제 #6
0
    def forward(self,
                batch: Union[List[str], BatchMolGraph],
                features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
        """
        Encodes a batch of molecular SMILES strings.

        :param batch: A list of SMILES strings or a BatchMolGraph (if self.graph_input is True).
        :param features_batch: A list of ndarrays containing additional features.
        :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.
        """
        if not self.graph_input:  # if features only, batch won't even be used
            batch = mol2graph(batch, self.args)
        output = self.encoder.forward(batch, features_batch)  # batch is molecular graph

        return output
예제 #7
0
    def forward(self,
                batch: Union[List[str], List[Chem.Mol], BatchMolGraph],
                features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
        """
        Encodes a batch of molecular SMILES strings.

        :param batch: A list of SMILES strings, a list of RDKit molecules, or a BatchMolGraph.
        :param features_batch: A list of ndarrays containing additional features.
        :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.
        """
        if type(batch) != BatchMolGraph:
            batch = mol2graph(batch)

        output = self.encoder.forward(batch, features_batch)

        return output
예제 #8
0
def predict(model: nn.Module,
            data: MoleculeDataset,
            args: Namespace,
            scaler: StandardScaler = None,
            bert_save_memory: bool = False,
            logger: logging.Logger = None) -> List[List[float]]:
    """
    Makes predictions on a dataset using an ensemble of models.

    :param model: A model.
    :param data: A MoleculeDataset.
    :param args: Arguments.
    :param scaler: A StandardScaler object fit on the training targets.
    :param bert_save_memory: Store unused predictions as None to avoid unnecessary memory use.
    :param logger: Logger.
    :return: A list of lists of predictions. The outer list is examples
    while the inner list is tasks.
    """
    model.eval()

    preds = []
    if args.dataset_type == 'bert_pretraining':
        features_preds = []

    if args.maml:
        num_iters, iter_step = data.num_tasks() * args.maml_batches_per_epoch, 1
        full_targets = []
    else:
        num_iters, iter_step = len(data), args.batch_size
    
    if args.parallel_featurization:
        batch_queue = Queue(args.batch_queue_max_size)
        exit_queue = Queue(1)
        batch_process = Process(target=async_mol2graph, args=(batch_queue, data, args, num_iters, iter_step, exit_queue, True))
        batch_process.start()
        currently_loaded_batches = []

    for i in trange(0, num_iters, iter_step):
        if args.maml:
            task_train_data, task_test_data, task_idx = data.sample_maml_task(args, seed=0)
            mol_batch = task_test_data
            smiles_batch, features_batch, targets_batch = task_train_data.smiles(), task_train_data.features(), task_train_data.targets(task_idx)
            targets = torch.Tensor(targets_batch).unsqueeze(1)
            if args.cuda:
                targets = targets.cuda()
        else:
            # Prepare batch
            if args.parallel_featurization:
                if len(currently_loaded_batches) == 0:
                    currently_loaded_batches = batch_queue.get()
                mol_batch, featurized_mol_batch = currently_loaded_batches.pop(0)
            else:
                mol_batch = MoleculeDataset(data[i:i + args.batch_size])
            smiles_batch, features_batch = mol_batch.smiles(), mol_batch.features()

        # Run model
        if args.dataset_type == 'bert_pretraining':
            batch = mol2graph(smiles_batch, args)
            batch.bert_mask(mol_batch.mask())
        else:
            batch = smiles_batch
        
        if args.maml:  # TODO refactor with train loop
            model.zero_grad()
            intermediate_preds = model(batch, features_batch)
            loss = get_loss_func(args)(intermediate_preds, targets)
            loss = loss.sum() / len(batch)
            grad = torch.autograd.grad(loss, [p for p in model.parameters() if p.requires_grad])
            theta = [p for p in model.named_parameters() if p[1].requires_grad]  # comes in same order as grad
            theta_prime = {p[0]: p[1] - args.maml_lr * grad[i] for i, p in enumerate(theta)}
            for name, nongrad_param in [p for p in model.named_parameters() if not p[1].requires_grad]:
                theta_prime[name] = nongrad_param + torch.zeros(nongrad_param.size()).to(nongrad_param)
            model_prime = build_model(args=args, params=theta_prime)
            smiles_batch, features_batch, targets_batch = task_test_data.smiles(), task_test_data.features(), task_test_data.targets(task_idx)
            # no mask since we only picked data points that have the desired target
            with torch.no_grad():
                batch_preds = model_prime(smiles_batch, features_batch)
            full_targets.extend([[t] for t in targets_batch])
        else:
            with torch.no_grad():
                if args.parallel_featurization:
                    previous_graph_input_mode = model.encoder.graph_input
                    model.encoder.graph_input = True  # force model to accept already processed input
                    batch_preds = model(featurized_mol_batch, features_batch)
                    model.encoder.graph_input = previous_graph_input_mode
                else:
                    batch_preds = model(batch, features_batch)

                if args.dataset_type == 'bert_pretraining':
                    if batch_preds['features'] is not None:
                        features_preds.extend(batch_preds['features'].data.cpu().numpy())
                    batch_preds = batch_preds['vocab']
                
                if args.dataset_type == 'kernel':
                    batch_preds = batch_preds.view(int(batch_preds.size(0)/2), 2, batch_preds.size(1))
                    batch_preds = model.kernel_output_layer(batch_preds)

        batch_preds = batch_preds.data.cpu().numpy()

        if scaler is not None:
            batch_preds = scaler.inverse_transform(batch_preds)
        
        if args.dataset_type == 'regression_with_binning':
            batch_preds = batch_preds.reshape((batch_preds.shape[0], args.num_tasks, args.num_bins))
            indices = np.argmax(batch_preds, axis=2)
            preds.extend(indices.tolist())
        else:
            batch_preds = batch_preds.tolist()
            if args.dataset_type == 'bert_pretraining' and bert_save_memory:
                for atom_idx, mask_val in enumerate(mol_batch.mask()):
                    if mask_val != 0:
                        batch_preds[atom_idx] = None  # not going to predict, so save some memory when passing around
            preds.extend(batch_preds)
    
    if args.dataset_type == 'regression_with_binning':
        preds = args.bin_predictions[np.array(preds)].tolist()

    if args.dataset_type == 'bert_pretraining':
        preds = {
            'features': features_preds if len(features_preds) > 0 else None,
            'vocab': preds
        }

    if args.parallel_featurization:
        exit_queue.put(0)  # dummy var to get the subprocess to know that we're done
        batch_process.join()

    if args.maml:
        # return the task targets here to guarantee alignment;
        # there's probably no reasonable scenario where we'd use MAML directly to predict something that's actually unknown
        return preds, full_targets

    return preds
예제 #9
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None,
          chunk_names: bool = False,
          val_smiles: List[str] = None,
          test_smiles: List[str] = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :param chunk_names: Whether to train on the data in chunks. In this case,
    data must be a list of paths to the data chunks.
    :param val_smiles: Validation smiles strings without targets.
    :param test_smiles: Test smiles strings without targets, used for adversarial setting.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    if args.dataset_type == 'bert_pretraining':
        features_loss = nn.MSELoss()

    if chunk_names:
        for path, memo_path in tqdm(data, total=len(data)):
            featurization.SMILES_TO_FEATURES = dict()
            if os.path.isfile(memo_path):
                found_memo = True
                with open(memo_path, 'rb') as f:
                    featurization.SMILES_TO_FEATURES = pickle.load(f)
            else:
                found_memo = False
            with open(path, 'rb') as f:
                chunk = pickle.load(f)
            if args.moe:
                for source in chunk:
                    source.shuffle()
            else:
                chunk.shuffle()
            n_iter = train(model=model,
                           data=chunk,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           logger=logger,
                           writer=writer,
                           chunk_names=False,
                           val_smiles=val_smiles,
                           test_smiles=test_smiles)
            if not found_memo:
                with open(memo_path, 'wb') as f:
                    pickle.dump(featurization.SMILES_TO_GRAPH,
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
        return n_iter

    if not args.moe:
        data.shuffle()

    loss_sum, iter_count = 0, 0
    if args.adversarial:
        if args.moe:
            train_smiles = []
            for d in data:
                train_smiles += d.smiles()
        else:
            train_smiles = data.smiles()
        train_val_smiles = train_smiles + val_smiles
        d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0

    if args.moe:
        test_smiles = list(test_smiles)
        random.shuffle(test_smiles)
        train_smiles = []
        for d in data:
            d.shuffle()
            train_smiles.append(d.smiles())
        num_iters = min(len(test_smiles), min([len(d) for d in data]))
    elif args.maml:
        num_iters = args.maml_batches_per_epoch * args.maml_batch_size
        model.zero_grad()
        maml_sum_loss = 0
    else:
        num_iters = len(data) if args.last_batch else len(
            data) // args.batch_size * args.batch_size

    if args.parallel_featurization:
        batch_queue = Queue(args.batch_queue_max_size)
        exit_queue = Queue(1)
        batch_process = Process(target=async_mol2graph,
                                args=(batch_queue, data, args, num_iters,
                                      args.batch_size, exit_queue,
                                      args.last_batch))
        batch_process.start()
        currently_loaded_batches = []

    iter_size = 1 if args.maml else args.batch_size

    for i in trange(0, num_iters, iter_size):
        if args.moe:
            if not args.batch_domain_encs:
                model.compute_domain_encs(
                    train_smiles)  # want to recompute every batch
            mol_batch = [
                MoleculeDataset(d[i:i + args.batch_size]) for d in data
            ]
            train_batch, train_targets = [], []
            for b in mol_batch:
                tb, tt = b.smiles(), b.targets()
                train_batch.append(tb)
                train_targets.append(tt)
            test_batch = test_smiles[i:i + args.batch_size]
            loss = model.compute_loss(train_batch, train_targets, test_batch)
            model.zero_grad()

            loss_sum += loss.item()
            iter_count += len(mol_batch)
        elif args.maml:
            task_train_data, task_test_data, task_idx = data.sample_maml_task(
                args)
            mol_batch = task_test_data
            smiles_batch, features_batch, target_batch = task_train_data.smiles(
            ), task_train_data.features(), task_train_data.targets(task_idx)
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor(target_batch).unsqueeze(1)
            if next(model.parameters()).is_cuda:
                targets = targets.cuda()
            preds = model(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            grad = torch.autograd.grad(
                loss, [p for p in model.parameters() if p.requires_grad])
            theta = [
                p for p in model.named_parameters() if p[1].requires_grad
            ]  # comes in same order as grad
            theta_prime = {
                p[0]: p[1] - args.maml_lr * grad[i]
                for i, p in enumerate(theta)
            }
            for name, nongrad_param in [
                    p for p in model.named_parameters()
                    if not p[1].requires_grad
            ]:
                theta_prime[name] = nongrad_param + torch.zeros(
                    nongrad_param.size()).to(nongrad_param)
        else:
            # Prepare batch
            if args.parallel_featurization:
                if len(currently_loaded_batches) == 0:
                    currently_loaded_batches = batch_queue.get()
                mol_batch, featurized_mol_batch = currently_loaded_batches.pop(
                )
            else:
                if not args.last_batch and i + args.batch_size > len(data):
                    break
                mol_batch = MoleculeDataset(data[i:i + args.batch_size])
            smiles_batch, features_batch, target_batch = mol_batch.smiles(
            ), mol_batch.features(), mol_batch.targets()

            if args.dataset_type == 'bert_pretraining':
                batch = mol2graph(smiles_batch, args)
                mask = mol_batch.mask()
                batch.bert_mask(mask)
                mask = 1 - torch.FloatTensor(mask)  # num_atoms
                features_targets = torch.FloatTensor(
                    target_batch['features']
                ) if target_batch[
                    'features'] is not None else None  # num_molecules x features_size
                targets = torch.FloatTensor(target_batch['vocab'])  # num_atoms
                if args.bert_vocab_func == 'feature_vector':
                    mask = mask.reshape(-1, 1)
                else:
                    targets = targets.long()
            else:
                batch = smiles_batch
                mask = torch.Tensor([[x is not None for x in tb]
                                     for tb in target_batch])
                targets = torch.Tensor([[0 if x is None else x for x in tb]
                                        for tb in target_batch])

            if next(model.parameters()).is_cuda:
                mask, targets = mask.cuda(), targets.cuda()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    features_targets = features_targets.cuda()

            if args.class_balance:
                class_weights = []
                for task_num in range(data.num_tasks()):
                    class_weights.append(
                        args.class_weights[task_num][targets[:,
                                                             task_num].long()])
                class_weights = torch.stack(
                    class_weights).t()  # num_molecules x num_tasks
            else:
                class_weights = torch.ones(targets.shape)

            if args.cuda:
                class_weights = class_weights.cuda()

            # Run model
            model.zero_grad()
            if args.parallel_featurization:
                previous_graph_input_mode = model.encoder.graph_input
                model.encoder.graph_input = True  # force model to accept already processed input
                preds = model(featurized_mol_batch, features_batch)
                model.encoder.graph_input = previous_graph_input_mode
            else:
                preds = model(batch, features_batch)
            if args.dataset_type == 'regression_with_binning':
                preds = preds.view(targets.size(0), targets.size(1), -1)
                targets = targets.long()
                loss = 0
                for task in range(targets.size(1)):
                    loss += loss_func(
                        preds[:, task, :], targets[:, task]
                    ) * class_weights[:,
                                      task] * mask[:,
                                                   task]  # for some reason cross entropy doesn't support multi target
                loss = loss.sum() / mask.sum()
            else:
                if args.dataset_type == 'unsupervised':
                    targets = targets.long().reshape(-1)

                if args.dataset_type == 'bert_pretraining':
                    features_preds, preds = preds['features'], preds['vocab']

                if args.dataset_type == 'kernel':
                    preds = preds.view(int(preds.size(0) / 2), 2,
                                       preds.size(1))
                    preds = model.kernel_output_layer(preds)

                loss = loss_func(preds, targets) * class_weights * mask
                if args.predict_features_and_task:
                    loss = (loss.sum() + loss[:, :-args.features_size].sum() * (args.task_weight-1)) \
                                / (mask.sum() + mask[:, :-args.features_size].sum() * (args.task_weight-1))
                else:
                    loss = loss.sum() / mask.sum()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    loss += features_loss(features_preds, features_targets)

            loss_sum += loss.item()
            iter_count += len(mol_batch)

        if args.maml:
            model_prime = build_model(args=args, params=theta_prime)
            smiles_batch, features_batch, target_batch = task_test_data.smiles(
            ), task_test_data.features(), [
                t[task_idx] for t in task_test_data.targets()
            ]
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor([[t] for t in target_batch])
            if next(model_prime.parameters()).is_cuda:
                targets = targets.cuda()
            model_prime.zero_grad()
            preds = model_prime(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            loss_sum += loss.item()
            iter_count += len(
                smiles_batch
            )  # TODO check that this makes sense, but it's just for display
            maml_sum_loss += loss
            if i % args.maml_batch_size == args.maml_batch_size - 1:
                maml_sum_loss.backward()
                optimizer.step()
                model.zero_grad()
                maml_sum_loss = 0
        else:
            loss.backward()
            if args.max_grad_norm is not None:
                clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()

        if args.adjust_weight_decay:
            current_pnorm = compute_pnorm(model)
            if current_pnorm < args.pnorm_target:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i]['weight_decay'] = max(
                        0, optimizer.param_groups[i]['weight_decay'] -
                        args.adjust_weight_decay_step)
            else:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i][
                        'weight_decay'] += args.adjust_weight_decay_step

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        if args.adversarial:
            for _ in range(args.gan_d_per_g):
                train_val_smiles_batch = random.sample(train_val_smiles,
                                                       args.batch_size)
                test_smiles_batch = random.sample(test_smiles, args.batch_size)
                d_loss, gp_norm = model.train_D(train_val_smiles_batch,
                                                test_smiles_batch)
            train_val_smiles_batch = random.sample(train_val_smiles,
                                                   args.batch_size)
            test_smiles_batch = random.sample(test_smiles, args.batch_size)
            g_loss = model.train_G(train_val_smiles_batch, test_smiles_batch)

            # we probably only care about the g_loss honestly
            d_loss_sum += d_loss * args.batch_size
            gp_norm_sum += gp_norm * args.batch_size
            g_loss_sum += g_loss * args.batch_size

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            if args.adversarial:
                d_loss_avg, g_loss_avg, gp_norm_avg = d_loss_sum / iter_count, g_loss_sum / iter_count, gp_norm_sum / iter_count
                d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join('lr_{} = {:.4e}'.format(i, lr)
                                for i, lr in enumerate(lrs))
            debug("Loss = {:.4e}, PNorm = {:.4f}, GNorm = {:.4f}, {}".format(
                loss_avg, pnorm, gnorm, lrs_str))
            if args.adversarial:
                debug(
                    "D Loss = {:.4e}, G Loss = {:.4e}, GP Norm = {:.4}".format(
                        d_loss_avg, g_loss_avg, gp_norm_avg))

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar('learning_rate_{}'.format(i), lr, n_iter)

    if args.parallel_featurization:
        exit_queue.put(
            0)  # dummy var to get the subprocess to know that we're done
        batch_process.join()

    return n_iter
예제 #10
0
    def forward(self,
                batch: Union[List[List[str]], List[List[Chem.Mol]], List[List[Tuple[Chem.Mol, Chem.Mol]]], List[BatchMolGraph]],
                features_batch: List[np.ndarray] = None,
                atom_descriptors_batch: List[np.ndarray] = None,
                atom_features_batch: List[np.ndarray] = None,
                bond_features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
        """
        Encodes a batch of molecules.

        :param batch: A list of list of SMILES, a list of list of RDKit molecules, or a
                      list of :class:`~chemprop.features.featurization.BatchMolGraph`.
                      The outer list or BatchMolGraph is of length :code:`num_molecules` (number of datapoints in batch),
                      the inner list is of length :code:`number_of_molecules` (number of molecules per datapoint).
        :param features_batch: A list of numpy arrays containing additional features.
        :param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.
        :param atom_features_batch: A list of numpy arrays containing additional atom features.
        :param bond_features_batch: A list of numpy arrays containing additional bond features.
        :return: A PyTorch tensor of shape :code:`(num_molecules, hidden_size)` containing the encoding of each molecule.
        """
        if type(batch[0]) != BatchMolGraph:
            # Group first molecules, second molecules, etc for mol2graph
            batch = [[mols[i] for mols in batch] for i in range(len(batch[0]))]

            # TODO: handle atom_descriptors_batch with multiple molecules per input
            if self.atom_descriptors == 'feature':
                if len(batch) > 1:
                    raise NotImplementedError('Atom/bond descriptors are currently only supported with one molecule '
                                              'per input (i.e., number_of_molecules = 1).')

                batch = [
                    mol2graph(
                        mols=b,
                        atom_features_batch=atom_features_batch,
                        bond_features_batch=bond_features_batch,
                        overwrite_default_atom_features=self.overwrite_default_atom_features,
                        overwrite_default_bond_features=self.overwrite_default_bond_features
                    )
                    for b in batch
                ]
            elif bond_features_batch is not None:
                if len(batch) > 1:
                    raise NotImplementedError('Atom/bond descriptors are currently only supported with one molecule '
                                              'per input (i.e., number_of_molecules = 1).')

                batch = [
                    mol2graph(
                        mols=b,
                        bond_features_batch=bond_features_batch,
                        overwrite_default_atom_features=self.overwrite_default_atom_features,
                        overwrite_default_bond_features=self.overwrite_default_bond_features
                    )
                    for b in batch
                ]
            else:
                batch = [mol2graph(b) for b in batch]

        if self.use_input_features:
            features_batch = torch.from_numpy(np.stack(features_batch)).float().to(self.device)

            if self.features_only:
                return features_batch

        if self.atom_descriptors == 'descriptor':
            if len(batch) > 1:
                raise NotImplementedError('Atom descriptors are currently only supported with one molecule '
                                          'per input (i.e., number_of_molecules = 1).')

            encodings = [enc(ba, atom_descriptors_batch) for enc, ba in zip(self.encoder, batch)]
        else:
            if not self.reaction_solvent:
                 encodings = [enc(ba) for enc, ba in zip(self.encoder, batch)]
            else:
                 encodings = []
                 for ba in batch:
                     if ba.is_reaction:
                         encodings.append(self.encoder(ba))
                     else:
                         encodings.append(self.encoder_solvent(ba))

        output = reduce(lambda x, y: torch.cat((x, y), dim=1), encodings)

        if self.use_input_features:
            if len(features_batch.shape) == 1:
                features_batch = features_batch.view(1, -1)

            output = torch.cat([output, features_batch], dim=1)

        return output