def get_csgnet(): config = read_config.Config("config_synthetic.yml") # Encoder encoder_net = Encoder(config.encoder_drop) encoder_net = encoder_net.to(device) imitate_net = ImitateJoint(hd_sz=config.hidden_size, input_size=config.input_size, encoder=encoder_net, mode=config.mode, num_draws=400, canvas_shape=config.canvas_shape) imitate_net = imitate_net.to(device) print("pre loading model") pretrained_dict = torch.load(config.pretrain_modelpath, map_location=device) imitate_net_dict = imitate_net.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in imitate_net_dict } imitate_net_dict.update(pretrained_dict) imitate_net.load_state_dict(imitate_net_dict) return imitate_net.encoder
def get_blank_csgnet(): config = read_config.Config("config_synthetic.yml") # Encoder encoder_net = Encoder(config.encoder_drop) encoder_net = encoder_net.to(device) imitate_net = ImitateJoint(hd_sz=config.hidden_size, input_size=config.input_size, encoder=encoder_net, mode=config.mode, num_draws=400, canvas_shape=config.canvas_shape) imitate_net = imitate_net.to(device) return imitate_net
def get_csgnet(): config = read_config.Config("config_synthetic.yml") # Encoder encoder_net = Encoder(config.encoder_drop) encoder_net = encoder_net.to(device) # Load the terminals symbols of the grammar with open("terminals.txt", "r") as file: unique_draw = file.readlines() for index, e in enumerate(unique_draw): unique_draw[index] = e[0:-1] imitate_net = ImitateJoint(hd_sz=config.hidden_size, input_size=config.input_size, encoder=encoder_net, mode=config.mode, num_draws=len(unique_draw), canvas_shape=config.canvas_shape) imitate_net = imitate_net.to(device) print("pre loading model") pretrained_dict = torch.load(config.pretrain_modelpath, map_location=device) imitate_net_dict = imitate_net.state_dict() imitate_pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in imitate_net_dict } imitate_net_dict.update(imitate_pretrained_dict) imitate_net.load_state_dict(imitate_net_dict) for param in imitate_net.parameters(): param.requires_grad = True for param in encoder_net.parameters(): param.requires_grad = True return (encoder_net, imitate_net)
# Setup Tensorboard logger configure("log/tensorboard/{}".format(model_name), flush_secs=5) # Setup logger logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s') file_handler = logging.FileHandler('log/logger/{}.log'.format(model_name), mode='w') file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.info(config.config) # CNN encoder encoder_net = Encoder(config.encoder_drop) encoder_net.cuda() # Load the terminals symbols of the grammar with open("terminals.txt", "r") as file: unique_draw = file.readlines() for index, e in enumerate(unique_draw): unique_draw[index] = e[0:-1] # RNN decoder imitate_net = ImitateJoint(hd_sz=config.hidden_size, input_size=config.input_size, encoder=encoder_net, mode=config.mode, num_draws=len(unique_draw), canvas_shape=config.canvas_shape)
import os import numpy as np import torch from torch.autograd.variable import Variable from src.Models.models import Encoder from src.Models.models import ImitateJoint from src.Models.models import ParseModelOutput from src.utils import read_config from src.utils.generators.mixed_len_generator import MixedGenerateData from src.utils.train_utils import prepare_input_op, chamfer config = read_config.Config("config_synthetic.yml") model_name = config.pretrain_modelpath.split("/")[-1][0:-4] encoder_net = Encoder() encoder_net.cuda() data_labels_paths = { 3: "data/synthetic/one_op/expressions.txt", 5: "data/synthetic/two_ops/expressions.txt", 7: "data/synthetic/three_ops/expressions.txt", 9: "data/synthetic/four_ops/expressions.txt", 11: "data/synthetic/five_ops/expressions.txt", 13: "data/synthetic/six_ops/expressions.txt" } # first element of list is num of training examples, and second is number of # testing examples. proportion = config.proportion # proportion is in percentage. vary from [1, 100]. dataset_sizes = { 3: [30000, 50 * proportion],