def maybe_load_checkpoint(model: Vicl, model_optimizer: LocalSgd, moptim_scheduler: ExponentialLR, task: int, models_dir: str): epoch = 0 halo = Halo(text='Trying to load a checkpoint', spinner='dots').start() load_name = f'vicl-task-{task}-cp.pt' load_path = os.path.join(models_dir, load_name) try: checkpoint = torch.load(load_path, map_location=model.device()) except Exception as e: halo.fail(f'No checkpoints found for this run: {e}') else: model.load_state(checkpoint['model']) model_optimizer.load_state_dict(checkpoint['model_optimizer']) moptim_scheduler.load_state_dict(checkpoint['moptim_scheduler']) epoch = checkpoint['epoch'] halo.succeed(f'Found a checkpoint (epoch: {epoch})') return epoch
def main(rank, args): # Distributed setup if args.distributed: setup_distributed(rank, args.world_size) not_main_rank = args.distributed and rank != 0 logging.info("Start time: %s", datetime.now()) # Explicitly set seed to make sure models created in separate processes # start from same random weights and biases torch.manual_seed(args.seed) # Empty CUDA cache torch.cuda.empty_cache() # Change backend for flac files torchaudio.set_audio_backend("soundfile") # Transforms melkwargs = { "n_fft": args.win_length, "n_mels": args.n_bins, "hop_length": args.hop_length, } sample_rate_original = 16000 if args.type == "mfcc": transforms = torch.nn.Sequential( torchaudio.transforms.MFCC( sample_rate=sample_rate_original, n_mfcc=args.n_bins, melkwargs=melkwargs, ), ) num_features = args.n_bins elif args.type == "waveform": transforms = torch.nn.Sequential(UnsqueezeFirst()) num_features = 1 else: raise ValueError("Model type not supported") if args.normalize: transforms = torch.nn.Sequential(transforms, Normalize()) augmentations = torch.nn.Sequential() if args.freq_mask: augmentations = torch.nn.Sequential( augmentations, torchaudio.transforms.FrequencyMasking( freq_mask_param=args.freq_mask), ) if args.time_mask: augmentations = torch.nn.Sequential( augmentations, torchaudio.transforms.TimeMasking(time_mask_param=args.time_mask), ) # Text preprocessing char_blank = "*" char_space = " " char_apostrophe = "'" labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase language_model = LanguageModel(labels, char_blank, char_space) # Dataset training, validation = split_process_librispeech( [args.dataset_train, args.dataset_valid], [transforms, transforms], language_model, root=args.dataset_root, folder_in_archive=args.dataset_folder_in_archive, ) # Decoder if args.decoder == "greedy": decoder = GreedyDecoder() else: raise ValueError("Selected decoder not supported") # Model model = Wav2Letter( num_classes=language_model.length, input_type=args.type, num_features=num_features, ) if args.jit: model = torch.jit.script(model) if args.distributed: n = torch.cuda.device_count() // args.world_size devices = list(range(rank * n, (rank + 1) * n)) model = model.to(devices[0]) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices) else: devices = ["cuda" if torch.cuda.is_available() else "cpu"] model = model.to(devices[0], non_blocking=True) model = torch.nn.DataParallel(model) n = count_parameters(model) logging.info("Number of parameters: %s", n) # Optimizer if args.optimizer == "adadelta": optimizer = Adadelta( model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, eps=args.eps, rho=args.rho, ) elif args.optimizer == "sgd": optimizer = SGD( model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, ) elif args.optimizer == "adam": optimizer = Adam( model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, ) elif args.optimizer == "adamw": optimizer = AdamW( model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, ) else: raise ValueError("Selected optimizer not supported") if args.scheduler == "exponential": scheduler = ExponentialLR(optimizer, gamma=args.gamma) elif args.scheduler == "reduceonplateau": scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3) else: raise ValueError("Selected scheduler not supported") criterion = torch.nn.CTCLoss(blank=language_model.mapping[char_blank], zero_infinity=False) # Data Loader collate_fn_train = collate_factory(model_length_function, augmentations) collate_fn_valid = collate_factory(model_length_function) loader_training_params = { "num_workers": args.workers, "pin_memory": True, "shuffle": True, "drop_last": True, } loader_validation_params = loader_training_params.copy() loader_validation_params["shuffle"] = False loader_training = DataLoader( training, batch_size=args.batch_size, collate_fn=collate_fn_train, **loader_training_params, ) loader_validation = DataLoader( validation, batch_size=args.batch_size, collate_fn=collate_fn_valid, **loader_validation_params, ) # Setup checkpoint best_loss = 1.0 load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint) if args.distributed: torch.distributed.barrier() if load_checkpoint: logging.info("Checkpoint: loading %s", args.checkpoint) checkpoint = torch.load(args.checkpoint) args.start_epoch = checkpoint["epoch"] best_loss = checkpoint["best_loss"] model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) logging.info("Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"]) else: logging.info("Checkpoint: not found") save_checkpoint( { "epoch": args.start_epoch, "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), }, False, args.checkpoint, not_main_rank, ) if args.distributed: torch.distributed.barrier() torch.autograd.set_detect_anomaly(False) for epoch in range(args.start_epoch, args.epochs): logging.info("Epoch: %s", epoch) train_one_epoch( model, criterion, optimizer, scheduler, loader_training, decoder, language_model, devices[0], epoch, args.clip_grad, not_main_rank, not args.reduce_lr_valid, ) loss = evaluate( model, criterion, loader_validation, decoder, language_model, devices[0], epoch, not_main_rank, ) if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss) is_best = loss < best_loss best_loss = min(loss, best_loss) save_checkpoint( { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_loss": best_loss, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), }, is_best, args.checkpoint, not_main_rank, ) logging.info("End time: %s", datetime.now()) if args.distributed: torch.distributed.destroy_process_group()
# init? optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay_lr) schedular = ExponentialLR(optimizer, gamma=args.decay_lr) # Main training loop best_loss = np.inf # Resume training if args.load_model is not None: if os.path.isfile(args.load_model): checkpoint = torch.load(args.load_model) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) schedular.load_state_dict(checkpoint['schedular']) best_loss = checkpoint['val_loss'] epoch = checkpoint['epoch'] print('Loading model: {}. Resuming from epoch: {}'.format(args.load_model, epoch)) else: print('Model: {} not found'.format(args.load_model)) for epoch in range(args.epochs): v_loss = execute_graph(model, loader, optimizer, schedular, epoch, use_cuda) if v_loss < best_loss: best_loss = v_loss print('Writing model checkpoint') state = { 'epoch': epoch, 'model': model.state_dict(),
def main() -> None: """Entrypoint. """ config: Any = importlib.import_module(args.config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_data = tx.data.MonoTextData(config.train_data_hparams, device=device) val_data = tx.data.MonoTextData(config.val_data_hparams, device=device) test_data = tx.data.MonoTextData(config.test_data_hparams, device=device) iterator = tx.data.DataIterator({ "train": train_data, "valid": val_data, "test": test_data }) opt_vars = { 'learning_rate': config.lr_decay_hparams["init_lr"], 'best_valid_nll': 1e100, 'steps_not_improved': 0, 'kl_weight': config.kl_anneal_hparams["start"] } decay_cnt = 0 max_decay = config.lr_decay_hparams["max_decay"] decay_factor = config.lr_decay_hparams["decay_factor"] decay_ts = config.lr_decay_hparams["threshold"] save_dir = f"./models/{config.dataset}" if not os.path.exists(save_dir): os.makedirs(save_dir) suffix = f"{config.dataset}_{config.decoder_type}Decoder.ckpt" save_path = os.path.join(save_dir, suffix) # KL term annealing rate anneal_r = 1.0 / (config.kl_anneal_hparams["warm_up"] * (len(train_data) / config.batch_size)) vocab = train_data.vocab model = VAE(train_data.vocab.size, config) model.to(device) start_tokens = torch.full((config.batch_size, ), vocab.bos_token_id, dtype=torch.long).to(device) end_token = vocab.eos_token_id optimizer = tx.core.get_optimizer(params=model.parameters(), hparams=config.opt_hparams) scheduler = ExponentialLR(optimizer, decay_factor) def _run_epoch(epoch: int, mode: str, display: int = 10) \ -> Tuple[Tensor, float]: iterator.switch_to_dataset(mode) if mode == 'train': model.train() opt_vars["kl_weight"] = min(1.0, opt_vars["kl_weight"] + anneal_r) kl_weight = opt_vars["kl_weight"] else: model.eval() kl_weight = 1.0 step = 0 start_time = time.time() num_words = 0 nll_total = 0. avg_rec = tx.utils.AverageRecorder() for batch in iterator: ret = model(batch, kl_weight, start_tokens, end_token) if mode == "train": opt_vars["kl_weight"] = min(1.0, opt_vars["kl_weight"] + anneal_r) kl_weight = opt_vars["kl_weight"] ret["nll"].backward() optimizer.step() optimizer.zero_grad() batch_size = len(ret["lengths"]) num_words += torch.sum(ret["lengths"]).item() nll_total += ret["nll"].item() * batch_size avg_rec.add([ ret["nll"].item(), ret["kl_loss"].item(), ret["rc_loss"].item() ], batch_size) if step % display == 0 and mode == 'train': nll = avg_rec.avg(0) klw = opt_vars["kl_weight"] KL = avg_rec.avg(1) rc = avg_rec.avg(2) log_ppl = nll_total / num_words ppl = math.exp(log_ppl) time_cost = time.time() - start_time print( f"{mode}: epoch {epoch}, step {step}, nll {nll:.4f}, " f"klw {klw:.4f}, KL {KL:.4f}, rc {rc:.4f}, " f"log_ppl {log_ppl:.4f}, ppl {ppl:.4f}, " f"time_cost {time_cost:.1f}", flush=True) step += 1 nll = avg_rec.avg(0) KL = avg_rec.avg(1) rc = avg_rec.avg(2) log_ppl = nll_total / num_words ppl = math.exp(log_ppl) print(f"\n{mode}: epoch {epoch}, nll {nll:.4f}, KL {KL:.4f}, " f"rc {rc:.4f}, log_ppl {log_ppl:.4f}, ppl {ppl:.4f}") return nll, ppl # type: ignore @torch.no_grad() def _generate(start_tokens: torch.LongTensor, end_token: int, filename: Optional[str] = None): ckpt = torch.load(args.model) model.load_state_dict(ckpt['model']) model.eval() batch_size = train_data.batch_size dst = MultivariateNormalDiag(loc=torch.zeros(batch_size, config.latent_dims), scale_diag=torch.ones( batch_size, config.latent_dims)) latent_z = dst.rsample().to(device) helper = model.decoder.create_helper(decoding_strategy='infer_sample', start_tokens=start_tokens, end_token=end_token) outputs = model.decode(helper=helper, latent_z=latent_z, max_decoding_length=100) sample_tokens = vocab.map_ids_to_tokens_py(outputs.sample_id.cpu()) if filename is None: fh = sys.stdout else: fh = open(filename, 'w', encoding='utf-8') for sent in sample_tokens: sent = tx.utils.compat_as_text(list(sent)) end_id = len(sent) if vocab.eos_token in sent: end_id = sent.index(vocab.eos_token) fh.write(' '.join(sent[:end_id + 1]) + '\n') print('Output done') fh.close() if args.mode == "predict": _generate(start_tokens, end_token, args.out) return # Counts trainable parameters total_parameters = sum(param.numel() for param in model.parameters()) print(f"{total_parameters} total parameters") best_nll = best_ppl = 0. for epoch in range(config.num_epochs): _, _ = _run_epoch(epoch, 'train', display=200) val_nll, _ = _run_epoch(epoch, 'valid') test_nll, test_ppl = _run_epoch(epoch, 'test') if val_nll < opt_vars['best_valid_nll']: opt_vars['best_valid_nll'] = val_nll opt_vars['steps_not_improved'] = 0 best_nll = test_nll best_ppl = test_ppl states = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() } torch.save(states, save_path) else: opt_vars['steps_not_improved'] += 1 if opt_vars['steps_not_improved'] == decay_ts: old_lr = opt_vars['learning_rate'] opt_vars['learning_rate'] *= decay_factor opt_vars['steps_not_improved'] = 0 new_lr = opt_vars['learning_rate'] ckpt = torch.load(save_path) model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) scheduler.load_state_dict(ckpt['scheduler']) scheduler.step() print(f"-----\nchange lr, old lr: {old_lr}, " f"new lr: {new_lr}\n-----") decay_cnt += 1 if decay_cnt == max_decay: break print(f"\nbest testing nll: {best_nll:.4f}," f"best testing ppl {best_ppl:.4f}\n")
class Parser(object): NAME = None MODEL = None def __init__(self, args, model, transform): self.args = args self.model = model self.transform = transform def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1, clip=5.0, epochs=5000, patience=100, **kwargs): args = self.args.update(locals()) init_logger(logger, verbose=args.verbose) self.transform.train() batch_size = batch_size // update_steps if dist.is_initialized(): batch_size = batch_size // dist.get_world_size() logger.info("Loading the data") train = Dataset(self.transform, args.train, **args).build(batch_size, buckets, True, dist.is_initialized()) dev = Dataset(self.transform, args.dev).build(batch_size, buckets) test = Dataset(self.transform, args.test).build(batch_size, buckets) logger.info( f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n") if args.encoder == 'lstm': self.optimizer = Adam(self.model.parameters(), args.lr, (args.mu, args.nu), args.eps, args.weight_decay) self.scheduler = ExponentialLR(self.optimizer, args.decay**(1 / args.decay_steps)) else: from transformers import AdamW, get_linear_schedule_with_warmup steps = len(train.loader) * epochs // args.update_steps self.optimizer = AdamW([{ 'params': p, 'lr': args.lr * (1 if n.startswith('encoder') else args.lr_rate) } for n, p in self.model.named_parameters()], args.lr) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, int(steps * args.warmup), steps) if dist.is_initialized(): self.model = DDP(self.model, device_ids=[args.local_rank], find_unused_parameters=True) self.epoch, self.best_e, self.patience, self.best_metric, self.elapsed = 1, 1, patience, Metric( ), timedelta() if self.args.checkpoint: self.optimizer.load_state_dict( self.checkpoint_state_dict.pop('optimizer_state_dict')) self.scheduler.load_state_dict( self.checkpoint_state_dict.pop('scheduler_state_dict')) set_rng_state(self.checkpoint_state_dict.pop('rng_state')) for k, v in self.checkpoint_state_dict.items(): setattr(self, k, v) train.loader.batch_sampler.epoch = self.epoch for epoch in range(self.epoch, args.epochs + 1): start = datetime.now() logger.info(f"Epoch {epoch} / {args.epochs}:") self._train(train.loader) loss, dev_metric = self._evaluate(dev.loader) logger.info(f"{'dev:':5} loss: {loss:.4f} - {dev_metric}") loss, test_metric = self._evaluate(test.loader) logger.info(f"{'test:':5} loss: {loss:.4f} - {test_metric}") t = datetime.now() - start self.epoch += 1 self.patience -= 1 self.elapsed += t if dev_metric > self.best_metric: self.best_e, self.patience, self.best_metric = epoch, patience, dev_metric if is_master(): self.save_checkpoint(args.path) logger.info(f"{t}s elapsed (saved)\n") else: logger.info(f"{t}s elapsed\n") if self.patience < 1: break parser = self.load(**args) loss, metric = parser._evaluate(test.loader) parser.save(args.path) logger.info(f"Epoch {self.best_e} saved") logger.info(f"{'dev:':5} {self.best_metric}") logger.info(f"{'test:':5} {metric}") logger.info(f"{self.elapsed}s elapsed, {self.elapsed / epoch}s/epoch") def evaluate(self, data, buckets=8, batch_size=5000, **kwargs): args = self.args.update(locals()) init_logger(logger, verbose=args.verbose) self.transform.train() logger.info("Loading the data") dataset = Dataset(self.transform, data) dataset.build(batch_size, buckets) logger.info(f"\n{dataset}") logger.info("Evaluating the dataset") start = datetime.now() loss, metric = self._evaluate(dataset.loader) elapsed = datetime.now() - start logger.info(f"loss: {loss:.4f} - {metric}") logger.info( f"{elapsed}s elapsed, {len(dataset)/elapsed.total_seconds():.2f} Sents/s" ) return loss, metric def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=False, **kwargs): args = self.args.update(locals()) init_logger(logger, verbose=args.verbose) self.transform.eval() if args.prob: self.transform.append(Field('probs')) logger.info("Loading the data") dataset = Dataset(self.transform, data, lang=lang) dataset.build(batch_size, buckets) logger.info(f"\n{dataset}") logger.info("Making predictions on the dataset") start = datetime.now() preds = self._predict(dataset.loader) elapsed = datetime.now() - start for name, value in preds.items(): setattr(dataset, name, value) if pred is not None and is_master(): logger.info(f"Saving predicted results to {pred}") self.transform.save(pred, dataset.sentences) logger.info( f"{elapsed}s elapsed, {len(dataset) / elapsed.total_seconds():.2f} Sents/s" ) return dataset def _train(self, loader): raise NotImplementedError @torch.no_grad() def _evaluate(self, loader): raise NotImplementedError @torch.no_grad() def _predict(self, loader): raise NotImplementedError @classmethod def build(cls, path, **kwargs): raise NotImplementedError @classmethod def load(cls, path, reload=False, src='github', checkpoint=False, **kwargs): r""" Loads a parser with data fields and pretrained model parameters. Args: path (str): - a string with the shortcut name of a pretrained model defined in ``supar.MODEL`` to load from cache or download, e.g., ``'biaffine-dep-en'``. - a local path to a pretrained model, e.g., ``./<path>/model``. reload (bool): Whether to discard the existing cache and force a fresh download. Default: ``False``. src (str): Specifies where to download the model. ``'github'``: github release page. ``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8). Default: ``'github'``. checkpoint (bool): If ``True``, loads all checkpoint states to restore the training process. Default: ``False``. kwargs (dict): A dict holding unconsumed arguments for updating training configs and initializing the model. Examples: >>> from supar import Parser >>> parser = Parser.load('biaffine-dep-en') >>> parser = Parser.load('./ptb.biaffine.dep.lstm.char') """ args = Config(**locals()) args.device = 'cuda' if torch.cuda.is_available() else 'cpu' state = torch.load(path if os.path.exists(path) else download( supar.MODEL[src].get(path, path), reload=reload)) cls = supar.PARSER[state['name']] if cls.NAME is None else cls args = state['args'].update(args) model = cls.MODEL(**args) model.load_pretrained(state['pretrained']) model.load_state_dict(state['state_dict'], False) model.to(args.device) transform = state['transform'] parser = cls(args, model, transform) parser.checkpoint_state_dict = state[ 'checkpoint_state_dict'] if args.checkpoint else None return parser def save(self, path): model = self.model if hasattr(model, 'module'): model = self.model.module args = model.args state_dict = {k: v.cpu() for k, v in model.state_dict().items()} pretrained = state_dict.pop('pretrained.weight', None) state = { 'name': self.NAME, 'args': args, 'state_dict': state_dict, 'pretrained': pretrained, 'transform': self.transform } torch.save(state, path, pickle_module=dill) def save_checkpoint(self, path): model = self.model if hasattr(model, 'module'): model = self.model.module args = model.args checkpoint_state_dict = { k: getattr(self, k) for k in ['epoch', 'best_e', 'patience', 'best_metric', 'elapsed'] } checkpoint_state_dict.update({ 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'rng_state': get_rng_state() }) state_dict = {k: v.cpu() for k, v in model.state_dict().items()} pretrained = state_dict.pop('pretrained.weight', None) state = { 'name': self.NAME, 'args': args, 'state_dict': state_dict, 'pretrained': pretrained, 'checkpoint_state_dict': checkpoint_state_dict, 'transform': self.transform } torch.save(state, path, pickle_module=dill)
def main(): """Entrypoint. """ config: Any = importlib.import_module(args.config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # train_data = tx.data.MonoTextData(config.train_data_hparams, device=device) # val_data = tx.data.MonoTextData(config.val_data_hparams, device=device) # test_data = tx.data.MonoTextData(config.test_data_hparams, device=device) train_data = tx.data.MonoTextData(config.train_data_hparams, device=torch.device("cpu")) val_data = tx.data.MonoTextData(config.val_data_hparams, device=torch.device("cpu")) test_data = tx.data.MonoTextData(config.test_data_hparams, device=torch.device("cpu")) iterator = tx.data.DataIterator({ "train": train_data, "valid": val_data, "test": test_data }) opt_vars = { 'learning_rate': config.lr_decay_hparams["init_lr"], 'best_valid_nll': 1e100, 'steps_not_improved': 0, 'kl_weight': config.kl_anneal_hparams["start"] } decay_cnt = 0 max_decay = config.lr_decay_hparams["max_decay"] decay_factor = config.lr_decay_hparams["decay_factor"] decay_ts = config.lr_decay_hparams["threshold"] if 'pid' in args.model_name: save_dir = args.model_name + '_' + str(config.dataset) + '_KL' + str( args.exp_kl) elif 'cost' in args.model_name: save_dir = args.model_name + '_' + str(config.dataset) + '_step' + str( args.anneal_steps) elif 'cyclical' in args.model_name: save_dir = args.model_name + '_' + str(config.dataset) + '_cyc_' + str( args.cycle) if not os.path.exists(save_dir): os.makedirs(save_dir) suffix = f"{config.dataset}_{config.decoder_type}Decoder.ckpt" save_path = os.path.join(save_dir, suffix) # KL term annealing rate warm_up=10 ## replace it with sigmoid function anneal_r = 1.0 / (config.kl_anneal_hparams["warm_up"] * (len(train_data) / config.batch_size)) vocab = train_data.vocab model = VAE(train_data.vocab.size, config) model.to(device) start_tokens = torch.full((config.batch_size, ), vocab.bos_token_id, dtype=torch.long).to(device) end_token = vocab.eos_token_id optimizer = tx.core.get_optimizer(params=model.parameters(), hparams=config.opt_hparams) scheduler = ExponentialLR(optimizer, decay_factor) ## max iteration max_iter = config.num_epochs * len(train_data) / config.batch_size max_iter = min(max_iter, args.max_steps) print('max steps:', max_iter) pbar = tqdm(total=int(max_iter)) if args.mode == "train": outFile = os.path.join(save_dir, 'train.log') fw_log = open(outFile, "w") global_steps = {} global_steps['step'] = 0 pid = PIDControl() opt_vars["kl_weight"] = 0.0 Kp = args.Kp Ki = args.Ki exp_kl = args.exp_kl ## train model def _run_epoch(epoch: int, mode: str, display: int = 10) \ -> Tuple[Tensor, float]: iterator.switch_to_dataset(mode) if mode == 'train': model.train() kl_weight = opt_vars["kl_weight"] else: model.eval() kl_weight = 1.0 # kl_weight = opt_vars["kl_weight"] start_time = time.time() num_words = 0 nll_total = 0. avg_rec = tx.utils.AverageRecorder() for batch in iterator: ## run model to get loss function if global_steps['step'] >= args.max_steps: break ret = model(batch, kl_weight, start_tokens, end_token) if mode == "train": pbar.update(1) global_steps['step'] += 1 kl_loss = ret['kl_loss'].item() rec_loss = ret['rc_loss'].item() total_loss = ret["nll"].item() if 'cost' in args.model_name: kl_weight = _cost_annealing(global_steps['step'], 1.0, args.anneal_steps) elif 'pid' in args.model_name: kl_weight = pid.pid(exp_kl, kl_loss, Kp, Ki) elif 'cyclical' in args.model_name: kl_weight = _cyclical_annealing(global_steps['step'], max_iter / args.cycle) opt_vars["kl_weight"] = kl_weight ## total loss ret["nll"].backward() optimizer.step() optimizer.zero_grad() fw_log.write('epoch:{0} global_step:{1} total_loss:{2:.3f} kl_loss:{3:.3f} rec_loss:{4:.3f} kl_weight:{5:.4f}\n'\ .format(epoch, global_steps['step'], total_loss, kl_loss, rec_loss, kl_weight)) fw_log.flush() batch_size = len(ret["lengths"]) num_words += torch.sum(ret["lengths"]).item() nll_total += ret["nll"].item() * batch_size avg_rec.add([ ret["nll"].item(), ret["kl_loss"].item(), ret["rc_loss"].item() ], batch_size) if global_steps['step'] % display == 1 and mode == 'train': nll = avg_rec.avg(0) klw = opt_vars["kl_weight"] KL = avg_rec.avg(1) rc = avg_rec.avg(2) writer.add_scalar(f'Loss/Rec_loss_{args.model_name}', rc, global_steps['step']) writer.add_scalar(f'Loss/KL_diverg_{args.model_name}', KL, global_steps['step']) writer.add_scalar(f'Loss/KL_weight_{args.model_name}', klw, global_steps['step']) nll = avg_rec.avg(0) KL = avg_rec.avg(1) rc = avg_rec.avg(2) if num_words > 0: log_ppl = nll_total / num_words ppl = math.exp(log_ppl) else: log_ppl = 100 ppl = math.exp(log_ppl) nll = 1000 KL = args.exp_kl print(f"\n{mode}: epoch {epoch}, nll {nll:.4f}, KL {KL:.4f}, " f"rc {rc:.4f}, log_ppl {log_ppl:.4f}, ppl {ppl:.4f}") return nll, ppl # type: ignore args.model = save_path @torch.no_grad() def _generate(start_tokens: torch.LongTensor, end_token: int, filename: Optional[str] = None): ckpt = torch.load(args.model) model.load_state_dict(ckpt['model']) model.eval() batch_size = train_data.batch_size dst = MultivariateNormalDiag(loc=torch.zeros(batch_size, config.latent_dims), scale_diag=torch.ones( batch_size, config.latent_dims)) # latent_z = dst.rsample().to(device) latent_z = torch.FloatTensor(batch_size, config.latent_dims).uniform_(-1, 1).to(device) # latent_z = torch.randn(batch_size, config.latent_dims).to(device) helper = model.decoder.create_helper(decoding_strategy='infer_sample', start_tokens=start_tokens, end_token=end_token) outputs = model.decode(helper=helper, latent_z=latent_z, max_decoding_length=100) if config.decoder_type == "transformer": outputs = outputs[0] sample_tokens = vocab.map_ids_to_tokens_py(outputs.sample_id.cpu()) if filename is None: fh = sys.stdout else: fh = open(filename, 'a', encoding='utf-8') for sent in sample_tokens: sent = tx.utils.compat_as_text(list(sent)) end_id = len(sent) if vocab.eos_token in sent: end_id = sent.index(vocab.eos_token) fh.write(' '.join(sent[:end_id + 1]) + '\n') print('Output done') fh.close() if args.mode == "predict": out_path = os.path.join(save_dir, 'results.txt') for _ in range(10): _generate(start_tokens, end_token, out_path) return # Counts trainable parameters total_parameters = sum(param.numel() for param in model.parameters()) print(f"{total_parameters} total parameters") best_nll = best_ppl = 0. ## start running model for epoch in range(config.num_epochs): _, _ = _run_epoch(epoch, 'train', display=200) val_nll, _ = _run_epoch(epoch, 'valid') test_nll, test_ppl = _run_epoch(epoch, 'test') if val_nll < opt_vars['best_valid_nll']: opt_vars['best_valid_nll'] = val_nll opt_vars['steps_not_improved'] = 0 best_nll = test_nll best_ppl = test_ppl states = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() } torch.save(states, save_path) else: opt_vars['steps_not_improved'] += 1 if opt_vars['steps_not_improved'] == decay_ts: old_lr = opt_vars['learning_rate'] opt_vars['learning_rate'] *= decay_factor opt_vars['steps_not_improved'] = 0 new_lr = opt_vars['learning_rate'] ckpt = torch.load(save_path) model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) scheduler.load_state_dict(ckpt['scheduler']) scheduler.step() print(f"-----\nchange lr, old lr: {old_lr}, " f"new lr: {new_lr}\n-----") decay_cnt += 1 if decay_cnt == max_decay: break if global_steps['step'] >= args.max_steps: break print(f"\nbest testing nll: {best_nll:.4f}," f"best testing ppl {best_ppl:.4f}\n") if args.mode == "train": fw_log.write(f"\nbest testing nll: {best_nll:.4f}," f"best testing ppl {best_ppl:.4f}\n") fw_log.close()
def main_worker(train_loader, val_loader, ntokens, args, device): global best_ppl model_kwargs = { 'dropout': args.dropout, 'tie_weights': not args.not_tied, 'norm': args.norm_mode, 'alpha_fwd': args.afwd, 'alpha_bkw': args.abkw, 'batch_size': args.batch_size, # Deprecated 'ecm': args.ecm, 'cell_norm': args.cell_norm, } # create model print("=> creating model: '{}'".format(args.ru_type)) model = models.RNNModel(args.ru_type, ntokens, args.emsize, args.nhid, args.nlayers, **model_kwargs).to(device) print(model) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD(model.parameters(), args.lr, weight_decay=args.weight_decay) scheduler = ExponentialLR(optimizer, gamma=1 / args.lr_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_ppl = checkpoint['best_ppl'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = False if args.seed else True if args.evaluate: validate(val_loader, model, criterion, device, args, ntokens) return for epoch in range(args.start_epoch, args.epochs): if epoch: scheduler.step() # train for one epoch train(train_loader, model, criterion, optimizer, epoch, device, args, ntokens) # evaluate on validation set ppl = validate(val_loader, model, criterion, device, args, ntokens) # remember best ppl and save checkpoint is_best = ppl < best_ppl best_ppl = min(ppl, best_ppl) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_ppl': best_ppl, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }, is_best, args)