def test_assert_different_length_batch_generation(): # prepare batch batch = 4 length = 32 x = np.random.randint(0, 256, size=(batch, 1)) h = np.random.randn(batch, 28, length) length_list = sorted( list(np.random.randint(length // 2, length - 1, batch))) with torch.no_grad(): net = WaveNet(256, 28, 4, 4, 10, 3, 2) net.apply(initialize) net.eval() # sample-by-sample generation gen1_list = [] for x_, h_, length in zip(x, h, length_list): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() gen1 = net.fast_generate(batch_x, batch_h, length, 1, "argmax") gen1_list += [gen1] # batch generation batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() gen2_list = net.batch_fast_generate(batch_x, batch_h, length_list, 1, "argmax") # assertion for gen1, gen2 in zip(gen1_list, gen2_list): np.testing.assert_array_equal(gen1, gen2)
def gpu_decode(feat_list, gpu): # set default gpu and do not track gradient torch.cuda.set_device(gpu) torch.set_grad_enabled(False) # define model and load parameters if config.use_upsampling_layer: upsampling_factor = config.upsampling_factor else: upsampling_factor = 0 model = WaveNet(n_quantize=config.n_quantize, n_aux=config.n_aux, n_resch=config.n_resch, n_skipch=config.n_skipch, dilation_depth=config.dilation_depth, dilation_repeat=config.dilation_repeat, kernel_size=config.kernel_size, upsampling_factor=upsampling_factor) model.load_state_dict( torch.load(args.checkpoint, map_location=lambda storage, loc: storage)["model"]) model.eval() model.cuda() # define generator generator = decode_generator( feat_list, batch_size=args.batch_size, feature_type=config.feature_type, wav_transform=wav_transform, feat_transform=feat_transform, upsampling_factor=config.upsampling_factor, use_upsampling_layer=config.use_upsampling_layer, use_speaker_code=config.use_speaker_code) # decode if args.batch_size > 1: for feat_ids, (batch_x, batch_h, n_samples_list) in generator: logging.info("decoding start") samples_list = model.batch_fast_generate( batch_x, batch_h, n_samples_list, args.intervals) for feat_id, samples in zip(feat_ids, samples_list): wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir)) else: for feat_id, (x, h, n_samples) in generator: logging.info("decoding %s (length = %d)" % (feat_id, n_samples)) samples = model.fast_generate(x, h, n_samples, args.intervals) wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir))
def gpu_decode(feat_list, gpu): with torch.cuda.device(gpu): # define model and load parameters model = WaveNet(n_quantize=config.n_quantize, n_aux=config.n_aux, n_resch=config.n_resch, n_skipch=config.n_skipch, dilation_depth=config.dilation_depth, dilation_repeat=config.dilation_repeat, kernel_size=config.kernel_size, upsampling_factor=config.upsampling_factor) model.load_state_dict( torch.load(args.checkpoint, map_location=lambda storage, loc: storage.cuda(gpu)) ["model"]) model.eval() model.cuda() torch.backends.cudnn.benchmark = True # define generator generator = decode_generator( feat_list, batch_size=args.batch_size, wav_transform=wav_transform, feat_transform=feat_transform, use_speaker_code=config.use_speaker_code, upsampling_factor=config.upsampling_factor) # decode if args.batch_size > 1: for feat_ids, (batch_x, batch_h, n_samples_list) in generator: logging.info("decoding start") samples_list = model.batch_fast_generate( batch_x, batch_h, n_samples_list, args.intervals) for feat_id, samples in zip(feat_ids, samples_list): wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir)) else: for feat_id, (x, h, n_samples) in generator: logging.info("decoding %s (length = %d)" % (feat_id, n_samples)) samples = model.fast_generate(x, h, n_samples, args.intervals) wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir))
def test_generate(): batch = 2 x = np.random.randint(0, 256, size=(batch, 1)) h = np.random.randn(batch, 28, 32) length = h.shape[-1] - 1 with torch.no_grad(): net = WaveNet(256, 28, 16, 32, 10, 3, 2) net.apply(initialize) net.eval() for x_, h_ in zip(x, h): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() net.generate(batch_x, batch_h, length, 1, "sampling") net.fast_generate(batch_x, batch_h, length, 1, "sampling") batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "sampling")
def test_forward(): # get batch generator = sine_generator(100) batch = next(generator) batch_input = batch.view(1, -1) batch_aux = torch.rand(1, 28, batch_input.size(1)).float() # define model without upsampling with kernel size = 2 net = WaveNet(256, 28, 32, 128, 10, 1, 2) net.apply(initialize) net.eval() y = net(batch_input, batch_aux)[0] assert y.size(0) == batch_input.size(1) assert y.size(1) == 256 # define model without upsampling with kernel size = 3 net = WaveNet(256, 28, 32, 128, 10, 1, 2) net.apply(initialize) net.eval() y = net(batch_input, batch_aux)[0] assert y.size(0) == batch_input.size(1) assert y.size(1) == 256 batch_input = batch.view(1, -1) batch_aux = torch.rand(1, 28, batch_input.size(1) // 10).float() # define model with upsampling and kernel size = 2 net = WaveNet(256, 28, 32, 128, 10, 1, 2, 10) net.apply(initialize) net.eval() y = net(batch_input, batch_aux)[0] assert y.size(0) == batch_input.size(1) assert y.size(1) == 256 # define model with upsampling and kernel size = 3 net = WaveNet(256, 28, 32, 128, 10, 1, 3, 10) net.apply(initialize) net.eval() y = net(batch_input, batch_aux)[0] assert y.size(0) == batch_input.size(1) assert y.size(1) == 256
def main(args): print('Starting') matplotlib.use('agg') os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu checkpoints = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth') checkpoints = [c for c in checkpoints if extract_id(c) in args.decoders] assert len(checkpoints) >= 1, "No checkpoints found." model_args = torch.load(args.model.parent / 'args.pth')[0] encoder = wavenet_models.Encoder(model_args) encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state']) encoder.eval() encoder = encoder.cuda() decoders = [] decoder_ids = [] for checkpoint in checkpoints: decoder = WaveNet(model_args) decoder.load_state_dict(torch.load(checkpoint)['decoder_state']) decoder.eval() decoder = decoder.cuda() if args.py: decoder = WavenetGenerator(decoder, args.batch_size, wav_freq=args.rate) else: decoder = NVWavenetGenerator(decoder, args.rate * (args.split_size // 20), args.batch_size, 3) decoders += [decoder] decoder_ids += [extract_id(checkpoint)] xs = [] assert args.output_next_to_orig ^ (args.output is not None) if len(args.files) == 1 and args.files[0].is_dir(): top = args.files[0] file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5')) else: file_paths = args.files if not args.skip_filter: file_paths = [f for f in file_paths if not '_' in str(f.name)] for file_path in file_paths: if file_path.suffix == '.wav': data, rate = librosa.load(file_path, sr=16000) assert rate == 16000 data = utils.mu_law(data) elif file_path.suffix == '.h5': data = utils.mu_law(h5py.File(file_path, 'r')['wav'][:] / (2**15)) if data.shape[-1] % args.rate != 0: data = data[:-(data.shape[-1] % args.rate)] assert data.shape[-1] % args.rate == 0 else: raise Exception(f'Unsupported filetype {file_path}') if args.sample_len: data = data[:args.sample_len] else: args.sample_len = len(data) xs.append(torch.tensor(data).unsqueeze(0).float().cuda()) xs = torch.stack(xs).contiguous() print(f'xs size: {xs.size()}') def save(x, decoder_ix, filepath): wav = utils.inv_mu_law(x.cpu().numpy()) print(f'X size: {x.shape}') print(f'X min: {x.min()}, max: {x.max()}') if args.output_next_to_orig: save_audio(wav.squeeze(), filepath.parent / f'{filepath.stem}_{decoder_ix}.wav', rate=args.rate) else: save_audio(wav.squeeze(), args.output / str(extract_id(args.model)) / str(args.update) / filepath.with_suffix('.wav').name, rate=args.rate) yy = {} with torch.no_grad(): zz = [] for xs_batch in torch.split(xs, args.batch_size): zz += [encoder(xs_batch)] zz = torch.cat(zz, dim=0) with utils.timeit("Generation timer"): for i, decoder_id in enumerate(decoder_ids): yy[decoder_id] = [] decoder = decoders[i] for zz_batch in torch.split(zz, args.batch_size): print(zz_batch.shape) splits = torch.split(zz_batch, args.split_size, -1) audio_data = [] decoder.reset() for cond in tqdm.tqdm(splits): audio_data += [decoder.generate(cond).cpu()] audio_data = torch.cat(audio_data, -1) yy[decoder_id] += [audio_data] yy[decoder_id] = torch.cat(yy[decoder_id], dim=0) del decoder for decoder_ix, decoder_result in yy.items(): for sample_result, filepath in zip(decoder_result, file_paths): save(sample_result, decoder_ix, filepath)
def test_assert_fast_generation(): # get batch batch = 2 x = np.random.randint(0, 256, size=(batch, 1)) h = np.random.randn(batch, 28, 32) length = h.shape[-1] - 1 with torch.no_grad(): # -------------------------------------------------------- # define model without upsampling and with kernel size = 2 # -------------------------------------------------------- net = WaveNet(256, 28, 4, 4, 10, 3, 2) net.apply(initialize) net.eval() # sample-by-sample generation gen1_list = [] gen2_list = [] for x_, h_ in zip(x, h): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() gen1 = net.generate(batch_x, batch_h, length, 1, "argmax") gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax") np.testing.assert_array_equal(gen1, gen2) gen1_list += [gen1] gen2_list += [gen2] gen1 = np.stack(gen1_list) gen2 = np.stack(gen2_list) np.testing.assert_array_equal(gen1, gen2) # batch generation batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "argmax") gen3 = np.stack(gen3_list) np.testing.assert_array_equal(gen3, gen2) # -------------------------------------------------------- # define model without upsampling and with kernel size = 3 # -------------------------------------------------------- net = WaveNet(256, 28, 4, 4, 10, 3, 3) net.apply(initialize) net.eval() # sample-by-sample generation gen1_list = [] gen2_list = [] for x_, h_ in zip(x, h): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() gen1 = net.generate(batch_x, batch_h, length, 1, "argmax") gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax") np.testing.assert_array_equal(gen1, gen2) gen1_list += [gen1] gen2_list += [gen2] gen1 = np.stack(gen1_list) gen2 = np.stack(gen2_list) np.testing.assert_array_equal(gen1, gen2) # batch generation batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "argmax") gen3 = np.stack(gen3_list) np.testing.assert_array_equal(gen3, gen2) # get batch batch = 2 upsampling_factor = 10 x = np.random.randint(0, 256, size=(batch, 1)) h = np.random.randn(batch, 28, 3) length = h.shape[-1] * upsampling_factor - 1 # ----------------------------------------------------- # define model with upsampling and with kernel size = 2 # ----------------------------------------------------- net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor) net.apply(initialize) net.eval() # sample-by-sample generation gen1_list = [] gen2_list = [] for x_, h_ in zip(x, h): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() gen1 = net.generate(batch_x, batch_h, length, 1, "argmax") gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax") np.testing.assert_array_equal(gen1, gen2) gen1_list += [gen1] gen2_list += [gen2] gen1 = np.stack(gen1_list) gen2 = np.stack(gen2_list) np.testing.assert_array_equal(gen1, gen2) # batch generation batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "argmax") gen3 = np.stack(gen3_list) np.testing.assert_array_equal(gen3, gen2) # ----------------------------------------------------- # define model with upsampling and with kernel size = 3 # ----------------------------------------------------- net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor) net.apply(initialize) net.eval() # sample-by-sample generation gen1_list = [] gen2_list = [] for x_, h_ in zip(x, h): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() gen1 = net.generate(batch_x, batch_h, length, 1, "argmax") gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax") np.testing.assert_array_equal(gen1, gen2) gen1_list += [gen1] gen2_list += [gen2] gen1 = np.stack(gen1_list) gen2 = np.stack(gen2_list) np.testing.assert_array_equal(gen1, gen2) # batch generation batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "argmax") gen3 = np.stack(gen3_list) np.testing.assert_array_equal(gen3, gen2)
class Trainer: def __init__(self, args): self.args = args self.args.n_datasets = len(self.args.data) self.expPath = Path('checkpoints') / args.expName torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) self.logger = create_output_dir(args, self.expPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_total = LossMeter('eval total') self.encoder = Encoder(args) self.decoder = WaveNet(args) assert args.checkpoint, 'you MUST pass a checkpoint for the encoder' if args.continue_training: checkpoint_args_path = os.path.dirname( args.checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) self.start_epoch = checkpoint_args[-1] + 1 else: self.start_epoch = 0 states = torch.load(args.checkpoint) self.encoder.load_state_dict(states['encoder_state']) if args.continue_training: self.decoder.load_state_dict(states['decoder_state']) self.logger.info('Loaded checkpoint parameters') self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(self.decoder.parameters(), lr=args.lr) if args.continue_training: self.model_optimizer.load_state_dict( states['model_optimizer_state']) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.last_epoch = self.start_epoch self.lr_manager.step() def eval_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() z = self.encoder(x) y = self.decoder(x, z) recon_loss = cross_entropy_loss(y, x) self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) total_loss = recon_loss.mean().data.item() self.eval_total.add(total_loss) return total_loss def train_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() # optimize G - reconstructs well z = self.encoder(x_aug) z = z.detach() # stop gradients y = self.decoder(x, z) recon_loss = cross_entropy_loss(y, x) self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) loss = recon_loss.mean() self.model_optimizer.zero_grad() loss.backward() if self.args.grad_clip is not None: clip_grad_value_(self.decoder.parameters(), self.args.grad_clip) self.model_optimizer.step() self.loss_total.add(loss.data.item()) return loss.data.item() def train_epoch(self, epoch): for meter in self.losses_recon: meter.reset() self.loss_total.reset() self.encoder.eval() self.decoder.train() n_batches = self.args.epoch_len with tqdm(total=n_batches, desc='Train epoch %d' % epoch) as train_enum: for batch_num in range(n_batches): if self.args.short and batch_num == 3: break dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].train_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.train_batch(x, x_aug, dset_num) train_enum.set_description( f'Train (loss: {batch_loss:.2f}) epoch {epoch}') train_enum.update() def evaluate_epoch(self, epoch): for meter in self.evals_recon: meter.reset() self.eval_total.reset() self.encoder.eval() self.decoder.eval() n_batches = int(np.ceil(self.args.epoch_len / 10)) with tqdm(total=n_batches) as valid_enum, \ torch.no_grad(): for batch_num in range(n_batches): if self.args.short and batch_num == 10: break dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].valid_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.eval_batch(x, x_aug, dset_num) valid_enum.set_description( f'Test (loss: {batch_loss:.2f}) epoch {epoch}') valid_enum.update() @staticmethod def format_losses(meters): losses = [meter.summarize_epoch() for meter in meters] return ', '.join('{:.4f}'.format(x) for x in losses) def train_losses(self): meters = [*self.losses_recon] return self.format_losses(meters) def eval_losses(self): meters = [*self.evals_recon] return self.format_losses(meters) def train(self): best_eval = float('inf') # Begin! for epoch in range(self.start_epoch, self.start_epoch + self.args.epochs): self.logger.info( f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}' ) self.train_epoch(epoch) self.evaluate_epoch(epoch) self.logger.info( f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)', epoch, self.train_losses(), self.eval_losses()) self.lr_manager.step() val_loss = self.eval_total.summarize_epoch() if val_loss < best_eval: self.save_model(f'bestmodel_{self.args.rank}.pth') best_eval = val_loss if not self.args.per_epoch: self.save_model(f'lastmodel_{self.args.rank}.pth') else: self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth') torch.save([self.args, epoch], '%s/args.pth' % self.expPath) self.logger.debug('Ended epoch') def save_model(self, filename): save_path = self.expPath / filename states = torch.load(self.args.checkpoint) torch.save( { 'encoder_state': states['encoder_state'], 'decoder_state': self.decoder.module.state_dict(), 'model_optimizer_state': self.model_optimizer.state_dict(), 'dataset': self.args.rank, }, save_path) self.logger.debug(f'Saved model to {save_path}')
class Trainer: def __init__(self, args): self.args = args self.args.n_datasets = len(self.args.data) self.expPath = Path('checkpoints') / args.expName torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) self.logger = create_output_dir(args, self.expPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] assert not args.distributed or len(self.data) == int( os.environ['WORLD_SIZE'] ), "Number of datasets must match number of nodes" self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_d_right = LossMeter('d') self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_d_right = LossMeter('eval d') self.eval_total = LossMeter('eval total') self.encoder = Encoder(args) self.decoder = WaveNet(args) self.discriminator = ZDiscriminator(args) if args.checkpoint: checkpoint_args_path = os.path.dirname( args.checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) self.start_epoch = checkpoint_args[-1] + 1 states = torch.load(args.checkpoint) self.encoder.load_state_dict(states['encoder_state']) self.decoder.load_state_dict(states['decoder_state']) self.discriminator.load_state_dict(states['discriminator_state']) self.logger.info('Loaded checkpoint parameters') else: self.start_epoch = 0 if args.distributed: self.encoder.cuda() self.encoder = torch.nn.parallel.DistributedDataParallel( self.encoder) self.discriminator.cuda() self.discriminator = torch.nn.parallel.DistributedDataParallel( self.discriminator) self.logger.info('Created DistributedDataParallel') else: self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.discriminator = torch.nn.DataParallel( self.discriminator).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(chain(self.encoder.parameters(), self.decoder.parameters()), lr=args.lr) self.d_optimizer = optim.Adam(self.discriminator.parameters(), lr=args.lr) if args.checkpoint and args.load_optimizer: self.model_optimizer.load_state_dict( states['model_optimizer_state']) self.d_optimizer.load_state_dict(states['d_optimizer_state']) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.last_epoch = self.start_epoch self.lr_manager.step() def eval_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() z = self.encoder(x) y = self.decoder(x, z) z_logits = self.discriminator(z) z_classification = torch.max(z_logits, dim=1)[1] z_accuracy = (z_classification == dset_num).float().mean() self.eval_d_right.add(z_accuracy.data.item()) # discriminator_right = F.cross_entropy(z_logits, dset_num).mean() discriminator_right = F.cross_entropy( z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean() recon_loss = cross_entropy_loss(y, x) self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) total_loss = discriminator_right.data.item() * self.args.d_lambda + \ recon_loss.mean().data.item() self.eval_total.add(total_loss) return total_loss def train_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() # Optimize D - discriminator right z = self.encoder(x) z_logits = self.discriminator(z) discriminator_right = F.cross_entropy( z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean() loss = discriminator_right * self.args.d_lambda self.d_optimizer.zero_grad() loss.backward() if self.args.grad_clip is not None: clip_grad_value_(self.discriminator.parameters(), self.args.grad_clip) self.d_optimizer.step() # optimize G - reconstructs well, discriminator wrong z = self.encoder(x_aug) y = self.decoder(x, z) z_logits = self.discriminator(z) discriminator_wrong = -F.cross_entropy( z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean() if not (-100 < discriminator_right.data.item() < 100): self.logger.debug(f'z_logits: {z_logits.detach().cpu().numpy()}') self.logger.debug(f'dset_num: {dset_num}') recon_loss = cross_entropy_loss(y, x) self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) loss = (recon_loss.mean() + self.args.d_lambda * discriminator_wrong) self.model_optimizer.zero_grad() loss.backward() if self.args.grad_clip is not None: clip_grad_value_(self.encoder.parameters(), self.args.grad_clip) clip_grad_value_(self.decoder.parameters(), self.args.grad_clip) self.model_optimizer.step() self.loss_total.add(loss.data.item()) return loss.data.item() def train_epoch(self, epoch): for meter in self.losses_recon: meter.reset() self.loss_d_right.reset() self.loss_total.reset() self.encoder.train() self.decoder.train() self.discriminator.train() n_batches = self.args.epoch_len with tqdm(total=n_batches, desc='Train epoch %d' % epoch) as train_enum: for batch_num in range(n_batches): if self.args.short and batch_num == 3: break if self.args.distributed: assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset" # dset_num = (batch_num + self.args.rank) % self.args.n_datasets dset_num = self.args.rank else: dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].train_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.train_batch(x, x_aug, dset_num) train_enum.set_description( f'Train (loss: {batch_loss:.2f}) epoch {epoch}') train_enum.update() def evaluate_epoch(self, epoch): for meter in self.evals_recon: meter.reset() self.eval_d_right.reset() self.eval_total.reset() self.encoder.eval() self.decoder.eval() self.discriminator.eval() n_batches = int(np.ceil(self.args.epoch_len / 10)) with tqdm(total=n_batches) as valid_enum, \ torch.no_grad(): for batch_num in range(n_batches): if self.args.short and batch_num == 10: break if self.args.distributed: assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset" dset_num = self.args.rank else: dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].valid_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.eval_batch(x, x_aug, dset_num) valid_enum.set_description( f'Test (loss: {batch_loss:.2f}) epoch {epoch}') valid_enum.update() @staticmethod def format_losses(meters): losses = [meter.summarize_epoch() for meter in meters] return ', '.join('{:.4f}'.format(x) for x in losses) def train_losses(self): meters = [*self.losses_recon, self.loss_d_right] return self.format_losses(meters) def eval_losses(self): meters = [*self.evals_recon, self.eval_d_right] return self.format_losses(meters) def train(self): best_eval = float('inf') # Begin! for epoch in range(self.start_epoch, self.start_epoch + self.args.epochs): self.logger.info( f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}' ) self.train_epoch(epoch) self.evaluate_epoch(epoch) self.logger.info( f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)', epoch, self.train_losses(), self.eval_losses()) self.lr_manager.step() val_loss = self.eval_total.summarize_epoch() if val_loss < best_eval: self.save_model(f'bestmodel_{self.args.rank}.pth') best_eval = val_loss if not self.args.per_epoch: self.save_model(f'lastmodel_{self.args.rank}.pth') else: self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth') if self.args.is_master: torch.save([self.args, epoch], '%s/args.pth' % self.expPath) self.logger.debug('Ended epoch') def save_model(self, filename): save_path = self.expPath / filename torch.save( { 'encoder_state': self.encoder.module.state_dict(), 'decoder_state': self.decoder.module.state_dict(), 'discriminator_state': self.discriminator.module.state_dict(), 'model_optimizer_state': self.model_optimizer.state_dict(), 'dataset': self.args.rank, 'd_optimizer_state': self.d_optimizer.state_dict() }, save_path) self.logger.debug(f'Saved model to {save_path}')
def main(): args = parse_args() cfg.resume = args.resume cfg.exp_name = args.exp cfg.work_root = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' cfg.workdir = cfg.work_root + args.exp + '/debug' cfg.sparse_mode = args.sparse_mode cfg.batch_size = args.batch_size cfg.lr = args.lr cfg.load_from = args.load_from cfg.save_excel = args.save_excel if args.find_pattern == True: cfg.find_pattern_num = 16 cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])] cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0]) cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1]) if int(cfg.find_pattern_shape[0] * cfg.find_pattern_shape[1]) <= cfg.find_score_threshold: exit() if args.skip_exist == True: if os.path.exists(cfg.workdir): exit() print('initial training...') print(f'work_dir:{cfg.workdir}, \n\ pretrained: {cfg.load_from}, \n\ batch_size: {cfg.batch_size}, \n\ lr : {cfg.lr}, \n\ epochs : {cfg.epochs}, \n\ sparse : {cfg.sparse_mode}') writer = SummaryWriter(log_dir=cfg.workdir+'/runs') # build train data vctk_train = VCTK(cfg, 'train') train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=4, shuffle=True, pin_memory=True) # train_loader = dataset.create("data/v28/train.record", cfg.batch_size, repeat=True) vctk_val = VCTK(cfg, 'val') if args.test_acc_cmodel == True: val_loader = DataLoader(vctk_val, batch_size=1, num_workers=4, shuffle=False, pin_memory=True) else: val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=4, shuffle=False, pin_memory=True) # build model model = WaveNet(num_classes=28, channels_in=40, dilations=[1,2,4,8,16]) model = nn.DataParallel(model) model.cuda() name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) a = model.state_dict() for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" \ and name.split(".")[-2] != "bn2" \ and name.split(".")[-2] != "bn3" \ and name.split(".")[-1] != "bias": raw_w = para_list[i] nn.init.xavier_normal_(raw_w, gain=1.0) a[name] = raw_w model.load_state_dict(a) weights_dir = os.path.join(cfg.workdir, 'weights') if not os.path.exists(weights_dir): os.mkdir(weights_dir) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) if args.vis_pattern == True or args.vis_mask == True: cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) model.train() if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'): model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=True) print("loading", cfg.workdir + '/weights/best.pth') cfg.load_from = cfg.workdir + '/weights/best.pth' if args.test_acc == True: if os.path.exists(cfg.load_from): model.load_state_dict(torch.load(cfg.load_from), strict=True) print("loading", cfg.load_from) else: print("Error: model file not exists, ", cfg.load_from) exit() else: if os.path.exists(cfg.load_from): model.load_state_dict(torch.load(cfg.load_from), strict=True) print("loading", cfg.load_from) # Export the model print("exporting onnx ...") model.eval() batch_size = 1 x = torch.randn(batch_size, 40, 720, requires_grad=True).cuda() torch.onnx.export(model.module, # model being run x, # model input (or a tuple for multiple inputs) "wavenet.onnx", # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=10, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['input'], # the model's input names output_names = ['output'], # the model's output names dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes 'output' : {0 : 'batch_size'}}) if os.path.exists(args.load_from_h5): # model.load_state_dict(torch.load(args.load_from_h5), strict=True) print("loading", args.load_from_h5) model.train() model_dict = model.state_dict() print(model_dict.keys()) #先将参数值numpy转换为tensor形式 pretrained_dict = dd.io.load(args.load_from_h5) print(pretrained_dict.keys()) new_pre_dict = {} for k,v in pretrained_dict.items(): new_pre_dict[k] = torch.Tensor(v) #更新 model_dict.update(new_pre_dict) #加载 model.load_state_dict(model_dict) if args.find_pattern == True: # cfg.find_pattern_num = 16 # cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])] # cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0]) # cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1]) # if cfg.find_pattern_shape[0] * cfg.find_pattern_shape[0] <= cfg.find_score_threshold: # exit() name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) a = model.state_dict() for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" \ and name.split(".")[-2] != "bn2" \ and name.split(".")[-2] != "bn3" \ and name.split(".")[-1] != "bias": raw_w = para_list[i] if raw_w.size(0) == 128 and raw_w.size(1) == 128: patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz, pattern_inner_nnz \ = find_pattern_by_similarity(raw_w , cfg.find_pattern_num , cfg.find_pattern_shape , cfg.find_zero_threshold , cfg.find_score_threshold) pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict \ = pattern_curve_analyse(raw_w.shape , cfg.find_pattern_shape , patterns , pattern_match_num , pattern_coo_nnz , pattern_nnz , pattern_inner_nnz) write_pattern_curve_analyse(os.path.join(cfg.work_root, args.save_pattern_count_excel) , cfg.exp_name + " " + args.find_pattern_shape + " " + args.find_pattern_para , patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz , pattern_inner_nnz , pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict) # write_pattern_count(os.path.join(cfg.work_root, args.save_pattern_count_excel) # , cfg.exp_name + " " + args.find_pattern_shape +" " + args.find_pattern_para # , all_nnzs.values(), all_patterns.values()) exit() if cfg.sparse_mode == 'sparse_pruning': cfg.sparsity = args.sparsity print(f'sparse_pruning {cfg.sparsity}') elif cfg.sparse_mode == 'pattern_pruning': print(args.pattern_para) pattern_num = int(args.pattern_para.split('_')[0]) pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])] pattern_nnz = int(args.pattern_para.split('_')[3]) print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}') cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) elif cfg.sparse_mode == 'coo_pruning': cfg.coo_shape = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])] cfg.coo_nnz = int(args.coo_para.split('_')[2]) # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}') elif cfg.sparse_mode == 'ptcoo_pruning': cfg.pattern_num = int(args.pattern_para.split('_')[0]) cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])] cfg.pt_nnz = int(args.ptcoo_para.split('_')[3]) cfg.coo_nnz = int(args.ptcoo_para.split('_')[4]) cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}') elif cfg.sparse_mode == 'find_retrain': cfg.pattern_num = int(args.find_retrain_para.split('_')[0]) cfg.pattern_shape = [int(args.find_retrain_para.split('_')[1]), int(args.find_retrain_para.split('_')[2])] cfg.pattern_nnz = int(args.find_retrain_para.split('_')[3]) cfg.coo_num = float(args.find_retrain_para.split('_')[4]) cfg.layer_or_model_wise = str(args.find_retrain_para.split('_')[5]) # cfg.fd_rtn_pattern_candidates = generate_complete_pattern_set( # cfg.pattern_shape, cfg.pattern_nnz) print(f'find_retrain {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pattern_nnz} {cfg.coo_num} {cfg.layer_or_model_wise}') elif cfg.sparse_mode == 'hcgs_pruning': print(args.pattern_para) cfg.block_shape = [int(args.hcgs_para.split('_')[0]), int(args.hcgs_para.split('_')[1])] cfg.reserve_num1 = int(args.hcgs_para.split('_')[2]) cfg.reserve_num2 = int(args.hcgs_para.split('_')[3]) print(f'hcgs_pruning {cfg.reserve_num1}/8 {cfg.reserve_num2}/16') cfg.hcgs_mask = generate_hcgs_mask(model, cfg.block_shape, cfg.reserve_num1, cfg.reserve_num2) if args.vis_mask == True: name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" \ and name.split(".")[-2] != "bn2" \ and name.split(".")[-2] != "bn3" \ and name.split(".")[-1] != "bias": raw_w = para_list[i] zero = torch.zeros_like(raw_w) one = torch.ones_like(raw_w) mask = torch.where(raw_w == 0, zero, one) vis.save_visualized_mask(mask, name) exit() if args.vis_pattern == True: pattern_count_dict = find_pattern_model(model, [8,8]) patterns = list(pattern_count_dict.keys()) counts = list(pattern_count_dict.values()) print(len(patterns)) print(counts) vis.save_visualized_pattern(patterns) exit() # build loss loss_fn = nn.CTCLoss(blank=27) # loss_fn = nn.CTCLoss() # scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4) # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5) if args.test_acc == True: f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), cfg.exp_name, f1, val_loss, tps, preds, poses) exit() if args.test_acc_cmodel == True: f1, val_loss, tps, preds, poses = test_acc_cmodel(val_loader, model, loss_fn) # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), cfg.exp_name, f1, val_loss, tps, preds, poses) exit() # train train(train_loader, scheduler, model, loss_fn, val_loader, writer)
class AutoencoderTrainer: """training the autoencoder for the first step of training""" def __init__(self, args): self.args = args self.data = [Dataset(args, domain_path) for domain_path in args.data] self.expPath = args.checkpoint / 'Autoencoder' / args.exp_name if not self.expPath.exists(): self.expPath.mkdir(parents=True, exist_ok=True) self.logger = train_logger(self.args, self.expPath) if torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" #seed torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) #modules self.encoder = Encoder(args) self.decoder = WaveNet(args) self.discriminator = ZDiscriminator(args) self.encoder = self.encoder.to(self.device) self.decoder = self.decoder.to(self.device) self.discriminator = self.discriminator.to(self.device) #distributed if args.world_size > 1: self.encoder = torch.nn.parallel.DistributedDataParallel( self.encoder, device_ids=[torch.cuda.current_device()], output_device=torch.cuda.current_device()) self.discriminator = torch.nn.parallel.DistributedDataParallel( self.discriminator, device_ids=[torch.cuda.current_device()], output_device=torch.cuda.current_device()) self.decoder = torch.nn.parallel.DistributedDataParallel( self.decoder, device_ids=[torch.cuda.current_device()], output_device=torch.cuda.current_device()) #losses self.reconstruction_loss = [ LossManager(f'train reconstruction {i}') for i in range(len(self.data)) ] self.discriminator_loss = LossManager('train discriminator') self.total_loss = LossManager('train total') self.reconstruction_val = [ LossManager(f'validation reconstruction {i}') for i in range(len(self.data)) ] self.discriminator_val = LossManager('validation discriminator') self.total_val = LossManager('validation total') #optimizers self.autoenc_optimizer = optim.Adam(chain(self.encoder.parameters(), self.decoder.parameters()), lr=args.lr) self.discriminator_optimizer = optim.Adam( self.discriminator.parameters(), lr=args.lr) #resume training if args.resume: checkpoint_args_file = self.expPath / 'args.pth' checkpoint_args = torch.load(checkpoint_args_file) last_epoch = checkpoint_args[-1] self.start_epoch = last_epoch + 1 checkpoint_state_file = self.expPath / f'lastmodel_{last_epoch}.pth' states = torch.load(args.checkpoint_state_file) self.encoder.load_state_dict(states['encoder_state']) self.decoder.load_state_dict(states['decoder_state']) self.discriminator.load_state_dict(states['discriminator_state']) if (args.load_optimizer): self.autoenc_optimizer.load_state_dict( states['autoenc_optimizer_state']) self.discriminator_optimizer.load_state_dict( states['discriminator_optimizer_state']) self.logger.info('Loaded checkpoint parameters') else: self.start_epoch = 0 #learning rates self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.last_epoch = self.start_epoch self.lr_manager.step() def train_epoch(self, epoch): #modules self.encoder.train() self.decoder.train() self.discriminator.train() #losses for lm in self.reconstruction_loss: lm.reset() self.discriminator_loss.reset() self.total_loss.reset() total_batches = self.args.epoch_length // self.args.batch_size with tqdm(total=total_batches, desc=f'Train epoch {epoch}') as train_enum: for batch_num in range(total_batches): if self.args.world_size > 1: dataset_no = self.args.rank else: dataset_no = batch_num % self.args.n_datasets x, x_aug = next(self.data[dataset_no].train_iter) x = x.to(self.device) x_aug = x_aug.to(self.device) x, x_aug = x.float(), x_aug.float() # Train discriminator z = self.encoder(x) z_logits = self.discriminator(z) discriminator_loss = F.cross_entropy( z_logits, torch.tensor([dataset_no] * x.size(0)).long().cuda()).mean() loss = discriminator_loss * self.args.d_weight self.discriminator_optimizer.zero_grad() loss.backward() if self.args.grad_clip is not None: clip_grad_value_(self.discriminator.parameters(), self.args.grad_clip) self.discriminator_optimizer.step() # Train autoencoder z = self.encoder(x_aug) y = self.decoder(x, z) z_logits = self.discriminator(z) discriminator_loss = -F.cross_entropy( z_logits, torch.tensor( [dataset_no] * x.size(0)).long().cuda()).mean() reconstruction_loss = cross_entropy_loss(y, x) self.reconstruction_loss[dataset_no].add( reconstruction_loss.data.cpu().numpy().mean()) loss = (reconstruction_loss.mean() + self.args.d_weight * discriminator_loss) self.model_optimizer.zero_grad() loss.backward() if self.args.grad_clip is not None: clip_grad_value_(self.encoder.parameters(), self.args.grad_clip) clip_grad_value_(self.decoder.parameters(), self.args.grad_clip) self.model_optimizer.step() self.loss_total.add(loss.data.item()) train_enum.set_description( f'Train (loss: {loss.data.item():.2f}) epoch {epoch}') train_enum.update() def validate_epoch(self, epoch): #modules self.encoder.eval() self.decoder.eval() self.discriminator.eval() #losses for lm in self.reconstruction_val: lm.reset() self.discriminator_val.reset() self.total_val.reset() total_batches = self.args.epoch_length // self.args.batch_size // 10 with tqdm(total=total_batches) as valid_enum, torch.no_grad(): for batch_num in range(total_batches): if self.args.world_size > 1: dataset_no = self.args.rank else: dataset_no = batch_num % self.args.n_datasets x, x_aug = next(self.data[dataset_no].valid_iter) x = x.to(self.device) x_aug = x.to(self.device) x, x_aug = x.float(), x_aug.float() z = self.encoder(x) y = self.decoder(x, z) z_logits = self.discriminator(z) z_classification = torch.max(z_logits, dim=1)[1] z_accuracy = (z_classification == dataset_no).float().mean() self.discriminator_val.add(z_accuracy.data.item()) # discriminator_right = F.cross_entropy(z_logits, dset_num).mean() discriminator_right = F.cross_entropy( z_logits, torch.tensor([dataset_no] * x.size(0)).long().cuda()).mean() recon_loss = cross_entropy_loss(y, x) self.evals_recon[dataset_no].add( recon_loss.data.cpu().numpy().mean()) total_loss = discriminator_right.data.item( ) * self.args.d_lambda + recon_loss.mean().data.item() self.total_val.add(total_loss) valid_enum.set_description( f'Test (loss: {total_loss:.2f}) epoch {epoch}') valid_enum.update() def train(self): best_loss = float('inf') for epoch in range(self.start_epoch, self.args.epochs): self.logger.info( f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}' ) self.train_epoch(epoch) self.validate_epoch(epoch) train_losses = [self.reconstruction_loss, self.discriminator_loss] val_losses = [self.reconstruction_val, self.discriminator_val] self.logger.info( f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Validation loss (%s)', epoch, train_losses, val_losses) mean_loss = self.val_total.epoch_mean() if mean_loss < best_loss: self.save_model(f'bestmodel_{self.args.rank}.pth') best_loss = mean_loss if self.args.save_model: self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth') else: self.save_model(f'lastmodel_{self.args.rank}.pth') # if self.args.rank: # torch.save([self.args, epoch], '%s/args.pth' % self.expPath) self.lr_manager.step() self.logger.debug('Ended epoch') def save_model(self, filename): save_to = self.expPath / filename torch.save( { 'encoder_state': self.encoder.module.state_dict(), 'decoder_state': self.decoder.module.state_dict(), 'discriminator_state': self.discriminator.module.state_dict(), 'autoenc_optimizer_state': self.autoenc_optimizer.state_dict(), 'd_optimizer_state': self.discriminator_optimizer.state_dict(), 'dataset': self.args.rank, }, save_to) self.logger.debug(f'Saved model to {save_to}')
batch_size = 32 i_max_training = int(data.shape[0] * 0.7) i_start_max_training = i_max_training - LEN_INPUT - 1 # i_start_max_training = i_max_training - LEN_INPUT - OFFSET_TARGET - LEN_TARGET SEED = None rng = np.random.RandomState(SEED) N_STEP = 10000 for i_step in range(N_STEP): inputs_start = rng.randint(i_start_max_training, size=batch_size) batch_xs, batch_ys = batcher(inputs_start) step(batch_xs, batch_ys, i_step) encoder.eval() sign_l, mag_l = forward(Variable(batch_xs)) sign_h = F.sigmoid(sign_l).data.numpy() * 2 - 1 mag_h = mag_l.abs().data.numpy() sign_t = batch_ys.gt(0.).float().numpy() * 2 - 1 mag_t = batch_ys.abs().numpy() # _next_xh = forward(Variable(batch_xs)) next_xh = _next_xh.data.numpy()[0][0] next_xt = batch_ys.numpy()[0][0] p_resid_hat = var_to_numpy(_next_xh.percent_resid)[0][0] resid_true = np.concatenate([[0], np.diff(next_xt)])
class Finetuner: def __init__(self, args): self.args = args self.args.n_datasets = len(args.data) self.modelPath = Path('checkpoints') / args.expName self.logger = create_output_dir(args, self.modelPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_total = LossMeter('eval total') self.start_epoch = 0 #torch.manual_seed(args.seed) #torch.cuda.manual_seed(args.seed) #get the pretrained model checkpoints checkpoint = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth') checkpoint = [c for c in checkpoint if extract_id(c) in args.decoder][0] model_args = torch.load(args.checkpoint.parent / 'args.pth')[0] self.encoder = Encoder(model_args) self.decoder = WaveNet(model_args) self.encoder = Encoder(model_args) self.encoder.load_state_dict(torch.load(checkpoint)['encoder_state']) #encoder freeze for param in self.encoder.parameters(): param.requires_grad = False #self.logger.debug(f'encoder at start: {param}') self.decoder = WaveNet(model_args) self.decoder.load_state_dict(torch.load(checkpoint)['decoder_state']) #decoder freeze for param in self.decoder.layers[:-args.decoder_update].parameters(): param.requires_grad = False #self.logger.debug(f'decoder at start: {param}') self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(chain(self.encoder.parameters(), self.decoder.parameters()), lr=args.lr) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.step() def train_batch(self, x, x_aug, dset_num): 'train batch without considering the discriminator' x = x.float() x_aug = x_aug.float() z = self.encoder(x_aug) y = self.decoder(x, z) recon_loss = cross_entropy_loss(y, x) self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) loss = recon_loss.mean() self.model_optimizer.zero_grad() loss.backward() self.model_optimizer.step() self.loss_total.add(loss.data.item()) return loss.data.item() def train_epoch(self, epoch): for meter in self.losses_recon: meter.reset() self.loss_total.reset() self.decoder.train() n_batches = self.args.epoch_len with tqdm(total=n_batches, desc='Train epoch %d' % epoch) as train_enum: for batch_num in range(n_batches): if self.args.short and batch_num == 3: break if self.args.distributed: assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset" # dset_num = (batch_num + self.args.rank) % self.args.n_datasets dset_num = self.args.rank else: dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].train_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.train_batch(x, x_aug, dset_num) train_enum.set_description( f'Train (loss: {batch_loss:.2f}) epoch {epoch}') train_enum.update() def eval_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() z = self.encoder(x) y = self.decoder(x, z) recon_loss = cross_entropy_loss(y, x) self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) total_loss = recon_loss.mean().data.item() self.eval_total.add(total_loss) return total_loss def evaluate_epoch(self, epoch): for meter in self.evals_recon: meter.reset() self.eval_total.reset() self.encoder.eval() self.decoder.eval() n_batches = int(np.ceil(self.args.epoch_len / 10)) with tqdm(total=n_batches) as valid_enum, torch.no_grad(): for batch_num in range(n_batches): if self.args.short and batch_num == 10: break if self.args.distributed: assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset" dset_num = self.args.rank else: dset_num = batch_num % self.args.n_datasets x, x_aug = next(self.data[dset_num].valid_iter) x = wrap(x) x_aug = wrap(x_aug) batch_loss = self.eval_batch(x, x_aug, dset_num) valid_enum.set_description( f'Test (loss: {batch_loss:.2f}) epoch {epoch}') valid_enum.update() @staticmethod def format_losses(meters): losses = [meter.summarize_epoch() for meter in meters] return ', '.join('{:.4f}'.format(x) for x in losses) def train_losses(self): meters = [*self.losses_recon] return self.format_losses(meters) def eval_losses(self): meters = [*self.evals_recon] return self.format_losses(meters) def finetune(self): best_eval = float('inf') for epoch in range(self.start_epoch, self.start_epoch + self.args.epochs): self.logger.info( f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}' ) self.train_epoch(epoch) self.evaluate_epoch(epoch) self.logger.info( f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)', epoch, self.train_losses(), self.eval_losses()) self.lr_manager.step() val_loss = self.eval_total.summarize_epoch() if val_loss < best_eval: self.save_model(f'bestmodel_{self.args.rank}.pth') best_eval = val_loss if not self.args.per_epoch: self.save_model(f'lastmodel_{self.args.rank}.pth') else: self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth') if self.args.is_master: torch.save([self.args, epoch], '%s/args.pth' % self.modelPath) self.logger.debug('Ended epoch') def save_model(self, filename): save_path = self.modelPath / filename torch.save( { 'encoder_state': self.encoder.module.state_dict(), 'decoder_state': self.decoder.module.state_dict(), 'model_optimizer_state': self.model_optimizer.state_dict(), 'dataset': self.args.rank, }, save_path) self.logger.debug(f'Saved model to {save_path}')