Esempio n. 1
0
def main_cs2sg_one(csnum: int,
                   hs_indices: List[int],
                   model_struct: tuple,
                   num_epoch: int = 1):
    print("crystal system: {}".format(csnum))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = base_model.get_cs2sg(*model_struct)
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1,
                                                gamma=0.75)
    criterion = torch.nn.CrossEntropyLoss()

    dataset = data_loader.SetCs2Sg(csnum, hs_indices, [
        "data/actual/spacegroup_list_{}.txt".format(i)
        for i in crystal.spacegroup_number_range(csnum)
    ], 0.1)
    train_loader, valid_loader = data_loader.get_valid_train_loader(
        dataset, 32)

    network.validate_train_loop(device, model, optimizer, scheduler, criterion,
                                valid_loader, train_loader, num_epoch)

    data_processing.append_guess_spacegroup_in_crystal_list_files(
        device, model, csnum, hs_indices,
        "data/actual/crystal_list_{}.txt".format(csnum), "data/guess/")
Esempio n. 2
0
def main_crys2sg_one(crysnum: int, num_epoch: int = 1):
    print("crystal system: {}".format(crysnum))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    margins = [2, 15, 74, 142, 167, 194, 230]
    model = base_model.get_crys2sg(
        # 360, 100, 100, margins[crysnum - 1] - margins[crysnum - 2] if crysnum > 1 else 2
        1200,
        128,
        128,
        margins[crysnum - 1] - margins[crysnum - 2] +
        1 if crysnum > 1 else 3  # <<<<<<
    )
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1,
                                                gamma=0.75)
    criterion = torch.nn.CrossEntropyLoss()

    dataset = data_loader.SetCrys2Sg([
        "data/actual/spacegroup_list_{}.txt".format(i)
        for i in crystal.spacegroup_number_range(crysnum)
    ], crysnum, 0.1)

    train_loader, valid_loader = data_loader.get_valid_train_loader(
        dataset, 32)

    network.validate_train_loop(device, model, optimizer, scheduler, criterion,
                                valid_loader, train_loader, num_epoch)

    data_processing.append_guess_spacegroup_in_crystal_list_files(
        device, model, crysnum,
        "data/guess/crystal_list_{}.txt".format(crysnum), "data/guess/")
Esempio n. 3
0
def append_guess_spacegroup_in_crystal_list_files(device, model, csnum,
                                                  hs_indices, in_list_path,
                                                  out_list_dir):
    if os.stat(in_list_path).st_size == 0:
        return
    file_paths = numpy.loadtxt(in_list_path, "U60")
    for i, file_path in enumerate(file_paths):
        with open(file_path, "r") as file:
            data_json = json.load(file)
            data_input_np = numpy.array(data_json["bands"])
            data_input_np = data_input_np[:, hs_indices]
            data_input_np = data_input_np.flatten().T
            data_input = torch.from_numpy(data_input_np).float()
            output = model(data_input.to(device))
            sgnum = torch.max(
                output,
                0)[1].item() + 1 + crystal.spacegroup_index_lower(csnum)
            if sgnum not in crystal.spacegroup_number_range(csnum):
                continue
        with open(out_list_dir + "spacegroup_list_{}.txt".format(sgnum),
                  "a") as file_out:
            file_out.write(file_path + "\n")
        print("\r\tcreate guess list: {}/{}".format(i, len(file_paths)),
              end="")
    print("\rcreate guess list: {}".format(len(file_paths)))
Esempio n. 4
0
    # prepare actual data # (Do this everytime dataset is changed)
    # data_processing.create_valid_list_files(100, "data/input_data_test06/", "data/actual/valid_list.txt")  # <<<<
    # data_processing.create_actual_crystal_list_files("data/actual/valid_list.txt", "data/actual/")
    # data_processing.create_actual_spacegroup_list_files("data/actual/valid_list.txt", "data/actual/")

    # generate guess data #
    # main_bs2sg(num_epoch=30)
    # main_bs2crys(num_epoch=15)
    main_crys2sg_all(num_epoch=20)

    # analyse result #
    # analysis.print_result(range(1, 8), "data/guess/", "data/actual/", "crystal_list_{}.txt", plt=True)  # into crys
    # analysis.print_result(range(1, 231), "data/guess/", "data/actual/", "spacegroup_list_{}.txt")  # into sg
    for c in range(1, 8):  # into crys and sg
        print("crystal system {}".format(c))
        print("\nbandstructure to crystal system result")
        analysis.print_result([c], "data/guess/", "data/actual/",
                              "crystal_list_{}.txt")
        print("\ncrystal to spacegroup system result")
        analysis.print_result(crystal.spacegroup_number_range(c),
                              "data/guess/", "data/actual/",
                              "spacegroup_list_{}.txt")
        print("\n")

    torch.cuda.empty_cache()
    import winsound

    duration = 500  # milliseconds
    freq = 200  # Hz
    winsound.Beep(freq, duration)