コード例 #1
0
ファイル: benchmark_crabnet.py プロジェクト: kaaiian/CrabNet
def load_model(mat_prop, classification, file_name, verbose=True):
    # Load up a saved network.
    model = Model(CrabNet(compute_device=compute_device).to(compute_device),
                  model_name=f'{mat_prop}', verbose=verbose)
    model.load_network(f'{mat_prop}.pth')

    # Check if classifcation task
    if classification:
        model.classification = True

    # Load the data you want to predict with
    data = rf'data\benchmark_data\{mat_prop}\{file_name}'
    # data is reloaded to model.data_loader
    model.load_data(data, batch_size=2**9, train=False)
    return model
コード例 #2
0
def save_test_results(mat_prop, classification_list):
    # Load up a saved network.
    model = Model(CrabNet(compute_device=compute_device).to(compute_device))
    model.load_network(f'{mat_prop}.pth')
    if mat_prop in classification_list:
        model.classification = True
    # Load the data you want to predict with
    test_data = rf'data\benchmark_data\{mat_prop}\test.csv'
    model.load_data(test_data)  # data is reloaded to model.data_loader
    output = model.predict(model.data_loader)  # predict the data saved here
    if model.classification:
        auc = roc_auc_score(output[0], output[1])
        print(f'\n{mat_prop} ROC AUC: {auc:0.3f}')
    else:
        print(f'\n{mat_prop} mae: {abs(output[0] - output[1]).mean():0.3f}')
    # save your predictions to a csv
    save_results(output, f'{mat_prop}_output.csv')
コード例 #3
0
def get_model(data_dir,
              mat_prop,
              classification=False,
              batch_size=None,
              transfer=None,
              verbose=True):
    # Get the TorchedCrabNet architecture loaded
    model = Model(CrabNet(compute_device=compute_device).to(compute_device),
                  model_name=f'{mat_prop}',
                  verbose=verbose)

    # Train network starting at pretrained weights
    if transfer is not None:
        model.load_network(f'{transfer}.pth')
        model.model_name = f'{mat_prop}'

    # Apply BCEWithLogitsLoss to model output if binary classification is True
    if classification:
        model.classification = True

    # Get the datafiles you will learn from
    train_data = f'{data_dir}/{mat_prop}/train.csv'
    try:
        val_data = f'{data_dir}/{mat_prop}/val.csv'
    except:
        print('Please ensure you have train (train.csv) and validation data',
              f'(val.csv) in folder "data/materials_data/{mat_prop}"')

    # Load the train and validation data before fitting the network
    data_size = pd.read_csv(train_data).shape[0]
    batch_size = 2**round(np.log2(data_size) - 4)
    if batch_size < 2**7:
        batch_size = 2**7
    if batch_size > 2**12:
        batch_size = 2**12
    model.load_data(train_data, batch_size=batch_size, train=True)
    print(f'training with batchsize {model.batch_size} '
          f'(2**{np.log2(model.batch_size):0.3f})')
    model.load_data(val_data, batch_size=batch_size)

    # Set the number of epochs, decide if you want a loss curve to be plotted
    model.fit(epochs=40, losscurve=False)

    # Save the network (saved as f"{model_name}.pth")
    model.save_network()
    return model
コード例 #4
0
ファイル: matbench_crabnet.py プロジェクト: kaaiian/CrabNet
def get_model(mat_prop,
              i,
              classification=False,
              batch_size=None,
              transfer=None,
              verbose=True):
    # Get the TorchedCrabNet architecture loaded
    model = Model(CrabNet(compute_device=compute_device).to(compute_device),
                  model_name=f'{mat_prop}{i}',
                  verbose=verbose)

    # Train network starting at pretrained weights
    if transfer is not None:
        model.load_network(f'{transfer}.pth')
        model.model_name = f'{mat_prop}'

    # Apply BCEWithLogitsLoss to model output if binary classification is True
    if classification:
        model.classification = True

    # Get the datafiles you will learn from
    train_data = rf'data\matbench_cv\{mat_prop}\train{i}.csv'
    val_data = rf'data\matbench_cv\{mat_prop}\val{i}.csv'

    # Load the train and validation data before fitting the network
    data_size = pd.read_csv(train_data).shape[0]
    batch_size = 2**round(np.log2(data_size) - 4)
    if batch_size < 2**7:
        batch_size = 2**7
    if batch_size > 2**12:
        batch_size = 2**12
    # batch_size = 2**7
    model.load_data(train_data, batch_size=batch_size, train=True)
    print(f'training with batchsize {model.batch_size} '
          f'(2**{np.log2(model.batch_size):0.3f})')
    model.load_data(val_data, batch_size=batch_size)

    # Set the number of epochs, decide if you want a loss curve to be plotted
    model.fit(epochs=300, losscurve=False)

    # Save the network (saved as f"{model_name}.pth")
    model.save_network()
    return model
コード例 #5
0
def model(mat_prop, classification_list, simple=False):
    # Get the TorchedCrabNet architecture loaded
    model = Model(CrabNet(compute_device=compute_device).to(compute_device),
                  model_name=f'{mat_prop}')
    if True:
        model.load_network(f'{mat_prop}.pth')
        model.model_name = f'{mat_prop}'

    if mat_prop in classification_list:
        model.classification = True

    dataset = rf'{data_dir}\{mat_prop}\train.csv'
    model.load_data(dataset,
                    batch_size=2**7)  # data is reloaded to model.data_loader

    model.model.eval()
    model.model.avg = False

    simple_tracker = {i: [] for i in range(119)}
    element_tracker = {i: [] for i in range(119)}
    composition_tracker = {}

    with torch.no_grad():
        for i, data in enumerate(tqdm(model.data_loader)):
            X, y, formula = data
            src, frac = X.squeeze(-1).chunk(2, dim=1)
            src = src.to(compute_device, dtype=torch.long, non_blocking=True)
            frac = frac.to(compute_device, dtype=data_type, non_blocking=True)
            y = y.to(compute_device, dtype=data_type, non_blocking=True)
            output = model.model.forward(src, frac)
            mask = (src == 0).unsqueeze(-1).repeat(1, 1, 1)
            prediction, uncertainty, prob = output.chunk(3, dim=-1)
            prediction = prediction * torch.sigmoid(prob)
            uncertainty = torch.exp(uncertainty) * model.scaler.std
            prediction = model.scaler.unscale(prediction)
            prediction = prediction * ~mask
            uncertainty = uncertainty * ~mask
            if model.classification:
                prediction = torch.sigmoid(prediction)
            for i in range(src.shape[0]):
                if any(prediction[i].cpu().numpy().ravel() < 0):
                    composition_tracker[formula[i]] = [
                        src[i].cpu().numpy(), frac[i].cpu().numpy(),
                        y[i].cpu().numpy(), prediction[i].cpu().numpy(),
                        uncertainty[i].cpu().numpy()
                    ]
                for j in range(src.shape[1]):
                    element_tracker[int(src[i][j])].append(
                        float(prediction[i][j]))
                    simple_tracker[int(src[i][j])].append(float(y[i]))

    def elem_view(element_tracker, plot=True):
        property_tracker = {}
        x_max = max([y[1] for y in model.data_loader.dataset])
        x_min = min([y[1] for y in model.data_loader.dataset])
        x_range = x_max - x_min
        x_min_buffer = 0.1 * x_range
        x_max_buffer = 0.1 * x_range
        for key in element_tracker.keys():
            data = element_tracker[key]
            if len(data) > 10:
                sum_prop = sum(data)
                mean_prop = sum_prop / len(data)
                prop = mean_prop
                property_tracker[all_symbols[key]] = prop
                if plot:
                    plt.figure(figsize=(4, 4))
                    hist_kws = {
                        'edgecolor': 'k',
                        'linewidth': 2,
                        'alpha': 1,
                        'facecolor': '#A1D884'
                    }
                    ax = sns.distplot(
                        data,
                        label=f'{all_symbols[key]}, n={len(data)}',
                        kde=False,
                        bins=np.arange(0, 500, 25),
                        hist_kws=hist_kws,
                        kde_kws={
                            'color': 'k',
                            'linewidth': 2
                        })

                    ax.axes.yaxis.set_visible(False)
                    plt.legend()
                    plt.xlim(x_min - x_min_buffer, x_max + x_max_buffer)
                    plt.xlabel('Bulk Modulus Contribution (GPa)')
                    plt.tick_params(axis='both', which='both', direction='in')

                    save_dir = f'figures/contributions/{mat_prop}/'
                    os.makedirs(save_dir, exist_ok=True)
                    plt.savefig(f'{save_dir}{all_symbols[key]}.png',
                                dpi=300,
                                bbox_inches='tight')
                    plt.show()
        return property_tracker

    if simple:
        property_tracker = elem_view(simple_tracker, plot=True)
    else:
        property_tracker = elem_view(element_tracker, plot=True)

    return property_tracker
コード例 #6
0
    'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta',
    'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At',
    'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk',
    'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt',
    'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'
]

classification = False
batch_size = None
transfer = None
verbose = True
# Get the TorchedCrabNet architecture loaded
mat_prop = 'steels_yield'
file_name = rf'data\matbench_cv\{mat_prop}\val0.csv'

model = Model(CrabNet(compute_device=compute_device).to(compute_device),
              model_name=f'{mat_prop}',
              verbose=verbose)

batch_size = 1
edm_loader = EDM_CsvLoader(csv_data=file_name, batch_size=batch_size)
data_loader = edm_loader.get_data_loaders(inference=True)


# %%
class Embedder(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.mat2vec = True
        if self.mat2vec:
コード例 #7
0
    'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt',
    'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'
]

color = [
    '#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33',
    '#a65628', '#f781bf'
]

classification_list = []

num = ''
mat_prop = 'aflow__Egap'

# Get the TorchedCrabNet architecture loaded
model = Model(CrabNet().to(compute_device), model_name=f'{mat_prop}')
if True:
    model.load_network(f'{mat_prop}{num}.pth')
    model.model_name = f'{mat_prop}{num}'

if mat_prop in classification_list:
    model.classification = True

mat_prop = 'aflow__Egap'
test_data = rf'data\benchmark_data\{mat_prop}\train.csv'
# test_data = rf'data\matbench_cv\{mat_prop}\train{num}.csv'

model.load_data(test_data,
                batch_size=2**0)  # data is reloaded to model.data_loader

len_dataset = len(model.data_loader.dataset)
コード例 #8
0
from utils.get_compute_device import get_compute_device
from benchmark_crabnet import load_model, get_results

import torch
from torch import nn

from utils.utils import CONSTANTS

compute_device = get_compute_device()


# %%
mat_prop = 'mp_bulk_modulus'
crabnet_params = {'d_model': 512, 'N': 3, 'heads': 4}

model = Model(CrabNet(**crabnet_params, compute_device=compute_device).to(compute_device))
model.load_network(f'{mat_prop}.pth')

# Load the data you want to predict with
test_data = rf'data\benchmark_data\{mat_prop}\train.csv'
model.load_data(test_data)  # data is reloaded to model.data_loader
output = model.predict(model.data_loader)  # predict the data saved here


# %%
class SaveOutput:
    def __init__(self):
        self.outputs = []

    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out)
コード例 #9
0
ファイル: Paper_FIG_2.py プロジェクト: kaaiian/CrabNet
    def __call__(self, module, module_in, module_out):
        module_out = [out.detach().cpu() for out in module_out]
        self.outputs.append(module_out)

    def clear(self):
        self.outputs = []


save_output = SaveOutput()

for data, in_frac in zip(datas, in_fracs):

    # Create a model
    model = Model(
        CrabNet(**torchnet_params,
                compute_device=compute_device).to(compute_device))
    model.load_network(f'{mat_prop}.pth')
    hook_handles = []

    # Insert forward hooks into model
    for layer in model.model.modules():
        if isinstance(layer, torch.nn.modules.activation.MultiheadAttention):
            # print('isinstance')
            handle = layer.register_forward_hook(save_output)
            hook_handles.append(handle)

    model.load_data(data)  # data is reloaded to model.data_loader

    save_output.clear()
    output = model.predict(model.data_loader)