model.train_from_graphs(train_graphs, train_targets, val_graphs, val_targets, epochs=EPOCHS, verbose=2, initial_epoch=0, callbacks=callbacks) # 6. Model testing ## load the best model with lowest validation error files = glob("./callback/*.hdf5") best_model = sorted(files, key=os.path.getctime)[-1] model.load_weights(best_model) model.save_model("best_model.hdf5") def evaluate(test_graphs, test_targets): """ Evaluate the test errors using test_graphs and test_targets Args: test_graphs (list): list of graphs test_targets (list): list of target properties Returns: mean absolute errors """ test_data = model.graph_converter.get_flat_data(test_graphs, test_targets)
def train(): # Parse args args = parse_args() radius = args.radius n_works = args.n_works warm_start = args.warm_start output_path = args.output_path graph_file = args.graph_file prop_col = args.property learning_rate = args.learning_rate embedding_file = args.embedding_file k_folds = list(map(int, args.k_folds.split(","))) print("args is : {}".format(args)) print("Local devices are : {}, \n\n Available gpus are : {}".format( device_lib.list_local_devices(), K.tensorflow_backend._get_available_gpus())) # prepare output path if not os.path.exists(output_path): os.makedirs(output_path, exist_ok=True) # Get a crystal graph with cutoff radius A cg = CrystalGraph( bond_convertor=GaussianDistance(np.linspace(0, radius + 1, 100), 0.5), cutoff=radius, ) if graph_file is not None: # load graph data with gzip.open(graph_file, "rb") as f: valid_graph_dict = pickle.load(f) idx_list = list(range(len(valid_graph_dict))) valid_idx_list = [ idx for idx, graph in valid_graph_dict.items() if graph is not None ] else: # load structure data with gzip.open(args.input_file, "rb") as f: df = pd.DataFrame(pickle.load(f))[["structure", prop_col]] idx_list = list(range(len(df))) # load embedding data for transfer learning if embedding_file is not None: with open(embedding_file) as json_file: embedding_data = json.load(json_file) # Calculate and save valid graphs valid_idx_list = list() valid_graph_dict = dict() for idx in idx_list: try: graph = cg.convert(df["structure"].iloc[idx]) if embedding_file is not None: graph["atom"] = [embedding_data[i] for i in graph["atom"]] valid_graph_dict[idx] = { "graph": graph, "target": df[prop_col].iloc[idx], } valid_idx_list.append(idx) except RuntimeError: valid_graph_dict[idx] = None # Save graphs with gzip.open(os.path.join(output_path, "graphs.pkl.gzip"), "wb") as f: pickle.dump(valid_graph_dict, f) # Split data kf = KFold(n_splits=args.cv, random_state=18012019, shuffle=True) for fold, (train_val_idx, test_idx) in enumerate(kf.split(idx_list)): print(fold) if fold not in k_folds: continue fold_output_path = os.path.join(output_path, "kfold_{}".format(fold)) fold_model_path = os.path.join(fold_output_path, "model") if not os.path.exists(fold_model_path): os.makedirs(fold_model_path, exist_ok=True) train_idx, val_idx = train_test_split(train_val_idx, test_size=0.25, random_state=18012019, shuffle=True) # Calculate valid train validation test ids and save it valid_train_idx = sorted(list(set(train_idx) & (set(valid_idx_list)))) valid_val_idx = sorted(list(set(val_idx) & (set(valid_idx_list)))) valid_test_idx = sorted(list(set(test_idx) & (set(valid_idx_list)))) np.save(os.path.join(fold_output_path, "train_idx.npy"), valid_train_idx) np.save(os.path.join(fold_output_path, "val_idx.npy"), valid_val_idx) np.save(os.path.join(fold_output_path, "test_idx.npy"), valid_test_idx) # Prepare training graphs train_graphs = [valid_graph_dict[i]["graph"] for i in valid_train_idx] train_targets = [ valid_graph_dict[i]["target"] for i in valid_train_idx ] # Prepare validation graphs val_graphs = [valid_graph_dict[i]["graph"] for i in valid_val_idx] val_targets = [valid_graph_dict[i]["target"] for i in valid_val_idx] # Normalize targets or not if args.normalize: y_scaler = StandardScaler() train_targets = y_scaler.fit_transform( np.array(train_targets).reshape(-1, 1)).ravel() val_targets = y_scaler.transform( np.array(val_targets).reshape((-1, 1))).ravel() else: y_scaler = None # Initialize model if warm_start is None: # Set up model if learning_rate is None: learning_rate = 1e-3 model = MEGNetModel( 100, 2, nblocks=args.n_blocks, nvocal=95, npass=args.n_pass, lr=learning_rate, loss=args.loss, graph_convertor=cg, is_classification=True if args.type == "classification" else False, nfeat_node=None if embedding_file is None else 16, ) initial_epoch = 0 else: # Model file model_list = [ m_file for m_file in os.listdir( os.path.join(warm_start, "kfold_{}".format(fold), "model")) if m_file.endswith(".hdf5") ] if args.type == "classification": model_list.sort( key=lambda m_file: float( m_file.split("_")[3].replace(".hdf5", "")), reverse=False, ) else: model_list.sort( key=lambda m_file: float( m_file.split("_")[3].replace(".hdf5", "")), reverse=True, ) model_file = os.path.join(warm_start, "kfold_{}".format(fold), "model", model_list[-1]) # Load model from file if learning_rate is None: full_model = load_model( model_file, custom_objects={ "softplus2": softplus2, "Set2Set": Set2Set, "mean_squared_error_with_scale": mean_squared_error_with_scale, "MEGNetLayer": MEGNetLayer, }, ) learning_rate = K.get_value(full_model.optimizer.lr) # Set up model model = MEGNetModel( 100, 2, nblocks=args.n_blocks, nvocal=95, npass=args.n_pass, lr=learning_rate, loss=args.loss, graph_convertor=cg, is_classification=True if args.type == "classification" else False, nfeat_node=None if embedding_file is None else 16, ) model.load_weights(model_file) initial_epoch = int(model_list[-1].split("_")[2]) print("warm start from : {}, \nlearning_rate is {}.".format( model_file, learning_rate)) # Train model.train_from_graphs( train_graphs, train_targets, val_graphs, val_targets, batch_size=args.batch_size, epochs=args.max_epochs, verbose=2, initial_epoch=initial_epoch, use_multiprocessing=False if n_works <= 1 else True, workers=n_works, dirname=fold_model_path, y_scaler=y_scaler, save_best_only=args.save_best_only, )