def main_cs2sg_one(csnum: int, hs_indices: List[int], model_struct: tuple, num_epoch: int = 1): print("crystal system: {}".format(csnum)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = base_model.get_cs2sg(*model_struct) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.75) criterion = torch.nn.CrossEntropyLoss() dataset = data_loader.SetCs2Sg(csnum, hs_indices, [ "data/actual/spacegroup_list_{}.txt".format(i) for i in crystal.spacegroup_number_range(csnum) ], 0.1) train_loader, valid_loader = data_loader.get_valid_train_loader( dataset, 32) network.validate_train_loop(device, model, optimizer, scheduler, criterion, valid_loader, train_loader, num_epoch) data_processing.append_guess_spacegroup_in_crystal_list_files( device, model, csnum, hs_indices, "data/actual/crystal_list_{}.txt".format(csnum), "data/guess/")
def main_crys2sg_one(crysnum: int, num_epoch: int = 1): print("crystal system: {}".format(crysnum)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") margins = [2, 15, 74, 142, 167, 194, 230] model = base_model.get_crys2sg( # 360, 100, 100, margins[crysnum - 1] - margins[crysnum - 2] if crysnum > 1 else 2 1200, 128, 128, margins[crysnum - 1] - margins[crysnum - 2] + 1 if crysnum > 1 else 3 # <<<<<< ) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.75) criterion = torch.nn.CrossEntropyLoss() dataset = data_loader.SetCrys2Sg([ "data/actual/spacegroup_list_{}.txt".format(i) for i in crystal.spacegroup_number_range(crysnum) ], crysnum, 0.1) train_loader, valid_loader = data_loader.get_valid_train_loader( dataset, 32) network.validate_train_loop(device, model, optimizer, scheduler, criterion, valid_loader, train_loader, num_epoch) data_processing.append_guess_spacegroup_in_crystal_list_files( device, model, crysnum, "data/guess/crystal_list_{}.txt".format(crysnum), "data/guess/")
def append_guess_spacegroup_in_crystal_list_files(device, model, csnum, hs_indices, in_list_path, out_list_dir): if os.stat(in_list_path).st_size == 0: return file_paths = numpy.loadtxt(in_list_path, "U60") for i, file_path in enumerate(file_paths): with open(file_path, "r") as file: data_json = json.load(file) data_input_np = numpy.array(data_json["bands"]) data_input_np = data_input_np[:, hs_indices] data_input_np = data_input_np.flatten().T data_input = torch.from_numpy(data_input_np).float() output = model(data_input.to(device)) sgnum = torch.max( output, 0)[1].item() + 1 + crystal.spacegroup_index_lower(csnum) if sgnum not in crystal.spacegroup_number_range(csnum): continue with open(out_list_dir + "spacegroup_list_{}.txt".format(sgnum), "a") as file_out: file_out.write(file_path + "\n") print("\r\tcreate guess list: {}/{}".format(i, len(file_paths)), end="") print("\rcreate guess list: {}".format(len(file_paths)))
# prepare actual data # (Do this everytime dataset is changed) # data_processing.create_valid_list_files(100, "data/input_data_test06/", "data/actual/valid_list.txt") # <<<< # data_processing.create_actual_crystal_list_files("data/actual/valid_list.txt", "data/actual/") # data_processing.create_actual_spacegroup_list_files("data/actual/valid_list.txt", "data/actual/") # generate guess data # # main_bs2sg(num_epoch=30) # main_bs2crys(num_epoch=15) main_crys2sg_all(num_epoch=20) # analyse result # # analysis.print_result(range(1, 8), "data/guess/", "data/actual/", "crystal_list_{}.txt", plt=True) # into crys # analysis.print_result(range(1, 231), "data/guess/", "data/actual/", "spacegroup_list_{}.txt") # into sg for c in range(1, 8): # into crys and sg print("crystal system {}".format(c)) print("\nbandstructure to crystal system result") analysis.print_result([c], "data/guess/", "data/actual/", "crystal_list_{}.txt") print("\ncrystal to spacegroup system result") analysis.print_result(crystal.spacegroup_number_range(c), "data/guess/", "data/actual/", "spacegroup_list_{}.txt") print("\n") torch.cuda.empty_cache() import winsound duration = 500 # milliseconds freq = 200 # Hz winsound.Beep(freq, duration)