Exemple #1
0
    def __init__(self, device, index, hps, dat_file):
        self.is_tpu = (hps.hw in ('TPU', 'TPU-single'))
        if self.is_tpu:
            num_replicas = xm.xrt_world_size()
            rank = xm.get_ordinal()
        elif hps.hw == 'GPU':
            if not t.cuda.is_available():
                raise RuntimeError('GPU requested but not available')
            num_replicas = 1
            rank = 0
        elif hps.hw == 'CPU':
            num_replicas = 1
            rank = 0
        else:
            raise ValueError(f'Chassis: Invalid device "{hps.hw}" requested')

        self.replica_index = index

        self.state = ckpt.Checkpoint(hps,
                                     dat_file,
                                     train_mode=True,
                                     ckpt_file=hps.get('ckpt_file', None),
                                     num_replicas=num_replicas,
                                     rank=rank)

        hps = self.state.hps
        if not self.is_tpu or xm.is_master_ordinal():
            print('Hyperparameters:\n', file=stderr)
            print('\n'.join(f'{k} = {v}' for k, v in hps.items()), file=stderr)

        self.learning_rates = dict(
            zip(hps.learning_rate_steps, hps.learning_rate_rates))

        if self.state.model.bn_type == 'vae':
            self.anneal_schedule = dict(
                zip(hps.bn_anneal_weight_steps, hps.bn_anneal_weight_vals))

        self.ckpt_path = util.CheckpointPath(
            hps.ckpt_template, not self.is_tpu or xm.is_master_ordinal())

        self.softmax = t.nn.Softmax(1)  # input to this is (B, Q, N)
        self.hw = hps.hw

        if hps.hw == 'GPU':
            self.device_loader = GPULoaderIter(self.state.data.loader, device)
            self.state.to(device)
        else:
            para_loader = pl.ParallelLoader(self.state.data.loader, [device])
            self.device_loader = para_loader.per_device_loader(device)
            self.num_devices = xm.xrt_world_size()
            self.state.to(device)

        self.state.init_torch_generator()

        if not self.is_tpu or xm.is_master_ordinal():
            self.writer = SummaryWriter(log_dir=hps.log_dir)
        else:
            self.writer = None
Exemple #2
0
    def __init__(self, mode, opts):
        print('Initializing model and data source...', end='', file=stderr)
        stderr.flush()
        self.learning_rates = dict(
            zip(opts.learning_rate_steps, opts.learning_rate_rates))
        self.opts = opts

        if mode == 'new':
            torch.manual_seed(opts.random_seed)

            # Initialize data
            dataset = data.Slice(opts)
            dataset.load_data(opts.dat_file)
            opts.training = True
            if opts.global_model == 'autoencoder':
                model = ae.AutoEncoder(opts, dataset)
            elif opts.global_model == 'mfcc_inverter':
                model = mi.MfccInverter(opts, dataset)

            model.post_init(dataset)
            dataset.post_init(model)
            optim = torch.optim.Adam(params=model.parameters(),
                                     lr=self.learning_rates[0])
            self.state = checkpoint.State(0, model, dataset, optim)
            self.start_step = self.state.step

        else:
            self.state = checkpoint.State()
            self.state.load(opts.ckpt_file, opts.dat_file)
            self.start_step = self.state.step
            # print('Restored model, data, and optim from {}'.format(opts.ckpt_file), file=stderr)
            #print('Data state: {}'.format(state.data), file=stderr)
            #print('Model state: {}'.format(state.model.checksum()))
            #print('Optim state: {}'.format(state.optim_checksum()))
            stderr.flush()

        if self.state.model.bn_type == 'vae':
            self.anneal_schedule = dict(
                zip(opts.bn_anneal_weight_steps, opts.bn_anneal_weight_vals))

        self.ckpt_path = util.CheckpointPath(self.opts.ckpt_template)
        self.quant = None
        self.target = None
        self.softmax = torch.nn.Softmax(1)  # input to this is (B, Q, N)

        if self.opts.hwtype == 'GPU':
            self.device = torch.device('cuda')
            self.data_loader = self.state.data_loader
            self.data_loader.set_target_device(self.device)
            self.optim_step_fn = (lambda: self.state.optim.step(self.loss_fn))
            self.data_iter = GPULoaderIter(iter(self.data_loader))
        else:
            import torch_xla.core.xla_model as xm
            import torch_xla.distributed.parallel_loader as pl
            self.device = xm.xla_device()
            self.data_loader = pl.ParallelLoader(self.state.data_loader,
                                                 [self.device])
            self.data_iter = TPULoaderIter(self.data_loader, self.device)
            self.optim_step_fn = (lambda: xm.optimizer_step(
                self.state.optim, optimizer_args={'closure': self.loss_fn}))

        self.state.init_torch_generator()
        print('Done.', file=stderr)
        stderr.flush()
Exemple #3
0
def main():
    if len(sys.argv) == 1 or sys.argv[1] not in ('new', 'resume'):
        print(parse_tools.top_usage, file=stderr)
        return

    print('Command line: ', ' '.join(sys.argv), file=stderr)
    stderr.flush()

    mode = sys.argv[1]
    del sys.argv[1]
    if mode == 'new':
        cold_parser = parse_tools.cold_parser()
        opts = parse_tools.two_stage_parse(cold_parser)
    elif mode == 'resume':
        resume_parser = parse_tools.resume_parser()
        opts = resume_parser.parse_args()

    opts.device = None
    if not opts.disable_cuda and torch.cuda.is_available():
        opts.device = torch.device('cuda')
        print('Using GPU', file=stderr)
    else:
        opts.device = torch.device('cpu')
        print('Using CPU', file=stderr)
    stderr.flush()

    ckpt_path = util.CheckpointPath(opts.ckpt_template)
    learning_rates = dict(
        zip(opts.learning_rate_steps, opts.learning_rate_rates))

    # Construct model
    if mode == 'new':
        # Initialize model
        pre_params = parse_tools.get_prefixed_items(vars(opts), 'pre_')
        enc_params = parse_tools.get_prefixed_items(vars(opts), 'enc_')
        bn_params = parse_tools.get_prefixed_items(vars(opts), 'bn_')
        dec_params = parse_tools.get_prefixed_items(vars(opts), 'dec_')

        # Initialize data
        data_source = data.Slice(opts.index_file_prefix,
                                 opts.max_gpu_data_bytes, opts.n_batch)

        dec_params['n_speakers'] = data_source.num_speakers()

        model = ae.AutoEncoder(pre_params, enc_params, bn_params, dec_params,
                               opts.n_sam_per_slice)
        optim = torch.optim.Adam(params=model.parameters(),
                                 lr=learning_rates[0])
        state = checkpoint.State(0, model, data_source, optim)

    else:
        state = checkpoint.State()
        state.load(opts.ckpt_file)
        state.model.set_slice_size(opts.n_sam_per_slice)
        print('Restored model, data, and optim from {}'.format(opts.ckpt_file),
              file=stderr)
        #print('Data state: {}'.format(state.data), file=stderr)
        #print('Model state: {}'.format(state.model.checksum()))
        #print('Optim state: {}'.format(state.optim_checksum()))
        stderr.flush()

    start_step = state.step

    print('Model input size: {}'.format(state.model.input_size), file=stderr)
    stderr.flush()

    # set this to zero if you want to print out a logging header in resume mode as well
    netmisc.set_print_iter(0)

    state.data.init_geometry(state.model.preprocess.rf, state.model)

    #state.data.set_geometry(opts.n_batch, state.model.input_size,
    #        state.model.output_size)
    state.to(device=opts.device)

    # Initialize optimizer
    metrics = ae.Metrics(state)
    batch_gen = state.data.batch_slice_gen_fn()

    #for p in list(state.model.encoder.parameters()):
    #    with torch.no_grad():
    #        p *= 1

    # Start training
    print('Training parameters used:', file=stderr)
    pprint(opts, stderr)

    state.init_torch_generator()

    while state.step < opts.max_steps:
        if state.step in learning_rates:
            state.update_learning_rate(learning_rates[state.step])
        # do 'pip install --upgrade scipy' if you get 'FutureWarning: ...'
        # print('in main loop')

        if state.step in (1, 10, 50, 100, 300,
                          500) and state.model.bn_type == 'vqvae':
            print('Reinitializing embed with current distribution',
                  file=stderr)
            stderr.flush()
            state.model.init_vq_embed(batch_gen)

        metrics.update(batch_gen)
        loss = metrics.state.optim.step(metrics.loss)
        avg_peak_dist = metrics.peak_dist()
        avg_max = metrics.avg_max()
        avg_prob_target = metrics.avg_prob_target()

        if False:
            for n, p in list(state.model.encoder.named_parameters()):
                g = p.grad
                if g is None:
                    print('{:60s}\tNone'.format(n), file=stderr)
                else:
                    fmt = '{:s}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}'
                    print(fmt.format(n, g.max(), g.min(), g.mean(), g.std()),
                          file=stderr)

        # Progress reporting
        if state.step % opts.progress_interval == 0:
            current_stats = {
                'step': state.step,
                'loss': loss,
                'tprb_m': avg_prob_target,
                'pk_d_m': avg_peak_dist
            }
            #fmt = "M\t{:d}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}"
            #print(fmt.format(state.step, loss, avg_prob_target, avg_peak_dist,
            #    avg_max), file=stderr)
            if state.model.bn_type == 'vqvae':
                current_stats.update(state.model.objective.metrics)

            netmisc.print_metrics(current_stats, 1000000)
            stderr.flush()

        state.step += 1

        # Checkpointing
        if ((state.step % opts.save_interval == 0 and state.step != start_step)
                or (mode == 'new' and state.step == 1)):
            ckpt_file = ckpt_path.path(state.step)
            state.save(ckpt_file)
            print('Saved checkpoint to {}'.format(ckpt_file), file=stderr)
            #print('Optim state: {}'.format(state.optim_checksum()), file=stderr)
            stderr.flush()
Exemple #4
0
    def __init__(self, mode, opts):
        print('Initializing model and data source...', end='', file=stderr)
        stderr.flush()
        self.learning_rates = dict(zip(opts.learning_rate_steps,
            opts.learning_rate_rates))
        self.opts = opts

        if mode == 'new':
            torch.manual_seed(opts.random_seed)
            pre_par = parse_tools.get_prefixed_items(vars(opts), 'pre_')
            enc_par = parse_tools.get_prefixed_items(vars(opts), 'enc_')
            bn_par = parse_tools.get_prefixed_items(vars(opts), 'bn_')
            dec_par = parse_tools.get_prefixed_items(vars(opts), 'dec_')

            # Initialize data
            jprob = dec_par.pop('jitter_prob')
            dataset = data.Slice(opts.n_batch, opts.n_win_batch, jprob,
                    pre_par['sample_rate'], pre_par['mfcc_win_sz'],
                    pre_par['mfcc_hop_sz'], pre_par['n_mels'],
                    pre_par['n_mfcc'])
            dataset.load_data(opts.dat_file)
            dec_par['n_speakers'] = dataset.num_speakers()
            model = ae.AutoEncoder(pre_par, enc_par, bn_par, dec_par,
                    dataset.num_mel_chan(), training=True)
            model.post_init(dataset)
            dataset.post_init(model)
            optim = torch.optim.Adam(params=model.parameters(), lr=self.learning_rates[0])
            self.state = checkpoint.State(0, model, dataset, optim)
            self.start_step = self.state.step

        else:
            self.state = checkpoint.State()
            self.state.load(opts.ckpt_file, opts.dat_file)
            self.start_step = self.state.step
            # print('Restored model, data, and optim from {}'.format(opts.ckpt_file), file=stderr)
            #print('Data state: {}'.format(state.data), file=stderr)
            #print('Model state: {}'.format(state.model.checksum()))
            #print('Optim state: {}'.format(state.optim_checksum()))
            stderr.flush()

        self.ckpt_path = util.CheckpointPath(self.opts.ckpt_template)
        self.quant = None
        self.target = None
        self.softmax = torch.nn.Softmax(1) # input to this is (B, Q, N)

        if self.opts.hwtype == 'GPU':
            self.device = torch.device('cuda')
            self.data_loader = self.state.data_loader
            self.data_loader.set_target_device(self.device)
            self.optim_step_fn = (lambda: self.state.optim.step(self.loss_fn))
            self.data_iter = GPULoaderIter(iter(self.data_loader))
        else:
            import torch_xla.core.xla_model as xm
            import torch_xla.distributed.parallel_loader as pl
            self.device = xm.xla_device()
            self.data_loader = pl.ParallelLoader(self.state.data_loader, [self.device])
            self.data_iter = TPULoaderIter(self.data_loader, self.device)
            self.optim_step_fn = (lambda : xm.optimizer_step(self.state.optim,
                    optimizer_args={'closure': self.loss_fn}))

        self.state.init_torch_generator()
        print('Done.', file=stderr)
        stderr.flush()
Exemple #5
0
def main():
    if len(sys.argv) == 1 or sys.argv[1] not in ('new', 'resume'):
        print(parse_tools.top_usage, file=stderr)
        return

    mode = sys.argv[1]
    del sys.argv[1]
    if mode == 'new':
        opts = parse_tools.two_stage_parse(parse_tools.cold)
    elif mode == 'resume':
        opts = parse_tools.resume.parse_args()  

    opts.device = None
    if not opts.disable_cuda and torch.cuda.is_available():
        opts.device = torch.device('cuda')
    else:
        opts.device = torch.device('cpu') 

    ckpt_path = util.CheckpointPath(opts.ckpt_template)

    # Construct model
    if mode == 'new':
        # Initialize model
        pre_params = parse_tools.get_prefixed_items(vars(opts), 'pre_')
        enc_params = parse_tools.get_prefixed_items(vars(opts), 'enc_')
        bn_params = parse_tools.get_prefixed_items(vars(opts), 'bn_')
        dec_params = parse_tools.get_prefixed_items(vars(opts), 'dec_')

        # Initialize data
        sample_catalog = D.parse_sample_catalog(opts.sam_file)
        data = D.WavSlices(sample_catalog, pre_params['sample_rate'],
                opts.frac_permutation_use, opts.requested_wav_buf_sz)
        dec_params['n_speakers'] = data.num_speakers()

        #with torch.autograd.set_detect_anomaly(True):
        model = ae.AutoEncoder(pre_params, enc_params, bn_params, dec_params)
        print('Initializing model parameters', file=stderr)
        model.initialize_weights()

        # Construct overall state
        state = checkpoint.State(0, model, data)

    else:
        state = checkpoint.State()
        state.load(opts.ckpt_file)
        print('Restored model and data from {}'.format(opts.ckpt_file), file=stderr)

    state.model.set_geometry(opts.n_sam_per_slice)

    state.data.set_geometry(opts.n_batch, state.model.input_size,
            state.model.output_size)

    state.model.to(device=opts.device)

    #total_bytes = 0
    #for name, par in model.named_parameters():
    #    n_bytes = par.data.nelement() * par.data.element_size()
    #    total_bytes += n_bytes
    #    print(name, type(par.data), par.size(), n_bytes)
    #print('total_bytes: ', total_bytes)

    # Initialize optimizer
    model_params = state.model.parameters()
    metrics = ae.Metrics(state.model, None)
    batch_gen = state.data.batch_slice_gen_fn()

    #loss_fcn = state.model.loss_factory(state.data.batch_slice_gen_fn())

    # Start training
    print('Starting training...', file=stderr)
    print("Step\tLoss\tAvgProbTarget\tPeakDist\tAvgMax", file=stderr)
    stderr.flush()

    learning_rates = dict(zip(opts.learning_rate_steps, opts.learning_rate_rates))
    start_step = state.step
    if start_step not in learning_rates:
        ref_step = util.greatest_lower_bound(opts.learning_rate_steps, start_step)
        metrics.optim = torch.optim.Adam(params=model_params,
                lr=learning_rates[ref_step])

    while state.step < opts.max_steps:
        if state.step in learning_rates:
            metrics.optim = torch.optim.Adam(params=model_params,
                    lr=learning_rates[state.step])
        # do 'pip install --upgrade scipy' if you get 'FutureWarning: ...'
        metrics.update(batch_gen)
        loss = metrics.optim.step(metrics.loss)
        avg_peak_dist = metrics.peak_dist()
        avg_max = metrics.avg_max()
        avg_prob_target = metrics.avg_prob_target()

        # Progress reporting
        if state.step % opts.progress_interval == 0:
            fmt = "{}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}"
            print(fmt.format(state.step, loss, avg_prob_target, avg_peak_dist,
                avg_max), file=stderr)
            stderr.flush()

        # Checkpointing
        if state.step % opts.save_interval == 0 and state.step != start_step:
            ckpt_file = ckpt_path.path(state.step)
            state.save(ckpt_file)
            print('Saved checkpoint to {}'.format(ckpt_file), file=stderr)

        state.step += 1