def run_valid_epoch(self, data_iter: BatchIterable) -> float: state = TrainerState(self.model, -1) with tqdm(data_iter, total=data_iter.num_batches, unit='batch', dynamic_ncols=True) as data_bar: for i, batch in enumerate(data_bar): batch = batch.to(device) # Step clear gradients self.model.zero_grad() # Step Run forward pass. outp_log_probs = self.model(batch) loss = self.loss_func(outp_log_probs, batch, train_mode=False) bar_msg, _ = state.step(batch.y_toks, loss) data_bar.set_postfix_str(bar_msg, refresh=False) del batch return state.running_loss()
def train(self, steps: int, check_point: int, batch_size: int, fine_tune=False, check_pt_callback: Optional[Callable] = None, **args): log.info(f'Going to train for {steps} steps; batch_size={batch_size}; ' f'check point size:{check_point}; fine tune={fine_tune}') keep_models = args.get('keep_models', 4) # keep last _ models and delete the old sort_by = args.get('sort_by', 'random') if steps <= self.start_step: raise Exception(f'The model was already trained to {self.start_step} steps. ' f'Please increase the steps or clear the existing models') train_data = self.exp.get_train_data(batch_size=batch_size, steps=steps - self.start_step, sort_by=sort_by, shuffle=True, batch_first=True, fine_tune=fine_tune) val_data = self.exp.get_val_data(batch_size, shuffle=False, batch_first=True, sort_desc=True) train_state = TrainerState(self.model, check_point=check_point) train_state.train_mode(True) unsaved_state = False with tqdm(train_data, initial=self.start_step, total=steps, unit='batch', dynamic_ncols=True) as data_bar: for batch in data_bar: batch = batch.to(device) # Step clear gradients self.model.zero_grad() # Step Run forward pass. outp_log_probs = self.model(batch) loss = self.loss_func(outp_log_probs, batch, True) unsaved_state = True self.tbd.add_scalars('training', {'step_loss': loss, 'learn_rate': self.opt.curr_lr}, self.opt.curr_step) bar_msg, is_check_pt = train_state.step(batch.y_toks, loss) bar_msg += f', LR={self.opt.curr_lr:g}' data_bar.set_postfix_str(bar_msg, refresh=False) del batch # TODO: force free memory if is_check_pt: train_loss = train_state.reset() train_state.train_mode(False) val_loss = self.run_valid_epoch(val_data) self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models) if check_pt_callback: check_pt_callback(model=self.model, step=self.opt.curr_step, train_loss=train_loss) train_state.train_mode(True) unsaved_state = False if unsaved_state: # End of training train_loss = train_state.reset() train_state.train_mode(False) val_loss = self.run_valid_epoch(val_data) self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models)
def train(self, steps: int, check_point: int, batch_size: int, check_pt_callback: Optional[Callable] = None, fine_tune=False, dec_bos_cut=False, keep_models=10, sort_by='eq_len_rand_batch', log_interval: int = 10, **args): """ :param steps: how many optimizer steps to train (also, means how many batches) :param check_point: after how many checkpoints to :param batch_size: how many target tokens in batch max ( = max_len * num_sentences) :param check_pt_callback: function to call back after checkpt :param fine_tune: should the fine tune corpus be used instead of training corpus :param dec_bos_cut: copy the first time step of input as decoder's BOS :param keep_models: how many to checkpts to keep :param args: any extra args :return: """ log_resources = args.pop('log_resources', False) assert log_interval > 0 if args: # no extra args. let user know if an extra arg is passed raise Exception(f" Found extra args: {args}") log.info( f'Going to train for {steps} steps (from {self.start_step} steps);' f' batch_size={batch_size} toks; sort_by={sort_by};' f' check point size:{check_point}; fine_tune={fine_tune};' f' dec_bos_cut={dec_bos_cut}') if self.n_gpus > 1: batch_size *= self.n_gpus log.info( f"# GPUs = {self.n_gpus}, batch_size is set to {batch_size}") if steps <= self.start_step: raise Exception( f'The model was already trained to {self.start_step} steps. ' f'Please increase the steps or clear the existing models') train_data = self.exp.get_train_data(batch_size=batch_size, steps=steps - self.start_step, sort_by=sort_by, batch_first=True, fine_tune=fine_tune) val_data = self.exp.get_val_data(batch_size, shuffle=False, batch_first=True, sort_desc=False) train_state = TrainerState(self.model, check_point=check_point) train_state.train_mode(True) unsaved_state = False cuda_available = torch.cuda.is_available() with tqdm(train_data, initial=self.start_step, total=steps, unit='batch', dynamic_ncols=True) as data_bar: for batch in data_bar: self.model.zero_grad() batch = batch.to(device) num_toks = batch.y_toks x_seqs = batch.x_seqs if dec_bos_cut: bos_step = x_seqs[:, :1] x_seqs = x_seqs[:, 1:] else: bos_step = torch.full((len(batch), 1), fill_value=Batch.bos_val, dtype=torch.long, device=device) x_mask = (x_seqs != batch.pad_value).unsqueeze(1) y_seqs_with_bos = torch.cat([bos_step, batch.y_seqs], dim=1) y_mask = Batch.make_target_mask(y_seqs_with_bos) out = self.model(x_seqs, y_seqs_with_bos, x_mask, y_mask) # [Batch x Time x D] # skip the last time step (the one with EOS as input) out = out[:, :-1, :] # assumption: y_seqs has EOS, and not BOS loss = self.loss_func(out, batch.y_seqs, num_toks, True) unsaved_state = True if self.opt.curr_step % log_interval == 0: self.tbd.add_scalars('training', { 'step_loss': loss, 'learn_rate': self.opt.curr_lr }, self.opt.curr_step) if log_resources and cuda_available: self._log_resources(batch) progress_msg, is_check_pt = train_state.step(num_toks, loss) progress_msg += f', LR={self.opt.curr_lr:g}' data_bar.set_postfix_str(progress_msg, refresh=False) del batch if is_check_pt: train_loss = train_state.reset() train_state.train_mode(False) val_loss = self.run_valid_epoch(val_data, dec_bos_cut=dec_bos_cut) self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models) if check_pt_callback: check_pt_callback(model=self.model, step=self.opt.curr_step, train_loss=train_loss) train_state.train_mode(True) unsaved_state = False gc.collect() # End of training if unsaved_state: train_loss = train_state.reset() train_state.train_mode(False) val_loss = self.run_valid_epoch(val_data, dec_bos_cut=dec_bos_cut) self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models)
def train(self, steps: int, check_point: int, batch_size: int, check_pt_callback: Optional[Callable] = None, keep_models=4, **args): log.info( f'Going to train for {steps} epochs; batch_size={batch_size}; ' f'check point size:{check_point}') rem_steps = steps - self.start_step if rem_steps <= 0: raise Exception( f'The model was already trained to {self.start_step} steps. ' f'Please increase the steps or clear the existing models') side = 'tgt' # TODO: this should be inferrable or configurable instead of hardcoded train_data = self.exp.get_mono_data('train', side, batch_size=batch_size, batch_first=True, sort_dec=False, num_batches=rem_steps, shuffle=True) val_data = self.exp.get_mono_data('valid', side, batch_size=batch_size, batch_first=True, sort_dec=False) train_state = TrainerState(self.model, check_point=check_point) train_state.train_mode(True) unsaved_state = False with tqdm(train_data, initial=self.start_step, total=steps, unit='batch', dynamic_ncols=True) as data_bar: for batch in data_bar: self.model.zero_grad() assert batch.eos_x # must have EOS batch = batch.to(device) num_toks = batch.x_toks seqs = batch.x_seqs bos_step = torch.full((len(batch), 1), fill_value=Batch.bos_val, dtype=torch.long, device=device) seqs_with_bos = torch.cat([bos_step, batch.x_seqs], dim=1) seq_mask = Batch.make_target_mask(seqs_with_bos) out = self.model(seqs_with_bos, seq_mask, gen_probs=False) # [Batch x Time x D] # skip the last time step (the one with EOS as input) out = out[:, :-1, :] # assumption: y_seqs has EOS, and not BOS loss = self.loss_func(out, seqs, num_toks, True) unsaved_state = True self.tbd.add_scalars('training', { 'step_loss': loss, 'learn_rate': self.opt.curr_lr }, self.opt.curr_step) progress_msg, is_check_pt = train_state.step(num_toks, loss) progress_msg += f', LR={self.opt.curr_lr:g}' data_bar.set_postfix_str(progress_msg, refresh=False) del batch if is_check_pt: train_loss = train_state.reset() train_state.train_mode(False) val_loss = self.run_valid_epoch(val_data) self.make_check_point(train_loss, val_loss, keep_models=keep_models) if check_pt_callback: check_pt_callback(model=self.model, step=self.opt.curr_step, train_loss=train_loss) train_state.train_mode(True) unsaved_state = False if unsaved_state: # End of training train_loss = train_state.reset() train_state.train_mode(False) val_loss = self.run_valid_epoch(val_data) self.make_check_point(train_loss, val_loss, keep_models=keep_models)
def train(self, steps: int, check_point: int, batch_size: int, check_pt_callback: Optional[Callable] = None, **args): train_state = TrainerState(self.model, check_point=check_point) train_state.train_mode(True) if self.start_step >= steps: log.warning(f"Already trained to {self.start_step}. Considering it as done.") return rem_steps = steps - self.start_step side = 'tgt' # TODO: this should be inferrable or configurable instead of hardcoded train_data = self.exp.get_mono_data('train', side, batch_size=batch_size, batch_first=True, sort_dec=True, num_batches=rem_steps, shuffle=True) val_data = self.exp.get_mono_data('valid', side, batch_size=batch_size, batch_first=True, sort_dec=True) keep_models = 8 unsaved_state = False with tqdm(train_data, initial=self.start_step, total=steps, unit='batch', dynamic_ncols=True) as data_bar: for batch in data_bar: batch.to(device) outp_log_probs = self.model.batch_forward(batch) loss = self.simple_loss_func(outp_log_probs, seq_lens=batch.x_len, tot_toks=batch.x_toks, max_seq_len=batch.max_x_len, train_mode=True) unsaved_state = True bar_msg, is_check_pt = train_state.step(batch.x_toks, loss) data_bar.set_postfix_str(bar_msg, refresh=True) del batch # TODO: force free memory if is_check_pt: train_loss = train_state.reset() train_state.train_mode(False) val_loss = self.run_valid_epoch(val_data) self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models) if check_pt_callback: check_pt_callback(model=self.model, step=self.opt.curr_step, train_loss=train_loss) train_state.train_mode(True) unsaved_state = False log.info("End of training session") if unsaved_state: # End of training train_loss = train_state.reset() train_state.train_mode(False) val_loss = self.run_valid_epoch(val_data) self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models)
def train(self, steps: int, check_point: int, batch_size: int, check_pt_callback: Optional[Callable] = None, fine_tune=False, dec_bos_cut=False, keep_models=10, sort_by='eq_len_rand_batch', log_interval: int = 10, keep_in_mem=False, early_stop=None, **args): """ :param steps: how many optimizer steps to train (also, means how many batches) :param check_point: after how many checkpoints to :param batch_size: how many target tokens in batch max ( = max_len * num_sentences) :param check_pt_callback: function to call back after checkpt :param fine_tune: should the fine tune corpus be used instead of training corpus :param dec_bos_cut: copy the first time step of input as decoder's BOS :param keep_models: how many checkpts to keep :param keep_in_mem: keep training data in memory :param early_stop: {patience: N validations, by: loss, enabled: True} :param args: any extra args :return: """ log_resources = args.pop('log_resources', False) log_embedding = args.pop('log_embedding', False) split_ratio = args.pop('split_ratio', 0.) dynamic_epoch = args.pop('dynamic_epoch', False) assert log_interval > 0 # Gradient accumulation opt_steps = steps batches = steps * self.grad_accum_interval start_batch = self.start_step * self.grad_accum_interval check_point = check_point * self.grad_accum_interval if isinstance(batch_size, int): max_toks, max_sents = batch_size, float('inf') else: max_toks, max_sents = batch_size if args: # no extra args. let user know if an extra arg is passed raise Exception(f" Found extra args: {args}") log.info( f'Going to train for {opt_steps} optimizer steps over {batches} batches' f' (from {self.start_step} steps);' f' batch_size={batch_size} toks; sort_by={sort_by};' f' check point size:{check_point}; fine_tune={fine_tune};' f' dec_bos_cut={dec_bos_cut}') distr = DistribTorch.instance() if batches <= start_batch: raise Exception( f'The model was already trained to {self.start_step} steps. ' f'Please increase the steps or clear the existing models') train_data = self.exp.get_train_data(batch_size=batch_size, steps=batches - start_batch, sort_by=sort_by, batch_first=True, fine_tune=fine_tune, keep_in_mem=keep_in_mem, split_ratio=split_ratio, dynamic_epoch=dynamic_epoch) val_data = None if distr.is_global_main: val_data = self.exp.get_val_data(batch_size=max_toks, shuffle=False, batch_first=True, sort_desc=False) train_state = TrainerState(self.model, check_point=check_point) train_state.train_mode(True) unsaved_state = False cuda_available = torch.cuda.is_available() batch_count = -1 stopper = None early_stopped = False # or converged if early_stop: stopper = EarlyStopper(cur_step=self.start_step, **early_stop) with tqdm(train_data, initial=start_batch, total=batches, unit='batch', dynamic_ncols=True, disable=not distr.is_global_main) as data_bar: for batch in data_bar: batch_count += 1 take_step = (batch_count % self.grad_accum_interval) == 0 # if update_interval == 0: # self.model.zero_grad() # if not dataparallel, then move if self.n_gpus <= 1: batch = batch.to(device) num_toks = batch.y_toks x_seqs = batch.x_seqs if dec_bos_cut: bos_step = x_seqs[:, :1] x_seqs = x_seqs[:, 1:] else: bos_step = torch.full((len(batch), 1), fill_value=batch.bos_val, dtype=torch.long, device=batch.y_seqs.device) # Prep masks x_mask = (x_seqs != batch.pad_val).unsqueeze(1) y_seqs_with_bos = torch.cat([bos_step, batch.y_seqs], dim=1) y_mask = batch.make_autoreg_mask(y_seqs_with_bos) with autocast(enabled=dtorch.fp16): # [Batch x Time x D] out = self.model(x_seqs, y_seqs_with_bos, x_mask, y_mask) # skip the last time step (the one with EOS as input) out = out[:, :-1, :] # assumption: y_seqs has EOS, and not BOS loss = self.loss_func(out, batch.y_seqs, num_toks, train_mode=True, take_step=take_step) if stopper and take_step: stopper.step() # Log unsaved_state = True if self.opt.curr_step % log_interval == 0: self.tbd.add_scalars('training', { 'step_loss': loss, 'learn_rate': self.opt.curr_lr }, self.opt.curr_step) if log_resources and cuda_available: self._log_resources(batch) progress_msg, is_check_pt = train_state.step(num_toks, loss) progress_msg += f', LR={self.opt.curr_lr:0.8f}' data_bar.set_postfix_str(progress_msg, refresh=False) del batch # Save checkpoint if is_check_pt: train_loss = train_state.reset() log.info( f"Chkpt Train loss={train_loss}; Runs validation? {distr.is_global_main}" ) if distr.is_global_main: train_state.train_mode(False) with torch.no_grad(): val_loss = self.run_valid_epoch( val_data, dec_bos_cut=dec_bos_cut) self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models, log_embedding=log_embedding) if check_pt_callback: check_pt_callback(model=self.model, step=self.opt.curr_step, train_loss=train_loss) train_state.train_mode(True) if stopper: stopper.validation(val_loss) if stopper.is_stop(): log.info( f"Stopping at {stopper.cur_step} because {stopper.by}" f" didnt improve over {stopper.patience} checkpoints" ) early_stopped = True break unsaved_state = False gc.collect() distr.barrier() # End of training if unsaved_state and distr.is_global_main: train_loss = train_state.reset() train_state.train_mode(False) val_loss = self.run_valid_epoch(val_data, dec_bos_cut=dec_bos_cut) self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models) distr.barrier() return early_stopped