Пример #1
0
    def worker(mix_k):
        local_hparams = JsonConfig(hparams_dir)
        local_hparams.Mixture.num_component = mix_k
        model = "GenMM-K{}" if local_hparams.Mixture.naive else "LatMM-K{}"
        local_hparams.Dir.log_root = os.path.join(local_hparams.Dir.log_root,
                                                  model.format(mix_k))
        this_device = next(device_iter)
        local_hparams.Device.glow[0] = this_device
        local_hparams.Device.data = this_device
        print("Dir: {} and device: {}".format(local_hparams.Dir.log_root,
                                              this_device))
        peeked = False
        if not peeked:
            tmp_dataloader = torch.utils.data.DataLoader(dataset_ins,
                                                         batch_size=64,
                                                         shuffle=True,
                                                         num_workers=int(2))
            img = next(iter(tmp_dataloader))[0]

            if not os.path.exists(local_hparams.Dir.log_root):
                os.makedirs(local_hparams.Dir.log_root)
            # peek the training data set
            vutils.save_image(
                img.add(0.5),
                os.path.join(local_hparams.Dir.log_root,
                             "img_under_evaluation.png"))
            peeked = True

        built = build(local_hparams, True)
        trainer = Trainer(**built, dataset=dataset_ins, hparams=local_hparams)
        trainer.train()
Пример #2
0
def load_classifier(net_name):
    myclassifer = []
    CLASSIFIER_DIR = os.path.join(hparams.Dir.classifier_dir, net_name)
    label_list = range(hparams.Data.num_classes)

    for the_label in label_list:
        print("[Loading classifer: {}]".format(the_label))
        hparams.Infer.pre_trained = CLASSIFIER_DIR.format(the_label)
        built = build(hparams, False)
        built["graph"].get_component().eval()
        myclassifer.append(built["graph"])
    return myclassifer
Пример #3
0
    def worker(label):
        # load the subset data of the label
        local_hparams = JsonConfig(hparams_dir)

        local_hparams.Dir.log_root = os.path.join(local_hparams.Dir.log_root,
                                                  "classfier{}".format(label))
        dataset = load_obj(
            os.path.join(dataset_root,
                         "classSets/" + "subset{}".format(label)))
        if True:
            tmp_dataloader = torch.utils.data.DataLoader(dataset,
                                                         batch_size=64,
                                                         shuffle=True,
                                                         num_workers=int(2))
            img = next(iter(tmp_dataloader))

            if not os.path.exists(local_hparams.Dir.log_root):
                os.makedirs(local_hparams.Dir.log_root)

            vutils.save_image(
                img.data.add(0.5),
                os.path.join(local_hparams.Dir.log_root,
                             "img_under_evaluation.png"))

        # dump the json file for performance evaluation
        if not os.path.exists(
                os.path.join(local_hparams.Dir.log_root,
                             local_hparams.Data.dataset + ".json")):
            get_hparams = JsonConfig(hparams_dir)
            data_dir = get_hparams.Data.dataset_root
            get_hparams.Data.dataset_root = data_dir.replace("separate", "all")
            get_hparams.dump(dir_path=get_hparams.Dir.log_root,
                             json_name=get_hparams.Data.dataset + ".json")

        ### build model and train
        built = build(local_hparams, True)

        print(hparams.Dir.log_root)
        trainer = Trainer(**built, dataset=dataset, hparams=local_hparams)
        trainer.train()
Пример #4
0
    date = date[:date.rfind(":")].replace("-", "")\
                                 .replace(":", "")\
                                 .replace(" ", "_")
    log_dir = os.path.join(hparams.Dir.log_root, "log_" + date)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    print("log_dir:" + str(log_dir))

    is_training = hparams.Infer.pre_trained == ""

    data = dataset(hparams, is_training)
    x_channels, cond_channels = data.n_channels()

    # build graph
    built = build(x_channels, cond_channels, hparams, is_training)

    if is_training:
        # build trainer
        trainer = Trainer(**built, data=data, log_dir=log_dir, hparams=hparams)

        # train model
        trainer.train()
    else:
        # Synthesize a lot of data.
        generator = Generator(data, built['data_device'], log_dir, hparams)
        if "temperature" in hparams.Infer:
            temp = hparams.Infer.temperature
        else:
            temp = 1
Пример #5
0
    dataset = args["<dataset>"]
    dataset_root = args["<dataset_root>"]
    assert dataset in vision.Datasets, (
        "`{}` is not supported, use `{}`".format(dataset,
                                                 vision.Datasets.keys()))
    assert os.path.exists(dataset_root), (
        "Failed to find root dir `{}` of dataset.".format(dataset_root))
    assert os.path.exists(hparams), (
        "Failed to find hparams josn `{}`".format(hparams))
    hparams = JsonConfig(hparams)
    dataset = vision.Datasets[dataset]
    # set transform of dataset

    print("-------dataset name: ", dataset)
    if dataset == "rgb2nir":
        print("----------use rgb2nir preprocess ")
        transform = transforms.Compose(
            [transforms.Resize(hparams.Data.resize),
             transforms.ToTensor()])
    else:
        transform = transforms.Compose(
            [transforms.Resize(hparams.Data.resize),
             transforms.ToTensor()])

    # build graph and dataset
    built = build(hparams, True)
    dataset = dataset(dataset_root, transform=transform)
    # begin to train
    trainer = Trainer(**built, dataset=dataset, hparams=hparams)
    trainer.train()
Пример #6
0
        os.makedirs(z_dir)
        generate_z = True
    else:
        print("Load Z from {}".format(z_dir))
        generate_z = False

    hparams = JsonConfig("hparams/celeba.json")
    dataset = vision.Datasets["celeba"]
    # set transform of dataset
    transform = transforms.Compose([
        transforms.CenterCrop(hparams.Data.center_crop),
        transforms.Resize(hparams.Data.resize),
        transforms.ToTensor()
    ])
    # build
    graph = build(hparams, False)["graph"]
    dataset = dataset(dataset_root, transform=transform)

    # get Z
    if not generate_z:
        # try to load
        try:
            delta_Z = []
            for i in range(hparams.Glow.y_classes):
                z = np.load(os.path.join(z_dir, "detla_z_{}.npy".format(i)))
                delta_Z.append(z)
        except FileNotFoundError:
            # need to generate
            generate_z = True
            print("Failed to load {} Z".format(hparams.Glow.y_classes))
            quit()
Пример #7
0
    dataname = args["<dataset>"]
    dataset_root = args["<dataset_root>"]
    mode = args["<mode>"]
    assert mode in ["Generating", "Interpolation"]
    #z_dir = args["<z_dir>"]
    assert os.path.exists(dataset_root), (
        "Failed to find root dir `{}` of dataset.".format(dataset_root))
    assert os.path.exists(hparams), (
        "Failed to find hparams josn `{}`".format(hparams))
    IMG_DIR = "pictures/mnist/"
    if not os.path.exists(IMG_DIR):
        os.makedirs(IMG_DIR)
    hparams = JsonConfig(hparams)

    batch_size = hparams.Train.batch_size
    builded = build(hparams, False)
    graph = builded["graph"]
    # obtain current prior of each component in the mixture distribution
    #################  1. do the generating work #####################
    pk = builded["graph_prior"]
    IMG_NAME = "GenMM_K{}".format(
        hparams.Mixture.num_component
    ) if hparams.Mixture.naive else "LatMM_K{}".format(
        hparams.Mixture.num_component)

    if mode == "Generating":
        # for the_graph in graph:
        #     pk.append(the_graph['prior'])
        pknp = pk.numpy()
        pk = tuple(pk.numpy())
        print("The current model prior is: {}".format(pk))
Пример #8
0
    date = str(datetime.datetime.now())
    date = date[:date.rfind(":")].replace("-", "")\
                                 .replace(":", "")\
                                 .replace(" ", "_")
    log_dir = os.path.join(hparams.Dir.log_root, "log_" + date)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    print("log_dir:" + str(log_dir))
    data = dataset(hparams)
    x_channels, cond_channels = data.get_train_dataset().n_channels()

    # build graph

    if hparams.Infer.pre_trained == "":
        built = build(x_channels, cond_channels, hparams, True)
    else:
        built = build(x_channels, cond_channels, hparams, False)

    # build trainer
    trainer = Trainer(**built, data=data, log_dir=log_dir, hparams=hparams)
    if hparams.Infer.pre_trained == "":

        # train model
        trainer.train()
    else:
        # generate from pre-trained model
        if "temperature" in hparams.Infer:
            temp = hparams.Infer.temperature
        else:
            temp = 1