nd_to_sample, ed_to_sample) fake_imgs_tensor = combine_images_maps(gen_mks, given_nds, given_eds, \ nd_to_sample, ed_to_sample) # Save images save_image(real_imgs_tensor, "./exps/{}/{}_real.png".format(exp_folder, batches_done), \ nrow=12, normalize=False) save_image(fake_imgs_tensor, "./exps/{}/{}_fake.png".format(exp_folder, batches_done), \ nrow=12, normalize=False) return # Configure data loader rooms_path = '/home/nelson/Workspace/autodesk/' fp_dataset_train = FloorplanGraphDataset(rooms_path, transforms.Normalize(mean=[0.5], std=[0.5]), target_set=opt.target_set) fp_loader = torch.utils.data.DataLoader(fp_dataset_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu, collate_fn=floorplan_collate_fn) fp_dataset_test = FloorplanGraphDataset(rooms_path, transforms.Normalize(mean=[0.5], std=[0.5]), target_set=opt.target_set, split='eval') fp_loader_test = torch.utils.data.DataLoader(fp_dataset_test, batch_size=8, shuffle=True,
# Initialize generator and discriminator generator = Generator() generator.load_state_dict(torch.load(checkpoint), strict=False) generator = generator.eval() # Initialize variables cuda = True if torch.cuda.is_available() else False if cuda: generator.cuda() rooms_path = '../' # Initialize dataset iterator fp_dataset_test = FloorplanGraphDataset(rooms_path, transforms.Normalize(mean=[0.5], std=[0.5]), target_set=target_set, split=phase) fp_loader = torch.utils.data.DataLoader(fp_dataset_test, batch_size=opt.batch_size, shuffle=False, collate_fn=floorplan_collate_fn, num_workers=1) # Optimizers Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor # Generate state def gen_state(curr_fixed_nodes_state, prev_fixed_nodes_state, sample, initial_state):
# Create output dir os.makedirs(opt.out, exist_ok=True) # Initialize generator and discriminator model = Generator() model.load_state_dict(torch.load(opt.checkpoint), strict=True) model = model.eval() # Initialize variables if torch.cuda.is_available(): model.cuda() # initialize dataset iterator fp_dataset_test = FloorplanGraphDataset(opt.data_path, transforms.Normalize(mean=[0.5], std=[0.5]), split='test') fp_loader = torch.utils.data.DataLoader(fp_dataset_test, batch_size=opt.batch_size, shuffle=False, collate_fn=floorplan_collate_fn) # optimizers Tensor = torch.cuda.FloatTensor if torch.cuda.is_available( ) else torch.FloatTensor # run inference def _infer(graph, model, prev_state=None): # configure input to the network z, given_masks_in, given_nds, given_eds = _init_input(graph, prev_state)