def get_csgnet(): config = read_config.Config("config_synthetic.yml") # Encoder encoder_net = Encoder(config.encoder_drop) encoder_net = encoder_net.to(device) imitate_net = ImitateJoint(hd_sz=config.hidden_size, input_size=config.input_size, encoder=encoder_net, mode=config.mode, num_draws=400, canvas_shape=config.canvas_shape) imitate_net = imitate_net.to(device) print("pre loading model") pretrained_dict = torch.load(config.pretrain_modelpath, map_location=device) imitate_net_dict = imitate_net.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in imitate_net_dict } imitate_net_dict.update(pretrained_dict) imitate_net.load_state_dict(imitate_net_dict) return imitate_net.encoder
def get_blank_csgnet(): config = read_config.Config("config_synthetic.yml") # Encoder encoder_net = Encoder(config.encoder_drop) encoder_net = encoder_net.to(device) imitate_net = ImitateJoint(hd_sz=config.hidden_size, input_size=config.input_size, encoder=encoder_net, mode=config.mode, num_draws=400, canvas_shape=config.canvas_shape) imitate_net = imitate_net.to(device) return imitate_net
def get_csgnet(): config = read_config.Config("config_synthetic.yml") # Encoder encoder_net = Encoder(config.encoder_drop) encoder_net = encoder_net.to(device) # Load the terminals symbols of the grammar with open("terminals.txt", "r") as file: unique_draw = file.readlines() for index, e in enumerate(unique_draw): unique_draw[index] = e[0:-1] imitate_net = ImitateJoint(hd_sz=config.hidden_size, input_size=config.input_size, encoder=encoder_net, mode=config.mode, num_draws=len(unique_draw), canvas_shape=config.canvas_shape) imitate_net = imitate_net.to(device) print("pre loading model") pretrained_dict = torch.load(config.pretrain_modelpath, map_location=device) imitate_net_dict = imitate_net.state_dict() imitate_pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in imitate_net_dict } imitate_net_dict.update(imitate_pretrained_dict) imitate_net.load_state_dict(imitate_net_dict) for param in imitate_net.parameters(): param.requires_grad = True for param in encoder_net.parameters(): param.requires_grad = True return (encoder_net, imitate_net)
plt.savefig("best_lest/" + "{}.png".format(batch_idx * config.batch_size + j), transparent=0) plt.close("all") # with open("best_st_expressions.txt", "w") as file: # for e in pred_expressions: # file.write(f"{e}\n") # break return CDs / (config.test_size // config.batch_size) config = read_config.Config("config_synthetic.yml") device = torch.device("cuda") encoder_net = Encoder(config.encoder_drop) encoder_net = encoder_net.to(device) imitate_net = ImitateJoint(hd_sz=config.hidden_size, input_size=config.input_size, encoder=encoder_net, mode=config.mode, num_draws=400, canvas_shape=config.canvas_shape) imitate_net = imitate_net.to(device) try: pretrained_dict = torch.load("imitate_27.pth", map_location=device) except Exception as e: print(e) imitate_net_dict = imitate_net.state_dict() pretrained_dict = { k: v