Beispiel #1
0
    def __init__(self, opts, dataset):
        opts_dict = vars(opts)
        dec_params = parse_tools.get_prefixed_items(opts_dict, 'dec_')
        dec_params['n_speakers'] = dataset.num_speakers()
        mi_params = parse_tools.get_prefixed_items(opts_dict, 'mi_')

        self.init_args = {'dec_params': dec_params, 'mi_params': mi_params}
        self._initialize()
Beispiel #2
0
    def __init__(self, opts, dataset):
        opts_dict = vars(opts)
        enc_params = parse_tools.get_prefixed_items(opts_dict, 'enc_')
        bn_params = parse_tools.get_prefixed_items(opts_dict, 'bn_')
        dec_params = parse_tools.get_prefixed_items(opts_dict, 'dec_')
        dec_params['n_speakers'] = dataset.num_speakers()

        self.init_args = {
            'enc_params': enc_params,
            'bn_params': bn_params,
            'dec_params': dec_params,
            'n_mel_chan': dataset.num_mel_chan(),
            'training': opts.training
        }
        self._initialize()
Beispiel #3
0
 def __init__(self, opts):
     opts_dict = vars(opts)
     pre_pars = parse_tools.get_prefixed_items(opts_dict, 'pre_')
     self.init_args = {
         'batch_size': opts.n_batch,
         'window_batch_size': opts.n_win_batch,
         'jitter_prob': opts.jitter_prob,
         'sample_rate': pre_pars['sample_rate'],
         'mfcc_win_sz': pre_pars['mfcc_win_sz'],
         'mfcc_hop_sz': pre_pars['mfcc_hop_sz'],
         'n_mels': pre_pars['n_mels'],
         'n_mfcc': pre_pars['n_mfcc']
     }
     self._initialize()
Beispiel #4
0
def main():
    if len(sys.argv) == 1 or sys.argv[1] not in ('new', 'resume'):
        print(parse_tools.top_usage, file=stderr)
        return

    print('Command line: ', ' '.join(sys.argv), file=stderr)
    stderr.flush()

    mode = sys.argv[1]
    del sys.argv[1]
    if mode == 'new':
        cold_parser = parse_tools.cold_parser()
        opts = parse_tools.two_stage_parse(cold_parser)
    elif mode == 'resume':
        resume_parser = parse_tools.resume_parser()
        opts = resume_parser.parse_args()

    opts.device = None
    if not opts.disable_cuda and torch.cuda.is_available():
        opts.device = torch.device('cuda')
        print('Using GPU', file=stderr)
    else:
        opts.device = torch.device('cpu')
        print('Using CPU', file=stderr)
    stderr.flush()

    ckpt_path = util.CheckpointPath(opts.ckpt_template)
    learning_rates = dict(
        zip(opts.learning_rate_steps, opts.learning_rate_rates))

    # Construct model
    if mode == 'new':
        # Initialize model
        pre_params = parse_tools.get_prefixed_items(vars(opts), 'pre_')
        enc_params = parse_tools.get_prefixed_items(vars(opts), 'enc_')
        bn_params = parse_tools.get_prefixed_items(vars(opts), 'bn_')
        dec_params = parse_tools.get_prefixed_items(vars(opts), 'dec_')

        # Initialize data
        data_source = data.Slice(opts.index_file_prefix,
                                 opts.max_gpu_data_bytes, opts.n_batch)

        dec_params['n_speakers'] = data_source.num_speakers()

        model = ae.AutoEncoder(pre_params, enc_params, bn_params, dec_params,
                               opts.n_sam_per_slice)
        optim = torch.optim.Adam(params=model.parameters(),
                                 lr=learning_rates[0])
        state = checkpoint.State(0, model, data_source, optim)

    else:
        state = checkpoint.State()
        state.load(opts.ckpt_file)
        state.model.set_slice_size(opts.n_sam_per_slice)
        print('Restored model, data, and optim from {}'.format(opts.ckpt_file),
              file=stderr)
        #print('Data state: {}'.format(state.data), file=stderr)
        #print('Model state: {}'.format(state.model.checksum()))
        #print('Optim state: {}'.format(state.optim_checksum()))
        stderr.flush()

    start_step = state.step

    print('Model input size: {}'.format(state.model.input_size), file=stderr)
    stderr.flush()

    # set this to zero if you want to print out a logging header in resume mode as well
    netmisc.set_print_iter(0)

    state.data.init_geometry(state.model.preprocess.rf, state.model)

    #state.data.set_geometry(opts.n_batch, state.model.input_size,
    #        state.model.output_size)
    state.to(device=opts.device)

    # Initialize optimizer
    metrics = ae.Metrics(state)
    batch_gen = state.data.batch_slice_gen_fn()

    #for p in list(state.model.encoder.parameters()):
    #    with torch.no_grad():
    #        p *= 1

    # Start training
    print('Training parameters used:', file=stderr)
    pprint(opts, stderr)

    state.init_torch_generator()

    while state.step < opts.max_steps:
        if state.step in learning_rates:
            state.update_learning_rate(learning_rates[state.step])
        # do 'pip install --upgrade scipy' if you get 'FutureWarning: ...'
        # print('in main loop')

        if state.step in (1, 10, 50, 100, 300,
                          500) and state.model.bn_type == 'vqvae':
            print('Reinitializing embed with current distribution',
                  file=stderr)
            stderr.flush()
            state.model.init_vq_embed(batch_gen)

        metrics.update(batch_gen)
        loss = metrics.state.optim.step(metrics.loss)
        avg_peak_dist = metrics.peak_dist()
        avg_max = metrics.avg_max()
        avg_prob_target = metrics.avg_prob_target()

        if False:
            for n, p in list(state.model.encoder.named_parameters()):
                g = p.grad
                if g is None:
                    print('{:60s}\tNone'.format(n), file=stderr)
                else:
                    fmt = '{:s}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
                    print(fmt.format(n, g.max(), g.min(), g.mean(), g.std()),
                          file=stderr)

        # Progress reporting
        if state.step % opts.progress_interval == 0:
            current_stats = {
                'step': state.step,
                'loss': loss,
                'tprb_m': avg_prob_target,
                'pk_d_m': avg_peak_dist
            }
            #fmt = "M\t{:d}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}"
            #print(fmt.format(state.step, loss, avg_prob_target, avg_peak_dist,
            #    avg_max), file=stderr)
            if state.model.bn_type == 'vqvae':
                current_stats.update(state.model.objective.metrics)

            netmisc.print_metrics(current_stats, 1000000)
            stderr.flush()

        state.step += 1

        # Checkpointing
        if ((state.step % opts.save_interval == 0 and state.step != start_step)
                or (mode == 'new' and state.step == 1)):
            ckpt_file = ckpt_path.path(state.step)
            state.save(ckpt_file)
            print('Saved checkpoint to {}'.format(ckpt_file), file=stderr)
            #print('Optim state: {}'.format(state.optim_checksum()), file=stderr)
            stderr.flush()
Beispiel #5
0
    def __init__(self, mode, opts):
        print('Initializing model and data source...', end='', file=stderr)
        stderr.flush()
        self.learning_rates = dict(zip(opts.learning_rate_steps,
            opts.learning_rate_rates))
        self.opts = opts

        if mode == 'new':
            torch.manual_seed(opts.random_seed)
            pre_par = parse_tools.get_prefixed_items(vars(opts), 'pre_')
            enc_par = parse_tools.get_prefixed_items(vars(opts), 'enc_')
            bn_par = parse_tools.get_prefixed_items(vars(opts), 'bn_')
            dec_par = parse_tools.get_prefixed_items(vars(opts), 'dec_')

            # Initialize data
            jprob = dec_par.pop('jitter_prob')
            dataset = data.Slice(opts.n_batch, opts.n_win_batch, jprob,
                    pre_par['sample_rate'], pre_par['mfcc_win_sz'],
                    pre_par['mfcc_hop_sz'], pre_par['n_mels'],
                    pre_par['n_mfcc'])
            dataset.load_data(opts.dat_file)
            dec_par['n_speakers'] = dataset.num_speakers()
            model = ae.AutoEncoder(pre_par, enc_par, bn_par, dec_par,
                    dataset.num_mel_chan(), training=True)
            model.post_init(dataset)
            dataset.post_init(model)
            optim = torch.optim.Adam(params=model.parameters(), lr=self.learning_rates[0])
            self.state = checkpoint.State(0, model, dataset, optim)
            self.start_step = self.state.step

        else:
            self.state = checkpoint.State()
            self.state.load(opts.ckpt_file, opts.dat_file)
            self.start_step = self.state.step
            # print('Restored model, data, and optim from {}'.format(opts.ckpt_file), file=stderr)
            #print('Data state: {}'.format(state.data), file=stderr)
            #print('Model state: {}'.format(state.model.checksum()))
            #print('Optim state: {}'.format(state.optim_checksum()))
            stderr.flush()

        self.ckpt_path = util.CheckpointPath(self.opts.ckpt_template)
        self.quant = None
        self.target = None
        self.softmax = torch.nn.Softmax(1) # input to this is (B, Q, N)

        if self.opts.hwtype == 'GPU':
            self.device = torch.device('cuda')
            self.data_loader = self.state.data_loader
            self.data_loader.set_target_device(self.device)
            self.optim_step_fn = (lambda: self.state.optim.step(self.loss_fn))
            self.data_iter = GPULoaderIter(iter(self.data_loader))
        else:
            import torch_xla.core.xla_model as xm
            import torch_xla.distributed.parallel_loader as pl
            self.device = xm.xla_device()
            self.data_loader = pl.ParallelLoader(self.state.data_loader, [self.device])
            self.data_iter = TPULoaderIter(self.data_loader, self.device)
            self.optim_step_fn = (lambda : xm.optimizer_step(self.state.optim,
                    optimizer_args={'closure': self.loss_fn}))

        self.state.init_torch_generator()
        print('Done.', file=stderr)
        stderr.flush()
Beispiel #6
0
def main():
    if len(sys.argv) == 1 or sys.argv[1] not in ('new', 'resume'):
        print(parse_tools.top_usage, file=stderr)
        return

    mode = sys.argv[1]
    del sys.argv[1]
    if mode == 'new':
        opts = parse_tools.two_stage_parse(parse_tools.cold)
    elif mode == 'resume':
        opts = parse_tools.resume.parse_args()  

    opts.device = None
    if not opts.disable_cuda and torch.cuda.is_available():
        opts.device = torch.device('cuda')
    else:
        opts.device = torch.device('cpu') 

    ckpt_path = util.CheckpointPath(opts.ckpt_template)

    # Construct model
    if mode == 'new':
        # Initialize model
        pre_params = parse_tools.get_prefixed_items(vars(opts), 'pre_')
        enc_params = parse_tools.get_prefixed_items(vars(opts), 'enc_')
        bn_params = parse_tools.get_prefixed_items(vars(opts), 'bn_')
        dec_params = parse_tools.get_prefixed_items(vars(opts), 'dec_')

        # Initialize data
        sample_catalog = D.parse_sample_catalog(opts.sam_file)
        data = D.WavSlices(sample_catalog, pre_params['sample_rate'],
                opts.frac_permutation_use, opts.requested_wav_buf_sz)
        dec_params['n_speakers'] = data.num_speakers()

        #with torch.autograd.set_detect_anomaly(True):
        model = ae.AutoEncoder(pre_params, enc_params, bn_params, dec_params)
        print('Initializing model parameters', file=stderr)
        model.initialize_weights()

        # Construct overall state
        state = checkpoint.State(0, model, data)

    else:
        state = checkpoint.State()
        state.load(opts.ckpt_file)
        print('Restored model and data from {}'.format(opts.ckpt_file), file=stderr)

    state.model.set_geometry(opts.n_sam_per_slice)

    state.data.set_geometry(opts.n_batch, state.model.input_size,
            state.model.output_size)

    state.model.to(device=opts.device)

    #total_bytes = 0
    #for name, par in model.named_parameters():
    #    n_bytes = par.data.nelement() * par.data.element_size()
    #    total_bytes += n_bytes
    #    print(name, type(par.data), par.size(), n_bytes)
    #print('total_bytes: ', total_bytes)

    # Initialize optimizer
    model_params = state.model.parameters()
    metrics = ae.Metrics(state.model, None)
    batch_gen = state.data.batch_slice_gen_fn()

    #loss_fcn = state.model.loss_factory(state.data.batch_slice_gen_fn())

    # Start training
    print('Starting training...', file=stderr)
    print("Step\tLoss\tAvgProbTarget\tPeakDist\tAvgMax", file=stderr)
    stderr.flush()

    learning_rates = dict(zip(opts.learning_rate_steps, opts.learning_rate_rates))
    start_step = state.step
    if start_step not in learning_rates:
        ref_step = util.greatest_lower_bound(opts.learning_rate_steps, start_step)
        metrics.optim = torch.optim.Adam(params=model_params,
                lr=learning_rates[ref_step])

    while state.step < opts.max_steps:
        if state.step in learning_rates:
            metrics.optim = torch.optim.Adam(params=model_params,
                    lr=learning_rates[state.step])
        # do 'pip install --upgrade scipy' if you get 'FutureWarning: ...'
        metrics.update(batch_gen)
        loss = metrics.optim.step(metrics.loss)
        avg_peak_dist = metrics.peak_dist()
        avg_max = metrics.avg_max()
        avg_prob_target = metrics.avg_prob_target()

        # Progress reporting
        if state.step % opts.progress_interval == 0:
            fmt = "{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}"
            print(fmt.format(state.step, loss, avg_prob_target, avg_peak_dist,
                avg_max), file=stderr)
            stderr.flush()

        # Checkpointing
        if state.step % opts.save_interval == 0 and state.step != start_step:
            ckpt_file = ckpt_path.path(state.step)
            state.save(ckpt_file)
            print('Saved checkpoint to {}'.format(ckpt_file), file=stderr)

        state.step += 1