def forward( self, batch: Union[List[str], List[Chem.Mol], BatchMolGraph], features_batch: List[np.ndarray] = None, atom_descriptors_batch: List[np.ndarray] = None ) -> torch.FloatTensor: """ Encodes a batch of molecules. :param batch: A list of SMILES, a list of RDKit molecules, or a :class:`~chemprop.features.featurization.BatchMolGraph`. :param features_batch: A list of numpy arrays containing additional features. :param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors. :return: A PyTorch tensor of shape :code:`(num_molecules, hidden_size)` containing the encoding of each molecule. """ if type(batch) != BatchMolGraph: if self.atom_descriptors == 'feature': batch = mol2graph(batch, atom_descriptors_batch) else: batch = mol2graph(batch) if self.atom_descriptors == 'descriptor': output = self.encoder.forward(batch, features_batch, atom_descriptors_batch) else: output = self.encoder.forward(batch, features_batch) return output
def forward( self, batch: Union[List[List[str]], List[List[Chem.Mol]], BatchMolGraph], features_batch: List[np.ndarray] = None, atom_descriptors_batch: List[np.ndarray] = None ) -> torch.FloatTensor: """ Encodes a batch of molecules. :param batch: A list of list of SMILES, a list of list of RDKit molecules, or a :class:`~chemprop.features.featurization.BatchMolGraph`. :param features_batch: A list of numpy arrays containing additional features. :param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors. :return: A PyTorch tensor of shape :code:`(num_molecules, hidden_size)` containing the encoding of each molecule. """ if type(batch[0]) != BatchMolGraph: # TODO: handle atom_descriptors_batch with multiple molecules per input if self.atom_descriptors == 'feature': if len(batch[0]) > 1: raise NotImplementedError( 'Atom descriptors are currently only supported with one molecule ' 'per input (i.e., number_of_molecules = 1).') batch = [mol2graph(b, atom_descriptors_batch) for b in batch] else: batch = [mol2graph(b) for b in batch] if self.use_input_features: features_batch = torch.from_numpy( np.stack(features_batch)).float().to(self.device) if self.features_only: return features_batch if self.atom_descriptors == 'descriptor': if len(batch) > 1: raise NotImplementedError( 'Atom descriptors are currently only supported with one molecule ' 'per input (i.e., number_of_molecules = 1).') encodings = [ enc(ba, atom_descriptors_batch) for enc, ba in zip(self.encoder, batch) ] else: encodings = [enc(ba) for enc, ba in zip(self.encoder, batch)] output = reduce(lambda x, y: torch.cat((x, y), dim=1), encodings) if self.use_input_features: if len(features_batch.shape) == 1: features_batch = features_batch.view(1, -1) output = torch.cat([output, features_batch], dim=1) return output
def forward(self, batch: Union[List[str], BatchMolGraph], features_batch: List[np.ndarray] = None) -> torch.FloatTensor: """ Encodes a batch of molecular SMILES strings. :param batch: A list of SMILES strings or a BatchMolGraph (if self.graph_input is True). :param features_batch: A list of ndarrays containing additional features. :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. """ print(len(batch)) if not self.graph_input: # if features only, batch won't even be used g_batch = mol2graph(batch, self.args) output = self.encoder.forward(g_batch, features_batch) if True: batch_len = len(batch) padding = torch.zeros((batch_len, 60)) for i in range(batch_len): mol = Chem.MolFromSmiles(batch[i]) for atom_id in range(0, 60): if delaney_atoms_info[atom_id] == 0: continue at = Chem.MolFromSmarts('[#{}]'.format(atom_id + 1)) padding[i, atom_id] = 2 * float(len(mol.GetSubstructMatches(at))) / \ float(delaney_atoms_info[atom_id]) - 1.0 output = torch.cat((output, padding), dim=1) return output
def forward(self, batch: Union[List[str], BatchMolGraph], features_batch: List[np.ndarray] = None) -> torch.FloatTensor: if not self.graph_input: # if features only, batch won't even be used batch = mol2graph(batch, self.args) output = self.encoder.forward(batch, features_batch) return output
def viz_attention(self, viz_dir: str, batch: Union[List[str], BatchMolGraph], features_batch: List[np.ndarray] = None): """ Visualizes attention weights for a batch of molecular SMILES strings :param viz_dir: Directory in which to save visualized attention weights. :param batch: A list of SMILES strings or a BatchMolGraph (if self.graph_input). :param features_batch: A list of ndarrays containing additional features. """ if not self.graph_input: batch = mol2graph(batch, self.args) self.encoder.forward(batch, features_batch, viz_dir=viz_dir)
def forward(self, batch: Union[List[str], BatchMolGraph], features_batch: List[np.ndarray] = None) -> torch.FloatTensor: """ Encodes a batch of molecular SMILES strings. :param batch: A list of SMILES strings or a BatchMolGraph (if self.graph_input is True). :param features_batch: A list of ndarrays containing additional features. :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. """ if not self.graph_input: # if features only, batch won't even be used batch = mol2graph(batch, self.args) output = self.encoder.forward(batch, features_batch) # batch is molecular graph return output
def forward(self, batch: Union[List[str], List[Chem.Mol], BatchMolGraph], features_batch: List[np.ndarray] = None) -> torch.FloatTensor: """ Encodes a batch of molecular SMILES strings. :param batch: A list of SMILES strings, a list of RDKit molecules, or a BatchMolGraph. :param features_batch: A list of ndarrays containing additional features. :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. """ if type(batch) != BatchMolGraph: batch = mol2graph(batch) output = self.encoder.forward(batch, features_batch) return output
def predict(model: nn.Module, data: MoleculeDataset, args: Namespace, scaler: StandardScaler = None, bert_save_memory: bool = False, logger: logging.Logger = None) -> List[List[float]]: """ Makes predictions on a dataset using an ensemble of models. :param model: A model. :param data: A MoleculeDataset. :param args: Arguments. :param scaler: A StandardScaler object fit on the training targets. :param bert_save_memory: Store unused predictions as None to avoid unnecessary memory use. :param logger: Logger. :return: A list of lists of predictions. The outer list is examples while the inner list is tasks. """ model.eval() preds = [] if args.dataset_type == 'bert_pretraining': features_preds = [] if args.maml: num_iters, iter_step = data.num_tasks() * args.maml_batches_per_epoch, 1 full_targets = [] else: num_iters, iter_step = len(data), args.batch_size if args.parallel_featurization: batch_queue = Queue(args.batch_queue_max_size) exit_queue = Queue(1) batch_process = Process(target=async_mol2graph, args=(batch_queue, data, args, num_iters, iter_step, exit_queue, True)) batch_process.start() currently_loaded_batches = [] for i in trange(0, num_iters, iter_step): if args.maml: task_train_data, task_test_data, task_idx = data.sample_maml_task(args, seed=0) mol_batch = task_test_data smiles_batch, features_batch, targets_batch = task_train_data.smiles(), task_train_data.features(), task_train_data.targets(task_idx) targets = torch.Tensor(targets_batch).unsqueeze(1) if args.cuda: targets = targets.cuda() else: # Prepare batch if args.parallel_featurization: if len(currently_loaded_batches) == 0: currently_loaded_batches = batch_queue.get() mol_batch, featurized_mol_batch = currently_loaded_batches.pop(0) else: mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch = mol_batch.smiles(), mol_batch.features() # Run model if args.dataset_type == 'bert_pretraining': batch = mol2graph(smiles_batch, args) batch.bert_mask(mol_batch.mask()) else: batch = smiles_batch if args.maml: # TODO refactor with train loop model.zero_grad() intermediate_preds = model(batch, features_batch) loss = get_loss_func(args)(intermediate_preds, targets) loss = loss.sum() / len(batch) grad = torch.autograd.grad(loss, [p for p in model.parameters() if p.requires_grad]) theta = [p for p in model.named_parameters() if p[1].requires_grad] # comes in same order as grad theta_prime = {p[0]: p[1] - args.maml_lr * grad[i] for i, p in enumerate(theta)} for name, nongrad_param in [p for p in model.named_parameters() if not p[1].requires_grad]: theta_prime[name] = nongrad_param + torch.zeros(nongrad_param.size()).to(nongrad_param) model_prime = build_model(args=args, params=theta_prime) smiles_batch, features_batch, targets_batch = task_test_data.smiles(), task_test_data.features(), task_test_data.targets(task_idx) # no mask since we only picked data points that have the desired target with torch.no_grad(): batch_preds = model_prime(smiles_batch, features_batch) full_targets.extend([[t] for t in targets_batch]) else: with torch.no_grad(): if args.parallel_featurization: previous_graph_input_mode = model.encoder.graph_input model.encoder.graph_input = True # force model to accept already processed input batch_preds = model(featurized_mol_batch, features_batch) model.encoder.graph_input = previous_graph_input_mode else: batch_preds = model(batch, features_batch) if args.dataset_type == 'bert_pretraining': if batch_preds['features'] is not None: features_preds.extend(batch_preds['features'].data.cpu().numpy()) batch_preds = batch_preds['vocab'] if args.dataset_type == 'kernel': batch_preds = batch_preds.view(int(batch_preds.size(0)/2), 2, batch_preds.size(1)) batch_preds = model.kernel_output_layer(batch_preds) batch_preds = batch_preds.data.cpu().numpy() if scaler is not None: batch_preds = scaler.inverse_transform(batch_preds) if args.dataset_type == 'regression_with_binning': batch_preds = batch_preds.reshape((batch_preds.shape[0], args.num_tasks, args.num_bins)) indices = np.argmax(batch_preds, axis=2) preds.extend(indices.tolist()) else: batch_preds = batch_preds.tolist() if args.dataset_type == 'bert_pretraining' and bert_save_memory: for atom_idx, mask_val in enumerate(mol_batch.mask()): if mask_val != 0: batch_preds[atom_idx] = None # not going to predict, so save some memory when passing around preds.extend(batch_preds) if args.dataset_type == 'regression_with_binning': preds = args.bin_predictions[np.array(preds)].tolist() if args.dataset_type == 'bert_pretraining': preds = { 'features': features_preds if len(features_preds) > 0 else None, 'vocab': preds } if args.parallel_featurization: exit_queue.put(0) # dummy var to get the subprocess to know that we're done batch_process.join() if args.maml: # return the task targets here to guarantee alignment; # there's probably no reasonable scenario where we'd use MAML directly to predict something that's actually unknown return preds, full_targets return preds
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None, chunk_names: bool = False, val_smiles: List[str] = None, test_smiles: List[str] = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :param chunk_names: Whether to train on the data in chunks. In this case, data must be a list of paths to the data chunks. :param val_smiles: Validation smiles strings without targets. :param test_smiles: Test smiles strings without targets, used for adversarial setting. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() if args.dataset_type == 'bert_pretraining': features_loss = nn.MSELoss() if chunk_names: for path, memo_path in tqdm(data, total=len(data)): featurization.SMILES_TO_FEATURES = dict() if os.path.isfile(memo_path): found_memo = True with open(memo_path, 'rb') as f: featurization.SMILES_TO_FEATURES = pickle.load(f) else: found_memo = False with open(path, 'rb') as f: chunk = pickle.load(f) if args.moe: for source in chunk: source.shuffle() else: chunk.shuffle() n_iter = train(model=model, data=chunk, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, args=args, n_iter=n_iter, logger=logger, writer=writer, chunk_names=False, val_smiles=val_smiles, test_smiles=test_smiles) if not found_memo: with open(memo_path, 'wb') as f: pickle.dump(featurization.SMILES_TO_GRAPH, f, protocol=pickle.HIGHEST_PROTOCOL) return n_iter if not args.moe: data.shuffle() loss_sum, iter_count = 0, 0 if args.adversarial: if args.moe: train_smiles = [] for d in data: train_smiles += d.smiles() else: train_smiles = data.smiles() train_val_smiles = train_smiles + val_smiles d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0 if args.moe: test_smiles = list(test_smiles) random.shuffle(test_smiles) train_smiles = [] for d in data: d.shuffle() train_smiles.append(d.smiles()) num_iters = min(len(test_smiles), min([len(d) for d in data])) elif args.maml: num_iters = args.maml_batches_per_epoch * args.maml_batch_size model.zero_grad() maml_sum_loss = 0 else: num_iters = len(data) if args.last_batch else len( data) // args.batch_size * args.batch_size if args.parallel_featurization: batch_queue = Queue(args.batch_queue_max_size) exit_queue = Queue(1) batch_process = Process(target=async_mol2graph, args=(batch_queue, data, args, num_iters, args.batch_size, exit_queue, args.last_batch)) batch_process.start() currently_loaded_batches = [] iter_size = 1 if args.maml else args.batch_size for i in trange(0, num_iters, iter_size): if args.moe: if not args.batch_domain_encs: model.compute_domain_encs( train_smiles) # want to recompute every batch mol_batch = [ MoleculeDataset(d[i:i + args.batch_size]) for d in data ] train_batch, train_targets = [], [] for b in mol_batch: tb, tt = b.smiles(), b.targets() train_batch.append(tb) train_targets.append(tt) test_batch = test_smiles[i:i + args.batch_size] loss = model.compute_loss(train_batch, train_targets, test_batch) model.zero_grad() loss_sum += loss.item() iter_count += len(mol_batch) elif args.maml: task_train_data, task_test_data, task_idx = data.sample_maml_task( args) mol_batch = task_test_data smiles_batch, features_batch, target_batch = task_train_data.smiles( ), task_train_data.features(), task_train_data.targets(task_idx) # no mask since we only picked data points that have the desired target targets = torch.Tensor(target_batch).unsqueeze(1) if next(model.parameters()).is_cuda: targets = targets.cuda() preds = model(smiles_batch, features_batch) loss = loss_func(preds, targets) loss = loss.sum() / len(smiles_batch) grad = torch.autograd.grad( loss, [p for p in model.parameters() if p.requires_grad]) theta = [ p for p in model.named_parameters() if p[1].requires_grad ] # comes in same order as grad theta_prime = { p[0]: p[1] - args.maml_lr * grad[i] for i, p in enumerate(theta) } for name, nongrad_param in [ p for p in model.named_parameters() if not p[1].requires_grad ]: theta_prime[name] = nongrad_param + torch.zeros( nongrad_param.size()).to(nongrad_param) else: # Prepare batch if args.parallel_featurization: if len(currently_loaded_batches) == 0: currently_loaded_batches = batch_queue.get() mol_batch, featurized_mol_batch = currently_loaded_batches.pop( ) else: if not args.last_batch and i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() if args.dataset_type == 'bert_pretraining': batch = mol2graph(smiles_batch, args) mask = mol_batch.mask() batch.bert_mask(mask) mask = 1 - torch.FloatTensor(mask) # num_atoms features_targets = torch.FloatTensor( target_batch['features'] ) if target_batch[ 'features'] is not None else None # num_molecules x features_size targets = torch.FloatTensor(target_batch['vocab']) # num_atoms if args.bert_vocab_func == 'feature_vector': mask = mask.reshape(-1, 1) else: targets = targets.long() else: batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() if args.dataset_type == 'bert_pretraining' and features_targets is not None: features_targets = features_targets.cuda() if args.class_balance: class_weights = [] for task_num in range(data.num_tasks()): class_weights.append( args.class_weights[task_num][targets[:, task_num].long()]) class_weights = torch.stack( class_weights).t() # num_molecules x num_tasks else: class_weights = torch.ones(targets.shape) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() if args.parallel_featurization: previous_graph_input_mode = model.encoder.graph_input model.encoder.graph_input = True # force model to accept already processed input preds = model(featurized_mol_batch, features_batch) model.encoder.graph_input = previous_graph_input_mode else: preds = model(batch, features_batch) if args.dataset_type == 'regression_with_binning': preds = preds.view(targets.size(0), targets.size(1), -1) targets = targets.long() loss = 0 for task in range(targets.size(1)): loss += loss_func( preds[:, task, :], targets[:, task] ) * class_weights[:, task] * mask[:, task] # for some reason cross entropy doesn't support multi target loss = loss.sum() / mask.sum() else: if args.dataset_type == 'unsupervised': targets = targets.long().reshape(-1) if args.dataset_type == 'bert_pretraining': features_preds, preds = preds['features'], preds['vocab'] if args.dataset_type == 'kernel': preds = preds.view(int(preds.size(0) / 2), 2, preds.size(1)) preds = model.kernel_output_layer(preds) loss = loss_func(preds, targets) * class_weights * mask if args.predict_features_and_task: loss = (loss.sum() + loss[:, :-args.features_size].sum() * (args.task_weight-1)) \ / (mask.sum() + mask[:, :-args.features_size].sum() * (args.task_weight-1)) else: loss = loss.sum() / mask.sum() if args.dataset_type == 'bert_pretraining' and features_targets is not None: loss += features_loss(features_preds, features_targets) loss_sum += loss.item() iter_count += len(mol_batch) if args.maml: model_prime = build_model(args=args, params=theta_prime) smiles_batch, features_batch, target_batch = task_test_data.smiles( ), task_test_data.features(), [ t[task_idx] for t in task_test_data.targets() ] # no mask since we only picked data points that have the desired target targets = torch.Tensor([[t] for t in target_batch]) if next(model_prime.parameters()).is_cuda: targets = targets.cuda() model_prime.zero_grad() preds = model_prime(smiles_batch, features_batch) loss = loss_func(preds, targets) loss = loss.sum() / len(smiles_batch) loss_sum += loss.item() iter_count += len( smiles_batch ) # TODO check that this makes sense, but it's just for display maml_sum_loss += loss if i % args.maml_batch_size == args.maml_batch_size - 1: maml_sum_loss.backward() optimizer.step() model.zero_grad() maml_sum_loss = 0 else: loss.backward() if args.max_grad_norm is not None: clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() if args.adjust_weight_decay: current_pnorm = compute_pnorm(model) if current_pnorm < args.pnorm_target: for i in range(len(optimizer.param_groups)): optimizer.param_groups[i]['weight_decay'] = max( 0, optimizer.param_groups[i]['weight_decay'] - args.adjust_weight_decay_step) else: for i in range(len(optimizer.param_groups)): optimizer.param_groups[i][ 'weight_decay'] += args.adjust_weight_decay_step if isinstance(scheduler, NoamLR): scheduler.step() if args.adversarial: for _ in range(args.gan_d_per_g): train_val_smiles_batch = random.sample(train_val_smiles, args.batch_size) test_smiles_batch = random.sample(test_smiles, args.batch_size) d_loss, gp_norm = model.train_D(train_val_smiles_batch, test_smiles_batch) train_val_smiles_batch = random.sample(train_val_smiles, args.batch_size) test_smiles_batch = random.sample(test_smiles, args.batch_size) g_loss = model.train_G(train_val_smiles_batch, test_smiles_batch) # we probably only care about the g_loss honestly d_loss_sum += d_loss * args.batch_size gp_norm_sum += gp_norm * args.batch_size g_loss_sum += g_loss * args.batch_size n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count if args.adversarial: d_loss_avg, g_loss_avg, gp_norm_avg = d_loss_sum / iter_count, g_loss_sum / iter_count, gp_norm_sum / iter_count d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0 loss_sum, iter_count = 0, 0 lrs_str = ', '.join('lr_{} = {:.4e}'.format(i, lr) for i, lr in enumerate(lrs)) debug("Loss = {:.4e}, PNorm = {:.4f}, GNorm = {:.4f}, {}".format( loss_avg, pnorm, gnorm, lrs_str)) if args.adversarial: debug( "D Loss = {:.4e}, G Loss = {:.4e}, GP Norm = {:.4}".format( d_loss_avg, g_loss_avg, gp_norm_avg)) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar('learning_rate_{}'.format(i), lr, n_iter) if args.parallel_featurization: exit_queue.put( 0) # dummy var to get the subprocess to know that we're done batch_process.join() return n_iter
def forward(self, batch: Union[List[List[str]], List[List[Chem.Mol]], List[List[Tuple[Chem.Mol, Chem.Mol]]], List[BatchMolGraph]], features_batch: List[np.ndarray] = None, atom_descriptors_batch: List[np.ndarray] = None, atom_features_batch: List[np.ndarray] = None, bond_features_batch: List[np.ndarray] = None) -> torch.FloatTensor: """ Encodes a batch of molecules. :param batch: A list of list of SMILES, a list of list of RDKit molecules, or a list of :class:`~chemprop.features.featurization.BatchMolGraph`. The outer list or BatchMolGraph is of length :code:`num_molecules` (number of datapoints in batch), the inner list is of length :code:`number_of_molecules` (number of molecules per datapoint). :param features_batch: A list of numpy arrays containing additional features. :param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors. :param atom_features_batch: A list of numpy arrays containing additional atom features. :param bond_features_batch: A list of numpy arrays containing additional bond features. :return: A PyTorch tensor of shape :code:`(num_molecules, hidden_size)` containing the encoding of each molecule. """ if type(batch[0]) != BatchMolGraph: # Group first molecules, second molecules, etc for mol2graph batch = [[mols[i] for mols in batch] for i in range(len(batch[0]))] # TODO: handle atom_descriptors_batch with multiple molecules per input if self.atom_descriptors == 'feature': if len(batch) > 1: raise NotImplementedError('Atom/bond descriptors are currently only supported with one molecule ' 'per input (i.e., number_of_molecules = 1).') batch = [ mol2graph( mols=b, atom_features_batch=atom_features_batch, bond_features_batch=bond_features_batch, overwrite_default_atom_features=self.overwrite_default_atom_features, overwrite_default_bond_features=self.overwrite_default_bond_features ) for b in batch ] elif bond_features_batch is not None: if len(batch) > 1: raise NotImplementedError('Atom/bond descriptors are currently only supported with one molecule ' 'per input (i.e., number_of_molecules = 1).') batch = [ mol2graph( mols=b, bond_features_batch=bond_features_batch, overwrite_default_atom_features=self.overwrite_default_atom_features, overwrite_default_bond_features=self.overwrite_default_bond_features ) for b in batch ] else: batch = [mol2graph(b) for b in batch] if self.use_input_features: features_batch = torch.from_numpy(np.stack(features_batch)).float().to(self.device) if self.features_only: return features_batch if self.atom_descriptors == 'descriptor': if len(batch) > 1: raise NotImplementedError('Atom descriptors are currently only supported with one molecule ' 'per input (i.e., number_of_molecules = 1).') encodings = [enc(ba, atom_descriptors_batch) for enc, ba in zip(self.encoder, batch)] else: if not self.reaction_solvent: encodings = [enc(ba) for enc, ba in zip(self.encoder, batch)] else: encodings = [] for ba in batch: if ba.is_reaction: encodings.append(self.encoder(ba)) else: encodings.append(self.encoder_solvent(ba)) output = reduce(lambda x, y: torch.cat((x, y), dim=1), encodings) if self.use_input_features: if len(features_batch.shape) == 1: features_batch = features_batch.view(1, -1) output = torch.cat([output, features_batch], dim=1) return output