def worker(mix_k): local_hparams = JsonConfig(hparams_dir) local_hparams.Mixture.num_component = mix_k model = "GenMM-K{}" if local_hparams.Mixture.naive else "LatMM-K{}" local_hparams.Dir.log_root = os.path.join(local_hparams.Dir.log_root, model.format(mix_k)) this_device = next(device_iter) local_hparams.Device.glow[0] = this_device local_hparams.Device.data = this_device print("Dir: {} and device: {}".format(local_hparams.Dir.log_root, this_device)) peeked = False if not peeked: tmp_dataloader = torch.utils.data.DataLoader(dataset_ins, batch_size=64, shuffle=True, num_workers=int(2)) img = next(iter(tmp_dataloader))[0] if not os.path.exists(local_hparams.Dir.log_root): os.makedirs(local_hparams.Dir.log_root) # peek the training data set vutils.save_image( img.add(0.5), os.path.join(local_hparams.Dir.log_root, "img_under_evaluation.png")) peeked = True built = build(local_hparams, True) trainer = Trainer(**built, dataset=dataset_ins, hparams=local_hparams) trainer.train()
def load_classifier(net_name): myclassifer = [] CLASSIFIER_DIR = os.path.join(hparams.Dir.classifier_dir, net_name) label_list = range(hparams.Data.num_classes) for the_label in label_list: print("[Loading classifer: {}]".format(the_label)) hparams.Infer.pre_trained = CLASSIFIER_DIR.format(the_label) built = build(hparams, False) built["graph"].get_component().eval() myclassifer.append(built["graph"]) return myclassifer
def worker(label): # load the subset data of the label local_hparams = JsonConfig(hparams_dir) local_hparams.Dir.log_root = os.path.join(local_hparams.Dir.log_root, "classfier{}".format(label)) dataset = load_obj( os.path.join(dataset_root, "classSets/" + "subset{}".format(label))) if True: tmp_dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=int(2)) img = next(iter(tmp_dataloader)) if not os.path.exists(local_hparams.Dir.log_root): os.makedirs(local_hparams.Dir.log_root) vutils.save_image( img.data.add(0.5), os.path.join(local_hparams.Dir.log_root, "img_under_evaluation.png")) # dump the json file for performance evaluation if not os.path.exists( os.path.join(local_hparams.Dir.log_root, local_hparams.Data.dataset + ".json")): get_hparams = JsonConfig(hparams_dir) data_dir = get_hparams.Data.dataset_root get_hparams.Data.dataset_root = data_dir.replace("separate", "all") get_hparams.dump(dir_path=get_hparams.Dir.log_root, json_name=get_hparams.Data.dataset + ".json") ### build model and train built = build(local_hparams, True) print(hparams.Dir.log_root) trainer = Trainer(**built, dataset=dataset, hparams=local_hparams) trainer.train()
date = date[:date.rfind(":")].replace("-", "")\ .replace(":", "")\ .replace(" ", "_") log_dir = os.path.join(hparams.Dir.log_root, "log_" + date) if not os.path.exists(log_dir): os.makedirs(log_dir) print("log_dir:" + str(log_dir)) is_training = hparams.Infer.pre_trained == "" data = dataset(hparams, is_training) x_channels, cond_channels = data.n_channels() # build graph built = build(x_channels, cond_channels, hparams, is_training) if is_training: # build trainer trainer = Trainer(**built, data=data, log_dir=log_dir, hparams=hparams) # train model trainer.train() else: # Synthesize a lot of data. generator = Generator(data, built['data_device'], log_dir, hparams) if "temperature" in hparams.Infer: temp = hparams.Infer.temperature else: temp = 1
dataset = args["<dataset>"] dataset_root = args["<dataset_root>"] assert dataset in vision.Datasets, ( "`{}` is not supported, use `{}`".format(dataset, vision.Datasets.keys())) assert os.path.exists(dataset_root), ( "Failed to find root dir `{}` of dataset.".format(dataset_root)) assert os.path.exists(hparams), ( "Failed to find hparams josn `{}`".format(hparams)) hparams = JsonConfig(hparams) dataset = vision.Datasets[dataset] # set transform of dataset print("-------dataset name: ", dataset) if dataset == "rgb2nir": print("----------use rgb2nir preprocess ") transform = transforms.Compose( [transforms.Resize(hparams.Data.resize), transforms.ToTensor()]) else: transform = transforms.Compose( [transforms.Resize(hparams.Data.resize), transforms.ToTensor()]) # build graph and dataset built = build(hparams, True) dataset = dataset(dataset_root, transform=transform) # begin to train trainer = Trainer(**built, dataset=dataset, hparams=hparams) trainer.train()
os.makedirs(z_dir) generate_z = True else: print("Load Z from {}".format(z_dir)) generate_z = False hparams = JsonConfig("hparams/celeba.json") dataset = vision.Datasets["celeba"] # set transform of dataset transform = transforms.Compose([ transforms.CenterCrop(hparams.Data.center_crop), transforms.Resize(hparams.Data.resize), transforms.ToTensor() ]) # build graph = build(hparams, False)["graph"] dataset = dataset(dataset_root, transform=transform) # get Z if not generate_z: # try to load try: delta_Z = [] for i in range(hparams.Glow.y_classes): z = np.load(os.path.join(z_dir, "detla_z_{}.npy".format(i))) delta_Z.append(z) except FileNotFoundError: # need to generate generate_z = True print("Failed to load {} Z".format(hparams.Glow.y_classes)) quit()
dataname = args["<dataset>"] dataset_root = args["<dataset_root>"] mode = args["<mode>"] assert mode in ["Generating", "Interpolation"] #z_dir = args["<z_dir>"] assert os.path.exists(dataset_root), ( "Failed to find root dir `{}` of dataset.".format(dataset_root)) assert os.path.exists(hparams), ( "Failed to find hparams josn `{}`".format(hparams)) IMG_DIR = "pictures/mnist/" if not os.path.exists(IMG_DIR): os.makedirs(IMG_DIR) hparams = JsonConfig(hparams) batch_size = hparams.Train.batch_size builded = build(hparams, False) graph = builded["graph"] # obtain current prior of each component in the mixture distribution ################# 1. do the generating work ##################### pk = builded["graph_prior"] IMG_NAME = "GenMM_K{}".format( hparams.Mixture.num_component ) if hparams.Mixture.naive else "LatMM_K{}".format( hparams.Mixture.num_component) if mode == "Generating": # for the_graph in graph: # pk.append(the_graph['prior']) pknp = pk.numpy() pk = tuple(pk.numpy()) print("The current model prior is: {}".format(pk))
date = str(datetime.datetime.now()) date = date[:date.rfind(":")].replace("-", "")\ .replace(":", "")\ .replace(" ", "_") log_dir = os.path.join(hparams.Dir.log_root, "log_" + date) if not os.path.exists(log_dir): os.makedirs(log_dir) print("log_dir:" + str(log_dir)) data = dataset(hparams) x_channels, cond_channels = data.get_train_dataset().n_channels() # build graph if hparams.Infer.pre_trained == "": built = build(x_channels, cond_channels, hparams, True) else: built = build(x_channels, cond_channels, hparams, False) # build trainer trainer = Trainer(**built, data=data, log_dir=log_dir, hparams=hparams) if hparams.Infer.pre_trained == "": # train model trainer.train() else: # generate from pre-trained model if "temperature" in hparams.Infer: temp = hparams.Infer.temperature else: temp = 1