def get_csgnet(): config = read_config.Config("config_synthetic.yml") # Encoder encoder_net = Encoder(config.encoder_drop) encoder_net = encoder_net.to(device) # Load the terminals symbols of the grammar with open("terminals.txt", "r") as file: unique_draw = file.readlines() for index, e in enumerate(unique_draw): unique_draw[index] = e[0:-1] imitate_net = ImitateJoint(hd_sz=config.hidden_size, input_size=config.input_size, encoder=encoder_net, mode=config.mode, num_draws=len(unique_draw), canvas_shape=config.canvas_shape) imitate_net = imitate_net.to(device) print("pre loading model") pretrained_dict = torch.load(config.pretrain_modelpath, map_location=device) imitate_net_dict = imitate_net.state_dict() imitate_pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in imitate_net_dict } imitate_net_dict.update(imitate_pretrained_dict) imitate_net.load_state_dict(imitate_net_dict) for param in imitate_net.parameters(): param.requires_grad = True for param in encoder_net.parameters(): param.requires_grad = True return (encoder_net, imitate_net)
if config.preload_model: print("pre loading model") pretrained_dict = torch.load(config.pretrain_modelpath) imitate_net_dict = imitate_net.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in imitate_net_dict } imitate_net_dict.update(pretrained_dict) imitate_net.load_state_dict(imitate_net_dict) for param in imitate_net.parameters(): param.requires_grad = True for param in encoder_net.parameters(): param.requires_grad = True generator = Generator() reinforce = Reinforce(unique_draws=unique_draw) if config.optim == "sgd": optimizer = optim.SGD( [para for para in imitate_net.parameters() if para.requires_grad], weight_decay=config.weight_decay, momentum=0.9, lr=config.lr, nesterov=False) elif config.optim == "adam": optimizer = optim.Adam( [para for para in imitate_net.parameters() if para.requires_grad], weight_decay=config.weight_decay,