Beispiel #1
0
def get_blank_csgnet():
    config = read_config.Config("config_synthetic.yml")

    # Encoder
    encoder_net = Encoder(config.encoder_drop)
    encoder_net = encoder_net.to(device)

    imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                               input_size=config.input_size,
                               encoder=encoder_net,
                               mode=config.mode,
                               num_draws=400,
                               canvas_shape=config.canvas_shape)
    imitate_net = imitate_net.to(device)

    return imitate_net
Beispiel #2
0
def get_csgnet():
    config = read_config.Config("config_synthetic.yml")

    # Encoder
    encoder_net = Encoder(config.encoder_drop)
    encoder_net = encoder_net.to(device)

    imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                               input_size=config.input_size,
                               encoder=encoder_net,
                               mode=config.mode,
                               num_draws=400,
                               canvas_shape=config.canvas_shape)
    imitate_net = imitate_net.to(device)

    print("pre loading model")
    pretrained_dict = torch.load(config.pretrain_modelpath,
                                 map_location=device)
    imitate_net_dict = imitate_net.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in imitate_net_dict
    }
    imitate_net_dict.update(pretrained_dict)
    imitate_net.load_state_dict(imitate_net_dict)

    return imitate_net.encoder
Beispiel #3
0
def get_csgnet():
    config = read_config.Config("config_synthetic.yml")

    # Encoder
    encoder_net = Encoder(config.encoder_drop)
    encoder_net = encoder_net.to(device)

    # Load the terminals symbols of the grammar
    with open("terminals.txt", "r") as file:
        unique_draw = file.readlines()
    for index, e in enumerate(unique_draw):
        unique_draw[index] = e[0:-1]

    imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                               input_size=config.input_size,
                               encoder=encoder_net,
                               mode=config.mode,
                               num_draws=len(unique_draw),
                               canvas_shape=config.canvas_shape)
    imitate_net = imitate_net.to(device)

    print("pre loading model")
    pretrained_dict = torch.load(config.pretrain_modelpath,
                                 map_location=device)
    imitate_net_dict = imitate_net.state_dict()
    imitate_pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in imitate_net_dict
    }
    imitate_net_dict.update(imitate_pretrained_dict)
    imitate_net.load_state_dict(imitate_net_dict)

    for param in imitate_net.parameters():
        param.requires_grad = True

    for param in encoder_net.parameters():
        param.requires_grad = True

    return (encoder_net, imitate_net)
Beispiel #4
0
logger.info(config.config)

# CNN encoder
encoder_net = Encoder(config.encoder_drop)
encoder_net.cuda()

# Load the terminals symbols of the grammar
with open("terminals.txt", "r") as file:
    unique_draw = file.readlines()
for index, e in enumerate(unique_draw):
    unique_draw[index] = e[0:-1]

# RNN decoder
imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                           input_size=config.input_size,
                           encoder=encoder_net,
                           mode=config.mode,
                           num_draws=len(unique_draw),
                           canvas_shape=config.canvas_shape)
imitate_net.cuda()
imitate_net.epsilon = config.eps

if config.preload_model:
    print("pre loading model")
    pretrained_dict = torch.load(config.pretrain_modelpath)
    imitate_net_dict = imitate_net.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in imitate_net_dict
    }
    imitate_net_dict.update(pretrained_dict)
    imitate_net.load_state_dict(imitate_net_dict)
Beispiel #5
0
    3: [25000, 50 * proportion],
    5: [100000, 500 * proportion],
    7: [150000, 500 * proportion],
    9: [250000, 500 * proportion],
    11: [350000, 1000 * proportion],
    13: [350000, 1000 * proportion]
}
#dataset_sizes = {k: [x // 100 for x in v] for k, v in dataset_sizes.items()}

generator = MixedGenerateData(data_labels_paths=data_labels_paths,
                              batch_size=config.batch_size,
                              canvas_shape=config.canvas_shape)

imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                           input_size=config.input_size,
                           encoder=encoder_net,
                           mode=config.mode,
                           num_draws=len(generator.unique_draw),
                           canvas_shape=config.canvas_shape)
imitate_net.cuda()

with open("terminals.txt", "r") as file:
    unique_draw = file.readlines()
for index, e in enumerate(unique_draw):
    unique_draw[index] = e[0:-1]

# if config.preload_model:
#     print("pre loading model")
#     pretrained_dict = torch.load(config.pretrain_modelpath)
#     imitate_net_dict = imitate_net.state_dict()
#     pretrained_dict = {
#         k: v
Beispiel #6
0
                plt.close("all")
            # with open("best_st_expressions.txt", "w") as file:
            #     for e in pred_expressions:
            #         file.write(f"{e}\n")
            # break

    return CDs / (config.test_size // config.batch_size)


config = read_config.Config("config_synthetic.yml")
device = torch.device("cuda")
encoder_net = Encoder(config.encoder_drop)
encoder_net = encoder_net.to(device)
imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                           input_size=config.input_size,
                           encoder=encoder_net,
                           mode=config.mode,
                           num_draws=400,
                           canvas_shape=config.canvas_shape)
imitate_net = imitate_net.to(device)

try:
    pretrained_dict = torch.load("imitate_27.pth", map_location=device)
except Exception as e:
    print(e)
imitate_net_dict = imitate_net.state_dict()
pretrained_dict = {
    k: v
    for k, v in pretrained_dict.items() if k in imitate_net_dict
}
imitate_net_dict.update(pretrained_dict)
imitate_net.load_state_dict(imitate_net_dict)
Beispiel #7
0
    3: [30000, 50 * proportion],
    5: [110000, 500 * proportion],
    7: [170000, 500 * proportion],
    9: [270000, 500 * proportion],
    11: [370000, 1000 * proportion],
    13: [370000, 1000 * proportion]
}
dataset_sizes = {k: [x // 100 for x in v] for k, v in dataset_sizes.items()}

generator = MixedGenerateData(data_labels_paths=data_labels_paths,
                              batch_size=config.batch_size,
                              canvas_shape=config.canvas_shape)

imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                           input_size=config.input_size,
                           encoder=encoder_net,
                           mode=config.mode,
                           num_draws=len(generator.unique_draw),
                           canvas_shape=config.canvas_shape)

imitate_net.cuda()
if config.preload_model:
    print("pre loading model")
    pretrained_dict = torch.load(config.pretrain_modelpath)
    imitate_net_dict = imitate_net.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in imitate_net_dict
    }
    imitate_net_dict.update(pretrained_dict)
    imitate_net.load_state_dict(imitate_net_dict)
Beispiel #8
0
encoder_net = Encoder()
encoder_net.cuda()

# CNN encoder
encoder_net = Encoder(config.encoder_drop)
encoder_net.cuda()

# Load the terminals symbols of the grammar
with open("terminals.txt", "r") as file:
    unique_draw = file.readlines()
for index, e in enumerate(unique_draw):
    unique_draw[index] = e[0:-1]

imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                           input_size=config.input_size,
                           encoder=encoder_net,
                           mode=config.mode,
                           num_draws=len(unique_draw),
                           canvas_shape=config.canvas_shape)

imitate_net.cuda()
imitate_net.epsilon = 0

test_size = 3000
# This is to find top-1 performance.
paths = [config.pretrain_modelpath]
save_viz = False
for p in paths:
    print(p, flush=True)
    config.pretrain_modelpath = p

    image_path = "data/cad/predicted_images/{}/top_1_prediction/images/".format(
dataset_sizes = {
    3: [30000, 50 * proportion],
    5: [110000, 500 * proportion],
    7: [170000, 500 * proportion],
    9: [270000, 500 * proportion],
    11: [370000, 1000 * proportion],
    13: [370000, 1000 * proportion]
}

generator = MixedGenerateData(data_labels_paths=data_labels_paths,
                              batch_size=config.batch_size,
                              canvas_shape=config.canvas_shape)

imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                           input_size=config.input_size,
                           encoder=encoder_net,
                           mode=config.mode,
                           num_draws=len(generator.unique_draw),
                           canvas_shape=config.canvas_shape)

imitate_net.cuda()
if config.preload_model:
    print("pre loading model")
    pretrained_dict = torch.load(config.pretrain_modelpath)
    imitate_net_dict = imitate_net.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in imitate_net_dict
    }
    imitate_net_dict.update(pretrained_dict)
    imitate_net.load_state_dict(imitate_net_dict)