def find_learning_rate(self, base_path: Union[(Path, str)], file_name: str = 'learning_rate.tsv', start_learning_rate: float = 1e-07, end_learning_rate: float = 10, iterations: int = 100, mini_batch_size: int = 32, stop_early: bool = True, smoothing_factor: float = 0.98, **kwargs) -> Path: best_loss = None moving_avg_loss = 0 if (type(base_path) is str): base_path = Path(base_path) learning_rate_tsv = init_output_file(base_path, file_name) with open(learning_rate_tsv, 'a') as f: f.write('ITERATION\tTIMESTAMP\tLEARNING_RATE\tTRAIN_LOSS\n') optimizer = self.optimizer(self.model.parameters(), lr=start_learning_rate, **kwargs) train_data = self.corpus.train batch_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True) scheduler = ExpAnnealLR(optimizer, end_learning_rate, iterations) model_state = self.model.state_dict() model_device = next(self.model.parameters()).device self.model.train() for (itr, batch) in enumerate(batch_loader): loss = self.model.forward_loss(batch) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) optimizer.step() scheduler.step(1) learning_rate = scheduler.get_lr()[0] loss_item = loss.item() if (itr == 0): best_loss = loss_item else: if (smoothing_factor > 0): moving_avg_loss = ((smoothing_factor * moving_avg_loss) + ((1 - smoothing_factor) * loss_item)) loss_item = (moving_avg_loss / (1 - (smoothing_factor**(itr + 1)))) if (loss_item < best_loss): best_loss = loss if (stop_early and ((loss_item > (4 * best_loss)) or torch.isnan(loss))): log_line(log) log.info('loss diverged - stopping early!') break if (itr > iterations): break with open(str(learning_rate_tsv), 'a') as f: f.write(''.join([ '{}'.format(itr), '\t', '{:%H:%M:%S}'.format(datetime.datetime.now()), '\t', '{}'.format(learning_rate), '\t', '{}'.format(loss_item), '\n' ])) self.model.load_state_dict(model_state) self.model.to(model_device) log_line(log) log.info(''.join([ 'learning rate finder finished - plot ', '{}'.format(learning_rate_tsv) ])) log_line(log) return Path(learning_rate_tsv)
def find_learning_rate( self, base_path: Union[Path, str], file_name: str = "learning_rate.tsv", start_learning_rate: float = 1e-7, end_learning_rate: float = 10, iterations: int = 100, mini_batch_size: int = 32, stop_early: bool = True, smoothing_factor: float = 0.98, **kwargs, ) -> Path: best_loss = None moving_avg_loss = 0 # cast string to Path if type(base_path) is str: base_path = Path(base_path) learning_rate_tsv = init_output_file(base_path, file_name) with open(learning_rate_tsv, "a") as f: f.write("ITERATION\tTIMESTAMP\tLEARNING_RATE\tTRAIN_LOSS\n") optimizer = self.optimizer(self.model.parameters(), lr=start_learning_rate, **kwargs) train_data = self.corpus.train scheduler = ExpAnnealLR(optimizer, end_learning_rate, iterations) model_state = self.model.state_dict() self.model.train() step = 0 while step < iterations: batch_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True) for batch in batch_loader: step += 1 # forward pass loss = self.model.forward_loss(batch) # update optimizer and scheduler optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) optimizer.step() scheduler.step(step) print(scheduler.get_lr()) learning_rate = scheduler.get_lr()[0] loss_item = loss.item() if step == 1: best_loss = loss_item else: if smoothing_factor > 0: moving_avg_loss = (smoothing_factor * moving_avg_loss + (1 - smoothing_factor) * loss_item) loss_item = moving_avg_loss / (1 - smoothing_factor** (step + 1)) if loss_item < best_loss: best_loss = loss if step > iterations: break if stop_early and (loss_item > 4 * best_loss or torch.isnan(loss)): log_line(log) log.info("loss diverged - stopping early!") step = iterations break with open(str(learning_rate_tsv), "a") as f: f.write( f"{step}\t{datetime.datetime.now():%H:%M:%S}\t{learning_rate}\t{loss_item}\n" ) self.model.load_state_dict(model_state) self.model.to(flair.device) log_line(log) log.info(f"learning rate finder finished - plot {learning_rate_tsv}") log_line(log) return Path(learning_rate_tsv)
def find_learning_rate( self, base_path: Union[Path, str], optimizer, mini_batch_size: int = 32, start_learning_rate: float = 1e-7, end_learning_rate: float = 10, iterations: int = 1000, stop_early: bool = True, file_name: str = "learning_rate.tsv", **kwargs, ) -> Path: best_loss = None # cast string to Path if type(base_path) is str: base_path = Path(base_path) base_path.mkdir(exist_ok=True, parents=True) learning_rate_tsv = init_output_file(base_path, file_name) with open(learning_rate_tsv, "a") as f: f.write("ITERATION\tTIMESTAMP\tLEARNING_RATE\tTRAIN_LOSS\n") optimizer = optimizer(self.model.parameters(), lr=start_learning_rate, **kwargs) train_data = self.corpus.train scheduler = ExpAnnealLR(optimizer, end_learning_rate, iterations) model_state = self.model.state_dict() self.model.train() step = 0 loss_list = [] average_loss_list = [] while step < iterations: batch_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True) for batch in batch_loader: step += 1 # forward pass loss = self.model.forward_loss(batch) if isinstance(loss, Tuple): loss = loss[0] # update optimizer and scheduler optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) optimizer.step() scheduler.step() learning_rate = scheduler.get_lr()[0] # append current loss to list of losses for all iterations loss_list.append(loss.item()) # compute averaged loss import statistics moving_avg_loss = statistics.mean(loss_list) average_loss_list.append(moving_avg_loss) if len(average_loss_list) > 10: drop = average_loss_list[-10] - moving_avg_loss else: drop = 0. if not best_loss or moving_avg_loss < best_loss: best_loss = moving_avg_loss if step > iterations: break if stop_early and (moving_avg_loss > 4 * best_loss or torch.isnan(loss)): log_line(log) log.info("loss diverged - stopping early!") step = iterations break with open(str(learning_rate_tsv), "a") as f: f.write(f"{step}\t{learning_rate}\t{loss.item()}\t{moving_avg_loss}\t{drop}\n") self.model.load_state_dict(model_state) self.model.to(flair.device) log_line(log) log.info(f"learning rate finder finished - plot {learning_rate_tsv}") log_line(log) return Path(learning_rate_tsv)