def __init__(self, override_hps, dat_file, ckpt_file): ckpt = t.load(ckpt_file) if 'hps' in ckpt: hps = hparams.Hyperparams(**ckpt['hps']) hps.update(override_hps) if hps.global_model == 'autoencoder': self.model = ae.AutoEncoder(hps) elif hps.global_model == 'mfcc_inverter': self.model = mi.MfccInverter(hps) sub_state = { k: v for k, v in ckpt['model_state_dict'].items() if '_lead' not in k and 'left_wing_size' not in k } self.model.load_state_dict(sub_state, strict=False) self.model.override(n_win_batch=1) self.data = data.DataProcessor(hps, dat_file, self.model.mfcc, slice_size=None, train_mode=False) self.device = None
def __init__(self, mode, opts): print('Initializing model and data source...', end='', file=stderr) stderr.flush() self.learning_rates = dict( zip(opts.learning_rate_steps, opts.learning_rate_rates)) self.opts = opts if mode == 'new': torch.manual_seed(opts.random_seed) # Initialize data dataset = data.Slice(opts) dataset.load_data(opts.dat_file) opts.training = True if opts.global_model == 'autoencoder': model = ae.AutoEncoder(opts, dataset) elif opts.global_model == 'mfcc_inverter': model = mi.MfccInverter(opts, dataset) model.post_init(dataset) dataset.post_init(model) optim = torch.optim.Adam(params=model.parameters(), lr=self.learning_rates[0]) self.state = checkpoint.State(0, model, dataset, optim) self.start_step = self.state.step else: self.state = checkpoint.State() self.state.load(opts.ckpt_file, opts.dat_file) self.start_step = self.state.step # print('Restored model, data, and optim from {}'.format(opts.ckpt_file), file=stderr) #print('Data state: {}'.format(state.data), file=stderr) #print('Model state: {}'.format(state.model.checksum())) #print('Optim state: {}'.format(state.optim_checksum())) stderr.flush() if self.state.model.bn_type == 'vae': self.anneal_schedule = dict( zip(opts.bn_anneal_weight_steps, opts.bn_anneal_weight_vals)) self.ckpt_path = util.CheckpointPath(self.opts.ckpt_template) self.quant = None self.target = None self.softmax = torch.nn.Softmax(1) # input to this is (B, Q, N) if self.opts.hwtype == 'GPU': self.device = torch.device('cuda') self.data_loader = self.state.data_loader self.data_loader.set_target_device(self.device) self.optim_step_fn = (lambda: self.state.optim.step(self.loss_fn)) self.data_iter = GPULoaderIter(iter(self.data_loader)) else: import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl self.device = xm.xla_device() self.data_loader = pl.ParallelLoader(self.state.data_loader, [self.device]) self.data_iter = TPULoaderIter(self.data_loader, self.device) self.optim_step_fn = (lambda: xm.optimizer_step( self.state.optim, optimizer_args={'closure': self.loss_fn})) self.state.init_torch_generator() print('Done.', file=stderr) stderr.flush()
def __init__(self, override_hps, dat_file, train_mode=True, ckpt_file=None, num_replicas=1, rank=0): """ Initialize total state """ if ckpt_file is not None: ckpt = t.load(ckpt_file) if 'hps' in ckpt: hps = hparams.Hyperparams(**ckpt['hps']) else: hps = hparams.Hyperparams() hps.update(override_hps) t.manual_seed(hps.random_seed) if hps.global_model == 'autoencoder': self.model = ae.AutoEncoder(hps) elif hps.global_model == 'mfcc_inverter': self.model = mi.MfccInverter(hps) slice_size = self.model.get_input_size(hps.n_win_batch) self.data = data.DataProcessor(hps, dat_file, self.model.mfcc, slice_size, train_mode, start_epoch=0, start_step=0, num_replicas=num_replicas, rank=rank) self.model.override(hps.n_win_batch) if ckpt_file is None: self.optim = t.optim.Adam(params=self.model.parameters(), lr=hps.learning_rate_rates[0]) self.optim_step = 0 else: sub_state = { k: v for k, v in ckpt['model_state_dict'].items() if '_lead' not in k and 'left_wing_size' not in k } self.model.load_state_dict(sub_state, strict=False) if 'epoch' in ckpt: self.data.dataset.set_pos(ckpt['epoch'], ckpt['step']) else: global_step = ckpt['step'] epoch = global_step // len(self.data.dataset) step = global_step % len(self.data.dataset) self.data.dataset.set_pos(epoch, step) self.optim = t.optim.Adam(self.model.parameters()) self.optim.load_state_dict(ckpt['optim']) self.optim_step = ckpt['optim_step'] # self.torch_rng_state = ckpt['rand_state'] # self.torch_cuda_rng_states = ckpt['cuda_rand_states'] self.device = None self.torch_rng_state = t.get_rng_state() if t.cuda.is_available(): self.torch_cuda_rng_states = t.cuda.get_rng_state_all() else: self.torch_cuda_rng_states = None self.hps = hps