Ejemplo n.º 1
0
 def __init__(self,
              get_atomic_attributes,
              node_attributes,
              filename,
              cols_to_read,
              delimiter=',',
              get_bond_attributes=None,
              edge_attributes=None):
     super(GraphDataset, self).__init__()
     assert (get_bond_attributes is None) == (edge_attributes is None)
     data_set = read_smiles_property_file(filename, cols_to_read, delimiter)
     data = data_set[0]
     target = data_set[1:]
     clean_smiles, clean_idx = sanitize_smiles(data)
     target = np.array(target).T
     max_size = 0
     for sm in clean_smiles:
         mol = Chem.MolFromSmiles(sm)
         if mol.GetNumAtoms() > max_size:
             max_size = mol.GetNumAtoms()
     self.target = target[clean_idx, :]
     self.graphs = []
     self.node_feature_matrix = []
     self.adj_matrix = []
     for sm in clean_smiles:
         graph = Graph(sm, max_size, get_atomic_attributes,
                       get_bond_attributes)
         self.node_feature_matrix.append(
             graph.get_node_feature_matrix(node_attributes, max_size))
         if get_bond_attributes is None:
             self.adj_matrix.append(graph.adj_matrix)
         else:
             self.adj_matrix.append(
                 graph.get_edge_attr_adj_matrix(edge_attributes, max_size))
     self.num_features = self.node_feature_matrix[0].shape[1]
Ejemplo n.º 2
0
 def __init__(self,
              filename,
              cols_to_read,
              delimiter=',',
              tokens=None,
              pad=True,
              tokenize=True,
              augment=False,
              flip=True):
     super(SmilesDataset, self).__init__()
     self.tokenize = tokenize
     data = read_smiles_property_file(filename, cols_to_read, delimiter)
     smiles = data[0]
     clean_smiles, clean_idx = sanitize_smiles(smiles)
     if len(data) > 1:
         target = np.array(data[1:], dtype='float')
         target = np.array(target)
         target = target.T
         self.target = target[clean_idx]
     else:
         self.target = None
     if augment:
         clean_smiles, self.target = augment_smiles(clean_smiles,
                                                    self.target)
     if pad:
         clean_smiles, self.length = pad_sequences(clean_smiles)
     tokens, self.token2idx, self.num_tokens = get_tokens(
         clean_smiles, tokens)
     if tokenize:
         clean_smiles, self.tokens = seq2tensor(clean_smiles, tokens, flip)
     self.data = clean_smiles
Ejemplo n.º 3
0
 def __init__(self, filename, cols_to_read, features, delimiter=',', tokens=None):
     super(VanillaDataset, self).__init__()
     data = read_smiles_property_file(filename, cols_to_read, delimiter)
     smiles = data[0]
     target = np.array(data[1], dtype='float')
     clean_smiles, clean_idx = sanitize_smiles(smiles)
     target = np.array(target)
     self.target = target[clean_idx]
Ejemplo n.º 4
0
    def from_smiles_file(cls, get_atomic_attributes, node_attributes, filename, 
                cols_to_read, delimiter=',', get_bond_attributes=None, edge_attributes=None):
        data_set = read_smiles_property_file(filename, cols_to_read,
                                                 delimiter)
        data = data_set[0]
        target = np.array(data_set[1:]).squeeze()

        clean_smiles, clean_idx = sanitize_smiles(data)
        clean_mols = [Chem.MolFromSmiles(smiles) for smiles in clean_smiles]

        clean_target = target[clean_idx]

        return cls(get_atomic_attributes, node_attributes, clean_mols, clean_target,
                get_bond_attributes, edge_attributes)
Ejemplo n.º 5
0
 def __init__(self,
              filename,
              tokenized=False,
              cols_to_read=None,
              delimiter=',',
              mol_tokens=None,
              prot_tokens=None,
              pad=True):
     super(SmilesProteinDataset, self).__init__()
     if not tokenized:
         data = read_smiles_property_file(filename, cols_to_read, delimiter)
         smiles = data[0]
         proteins = np.array(data[1])
         target = np.array(data[2], dtype='float')
         clean_smiles, clean_idx = sanitize_smiles(smiles)
         self.target = target[clean_idx]
         proteins = list(proteins[clean_idx])
         if pad:
             clean_smiles, self.mol_lengths = pad_sequences(clean_smiles)
             proteins, self.prot_lengths = pad_sequences(proteins)
         self.mol_tokens, self.mol_token2idx, self.mol_num_tokens = \
             get_tokens(clean_smiles, mol_tokens)
         self.prot_tokens, self.prot_token2idx, self.prot_num_tokens = \
             get_tokens(proteins, prot_tokens)
         clean_smiles = seq2tensor(clean_smiles, self.mol_tokens)
         proteins = seq2tensor(proteins, self.prot_tokens)
         self.molecules = clean_smiles
         self.proteins = proteins
     else:
         f = open(filename, 'rb')
         data = pickle.load(f)
         self.mol_tokens = data['smiles_tokens']
         self.prot_tokens = data['proteins_tokens']
         self.mol_num_tokens = len(data['smiles_tokens'])
         self.prot_num_tokens = len(data['proteins_tokens'])
         self.molecules = data['smiles']
         self.proteins = data['proteins']
         self.target = data['labels']
     assert len(self.molecules) == len(self.proteins)
     assert len(self.molecules) == len(self.target)
Ejemplo n.º 6
0
    def __init__(self,
                 get_atomic_attributes,
                 node_attributes,
                 filename,
                 cols_to_read,
                 delimiter=',',
                 get_bond_attributes=None,
                 edge_attributes=None,
                 restrict_min_atoms=-1,
                 restrict_max_atoms=-1,
                 kekulize=True,
                 file_format="smi",
                 addHs=False,
                 has_3D=False,
                 allowed_atoms=None,
                 return_smiles=False,
                 **kwargs):
        super(GraphDataset, self).__init__()
        assert (get_bond_attributes is None) == (edge_attributes is None)
        self.return_smiles = return_smiles
        self.restrict_min_atoms = restrict_min_atoms
        self.restrict_max_atoms = restrict_max_atoms
        self.kekulize = kekulize
        self.addHs = addHs
        self.has_3D = has_3D

        if file_format == "pickled":
            data = pickle.load(open(filename, "rb"))

            # this cleanup must be consistent with sanitize_smiles
            mn, mx = restrict_min_atoms, restrict_max_atoms
            indices = [
                i for i, n in enumerate(data["num_atoms_all"])
                if (n >= mn or mn < 0) and (n <= mx or mx < 0)
            ]
            data = {
                key: value[indices] if isinstance(value, np.ndarray) else
                [value[i] for i in indices]
                for key, value in data.items()
            }

            self.num_atoms_all = data["num_atoms_all"]
            self.target = data["target"]
            self.smiles = data["smiles"]
        elif file_format == "smi":
            data_set = read_smiles_property_file(filename, cols_to_read,
                                                 delimiter)
            data = data_set[0]
            if len(cols_to_read) == 1:
                target = None
            else:
                target = data_set[1:]
            clean_smiles, clean_idx, num_atoms, max_len = sanitize_smiles(
                data,
                min_atoms=restrict_min_atoms,
                max_atoms=restrict_max_atoms,
                return_num_atoms=True,
                return_max_len=True)
            self.max_len = max_len
            if target is not None:
                target = np.asarray(target, dtype=np.float).T
            clean_smiles = [clean_smiles[i] for i in clean_idx]
            num_atoms = [num_atoms[i] for i in clean_idx]
            self.clean_idx = clean_idx
            if target is not None:
                self.target = target[clean_idx, :]
            else:
                self.target = None
            self.smiles = clean_smiles
            self.num_atoms_all = num_atoms
        else:
            raise NotImplementedError()

        self.max_size = max(self.num_atoms_all)
        self.node_attributes = node_attributes
        self.edge_attributes = edge_attributes
        self.get_atomic_attributes = get_atomic_attributes
        self.get_bond_attributes = get_bond_attributes
    def __init__(self,
                 get_atomic_attributes,
                 node_attributes,
                 filename,
                 cols_to_read,
                 delimiter=',',
                 get_bond_attributes=None,
                 edge_attributes=None,
                 restrict_min_atoms=-1,
                 restrict_max_atoms=-1,
                 kekulize=True,
                 file_format="smi",
                 addHs=False,
                 has_3D=False,
                 allowed_atoms=None,
                 return_smiles=False,
                 **kwargs):
        super(GraphDataset, self).__init__()
        assert (get_bond_attributes is None) == (edge_attributes is None)
        self.return_smiles = return_smiles
        self.restrict_min_atoms = restrict_min_atoms
        self.restrict_max_atoms = restrict_max_atoms
        self.kekulize = kekulize
        self.addHs = addHs
        self.has_3D = has_3D

        if file_format == "pickled":
            data = pickle.load(open(kwargs["pickled"], "rb"))

            # this cleanup must be consistent with sanitize_smiles
            mn, mx = restrict_min_atoms, restrict_max_atoms
            indices = [
                i for i, n in enumerate(data["num_atoms_all"])
                if (n >= mn or mn < 0) and (n <= mx or mx < 0)
            ]
            data = {
                key: value[indices] if isinstance(value, np.ndarray) else
                [value[i] for i in indices]
                for key, value in data.items()
            }

            self.num_atoms_all = data["num_atoms_all"]
            self.target = data["target"]
            self.smiles = data["smiles"]
        elif file_format == "smi":
            data_set = read_smiles_property_file(filename, cols_to_read,
                                                 delimiter)
            data = data_set[0]
            if len(cols_to_read) == 1:
                target = None
            else:
                target = data_set[1:]
            clean_smiles, clean_idx, num_atoms, max_len = sanitize_smiles(
                data,
                min_atoms=restrict_min_atoms,
                max_atoms=restrict_max_atoms,
                return_num_atoms=True,
                return_max_len=True)
            self.max_len = max_len
            if target is not None:
                target = np.asarray(target, dtype=np.float).T
            clean_smiles = [clean_smiles[i] for i in clean_idx]
            num_atoms = [num_atoms[i] for i in clean_idx]
            self.clean_idx = clean_idx
            if target is not None:
                self.target = target[clean_idx, :]
            else:
                self.target = None
            self.smiles = clean_smiles
            self.num_atoms_all = num_atoms
        elif file_format == "sdf":
            filenames = []
            os.chdir("/home/Work/data/enamine_hll-500/")
            for file in glob.glob("*.sdf"):
                filenames.append(file)
            self.num_atoms_all = []
            smiles = []
            rd_mols = []
            for f in [filenames[10]]:
                print(f)
                supplier = Chem.SDMolSupplier(f, False, False)
                n = len(supplier)
                for i in range(n):
                    mol = supplier[i]
                    anum = [(a.GetAtomicNum() in allowed_atoms.keys())
                            for a in mol.GetAtoms()]
                    if sum(anum) == len(anum):
                        n = mol.GetNumAtoms()
                        x_coord = []
                        y_coord = []
                        z_coord = []
                        for k in range(n):
                            pos = mol.GetConformer().GetAtomPosition(k)
                            x_coord.append(pos.x)
                            y_coord.append(pos.y)
                            z_coord.append(pos.z)
                        if np.linalg.norm(z_coord, ord=2) > 1.0:
                            rd_mols.append(mol)
                            smiles.append(Chem.MolToSmiles(mol))
                            self.num_atoms_all.append(n)
            self.smiles = smiles
            self.rd_mols = rd_mols
            self.target = np.ones(len(self.smiles))
        else:
            raise NotImplementedError()

        self.max_size = max(self.num_atoms_all)
        self.node_attributes = node_attributes
        self.edge_attributes = edge_attributes
        self.get_atomic_attributes = get_atomic_attributes
        self.get_bond_attributes = get_bond_attributes
Ejemplo n.º 8
0
    def forward(self, input, target=None):

        log_policy = input["log_policy"]
        sizes = input["sizes"]
        trajectories = input["smiles"]
        adj = input["adj"]
        classes = input["classes"]

        device = log_policy.device
        len_trajectory = max(sizes)
        batch_size = len(sizes)

        if self.critic is not None:
            # Current convention is to run critic only on valid molecules
            # Others receive zero reward from the critic

            clean_smiles, clean_idx = sanitize_smiles(trajectories, allowed_tokens=self.tokens, logging="none")
            clean_smiles = [clean_smiles[i] for i in clean_idx]

            with torch.no_grad():
                clean_rewards = self.reward_fn(clean_smiles, self.critic, self.tokens, device, self.fn)

            rewards = torch.zeros((batch_size, *clean_rewards.shape[1:]),
                                  dtype=clean_rewards.dtype,
                                  device=clean_rewards.device)
            clean_idx = torch.tensor(clean_idx, device=clean_rewards.device)
            rewards.index_copy_(0, clean_idx, clean_rewards)

            rewards = rewards.view(batch_size, 1)
            discounts = torch.pow(self.gamma, torch.arange(len_trajectory, device=device, dtype=torch.float))
            discounts = discounts.view(1, len_trajectory)
            discounted_rewards = rewards * discounts

            discounted_rewards = pack_padded_sequence(discounted_rewards, sizes, batch_first=True).data

        sanitize_smiles(trajectories, allowed_tokens=self.tokens, logging="info")

        if self.max_atom_bonds is not None:

            structure_reward = torch.zeros((batch_size, len_trajectory),
                                           dtype=log_policy.dtype,
                                           device=log_policy.device)
            for i in range(batch_size):
                atom_bonds = torch.from_numpy(adj[i]).sum(dim=0)
                cl = torch.cat([torch.tensor([0], dtype=torch.long), classes[i]])
                max_atom_bonds = torch.tensor(self.max_atom_bonds)
                max_atom_bonds = max_atom_bonds[cl]

                # structure_reward[i, :sizes[i]] = \
                #     (atom_bonds <= max_atom_bonds).to(
                #         dtype=torch.float, device=device)
                structure_reward[i, :sizes[i]] = \
                    -15. * (atom_bonds > max_atom_bonds).to(
                        dtype=torch.float, device=device)

            structure_reward = pack_padded_sequence(structure_reward, sizes, batch_first=True).data

            if self.critic is not None:
                discounted_rewards += structure_reward
            else:
                discounted_rewards = structure_reward

        loss = -discounted_rewards * log_policy
        loss = loss.mean()
        if self.enable_supervised_loss:
            loss = loss + input["loss"]
        return loss
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(description='Experiment parameters')
    parser.add_argument("--use_cuda",
                        default=torch.cuda.is_available(),
                        help="Whether to train on GPU")
    parser.add_argument("--config_file",
                        required=True,
                        help="Path to the configuration file")
    parser.add_argument(
        "--mode",
        default='train',
        help="Could be \"train\", \"eval\", \"train_eval\", \"predict\"")
    parser.add_argument('--continue_learning',
                        dest='continue_learning',
                        action='store_true',
                        help="whether to continue learning")
    parser.add_argument("--force_checkpoint",
                        dest="force_checkpoint",
                        default="",
                        help="Full path to a pretrained snapshot "
                        "(e.g. useful for knowledge transfer or)")
    parser.add_argument('--dist-backend',
                        default='nccl',
                        type=str,
                        help='distributed backend')
    parser.add_argument('--seed',
                        default=None,
                        type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--workers',
                        default=0,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 0)')
    parser.add_argument('--random_seed',
                        default=0,
                        type=int,
                        metavar='N',
                        help='random_seed (default: 0)')
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--copy_config_file",
                        action="store_true",
                        help="Copy config file to logdir (useful in training)")

    args, unknown = parser.parse_known_args()

    num_gpus = int(os.environ["WORLD_SIZE"]) \
        if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend=args.dist_backend,
                                init_method='env://')
        print('Distributed process with rank {:d} initalized'.format(
            args.local_rank))

    cudnn.benchmark = True

    if args.mode not in ['train', 'eval', 'train_eval', 'infer', 'predict']:
        raise ValueError("Mode has to be one of "
                         "['train', 'eval', 'train_eval', 'infer', 'predict']")
    config_module = runpy.run_path(args.config_file)

    model_config = config_module.get('model_params', None)
    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    model_config['use_cuda'] = args.use_cuda
    if model_config is None:
        raise ValueError('model_params dictionary has to be '
                         'defined in the config file')
    model_object = config_module.get('model', None)
    if model_object is None:
        raise ValueError('model class has to be defined in the config file')

        # after we read the config, trying to overwrite some of the properties
        # with command line arguments that were passed to the script
    parser_unk = argparse.ArgumentParser()
    for pm, value in flatten_dict(model_config).items():
        if type(value) == int or type(value) == float or \
                isinstance(value, string_types):
            parser_unk.add_argument('--' + pm, default=value, type=type(value))
        elif type(value) == bool:
            parser_unk.add_argument('--' + pm,
                                    default=value,
                                    type=ast.literal_eval)

    config_update = parser_unk.parse_args(unknown)
    nested_update(model_config, nest_dict(vars(config_update)))

    # checking that everything is correct with log directory
    logdir = model_config['logdir']
    ckpt_dir = os.path.join(logdir, 'checkpoint')

    if args.force_checkpoint:
        assert not args.continue_learning, \
            "force_checkpoint and continue_learning are " \
            "mutually exclusive flags"
        checkpoint = args.force_checkpoint
        assert os.path.isfile(checkpoint), "{} is not a file".format(
            checkpoint)
        cur_epoch = 0
    elif args.mode in ['eval', 'infer', 'predict'] or args.continue_learning:
        checkpoint = get_latest_checkpoint(ckpt_dir)
        if checkpoint is None:
            raise IOError("Failed to find model checkpoint under "
                          "{}. Can't load the model".format(ckpt_dir))
        cur_epoch = int(os.path.basename(checkpoint).split("_")[-1]) + 1
    else:
        checkpoint = None
        cur_epoch = 0

    if not os.path.exists(logdir):
        comm.mkdir(logdir)
        print('Directory {} created'.format(logdir))
    elif os.path.isfile(logdir):
        raise IOError("There is a file with the same name as \"logdir\" "
                      "parameter. You should change the log directory path "
                      "or delete the file to continue.")

    if not os.path.exists(ckpt_dir):
        comm.mkdir(ckpt_dir)
        print('Directory {} created'.format(ckpt_dir))
    elif os.path.isdir(ckpt_dir) and os.listdir(ckpt_dir) != []:
        if not args.continue_learning and args.mode not in [
                'eval', 'infer', 'predict'
        ]:
            raise IOError("Log directory is not empty. If you want to "
                          "continue learning, you should provide "
                          "\"--continue_learning\" flag")

    doprint = comm.is_main_process()
    tofile = os.path.join(logdir, "log.txt")
    logger = setup_textlogger("openchem", doprint, tofile)
    msg = "Running with config:\n"
    for k, v in sorted(flatten_dict(model_config).items()):
        msg += ("{}:\t{}\n".format(k, v)).expandtabs(50)
    logger.info("Running on {:d} GPUs".format(comm.get_world_size()))
    logger.info("Logging directory is set to {}".format(logdir))
    logger.info(msg)
    if args.copy_config_file:
        shutil.copy(args.config_file, logdir)

    train_config = copy.deepcopy(model_config)
    eval_config = copy.deepcopy(model_config)

    if args.mode == 'train' or args.mode == 'train_eval':
        if 'train_params' in config_module:
            nested_update(train_config,
                          copy.deepcopy(config_module['train_params']))
    if args.mode in ['eval', 'train_eval', 'infer', 'predict']:
        if 'eval_params' in config_module:
            nested_update(eval_config,
                          copy.deepcopy(config_module['eval_params']))

    if args.mode == "train" or args.mode == "train_eval":
        train_dataset = copy.deepcopy(model_config['train_data_layer'])
        if model_config['task'] == 'classification':
            train_dataset.target = train_dataset.target.reshape(-1)
        if args.distributed:
            train_sampler = DistributedSampler(train_dataset)
        else:
            train_sampler = None
        train_loader = create_loader(train_dataset,
                                     batch_size=model_config['batch_size'],
                                     shuffle=(train_sampler is None),
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=train_sampler)
    else:
        train_loader = None

    if args.mode == "predict" and (
            'predict_data_layer' not in model_config.keys()
            or model_config['predict_data_layer'] is None):
        raise IOError("When model is run in 'predict' mode, "
                      "prediction data layer must be specified")

    if args.mode == "predict":
        predict_dataset = copy.deepcopy(model_config['predict_data_layer'])
        predict_loader = create_loader(predict_dataset,
                                       batch_size=model_config['batch_size'],
                                       shuffle=False,
                                       num_workers=1,
                                       pin_memory=True)
    else:
        predict_loader = None

    if args.mode in ["eval", "train_eval"
                     ] and ('val_data_layer' not in model_config.keys()
                            or model_config['val_data_layer'] is None):
        raise IOError("When model is run in 'eval' or 'train_eval' modes, "
                      "validation data layer must be specified")

    if args.mode in ["eval", "train_eval"]:
        val_dataset = copy.deepcopy(model_config['val_data_layer'])
        if model_config['task'] == 'classification':
            val_dataset.target = val_dataset.target.reshape(-1)
        val_loader = create_loader(val_dataset,
                                   batch_size=model_config['batch_size'],
                                   shuffle=False,
                                   num_workers=1,
                                   pin_memory=True)
    else:
        val_loader = None

    model_config['train_loader'] = train_loader
    model_config['val_loader'] = val_loader
    model_config['predict_loader'] = predict_loader

    # create model
    model = model_object(params=model_config)

    if args.use_cuda:
        model = model.to('cuda')

    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
    else:
        model = DataParallel(model)

    if checkpoint is not None:
        logger.info("Loading model from {}".format(checkpoint))
        weights = torch.load(checkpoint, map_location=torch.device("cpu"))
        model.load_state_dict(weights)
    else:
        logger.info("Starting training from scratch")
    if args.mode in ["train", "train_eval"]:
        logger.info("Training is set up from epoch {:d}".format(cur_epoch))

    criterion, optimizer, lr_scheduler = build_training(model, model_config)

    if args.mode == 'train':
        fit(model,
            lr_scheduler,
            train_loader,
            optimizer,
            criterion,
            model_config,
            eval=False,
            cur_epoch=cur_epoch)
    elif args.mode == 'train_eval':
        fit(model,
            lr_scheduler,
            train_loader,
            optimizer,
            criterion,
            model_config,
            eval=True,
            val_loader=val_loader,
            cur_epoch=cur_epoch)
    elif args.mode == "eval":
        evaluate(model, val_loader, criterion)
    elif args.mode == "predict":
        predict(model, predict_loader)
    elif args.mode == "infer":
        comm.synchronize()
        start_time = time.time()

        #if comm.get_world_size() > 1:
        #    seed = comm.get_rank() * 10000
        #    random.seed(seed)
        #    np.random.seed(seed)
        #    torch.manual_seed(seed)
        #    torch.cuda.manual_seed_all(seed)

        model.eval()
        smiles = []

        with torch.no_grad():
            for i in range(1):
                batch_smiles = model(None, batch_size=1024)
                smiles.extend(batch_smiles)
                print("Iteration {:d}: {:d} smiles".format(
                    i + 1, len(batch_smiles)))

        if comm.get_world_size() > 1:
            path = os.path.join(
                logdir, "debug_smiles_{:d}.txt".format(comm.get_rank()))
            with open(path, "w") as f:
                for s in smiles:
                    f.write(s + "\n")

            comm.synchronize()

            if not comm.is_main_process():
                return

            smiles = []
            for i in range(comm.get_world_size()):
                path = os.path.join(logdir, "debug_smiles_{:d}.txt".format(i))
                with open(path) as f:
                    smiles_local = f.readlines()
                os.remove(path)

                smiles_local = [s.rstrip() for s in smiles_local]
                smiles.extend(smiles_local)

        path = os.path.join(logdir, "debug_smiles.txt")
        with open(path, "w") as f:
            for s in smiles:
                f.write(s + "\n")

        print("Generated {:d} molecules in {:.1f} seconds".format(
            len(smiles),
            time.time() - start_time))

        eval_metrics = model_config['eval_metrics']
        score = eval_metrics(None, smiles)
        qed_score = metrics.qed(smiles)
        logger.info("Eval metrics = {:.2f}".format(score))
        logger.info("QED score = {:.2f}".format(qed_score))

        smiles, idx = sanitize_smiles(smiles, logging="info")