def load_optimizer_by_config( checkpoint: int, config: lmp.config.BaseConfig, model: lmp.model.BaseRNNModel ) -> Union[torch.optim.SGD, torch.optim.Adam, ]: r"""Helper function for constructing optimizer. Load optimizer from pre-trained checkpoint when `checkpoint != -1`. Args: checkpoint: Pre-trained optimizer's checkpoint. config: Configuration object with attributes `experiment`, `learning_rate` and `optimizer_class`. model: Source of model parameters. Returns: Same as `load_optimizer`. """ return load_optimizer(checkpoint=checkpoint, experiment=config.experiment, learning_rate=config.learning_rate, optimizer_class=config.optimizer_class, parameters=model.parameters())
def test_save_checkpoint(self): r"""Save checkpoint at each `checkpoint_step`.""" msg = 'Must save checkpoint at each `checkpoint_step`.' for ( batch_size, checkpoint_step, epoch, max_norm, max_seq_len, (model_cstr, optimizer_cstr, tokenizer), ) in product(*self.__class__.train_parameters.values()): config = lmp.config.BaseConfig( batch_size=batch_size, checkpoint_step=checkpoint_step, dataset=self.__class__.dataset, epoch=epoch, experiment=self.__class__.experiment, max_norm=max_norm, max_seq_len=max_seq_len) dataset = lmp.dataset.BaseDataset([''] * batch_size) model = model_cstr(d_emb=1, d_hid=1, dropout=0.0, num_linear_layers=1, num_rnn_layers=1, pad_token_id=0, vocab_size=tokenizer.vocab_size).to( config.device) optimizer = optimizer_cstr(params=model.parameters(), lr=1e-4) try: # Create test file. lmp.util.train_model_by_config(checkpoint=-1, config=config, dataset=dataset, model=model, optimizer=optimizer, tokenizer=tokenizer) for ckpt in range(0, epoch, checkpoint_step): if ckpt == 0: continue self.assertTrue(os.path.exists( os.path.join(self.__class__.test_dir, f'model-{ckpt}.pt')), msg=msg) self.assertTrue(os.path.exists( os.path.join(self.__class__.test_dir, f'optimizer-{ckpt}.pt')), msg=msg) finally: # Clean up test file. for ckpt in os.listdir(self.__class__.test_dir): os.remove(os.path.join(self.__class__.test_dir, ckpt)) for log in os.listdir(self.__class__.test_log_dir): os.remove(os.path.join(self.__class__.test_log_dir, log))
def test_save_checkpoint(self): r"""Save checkpoint at each `checkpoint_step`.""" msg = 'Must save checkpoint at each `checkpoint_step`.' for (batch_size, checkpoint_step, epoch, max_norm, (model_cstr, optimizer_cstr, tokenizer), vocab_size) in product(*self.__class__.train_parameters.values()): data_loader = torch.utils.data.DataLoader( batch_size=batch_size, dataset=lmp.dataset.LanguageModelDataset([''] * batch_size), collate_fn=lmp.dataset.LanguageModelDataset.create_collate_fn( tokenizer=tokenizer, max_seq_len=-1)) model = model_cstr(d_emb=1, d_hid=1, dropout=0.0, num_linear_layers=1, num_rnn_layers=1, pad_token_id=0, vocab_size=vocab_size) optimizer = optimizer_cstr(params=model.parameters(), lr=1e-4) try: # Create test file. lmp.util.train_model(checkpoint=-1, checkpoint_step=checkpoint_step, data_loader=data_loader, device=torch.device('cpu'), epoch=epoch, experiment=self.__class__.experiment, max_norm=max_norm, model=model, optimizer=optimizer, vocab_size=vocab_size) for ckpt in range(0, epoch, checkpoint_step): if ckpt == 0: continue self.assertTrue(os.path.exists( os.path.join(self.__class__.test_dir, f'model-{ckpt}.pt')), msg=msg) self.assertTrue(os.path.exists( os.path.join(self.__class__.test_dir, f'optimizer-{ckpt}.pt')), msg=msg) finally: # Clean up test file. for ckpt in os.listdir(self.__class__.test_dir): os.remove(os.path.join(self.__class__.test_dir, ckpt)) for log in os.listdir(self.__class__.test_log_dir): os.remove(os.path.join(self.__class__.test_log_dir, log))
def test_log_loss(self): r"""Log loss.""" msg = 'Must log loss.' for ( batch_size, checkpoint_step, epoch, max_norm, max_seq_len, (model_cstr, optimizer_cstr, tokenizer), ) in product(*self.__class__.train_parameters.values()): config = lmp.config.BaseConfig( batch_size=batch_size, checkpoint_step=checkpoint_step, dataset=self.__class__.dataset, epoch=epoch, experiment=self.__class__.experiment, max_norm=max_norm, max_seq_len=max_seq_len) dataset = lmp.dataset.LanguageModelDataset([''] * batch_size) model = model_cstr(d_emb=1, d_hid=1, dropout=0.0, num_linear_layers=1, num_rnn_layers=1, pad_token_id=0, vocab_size=tokenizer.vocab_size).to( config.device) optimizer = optimizer_cstr(params=model.parameters(), lr=1e-4) try: # Create test file. lmp.util.train_model_by_config(checkpoint=-1, config=config, dataset=dataset, model=model, optimizer=optimizer, tokenizer=tokenizer) self.assertGreater(len(os.listdir( self.__class__.test_log_dir)), 0, msg=msg) finally: # Clean up test file. for ckpt in os.listdir(self.__class__.test_dir): os.remove(os.path.join(self.__class__.test_dir, ckpt)) for log in os.listdir(self.__class__.test_log_dir): os.remove(os.path.join(self.__class__.test_log_dir, log))
def test_log_loss(self): r"""Log loss.""" msg = 'Must log loss.' for (batch_size, checkpoint_step, epoch, max_norm, (model_cstr, optimizer_cstr, tokenizer), vocab_size) in product(*self.__class__.train_parameters.values()): data_loader = torch.utils.data.DataLoader( batch_size=batch_size, dataset=lmp.dataset.BaseDataset([''] * batch_size), collate_fn=lmp.dataset.BaseDataset.create_collate_fn( tokenizer=tokenizer, max_seq_len=-1)) model = model_cstr(d_emb=1, d_hid=1, dropout=0.0, num_linear_layers=1, num_rnn_layers=1, pad_token_id=0, vocab_size=vocab_size) optimizer = optimizer_cstr(params=model.parameters(), lr=1e-4) try: # Create test file. lmp.util.train_model(checkpoint=-1, checkpoint_step=checkpoint_step, data_loader=data_loader, device=torch.device('cpu'), epoch=epoch, experiment=self.__class__.experiment, max_norm=max_norm, model=model, optimizer=optimizer, vocab_size=vocab_size) self.assertGreater(len(os.listdir( self.__class__.test_log_dir)), 0, msg=msg) finally: # Clean up test file. for ckpt in os.listdir(self.__class__.test_dir): os.remove(os.path.join(self.__class__.test_dir, ckpt)) for log in os.listdir(self.__class__.test_log_dir): os.remove(os.path.join(self.__class__.test_log_dir, log))
def load_optimizer_by_config( checkpoint: int, config: lmp.config.BaseConfig, model: Union[lmp.model.BaseRNNModel, lmp.model.BaseResRNNModel] ) -> Union[torch.optim.SGD, torch.optim.Adam]: r"""Helper function for constructing optimizer. Load optimizer from pre-trained checkpoint when `checkpoint != -1`. Args: checkpoint: Pre-trained model's checkpoint. Must be bigger than or equal to `-1`. config: Configuration object with attributes `experiment`, `learning_rate` and `optimizer_class`. model: Source of model parameters. Raises: TypeError: When one of the arguments are not an instance of their type annotation respectively. ValueError: When `checkpoint < -1`. Returns: Same as `load_optimizer`. """ # Type check. if not isinstance(config, lmp.config.BaseConfig): raise TypeError( '`config` must be an instance of `lmp.config.BaseConfig`.') if not isinstance(model, (lmp.model.BaseRNNModel, lmp.model.BaseResRNNModel)): raise TypeError( '`model` must be an instance of ' '`Union[lmp.model.BaseRNNModel, lmp.model.BaseResRNNModel]`.') return load_optimizer(checkpoint=checkpoint, experiment=config.experiment, learning_rate=config.learning_rate, optimizer_class=config.optimizer_class, parameters=model.parameters())
def train_model(checkpoint: int, checkpoint_step: int, data_loader: torch.utils.data.DataLoader, device: torch.device, epoch: int, experiment: str, max_norm: float, model: Union[lmp.model.BaseRNNModel, lmp.model.BaseResRNNModel], optimizer: Union[torch.optim.SGD, torch.optim.Adam], vocab_size: int) -> None: r"""Helper function for training language model. Continue training from pre-trained checkpoint when `checkpoint != -1`. Args: checkpoint: Pre-trained model's checkpoint. Must be bigger than or equal to `-1`. checkpoint_step: Checkpoint save interval. Must be bigger than or equal to `1`. data_loader: `torch.utils.data.DataLoader` for sampling. device: Model running device. epoch: Number of training epoch. Must be bigger than or equal to `1`. experiment: Name of the current experiment. Must not be empty. max_norm: Maximum gradient norm. Must be bigger than `0.0`. model: Language model. optimizer: Language model's optimizer. vocab_size: Number of classes to predict. Must be bigger than or equal to `1`. Raises: TypeError: When one of the arguments are not an instance of their type annotation respectively. ValueError: When one of the arguments do not follow their constraints. See docstring for arguments constraints. """ # Type check. if not isinstance(checkpoint, int): raise TypeError('`checkpoint` must be an instance of `int`.') if not isinstance(checkpoint_step, int): raise TypeError('`checkpoint_step` must be an instance of `int`.') if not isinstance(data_loader, torch.utils.data.DataLoader): raise TypeError('`data_loader` must be an instance of ' '`torch.utils.data.DataLoader`.') if not isinstance(device, torch.device): raise TypeError('`device` must be an instance of `torch.device`.') if not isinstance(epoch, int): raise TypeError('`epoch` must be an instance of `int`.') if not isinstance(experiment, str): raise TypeError('`experiment` must be an instance of `str`.') if not isinstance(max_norm, float): raise TypeError('`max_norm` must be an instance of `float`.') if not isinstance(model, (lmp.model.BaseRNNModel, lmp.model.BaseResRNNModel)): raise TypeError( '`model` must be an instance of ' '`Union[lmp.model.BaseRNNModel, lmp.model.BaseResRNNModel]`.') if not isinstance(optimizer, (torch.optim.SGD, torch.optim.Adam)): raise TypeError('`optimizer` must be an instance of ' '`Union[torch.optim.SGD, torch.optim.Adam]`.') if not isinstance(vocab_size, int): raise TypeError('`vocab_size` must be an instance of `int`.') # Value check. if checkpoint < -1: raise ValueError('`checkpoint` must be bigger than or equal to `-1`.') if checkpoint_step < 1: raise ValueError( '`checkpoint_step` must be bigger than or equal to `1`.') if epoch < 1: raise ValueError('`epoch` must be bigger than or equal to `1`.') if not experiment: raise ValueError('`experiment` must not be empty.') if max_norm < 0.0 or math.isnan(max_norm): raise ValueError('`max_norm` must be bigger than `0.0`.') if vocab_size < 1: raise ValueError('`vocab_size` must be bigger than or equal to `1`.') # Set experiment output folder. file_dir = os.path.join(lmp.path.DATA_PATH, experiment) log_dir = os.path.join(lmp.path.DATA_PATH, 'log', experiment) if not os.path.exists(file_dir): os.makedirs(file_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) # Set experiment log folder. writer = torch.utils.tensorboard.SummaryWriter(log_dir) # Define objective function. criterion = torch.nn.CrossEntropyLoss() # Step = number of updates. # Every update must increment `step`. step = 0 # Set model to train mode. model.train() # Clean up gradient in model parameters. model.zero_grad() # Initialize total loss. total_loss = 0.0 for cur_epoch in range(epoch): epoch_iterator = tqdm(data_loader, desc=f'epoch: {cur_epoch}, loss: {0:.6f}') for x, y in epoch_iterator: # Increment step for each update. step += 1 # Continue training from previous checkpoint step. if step < checkpoint: continue # Put tensors on to specified device (CPU or GPU). Reshape `y` into # shape (B x S) for cross-entropy. # x.size = (B, S) # y.size = (B x S) x = x.to(device) y = y.reshape(-1).to(device) # Forward pass. # pred_y_logits.size = (B, S, V) pred_y_logits = model(x) # Reshape `pred_y_logits` into shape (B x S, V) for cross-entropy. pred_y_logits = pred_y_logits.reshape(-1, vocab_size) # Perform cross-entropy. loss = criterion(pred_y_logits, y) # Calculate total loss. total_loss += loss.item() # Log loss. epoch_iterator.set_description( f'epoch: {cur_epoch}, loss: {loss.item():.6f}') # Backward pass. loss.backward() # Perform gradient clipping to avoid gradient explosion. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) # Gradient descent. optimizer.step() # `torch` required manually clean up gradient. optimizer.zero_grad() # Save checkpoint for each `checkpoint_step`. if step % checkpoint_step == 0: torch.save(model.state_dict(), os.path.join(file_dir, f'model-{step}.pt')) torch.save(optimizer.state_dict(), os.path.join(file_dir, f'optimizer-{step}.pt')) # Log average loss. writer.add_scalar('loss', total_loss / checkpoint_step, step) total_loss = 0.0 # Save last checkpoint. torch.save(model.state_dict(), os.path.join(file_dir, f'model-{step}.pt')) torch.save(optimizer.state_dict(), os.path.join(file_dir, f'optimizer-{step}.pt'))
def main() -> None: r"""Script entry point.""" # Parse command-line argument. args = parse_arg() # Save training configuration. lmp.util.cfg.save(args=args, exp_name=args.exp_name) # Set random seed for reproducibility. lmp.util.rand.set_seed(seed=args.seed) # Get dataset instance with specified version. dset = lmp.util.dset.load(dset_name=args.dset_name, ver=args.ver) # Mini-batch random sampler. dldr = torch.utils.data.DataLoader( dataset=dset, batch_size=args.batch_size, shuffle=True, ) # Load pre-trained tokenizer. tknzr_cfg = lmp.util.cfg.load(exp_name=args.tknzr_exp_name) tknzr = lmp.util.tknzr.load( exp_name=args.tknzr_exp_name, tknzr_name=tknzr_cfg.tknzr_name, ) # Get model running device. device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') # Get new model instance. model = lmp.util.model.create(tknzr=tknzr, **args.__dict__) model = model.train() # Move model to running device. model = model.to(device) # Remove weight decay on bias and layer-norm. no_decay = ['bias', 'LayerNorm.weight'] optim_group_params = [ { 'params': [ param for name, param in model.named_parameters() if not any(nd in name for nd in no_decay) ], 'weight_decay': args.wd, }, { 'params': [ param for name, param in model.named_parameters() if any(nd in name for nd in no_decay) ], 'weight_decay': 0.0, }, ] # Get new optimizer instance. optim = torch.optim.AdamW( optim_group_params, betas=(args.beta1, args.beta2), lr=args.lr, eps=args.eps, ) # Get tensorboard logger instance. writer = lmp.util.log.get_tb_logger(exp_name=args.exp_name) # Log performance target. pre_avg_loss = 0.0 avg_loss = 0.0 # Global optimization step. step = 0 for epoch in range(args.n_epoch): tqdm_dldr = tqdm( dldr, desc=f'epoch: {epoch}, loss: {pre_avg_loss:.6f}', ) for batch_txt in tqdm_dldr: # Encode batch text into batch token ids. batch_tkids = tknzr.batch_enc( batch_txt=batch_txt, max_seq_len=args.max_seq_len, ) # Convert batch token ids to `torch.Tensor` with # `dtype == torch.int64`. batch_tkids = torch.LongTensor(batch_tkids) # Move tensors to model running device. batch_tkids = batch_tkids.to(device) # Format batch token ids to satisfy language model training format. batch_prev_tkids = batch_tkids[..., :-1] batch_next_tkids = batch_tkids[..., 1:] # Calculate loss using loss function. loss = model.loss_fn( batch_next_tkids=batch_next_tkids, batch_prev_tkids=batch_prev_tkids, ) # Accumulate average loss. avg_loss += loss.item() # Backward pass / back propagation. loss.backward() # Perform gradient clipping to avoid gradient explosion. torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=args.max_norm, ) # Gradient descent. optim.step() # Clean up gradient. # This is needed only in `torch`. optim.zero_grad() # Increment global step. step += 1 # Save checkpoint for each `ckpt_step` step. if step % args.ckpt_step == 0: model.save(ckpt=step, exp_name=args.exp_name) # Log performance for each `log_step` step. if step % args.log_step == 0: avg_loss = avg_loss / args.log_step # Log on CLI. tqdm_dldr.set_description( f'epoch: {epoch}, loss: {avg_loss:.6f}', ) # Log on tensorboard writer.add_scalar( f'loss/{args.dset_name}/{args.ver}', avg_loss, step, ) # Refresh log performance. pre_avg_loss = avg_loss avg_loss = 0.0 # Save last checkpoint. model.save(ckpt=step, exp_name=args.exp_name) # Close tensorboard logger. writer.close()
def train_model(checkpoint: int, checkpoint_step: int, data_loader: torch.utils.data.DataLoader, device: torch.device, epoch: int, experiment: str, max_norm: float, model: lmp.model.BaseRNNModel, optimizer: Union[torch.optim.SGD, torch.optim.Adam, ], vocab_size: int): r"""Helper function for training language model. Continue training from pre-trained checkpoint when `checkpoint != -1`. Args: checkpoint: Pre-trained model's checkpoint. checkpoint_step: Checkpoint save interval. data_loader: `torch.utils.data.DataLoader` for sampling. device: Model running device. epoch: Number of training epoch. experiment: Name of the current experiment. max_norm: Maximum gradient norm. model: Language model. optimizer: Language model's optimizer. vocab_size: Number of classes to predict. """ # Set experiment output folder. file_dir = f'{lmp.path.DATA_PATH}/{experiment}' if not os.path.exists(file_dir): os.makedirs(file_dir) # Set experiment log folder. writer = torch.utils.tensorboard.SummaryWriter( f'{lmp.path.DATA_PATH}/log/{experiment}') # Define objective function. criterion = torch.nn.CrossEntropyLoss() # Step = number of updates. # Every update must increment `step`. step = 0 # Set model to train mode. model.train() # Clean up gradient in model parameters. model.zero_grad() # Initialize total loss. total_loss = 0.0 for cur_epoch in range(epoch): epoch_iterator = tqdm(data_loader, desc=f'epoch: {cur_epoch}, loss: {0:.6f}') for x, y in epoch_iterator: # Increment step for each update. step += 1 # Continue training from previous checkpoint step. if step < checkpoint: continue # Put tensors on to specified device (CPU or GPU). Reshape `y` into # shape (B x S) for cross-entropy. # x.size = (B, S) # y.size = (B x S) x = x.to(device) y = y.reshape(-1).to(device) # Forward pass. # pred_y_logits.size = (B, S, V) pred_y_logits = model(x) # Reshape `pred_y_logits` into shape (B x S, V) for cross-entropy. pred_y_logits = pred_y_logits.reshape(-1, vocab_size) # Perform cross-entropy. loss = criterion(pred_y_logits, y) # Calculate total loss. total_loss += loss.item() # Log loss. epoch_iterator.set_description( f'epoch: {cur_epoch}, loss: {loss.item():.6f}') # Backward pass. loss.backward() # Perform gradient clipping to avoid gradient explosion. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) # Gradient descent. optimizer.step() # `torch` required manually clean up gradient. optimizer.zero_grad() # Save checkpoint for each `checkpoint_step`. if step % checkpoint_step == 0: torch.save(model.state_dict(), os.path.join(file_dir, f'model-{step}.pt')) torch.save(optimizer.state_dict(), os.path.join(file_dir, f'optimizer-{step}.pt')) # Log average loss. writer.add_scalar(f'{experiment}/loss', total_loss / checkpoint_step, step) total_loss = 0.0 # Save last checkpoint. torch.save(model.state_dict(), os.path.join(file_dir, f'model-{step}.pt')) torch.save(optimizer.state_dict(), os.path.join(file_dir, f'optimizer-{step}.pt'))
def test_load_result(self): r"""Load result must be consistent.""" msg = 'Inconsistent load result.' test_path = os.path.join( self.__class__.test_dir, f'optimizer-{self.__class__.checkpoint}.pt' ) for ( d_emb, d_hid, dropout, model_cstr, num_linear_layers, num_rnn_layers, pad_token_id, vocab_size, learning_rate, (optimizer_class, optimizer_cstr) ) in product( *self.__class__.model_parameters.values(), *self.__class__.optimizer_parameters.values() ): if vocab_size <= pad_token_id: continue model = model_cstr( d_emb=d_emb, d_hid=d_hid, dropout=dropout, num_linear_layers=num_linear_layers, num_rnn_layers=num_rnn_layers, pad_token_id=pad_token_id, vocab_size=vocab_size ) try: # Create test file. ans_optimizer = optimizer_cstr( params=model.parameters(), lr=learning_rate ) torch.save(ans_optimizer.state_dict(), test_path) self.assertTrue(os.path.exists(test_path), msg=msg) optimizer_1 = lmp.util.load_optimizer( checkpoint=-1, experiment=self.__class__.experiment, learning_rate=learning_rate, optimizer_class=optimizer_class, parameters=model.parameters() ) optimizer_2 = lmp.util.load_optimizer( checkpoint=self.__class__.checkpoint, experiment=self.__class__.experiment, learning_rate=learning_rate, optimizer_class=optimizer_class, parameters=model.parameters() ) self.assertEqual( len(list(ans_optimizer.state_dict())), len(list(optimizer_1.state_dict())), msg=msg ) self.assertEqual( len(list(ans_optimizer.state_dict())), len(list(optimizer_2.state_dict())), msg=msg ) for p1, p2 in zip( ans_optimizer.state_dict(), optimizer_2.state_dict() ): self.assertTrue((p1 == p2), msg=msg) finally: # Clean up test file. os.remove(test_path)
def test_return_type(self): r"""Return `torch.optim.SGD` or `torch.optim.Adam`.""" msg = 'Must return `torch.optim.SGD` or `torch.optim.Adam`.' test_path = os.path.join( self.__class__.test_dir, f'optimizer-{self.__class__.checkpoint}.pt' ) for ( d_emb, d_hid, dropout, model_cstr, num_linear_layers, num_rnn_layers, pad_token_id, vocab_size, learning_rate, (optimizer_class, optimizer_cstr) ) in product( *self.__class__.model_parameters.values(), *self.__class__.optimizer_parameters.values() ): if vocab_size <= pad_token_id: continue model = model_cstr( d_emb=d_emb, d_hid=d_hid, dropout=dropout, num_linear_layers=num_linear_layers, num_rnn_layers=num_rnn_layers, pad_token_id=pad_token_id, vocab_size=vocab_size ) optimizer_1 = lmp.util.load_optimizer( checkpoint=-1, experiment=self.__class__.experiment, learning_rate=learning_rate, optimizer_class=optimizer_class, parameters=model.parameters() ) self.assertIsInstance(optimizer_1, optimizer_cstr, msg=msg) try: # Create test file. torch.save(optimizer_1.state_dict(), test_path) self.assertTrue(os.path.exists(test_path), msg=msg) optimizer_2 = lmp.util.load_optimizer( checkpoint=self.__class__.checkpoint, experiment=self.__class__.experiment, learning_rate=learning_rate, optimizer_class=optimizer_class, parameters=model.parameters() ) self.assertIsInstance(optimizer_2, optimizer_cstr, msg=msg) finally: # Clean up test file. os.remove(test_path)
def main(argv: List[str]) -> None: """Script entry point. Parameters ---------- argv: list[str] List of CLI arguments. Returns ------- None """ # Parse CLI arguments. args = parse_args(argv=argv) # `args.batch_size` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[1, args.batch_size], val_names=['1', 'args.batch_size']) # `args.ckpt_step` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[1, args.ckpt_step], val_names=['1', 'args.ckpt_step']) # `args.log_step` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[1, args.log_step], val_names=['1', 'args.log_step']) # `args.max_norm` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[0, args.max_norm], val_names=['0', 'args.max_norm']) # `args.n_epoch` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[1, args.n_epoch], val_names=['1', 'args.n_epoch']) # `args.n_worker` validation. lmp.util.validate.raise_if_wrong_ordered( vals=[0, args.n_worker, len(os.sched_getaffinity(0))], val_names=['0', 'args.n_worker', 'number of available CPUs'], ) lmp.util.validate.raise_if_wrong_ordered( vals=[args.n_worker, args.batch_size], val_names=['args.n_worker', 'args.batch_size'], ) # Save training configuration. lmp.util.cfg.save(args=args, exp_name=args.exp_name) # Set random seed for reproducibility. lmp.util.rand.set_seed(seed=args.seed) # Get model running device. device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') # Load pre-trained tokenizer. tknzr = lmp.util.tknzr.load(exp_name=args.tknzr_exp_name) # Get dataset instance and convert samples to tensor. if args.is_dset_in_memory: dset: torch.utils.data.Dataset = lmp.util.dset.FastTensorDset( dset=lmp.util.dset.load(**args.__dict__), max_seq_len=args.max_seq_len, tknzr=tknzr, ) else: dset = lmp.util.dset.SlowTensorDset( dset=lmp.util.dset.load(**args.__dict__), max_seq_len=args.max_seq_len, tknzr=tknzr, ) # Mini-batch random sampler. Only when `args.n_worker > 0` we set `persisten_worker = True`. We set # `pin_memory = True` to speed up process (which only speed up a few seconds). data_loader = torch.utils.data.DataLoader( batch_size=args.batch_size, dataset=dset, shuffle=True, num_workers=args.n_worker, persistent_workers=bool(args.n_worker != 0), pin_memory=True, ) # Get new model instance and move model to running device. model = lmp.util.model.create(tknzr=tknzr, **args.__dict__) model = model.train() model = model.to(device) # Get new optimizer instance. optim = lmp.util.optim.get_optimizer( beta1=args.beta1, beta2=args.beta2, eps=args.eps, lr=args.lr, model=model, wd=args.wd, ) # Get learning rate scheduler. schdl = lmp.util.optim.get_scheduler( optim=optim, total_step=args.n_epoch * len(data_loader), warmup_step=args.warmup_step, ) # Get tensorboard logger instance. writer = lmp.util.log.get_tb_logger(exp_name=args.exp_name) # Log performance target. pre_avg_loss = 0.0 avg_loss = 0.0 # Global optimization step. step = 0 for epoch in range(args.n_epoch): tqdm_data_loader = tqdm(data_loader, desc=f'epoch: {epoch}, loss: {pre_avg_loss:.6f}', dynamic_ncols=True) for batch_tkids in tqdm_data_loader: # Encode batch text into batch token ids. We convert batch token ids into tensor and move to tensor to the same # running device as model. batch_tkids = batch_tkids.to(device) # Format batch token ids to satisfy language model training format. batch_cur_tkids = batch_tkids[..., :-1] batch_next_tkids = batch_tkids[..., 1:] # Calculate loss using loss function. loss = model(batch_cur_tkids=batch_cur_tkids, batch_next_tkids=batch_next_tkids) # Accumulate average loss for logging. Use `.item()` to avoid construct tensor graph. avg_loss += loss.item() # Perform backward pass / back propagation. loss.backward() # Perform gradient clipping to avoid gradient explosion. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_norm) # Gradient descent. optim.step() # Update learning rate. schdl.step() # Clean up gradient. optim.zero_grad() # Increment global step. step += 1 # Save checkpoint for each `ckpt_step` step. We move model to CPU first then move back to CUDA device. if step % args.ckpt_step == 0: lmp.util.model.save(ckpt=step, exp_name=args.exp_name, model=copy.deepcopy(model).to('cpu')) # Log performance for each `log_step` step. if step % args.log_step == 0: avg_loss = avg_loss / args.log_step # Log on CLI. tqdm_data_loader.set_description(f'epoch: {epoch}, loss: {avg_loss:.6f}') # Log on tensorboard. writer.add_scalar(f'train-loss/{args.dset_name}/{args.ver}', avg_loss, step) writer.add_scalar('lr', schdl.get_last_lr()[0], step) # Refresh log performance. pre_avg_loss = avg_loss avg_loss = 0.0 # Save last checkpoint. lmp.util.model.save(ckpt=step, exp_name=args.exp_name, model=copy.deepcopy(model).to('cpu')) # Close tensorboard logger. writer.close() # Free memory. This is only need for unit test. del args del avg_loss del batch_cur_tkids del batch_next_tkids del batch_tkids del data_loader del device del dset del loss del model del optim del pre_avg_loss del schdl del step del tknzr del tqdm_data_loader del writer torch.cuda.empty_cache() gc.collect()
def main(argv: List[str]) -> None: """Script entry point. Parameters ---------- argv: list[str] List of CLI arguments. Returns ------- None """ # Parse CLI arguments. args = parse_args(argv=argv) # `args.batch_size` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[1, args.batch_size], val_names=['1', 'args.batch_size']) # `args.ckpt_step` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[1, args.ckpt_step], val_names=['1', 'args.ckpt_step']) # `args.log_step` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[1, args.log_step], val_names=['1', 'args.log_step']) # `args.max_norm` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[0, args.max_norm], val_names=['0', 'args.max_norm']) # `args.n_epoch` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[1, args.n_epoch], val_names=['1', 'args.n_epoch']) # `args.n_worker` validation. lmp.util.validate.raise_if_wrong_ordered( vals=[0, args.n_worker, len(os.sched_getaffinity(0))], val_names=['0', 'args.n_worker', 'number of available CPUs'], ) lmp.util.validate.raise_if_wrong_ordered( vals=[args.n_worker, args.batch_size], val_names=['args.n_worker', 'args.batch_size'], ) # `args.world_size` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[0, args.world_size], val_names=['0', 'args.world_size']) # `args.local_rank` validation. lmp.util.validate.raise_if_wrong_ordered(vals=[0, args.local_rank], val_names=['0', 'args.local_rank']) # `args.rank` validation. lmp.util.validate.raise_if_wrong_ordered( vals=[0, args.rank, args.world_size - 1], val_names=['0', 'args.rank', 'args.world_size - 1'], ) # Save training configuration. Only main process need to save configuration. if args.rank == HOST_RANK: lmp.util.cfg.save(args=args, exp_name=args.exp_name) # We use TCP to perform RPC. Timeout is set to 5 minutes. store = dist.TCPStore( is_master=args.rank == HOST_RANK, host_name=args.host_name, port=args.host_port, timeout=timedelta(minutes=5), world_size=args.world_size, ) # Use NCCL backend to perform CUDA collectives. dist.init_process_group( backend=dist.Backend.NCCL, store=store, rank=args.rank, timeout=timedelta(minutes=5), world_size=args.world_size, ) # Sync arguments. dist_args_k = ['host_name', 'host_port', 'local_rank', 'rank', 'world_size'] for k in args.__dict__.keys(): if k in dist_args_k: continue # Host broadcast arguments. if args.rank == HOST_RANK: store.set(k, str(args.__dict__[k])) # Non-host receive host arguments. else: v = store.get(k) if isinstance(args.__dict__[k], str): args.__dict__[k] = v.decode('utf-8') else: args.__dict__[k] = type(args.__dict__[k])(v) # Set random seed for reproducibility. Note that each process use different seed to get different slice of batch. lmp.util.rand.set_seed(seed=args.seed + args.rank) # Get model running device. device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device(f'cuda:{args.local_rank}') # Load pre-trained tokenizer. tknzr = lmp.util.tknzr.load(exp_name=args.tknzr_exp_name) # Get dataset instance and convert samples to tensor. if args.is_dset_in_memory: dset: torch.utils.data.Dataset = lmp.util.dset.FastTensorDset( dset=lmp.util.dset.load(**args.__dict__), max_seq_len=args.max_seq_len, tknzr=tknzr, ) else: dset = lmp.util.dset.SlowTensorDset( dset=lmp.util.dset.load(**args.__dict__), max_seq_len=args.max_seq_len, tknzr=tknzr, ) # Mini-batch sampler. Each process will get batches exclusive to itself. dist_sampler = torch.utils.data.distributed.DistributedSampler( num_replicas=args.world_size, rank=args.rank, dataset=dset, shuffle=True, ) # Mini-batch distributed random sampler. Only when `args.n_worker > 0` we set `persisten_worker = True`. We set # `pin_memory = True` to speed up process (which only speed up a few seconds). data_loader = torch.utils.data.DataLoader( batch_size=args.batch_size // args.world_size, dataset=dset, num_workers=args.n_worker, persistent_workers=bool(args.n_worker != 0), pin_memory=True, sampler=dist_sampler, ) # Get new model instance and move model to running device. model = lmp.util.model.create(tknzr=tknzr, **args.__dict__) model = model.train() model = model.to(device) # Get new optimizer instance. optim = lmp.util.optim.get_optimizer( beta1=args.beta1, beta2=args.beta2, eps=args.eps, lr=args.lr, model=model, wd=args.wd, ) # Get learning rate scheduler. schdl = lmp.util.optim.get_scheduler( optim=optim, total_step=args.n_epoch * len(data_loader), warmup_step=args.warmup_step, ) # Create DDP model. ddp_model = torch.nn.parallel.DistributedDataParallel(model) # Get tensorboard logger instance. Only main process need to log performance. if args.rank == HOST_RANK: writer = lmp.util.log.get_tb_logger(exp_name=args.exp_name) else: writer = None # Log performance target. pre_avg_loss = 0.0 avg_loss = 0.0 # Global optimization step. step = 0 for epoch in range(args.n_epoch): # Update random sample order. dist_sampler.set_epoch(epoch) # Processes can have unevenly distributed number of batch. Thus one must use `ddp_model.join()` to avoid dead lock. with ddp_model.join(): tqdm_data_loader = tqdm(data_loader, desc=f'epoch: {epoch}, loss: {pre_avg_loss:.6f}', dynamic_ncols=True) for batch_tkids in tqdm_data_loader: # Encode batch text into batch token ids. We convert batch token ids into tensor and move to tensor to the same # running device as model. batch_tkids = batch_tkids.to(device) # Format batch token ids to satisfy language model training format. batch_cur_tkids = batch_tkids[..., :-1] batch_next_tkids = batch_tkids[..., 1:] # Calculate loss using loss function. loss = ddp_model(batch_cur_tkids=batch_cur_tkids, batch_next_tkids=batch_next_tkids) # Accumulate average loss for logging. Use `.item()` to avoid construct tensor graph. avg_loss += loss.item() # Perform backward pass / back propagation. loss.backward() # Perform gradient clipping to avoid gradient explosion. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_norm) # Gradient descent. optim.step() # Update learning rate. schdl.step() # Clean up gradient. optim.zero_grad() # Increment global step. step += 1 # Save checkpoint for each `ckpt_step` step. We move model to CPU first then move back to CUDA device. Only # main process need to save checkpoint. if args.rank == HOST_RANK and step % args.ckpt_step == 0: lmp.util.model.save(ckpt=step, exp_name=args.exp_name, model=copy.deepcopy(model).to('cpu')) # Log performance for each `log_step` step. if step % args.log_step == 0: avg_loss = avg_loss / args.log_step # Log on CLI. tqdm_data_loader.set_description(f'epoch: {epoch}, loss: {avg_loss:.6f}') # Log on tensorboard. Only main process need to log performance. if args.rank == HOST_RANK: writer.add_scalar(f'train-loss/{args.dset_name}/{args.ver}', avg_loss, step) writer.add_scalar('lr', schdl.get_last_lr()[0], step) # Refresh log performance. pre_avg_loss = avg_loss avg_loss = 0.0 # Save last checkpoint. Only main process need to save checkpoint. if args.rank == HOST_RANK: lmp.util.model.save(ckpt=step, exp_name=args.exp_name, model=copy.deepcopy(model).to('cpu')) # Close tensorboard logger. writer.close() # Free memory. This is only need for unit test. del args del avg_loss del batch_cur_tkids del batch_next_tkids del batch_tkids del data_loader del device del dist_args_k del dist_sampler del ddp_model del dset del loss del model del optim del pre_avg_loss del schdl del step del store del tknzr del tqdm_data_loader del writer torch.cuda.empty_cache() gc.collect()