Esempio n. 1
0
def export_model(args):
    jsonpath = os.path.join(args.modelpath, "args.json")
    train_args = read_from_json(jsonpath)
    model = get_model(train_args, atomref=np.zeros((100, 1)))
    model.load_state_dict(torch.load(os.path.join(args.modelpath, "best_model")))

    torch.save(model, args.destpath)
Esempio n. 2
0
def export_model(args):
    jsonpath = os.path.join(args.modelpath, 'args.json')
    train_args = read_from_json(jsonpath)
    model = get_model(train_args)
    model.load_state_dict(
        torch.load(os.path.join(args.modelpath, 'best_model')))

    torch.save(model, args.destpath)
Esempio n. 3
0
def export_model(args, basisdef, orbital_energies, mean, stddev):
    jsonpath = os.path.join(args.modelpath, 'args.json')
    train_args = read_from_json(jsonpath)
    model = get_model(train_args, basisdef, orbital_energies, mean, stddev,
                      False)
    model.load_state_dict(
        torch.load(os.path.join(args.modelpath, 'best_model'),
                   map_location='cpu'))

    torch.save(model, args.destpath)
Esempio n. 4
0
def load_model(modelpath, cuda=True):
    """
    Load a trained model from its directory and prepare it for simulations with ASE.

    Args:
        modelpath (str): Path to model directory.
        cuda (bool): Use cuda (default=True).

    Returns:
        object: Model class specified in molecular_dynamics. Contains the model, model type and device.
    """
    # Load stored arguments
    argspath = os.path.join(modelpath, 'args.json')
    args = read_from_json(argspath)

    # Reconstruct model based on arguments
    if args.model == 'schnet':
        representation = SchNet(args.features, args.features, args.interactions,
                                args.cutoff, args.num_gaussians)
        atomwise_output = Energy(args.features, return_force=True, create_graph=True)
    elif args.model == 'wacsf':
        # Build HDNN model
        mode = ('weighted', 'Behler')[args.behler]
        # Convert element strings to atomic charges
        elements = frozenset((atomic_numbers[i] for i in sorted(args.elements)))
        representation = BehlerSFBlock(args.radial, args.angular, zetas=set(args.zetas), cutoff_radius=args.cutoff,
                                       centered=args.centered, crossterms=args.crossterms, mode=mode,
                                       elements=elements)
        representation = StandardizeSF(representation, cuda=args.cuda)
        atomwise_output = ElementalEnergy(representation.n_symfuncs, n_hidden=args.n_nodes, n_layers=args.n_layers,
                                          return_force=True, create_graph=True, elements=elements)
    else:
        raise ValueError('Unknown model class:', args.model)

    model = AtomisticModel(representation, atomwise_output)

    # Load old parameters
    model.load_state_dict(torch.load(os.path.join(modelpath, 'best_model')))

    # Set cuda if requested
    device = torch.device("cuda" if cuda else "cpu")
    model = model.to(device)

    # Store into model wrapper for calculator
    ml_model = Model(model, args.model, device)

    return ml_model
Esempio n. 5
0
    jsonpath = os.path.join(args.modelpath, 'args.json')

    if args.mode == 'train':
        if args.overwrite and os.path.exists(args.modelpath):
            rmtree(args.modelpath)
            logging.info('existing model will be overwritten...')

        if not os.path.exists(args.modelpath):
            os.makedirs(args.modelpath)

        to_json(jsonpath, argparse_dict)

        spk.utils.set_random_seed(args.seed)
        train_args = args
    else:
        train_args = read_from_json(jsonpath)

    # will download qm9 if necessary, calculate_triples is required for wACSF angular functions
    logging.info('QM9 will be loaded...')
    qm9 = QM9(args.datapath,
              download=True,
              properties=[train_args.property],
              collect_triples=args.model == 'wacsf',
              remove_uncharacterized=train_args.remove_uncharacterized)
    atomref = qm9.get_atomref(train_args.property)

    # splits the dataset in test, val, train sets
    split_path = os.path.join(args.modelpath, 'split.npz')
    if args.mode == 'train':
        if args.split_path is not None:
            copyfile(args.split_path, split_path)
Esempio n. 6
0
def main(args):
    # set device (cpu or gpu)
    device = torch.device('cuda' if args.cuda else 'cpu')

    # store (or load) arguments
    argparse_dict = vars(args)
    jsonpath = os.path.join(args.modelpath, 'args.json')

    if args.mode == 'train':
        # overwrite existing model if desired
        if args.overwrite and os.path.exists(args.modelpath):
            rmtree(args.modelpath)
            logging.info('existing model will be overwritten...')

        # create model directory if it does not exist
        if not os.path.exists(args.modelpath):
            os.makedirs(args.modelpath)

        # get latest checkpoint of pre-trained model if a path was provided
        if args.pretrained_path is not None:
            model_chkpt_path = os.path.join(args.modelpath, 'checkpoints')
            pretrained_chkpt_path = os.path.join(args.pretrained_path,
                                                 'checkpoints')
            if os.path.exists(model_chkpt_path) \
                    and len(os.listdir(model_chkpt_path)) > 0:
                logging.info(
                    f'found existing checkpoints in model directory '
                    f'({model_chkpt_path}), please use --overwrite or choose '
                    f'empty model directory to start from a pre-trained '
                    f'model...')
                logging.warning(
                    f'will ignore pre-trained model and start from latest '
                    f'checkpoint at {model_chkpt_path}...')
                args.pretrained_path = None
            else:
                logging.info(
                    f'fetching latest checkpoint from pre-trained model at '
                    f'{pretrained_chkpt_path}...')
                if not os.path.exists(pretrained_chkpt_path):
                    logging.warning(
                        f'did not find checkpoints of pre-trained model, '
                        f'will train from scratch...')
                    args.pretrained_path = None
                else:
                    chkpt_files = [
                        f for f in os.listdir(pretrained_chkpt_path)
                        if f.startswith("checkpoint")
                    ]
                    if len(chkpt_files) == 0:
                        logging.warning(
                            f'did not find checkpoints of pre-trained '
                            f'model, will train from scratch...')
                        args.pretrained_path = None
                    else:
                        epoch = max([
                            int(f.split(".")[0].split("-")[-1])
                            for f in chkpt_files
                        ])
                        chkpt = os.path.join(
                            pretrained_chkpt_path,
                            "checkpoint-" + str(epoch) + ".pth.tar")
                        if not os.path.exists(model_chkpt_path):
                            os.makedirs(model_chkpt_path)
                        copyfile(
                            chkpt,
                            os.path.join(model_chkpt_path,
                                         f'checkpoint-{epoch}.pth.tar'))

        # store arguments for training in model directory
        to_json(jsonpath, argparse_dict)
        train_args = args

        # set seed
        spk.utils.set_random_seed(args.seed)
    else:
        # load arguments used for training from model directory
        train_args = read_from_json(jsonpath)

    # load data for training/evaluation
    if args.mode in ['train', 'eval']:
        # find correct data class
        assert train_args.dataset_name in dataset_name_to_class_mapping, \
            f'Could not find data class for dataset {train_args.dataset}. Please ' \
            f'specify a correct dataset name!'
        dataclass = dataset_name_to_class_mapping[train_args.dataset_name]

        # load the dataset
        logging.info(f'{train_args.dataset_name} will be loaded...')
        subset = None
        if train_args.subset_path is not None:
            logging.info(f'Using subset from {train_args.subset_path}')
            subset = np.load(train_args.subset_path)
            subset = [int(i) for i in subset]
        if issubclass(dataclass, DownloadableAtomsData):
            data = dataclass(args.datapath,
                             subset=subset,
                             precompute_distances=args.precompute_distances,
                             download=True if args.mode == 'train' else False)
        else:
            data = dataclass(args.datapath,
                             subset=subset,
                             precompute_distances=args.precompute_distances)

        # splits the dataset in test, val, train sets
        split_path = os.path.join(args.modelpath, 'split.npz')
        if args.mode == 'train':
            if args.split_path is not None:
                copyfile(args.split_path, split_path)

        logging.info('create splits...')
        data_train, data_val, data_test = data.create_splits(
            *train_args.split, split_file=split_path)

        logging.info('load data...')
        types = sorted(dataclass.available_atom_types)
        max_type = types[-1]
        # set up collate function according to args
        collate = lambda x: \
            collate_atoms(x,
                          all_types=types + [max_type+1],
                          start_token=max_type+2,
                          draw_samples=args.draw_random_samples,
                          label_width_scaling=train_args.label_width_factor,
                          max_dist=train_args.max_distance,
                          n_bins=train_args.num_distance_bins)

        train_loader = spk.data.AtomsLoader(data_train,
                                            batch_size=args.batch_size,
                                            sampler=RandomSampler(data_train),
                                            num_workers=4,
                                            pin_memory=True,
                                            collate_fn=collate)
        val_loader = spk.data.AtomsLoader(data_val,
                                          batch_size=args.batch_size,
                                          num_workers=2,
                                          pin_memory=True,
                                          collate_fn=collate)

    # construct the model
    if args.mode == 'train' or args.checkpoint >= 0:
        model = get_model(train_args, parallelize=args.parallel)
    logging.info(f'running on {device}')

    # load model or checkpoint for evaluation or generation
    if args.mode in ['eval', 'generate']:
        if args.checkpoint < 0:  # load best model
            logging.info(f'restoring best model')
            model = torch.load(os.path.join(args.modelpath,
                                            'best_model')).to(device)
        else:
            logging.info(f'restoring checkpoint {args.checkpoint}')
            chkpt = os.path.join(
                args.modelpath, 'checkpoints',
                'checkpoint-' + str(args.checkpoint) + '.pth.tar')
            state_dict = torch.load(chkpt)
            model.load_state_dict(state_dict['model'], strict=True)

    # execute training, evaluation, or generation
    if args.mode == 'train':
        logging.info("training...")
        train(args, model, train_loader, val_loader, device)
        logging.info("...training done!")

    elif args.mode == 'eval':
        logging.info("evaluating...")
        test_loader = spk.data.AtomsLoader(data_test,
                                           batch_size=args.batch_size,
                                           num_workers=2,
                                           pin_memory=True,
                                           collate_fn=collate)
        with torch.no_grad():
            evaluate(args, model, train_loader, val_loader, test_loader,
                     device)
        logging.info("... done!")

    elif args.mode == 'generate':
        logging.info(f'generating {args.amount_gen} molecules...')
        generated = generate(args, train_args, model, device)
        gen_path = os.path.join(args.modelpath, 'generated/')
        if not os.path.exists(gen_path):
            os.makedirs(gen_path)
        # get untaken filename and store results
        file_name = os.path.join(gen_path, args.file_name)
        if os.path.isfile(file_name + '.mol_dict'):
            expand = 0
            while True:
                expand += 1
                new_file_name = file_name + '_' + str(expand)
                if os.path.isfile(new_file_name + '.mol_dict'):
                    continue
                else:
                    file_name = new_file_name
                    break
        with open(file_name + '.mol_dict', 'wb') as f:
            pickle.dump(generated, f)
        logging.info('...done!')
    else:
        logging.info(f'Unknown mode: {args.mode}')