def main_cs2sg_one(csnum: int, hs_indices: List[int], model_struct: tuple, num_epoch: int = 1): print("crystal system: {}".format(csnum)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = base_model.get_cs2sg(*model_struct) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.75) criterion = torch.nn.CrossEntropyLoss() dataset = data_loader.SetCs2Sg(csnum, hs_indices, [ "data/actual/spacegroup_list_{}.txt".format(i) for i in crystal.spacegroup_number_range(csnum) ], 0.1) train_loader, valid_loader = data_loader.get_valid_train_loader( dataset, 32) network.validate_train_loop(device, model, optimizer, scheduler, criterion, valid_loader, train_loader, num_epoch) data_processing.append_guess_spacegroup_in_crystal_list_files( device, model, csnum, hs_indices, "data/actual/crystal_list_{}.txt".format(csnum), "data/guess/")
def main_bs2cs(num_epoch: int = 1): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = base_model.get_bs2cs() model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.75) criterion = torch.nn.CrossEntropyLoss() dataset = data_loader.SetBs2Cs( [0, 1, 3, 4, 5, 7, 8, 13, 31, 34, 37 ], # 12 hs points in Brillouin zone ["data/actual/crystal_list_{}.txt".format(i) for i in range(1, 8)], 0.1) train_loader, valid_loader = data_loader.get_valid_train_loader( dataset, 32) network.validate_train_loop(device, model, optimizer, scheduler, criterion, valid_loader, train_loader, num_epoch) data_processing.create_guess_list_files( device, model, 7, [0, 1, 3, 4, 5, 7, 8, 13, 31, 34, 37 ], # 12 hs points in Brillouin zone "data/actual/valid_list.txt", "data/guess/", "crystal_list_{}.txt")
def main_bs2sg(num_epoch: int = 1): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = base_model.get_bs2sg() model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.75) criterion = torch.nn.CrossEntropyLoss() dataset = data_loader.SetBs2Sg([ "data/actual/spacegroup_list_{}.txt".format(i) for i in range(1, 231) ], 0.1) train_loader, valid_loader = data_loader.get_valid_train_loader( dataset, 32) network.validate_train_loop(device, model, optimizer, scheduler, criterion, valid_loader, train_loader, num_epoch) data_processing.create_empty_list_files(230, "data/guess/", "spacegroup_list_{}.txt") data_processing.create_guess_list_files(device, model, 230, "data/actual/valid_list.txt", "data/guess/", "spacegroup_list_{}.txt")
def main_crys2sg_one(crysnum: int, num_epoch: int = 1): print("crystal system: {}".format(crysnum)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") margins = [2, 15, 74, 142, 167, 194, 230] model = base_model.get_crys2sg( # 360, 100, 100, margins[crysnum - 1] - margins[crysnum - 2] if crysnum > 1 else 2 1200, 128, 128, margins[crysnum - 1] - margins[crysnum - 2] + 1 if crysnum > 1 else 3 # <<<<<< ) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.75) criterion = torch.nn.CrossEntropyLoss() dataset = data_loader.SetCrys2Sg([ "data/actual/spacegroup_list_{}.txt".format(i) for i in crystal.spacegroup_number_range(crysnum) ], crysnum, 0.1) train_loader, valid_loader = data_loader.get_valid_train_loader( dataset, 32) network.validate_train_loop(device, model, optimizer, scheduler, criterion, valid_loader, train_loader, num_epoch) data_processing.append_guess_spacegroup_in_crystal_list_files( device, model, crysnum, "data/guess/crystal_list_{}.txt".format(crysnum), "data/guess/")
def train(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if NET == "NN": model = SpaceGroupNN(OUT_FEATURES, N_FC).to(device) else: model = SpaceGroupCNN(N_CNN, CHANNELS, OUT_FEATURES, N_FC).to(device) if MODEL_PATH: model.load_state_dict(torch.load(MODEL_PATH)) model.train() optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.75) criterion = torch.nn.CrossEntropyLoss() dataset = data_loader.Cs2Sg(CRYSTAL_SYSTEM, 0.1) train_loader, valid_loader = data_loader.get_valid_train_loader( dataset, 32) network.validate_train_loop(device, model, optimizer, scheduler, criterion, valid_loader, train_loader, epoch, SAVE_PATH, FIGURE_PATH)