예제 #1
0
파일: trainer.py 프로젝트: azawalich/flair
 def find_learning_rate(self,
                        base_path: Union[(Path, str)],
                        file_name: str = 'learning_rate.tsv',
                        start_learning_rate: float = 1e-07,
                        end_learning_rate: float = 10,
                        iterations: int = 100,
                        mini_batch_size: int = 32,
                        stop_early: bool = True,
                        smoothing_factor: float = 0.98,
                        **kwargs) -> Path:
     best_loss = None
     moving_avg_loss = 0
     if (type(base_path) is str):
         base_path = Path(base_path)
     learning_rate_tsv = init_output_file(base_path, file_name)
     with open(learning_rate_tsv, 'a') as f:
         f.write('ITERATION\tTIMESTAMP\tLEARNING_RATE\tTRAIN_LOSS\n')
     optimizer = self.optimizer(self.model.parameters(),
                                lr=start_learning_rate,
                                **kwargs)
     train_data = self.corpus.train
     batch_loader = DataLoader(train_data,
                               batch_size=mini_batch_size,
                               shuffle=True)
     scheduler = ExpAnnealLR(optimizer, end_learning_rate, iterations)
     model_state = self.model.state_dict()
     model_device = next(self.model.parameters()).device
     self.model.train()
     for (itr, batch) in enumerate(batch_loader):
         loss = self.model.forward_loss(batch)
         optimizer.zero_grad()
         loss.backward()
         torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
         optimizer.step()
         scheduler.step(1)
         learning_rate = scheduler.get_lr()[0]
         loss_item = loss.item()
         if (itr == 0):
             best_loss = loss_item
         else:
             if (smoothing_factor > 0):
                 moving_avg_loss = ((smoothing_factor * moving_avg_loss) +
                                    ((1 - smoothing_factor) * loss_item))
                 loss_item = (moving_avg_loss /
                              (1 - (smoothing_factor**(itr + 1))))
             if (loss_item < best_loss):
                 best_loss = loss
         if (stop_early and ((loss_item >
                              (4 * best_loss)) or torch.isnan(loss))):
             log_line(log)
             log.info('loss diverged - stopping early!')
             break
         if (itr > iterations):
             break
         with open(str(learning_rate_tsv), 'a') as f:
             f.write(''.join([
                 '{}'.format(itr), '\t',
                 '{:%H:%M:%S}'.format(datetime.datetime.now()), '\t',
                 '{}'.format(learning_rate), '\t', '{}'.format(loss_item),
                 '\n'
             ]))
     self.model.load_state_dict(model_state)
     self.model.to(model_device)
     log_line(log)
     log.info(''.join([
         'learning rate finder finished - plot ',
         '{}'.format(learning_rate_tsv)
     ]))
     log_line(log)
     return Path(learning_rate_tsv)
예제 #2
0
    def find_learning_rate(
        self,
        base_path: Union[Path, str],
        file_name: str = "learning_rate.tsv",
        start_learning_rate: float = 1e-7,
        end_learning_rate: float = 10,
        iterations: int = 100,
        mini_batch_size: int = 32,
        stop_early: bool = True,
        smoothing_factor: float = 0.98,
        **kwargs,
    ) -> Path:
        best_loss = None
        moving_avg_loss = 0

        # cast string to Path
        if type(base_path) is str:
            base_path = Path(base_path)
        learning_rate_tsv = init_output_file(base_path, file_name)

        with open(learning_rate_tsv, "a") as f:
            f.write("ITERATION\tTIMESTAMP\tLEARNING_RATE\tTRAIN_LOSS\n")

        optimizer = self.optimizer(self.model.parameters(),
                                   lr=start_learning_rate,
                                   **kwargs)

        train_data = self.corpus.train

        scheduler = ExpAnnealLR(optimizer, end_learning_rate, iterations)

        model_state = self.model.state_dict()
        self.model.train()

        step = 0
        while step < iterations:
            batch_loader = DataLoader(train_data,
                                      batch_size=mini_batch_size,
                                      shuffle=True)
            for batch in batch_loader:
                step += 1

                # forward pass
                loss = self.model.forward_loss(batch)

                # update optimizer and scheduler
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
                optimizer.step()
                scheduler.step(step)

                print(scheduler.get_lr())
                learning_rate = scheduler.get_lr()[0]

                loss_item = loss.item()
                if step == 1:
                    best_loss = loss_item
                else:
                    if smoothing_factor > 0:
                        moving_avg_loss = (smoothing_factor * moving_avg_loss +
                                           (1 - smoothing_factor) * loss_item)
                        loss_item = moving_avg_loss / (1 - smoothing_factor**
                                                       (step + 1))
                    if loss_item < best_loss:
                        best_loss = loss

                if step > iterations:
                    break

                if stop_early and (loss_item > 4 * best_loss
                                   or torch.isnan(loss)):
                    log_line(log)
                    log.info("loss diverged - stopping early!")
                    step = iterations
                    break

                with open(str(learning_rate_tsv), "a") as f:
                    f.write(
                        f"{step}\t{datetime.datetime.now():%H:%M:%S}\t{learning_rate}\t{loss_item}\n"
                    )

            self.model.load_state_dict(model_state)
            self.model.to(flair.device)

        log_line(log)
        log.info(f"learning rate finder finished - plot {learning_rate_tsv}")
        log_line(log)

        return Path(learning_rate_tsv)
예제 #3
0
    def find_learning_rate(
            self,
            base_path: Union[Path, str],
            optimizer,
            mini_batch_size: int = 32,
            start_learning_rate: float = 1e-7,
            end_learning_rate: float = 10,
            iterations: int = 1000,
            stop_early: bool = True,
            file_name: str = "learning_rate.tsv",
            **kwargs,
    ) -> Path:
        best_loss = None

        # cast string to Path
        if type(base_path) is str:
            base_path = Path(base_path)
        base_path.mkdir(exist_ok=True, parents=True)
        learning_rate_tsv = init_output_file(base_path, file_name)

        with open(learning_rate_tsv, "a") as f:
            f.write("ITERATION\tTIMESTAMP\tLEARNING_RATE\tTRAIN_LOSS\n")

        optimizer = optimizer(self.model.parameters(), lr=start_learning_rate, **kwargs)

        train_data = self.corpus.train

        scheduler = ExpAnnealLR(optimizer, end_learning_rate, iterations)

        model_state = self.model.state_dict()
        self.model.train()

        step = 0

        loss_list = []
        average_loss_list = []

        while step < iterations:

            batch_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True)

            for batch in batch_loader:
                step += 1

                # forward pass
                loss = self.model.forward_loss(batch)
                if isinstance(loss, Tuple):
                    loss = loss[0]

                # update optimizer and scheduler
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
                optimizer.step()
                scheduler.step()

                learning_rate = scheduler.get_lr()[0]

                # append current loss to list of losses for all iterations
                loss_list.append(loss.item())

                # compute averaged loss
                import statistics
                moving_avg_loss = statistics.mean(loss_list)
                average_loss_list.append(moving_avg_loss)

                if len(average_loss_list) > 10:
                    drop = average_loss_list[-10] - moving_avg_loss
                else:
                    drop = 0.

                if not best_loss or moving_avg_loss < best_loss:
                    best_loss = moving_avg_loss

                if step > iterations:
                    break

                if stop_early and (moving_avg_loss > 4 * best_loss or torch.isnan(loss)):
                    log_line(log)
                    log.info("loss diverged - stopping early!")
                    step = iterations
                    break

                with open(str(learning_rate_tsv), "a") as f:
                    f.write(f"{step}\t{learning_rate}\t{loss.item()}\t{moving_avg_loss}\t{drop}\n")

            self.model.load_state_dict(model_state)
            self.model.to(flair.device)

        log_line(log)
        log.info(f"learning rate finder finished - plot {learning_rate_tsv}")
        log_line(log)

        return Path(learning_rate_tsv)