def __init__(self, model=None, dataset=None): self.model = model self.data_loader = data.WavLoader(dataset) self.device = None self.torch_rng_state = t.get_rng_state() if t.cuda.is_available(): self.torch_cuda_rng_states = t.cuda.get_rng_state_all() else: self.torch_cuda_rng_states = None
def __init__(self, step=0, model=None, dataset=None, optim=None): self.model = model self.data_loader = data.WavLoader(dataset) self.optim = optim self.step = step self.device = None self.torch_rng_state = torch.get_rng_state() if torch.cuda.is_available(): self.torch_cuda_rng_states = torch.cuda.get_rng_state_all() else: self.torch_cuda_rng_states = None
def load(self, ckpt_file, dat_file): sinfo = torch.load(ckpt_file) # This is the required order for model and data init self.model = pickle.loads(sinfo['model']) # ignore the pickled dataset characteristics dataset = data.MfccInference(pickle.loads(sinfo['dataset'])) dataset.load_data(dat_file) self.model.post_init(dataset) self.model.load_state_dict(sinfo['model_state_dict']) dataset.post_init(self.model) self.data_loader = data.WavLoader(dataset)
def load(self, ckpt_file, dat_file): sinfo = torch.load(ckpt_file) # This is the required order for model and data init self.model = pickle.loads(sinfo['model']) dataset = pickle.loads(sinfo['dataset']) dataset.load_data(dat_file) self.model.post_init(dataset) self.model.load_state_dict(sinfo['model_state_dict']) dataset.post_init(self.model) self.data_loader = data.WavLoader(dataset) self.optim = torch.optim.Adam(self.model.parameters()) self.optim.load_state_dict(sinfo['optim']) self.step = sinfo['step'] self.torch_rng_state = sinfo['rand_state'] self.torch_cuda_rng_states = sinfo['cuda_rand_states']
def load(self, ckpt_file, dat_file): ckpt = t.load(ckpt_file) # This is the required order for model and data init self.model = pickle.loads(ckpt['model']) # win batch of 1 is inference mode self.model.override(n_win_batch=1) # ignore the pickled dataset characteristics dataset = data.MfccInference(pickle.loads(ckpt['dataset']), dat_file) # dataset.load_data(dat_file) self.model.post_init(dataset) sub_state = { k: v for k, v in ckpt['model_state_dict'].items() if '_lead' not in k and 'left_wing_size' not in k } self.model.load_state_dict(sub_state, strict=False) dataset.post_init(self.model) self.data_loader = data.WavLoader(dataset)
def __init__(self, model=None, dataset=None): self.model = model self.data_loader = data.WavLoader(dataset) self.device = None