コード例 #1
0
    def __init__(self, override_hps, dat_file, ckpt_file):

        ckpt = t.load(ckpt_file)
        if 'hps' in ckpt:
            hps = hparams.Hyperparams(**ckpt['hps'])
        hps.update(override_hps)

        if hps.global_model == 'autoencoder':
            self.model = ae.AutoEncoder(hps)
        elif hps.global_model == 'mfcc_inverter':
            self.model = mi.MfccInverter(hps)

        sub_state = {
            k: v
            for k, v in ckpt['model_state_dict'].items()
            if '_lead' not in k and 'left_wing_size' not in k
        }
        self.model.load_state_dict(sub_state, strict=False)
        self.model.override(n_win_batch=1)

        self.data = data.DataProcessor(hps,
                                       dat_file,
                                       self.model.mfcc,
                                       slice_size=None,
                                       train_mode=False)

        self.device = None
コード例 #2
0
ファイル: chassis.py プロジェクト: entn-at/ae-wavenet
    def __init__(self, mode, opts):
        print('Initializing model and data source...', end='', file=stderr)
        stderr.flush()
        self.learning_rates = dict(
            zip(opts.learning_rate_steps, opts.learning_rate_rates))
        self.opts = opts

        if mode == 'new':
            torch.manual_seed(opts.random_seed)

            # Initialize data
            dataset = data.Slice(opts)
            dataset.load_data(opts.dat_file)
            opts.training = True
            if opts.global_model == 'autoencoder':
                model = ae.AutoEncoder(opts, dataset)
            elif opts.global_model == 'mfcc_inverter':
                model = mi.MfccInverter(opts, dataset)

            model.post_init(dataset)
            dataset.post_init(model)
            optim = torch.optim.Adam(params=model.parameters(),
                                     lr=self.learning_rates[0])
            self.state = checkpoint.State(0, model, dataset, optim)
            self.start_step = self.state.step

        else:
            self.state = checkpoint.State()
            self.state.load(opts.ckpt_file, opts.dat_file)
            self.start_step = self.state.step
            # print('Restored model, data, and optim from {}'.format(opts.ckpt_file), file=stderr)
            #print('Data state: {}'.format(state.data), file=stderr)
            #print('Model state: {}'.format(state.model.checksum()))
            #print('Optim state: {}'.format(state.optim_checksum()))
            stderr.flush()

        if self.state.model.bn_type == 'vae':
            self.anneal_schedule = dict(
                zip(opts.bn_anneal_weight_steps, opts.bn_anneal_weight_vals))

        self.ckpt_path = util.CheckpointPath(self.opts.ckpt_template)
        self.quant = None
        self.target = None
        self.softmax = torch.nn.Softmax(1)  # input to this is (B, Q, N)

        if self.opts.hwtype == 'GPU':
            self.device = torch.device('cuda')
            self.data_loader = self.state.data_loader
            self.data_loader.set_target_device(self.device)
            self.optim_step_fn = (lambda: self.state.optim.step(self.loss_fn))
            self.data_iter = GPULoaderIter(iter(self.data_loader))
        else:
            import torch_xla.core.xla_model as xm
            import torch_xla.distributed.parallel_loader as pl
            self.device = xm.xla_device()
            self.data_loader = pl.ParallelLoader(self.state.data_loader,
                                                 [self.device])
            self.data_iter = TPULoaderIter(self.data_loader, self.device)
            self.optim_step_fn = (lambda: xm.optimizer_step(
                self.state.optim, optimizer_args={'closure': self.loss_fn}))

        self.state.init_torch_generator()
        print('Done.', file=stderr)
        stderr.flush()
コード例 #3
0
    def __init__(self,
                 override_hps,
                 dat_file,
                 train_mode=True,
                 ckpt_file=None,
                 num_replicas=1,
                 rank=0):
        """
        Initialize total state
        """
        if ckpt_file is not None:
            ckpt = t.load(ckpt_file)
            if 'hps' in ckpt:
                hps = hparams.Hyperparams(**ckpt['hps'])
        else:
            hps = hparams.Hyperparams()
        hps.update(override_hps)

        t.manual_seed(hps.random_seed)

        if hps.global_model == 'autoencoder':
            self.model = ae.AutoEncoder(hps)
        elif hps.global_model == 'mfcc_inverter':
            self.model = mi.MfccInverter(hps)

        slice_size = self.model.get_input_size(hps.n_win_batch)

        self.data = data.DataProcessor(hps,
                                       dat_file,
                                       self.model.mfcc,
                                       slice_size,
                                       train_mode,
                                       start_epoch=0,
                                       start_step=0,
                                       num_replicas=num_replicas,
                                       rank=rank)

        self.model.override(hps.n_win_batch)

        if ckpt_file is None:
            self.optim = t.optim.Adam(params=self.model.parameters(),
                                      lr=hps.learning_rate_rates[0])
            self.optim_step = 0

        else:
            sub_state = {
                k: v
                for k, v in ckpt['model_state_dict'].items()
                if '_lead' not in k and 'left_wing_size' not in k
            }
            self.model.load_state_dict(sub_state, strict=False)
            if 'epoch' in ckpt:
                self.data.dataset.set_pos(ckpt['epoch'], ckpt['step'])
            else:
                global_step = ckpt['step']
                epoch = global_step // len(self.data.dataset)
                step = global_step % len(self.data.dataset)
                self.data.dataset.set_pos(epoch, step)

            self.optim = t.optim.Adam(self.model.parameters())
            self.optim.load_state_dict(ckpt['optim'])
            self.optim_step = ckpt['optim_step']
            # self.torch_rng_state = ckpt['rand_state']
            # self.torch_cuda_rng_states = ckpt['cuda_rand_states']

        self.device = None
        self.torch_rng_state = t.get_rng_state()
        if t.cuda.is_available():
            self.torch_cuda_rng_states = t.cuda.get_rng_state_all()
        else:
            self.torch_cuda_rng_states = None

        self.hps = hps