data1 = np.load('./data/PTCDA_CO.npz') X1 = data1['data'] afm_dim1 = (data1['lengthX'], data1['lengthY']) data2 = np.load('./data/PTCDA_Xe.npz') X2 = data2['data'] afm_dim2 = (data2['lengthX'], data2['lengthY']) assert afm_dim1 == afm_dim2 afm_dim = afm_dim1 X_exp = apply_preprocessing_exp([X1[None], X2[None]], afm_dim) # Load model for sim input_shape = afmulator.scan_dim[:2] + (10, ) model_sim = ESUNet(n_in=2, n_out=2, input_shape=input_shape, last_relu=[False, True]) load_pretrained_weights(model_sim, tip_type='CO-Xe') # Load model for exp (need two models because of different input sizes) input_shape = X_exp[0].shape[1:] model_exp = ESUNet(n_in=2, n_out=2, input_shape=input_shape, last_relu=[False, True]) load_pretrained_weights(model_exp, tip_type='CO-Xe') # Get predictions pred_sim = model_sim.predict(X_sim) pred_exp = model_exp.predict(X_exp)
data_bcb = np.load('./data/BCB_CO.npz') X_bcb = data_bcb['data'] afm_dim_bcb = (data_bcb['lengthX'], data_bcb['lengthY']) X_bcb = apply_preprocessing_bcb([X_bcb[None]], afm_dim_bcb) # Load PTCDA data and preprocess data_ptcda = np.load('./data/PTCDA_CO.npz') X_ptcda = data_ptcda['data'] afm_dim_ptcda = (data_ptcda['lengthX'], data_ptcda['lengthY']) X_ptcda = apply_preprocessing_ptcda([X_ptcda[None]], afm_dim_ptcda) # Load model for simulations input_shape = afmulator.scan_dim[:2] + (afmulator.scan_dim[2] - afmulator.df_steps, ) model_sim = ESUNet(n_in=1, n_out=2, input_shape=input_shape, last_relu=[False, True]) load_pretrained_weights(model_sim, tip_type='CO') # Load model for BCB model_bcb = ESUNet(n_in=1, n_out=2, input_shape=X_bcb[0].shape[1:], last_relu=[False, True]) load_pretrained_weights(model_bcb, tip_type='CO') # Load model for PTCDA model_ptcda = ESUNet(n_in=1, n_out=2, input_shape=X_ptcda[0].shape[1:], last_relu=[False, True])
# Define generator for Xe-Cl trainer_Xe_Cl = Trainer( afmulator, aux_maps, molecules, batch_size=1, distAbove=5.2, iZPPs=[54, 17], # Xe, Cl Qs=[[30, -60, 30, 0], [-0.3, 0, 0, 0]], QZs=[[0.1, 0, -0.1, 0], [0, 0, 0, 0]]) # Load model for Cl-CO input_shape = afmulator.scan_dim[:2] + (afmulator.scan_dim[2] - afmulator.df_steps, ) model_Cl_CO = ESUNet(n_in=2, n_out=2, input_shape=input_shape, last_relu=[False, True]) load_pretrained_weights(model_Cl_CO, tip_type='Cl-CO') # Load model for Xe-Cl input_shape = afmulator.scan_dim[:2] + (afmulator.scan_dim[2] - afmulator.df_steps, ) model_Xe_Cl = ESUNet(n_in=2, n_out=2, input_shape=input_shape, last_relu=[False, False]) load_pretrained_weights(model_Xe_Cl, tip_type='Xe-Cl') # Loop over molecules and plot width_ratios = [6, 12] fig = plt.figure(figsize=(sum(width_ratios), 6 * len(molecules)))
model_dir = './model' # Directory where all output files are saved to pred_dir = os.path.join(model_dir, 'predictions/') # Where to save predictions checkpoint_dir = os.path.join(model_dir, 'checkpoints/') # Where to save model checkpoints log_path = os.path.join(model_dir, 'training.log') # Where to save loss history during training history_plot_path = os.path.join(model_dir, 'loss_history.png') # Where to plot loss history during training optimizer_path = os.path.join(model_dir, 'optimizer_state.npz') # Where to save optimizer state descriptors = ['ES', 'Height_Map'] # Labels for outputting information # Create output folder if not os.path.exists(model_dir): os.makedirs(model_dir) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # Define model model = ESUNet(n_in=2, n_out=2, input_shape=input_shape, last_relu=[False, True], labels=descriptors) # CO-Xe, Cl-CO # model = ESUNet(n_in=2, n_out=2, input_shape=input_shape, last_relu=[False, False], labels=descriptors) # Cl-Xe # model = ESUNet(n_in=1, n_out=2, input_shape=input_shape, last_relu=[False, True], labels=descriptors) # CO optimizer = optimizers.Adam(lr=0.001, decay=1e-5) model.compile(optimizer, 'mse', loss_weights=loss_weights) model.summary() # Setup data loading train_loader = Loader(os.path.join(data_dir, 'train/')) val_loader = Loader(os.path.join(data_dir, 'val/')) test_loader = Loader(os.path.join(data_dir, 'test/')) # Setup callbacks checkpointer = ModelCheckpoint(os.path.join(checkpoint_dir, 'weights_{epoch:d}.h5'), save_weights_only=True) logger = CSVLogger(log_path, append=True) plotter = HistoryPlotter(log_path, history_plot_path, descriptors)