return


# Initialize generator and discriminator
generator = Generator()
generator.load_state_dict(torch.load(opt.checkpoint))
generator.eval()

# Initialize variables
cuda = True if torch.cuda.is_available() else False
if cuda:
    generator.cuda()

# Configure data loader
rooms_path = '/local-scratch4/nnauata/autodesk/FloorplanDataset/'
fp_dataset = FloorplanGraphDataset(rooms_path, transforms.Normalize(mean=[0.5], std=[0.5]), \
                                   target_set=opt.target_set, split='eval')
fp_loader = torch.utils.data.DataLoader(fp_dataset,
                                        batch_size=opt.batch_size,
                                        shuffle=False,
                                        num_workers=0,
                                        collate_fn=floorplan_collate_fn)
fp_iter = tqdm(fp_loader, total=len(fp_dataset) // opt.batch_size + 1)

# Optimizers
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# Generate samples
graphs = []
for i, batch in enumerate(fp_iter):

    # Unpack batch
Exemplo n.º 2
0
os.makedirs(opt.exp_folder, exist_ok=True)

# Initialize generator and discriminator
generator = Generator()
generator.load_state_dict(torch.load(checkpoint))

# Initialize variables
cuda = True if torch.cuda.is_available() else False
if cuda:
    generator.cuda()
rooms_path = '/local-scratch4/nnauata/autodesk/FloorplanDataset/'

# Initialize dataset iterator
fp_dataset_test = FloorplanGraphDataset(rooms_path,
                                        transforms.Normalize(mean=[0.5],
                                                             std=[0.5]),
                                        target_set=target_set,
                                        split=phase)
fp_loader = torch.utils.data.DataLoader(fp_dataset_test,
                                        batch_size=opt.batch_size,
                                        shuffle=False,
                                        collate_fn=floorplan_collate_fn)
# Optimizers
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ------------
#  Vectorize
# ------------
globalIndex = 0
final_images = []
target_graph = [6]