Esempio n. 1
0
class CNN_skipCo_trainer(object):

    def __init__(self, image_type, batch_size, log_period, epochs, data_type, train_ratio,
                 process_raw_data, pro_and_augm_only_image_type, do_heavy_augment,do_augment,
                 add_augment, do_rchannels,do_flip, do_blur, do_deform, do_crop,do_speckle_noise,
                 trunc_points, get_scale_center, single_sample,do_scale_center,

                 height_channel_oa, use_regressed_oa, include_regression_error, add_f_test,
                 only_f_test_in_target, channel_slice_oa, process_all_raw_folders,
                 conv_channels,kernels, model_name, input_size,output_channels, drop_probs,
                 di_conv_channels, dilations, learning_rates, hetero_mask_to_mask,hyper_no, input_ds_mask,
                 input_ss_mask, ds_mask_channels, add_skip):

        self.image_type = image_type

        self.batch_size = batch_size
        self.log_period = log_period
        self.epochs = epochs

        self.dataset = ProcessData(data_type=data_type, train_ratio=train_ratio, process_raw_data=process_raw_data,
                                   pro_and_augm_only_image_type=pro_and_augm_only_image_type,
                                   do_heavy_augment=do_heavy_augment,
                                   do_augment=do_augment, add_augment=add_augment, do_rchannels=do_rchannels,
                                   do_flip=do_flip, do_blur=do_blur, do_deform=do_deform, do_crop=do_crop,
                                   do_speckle_noise=do_speckle_noise,trunc_points=trunc_points,
                                   image_type=image_type, get_scale_center=get_scale_center,
                                   single_sample=single_sample,
                                   do_scale_center=do_scale_center,
                                   height_channel_oa=height_channel_oa, use_regressed_oa=use_regressed_oa,
                                   include_regression_error=include_regression_error,
                                   add_f_test=add_f_test, only_f_test_in_target=only_f_test_in_target,
                                   channel_slice_oa=channel_slice_oa,
                                   process_all_raw_folders=process_all_raw_folders,
                                   hetero_mask_to_mask=hetero_mask_to_mask)

        self.model_convdeconv = ConvDeconv(conv_channels=conv_channels,
                                           kernels=kernels,
                                           model_name=model_name, input_size=input_size,
                                           output_channels=output_channels, drop_probs=drop_probs,
                                           add_skip=add_skip)


        self.model_dilated = DilatedTranslator(conv_channels=di_conv_channels, dilations=dilations)


        self.model = ImageTranslator([self.model_dilated])

        if torch.cuda.is_available():
            torch.cuda.current_device()
            self.model.cuda()

        self.learning_rates = learning_rates

        self.logger = Logger(model=self.model, project_root_dir=self.dataset.project_root_dir,
                             image_type=self.image_type, dataset=self.dataset, batch_size=self.batch_size,
                             epochs=self.epochs,learning_rates=self.learning_rates,hyper_no=hyper_no)

    def fit(self, learning_rate, lr_method='standard'):
        # get scale and center parameters
        scale_params_low, scale_params_high = self.dataset.load_params(param_type="scale_params")
        mean_image_low, mean_image_high = self.dataset.load_params(param_type="mean_images")

        # load validation set, normalize and parse into tensor
        input_tensor_val, target_tensor_val = self.dataset.scale_and_parse_to_tensor(
            batch_files=self.dataset.val_file_names,
            scale_params_low=scale_params_low,
            scale_params_high=scale_params_high,
            mean_image_low=mean_image_low,
            mean_image_high=mean_image_high)

        if torch.cuda.is_available():
            input_tensor_val = input_tensor_val.cuda()
            target_tensor_val = target_tensor_val.cuda()

        # activate optimizer with the base learning rate
        self.model.activate_optimizer(learning_rate)
        # now calculate the learning rates list
        self.learning_rates = self.get_learning_rate(learning_rate, self.epochs, lr_method)

        for e in range(0, self.epochs):
            # setting the learning rate each epoch
            lr = self.learning_rates[e]
            self.model.set_learning_rate(lr)

            # separate names into random batches and shuffle every epoch
            self.dataset.batch_names(batch_size=self.batch_size)

            # in self.batch_number is the number of batches in the training set
            # go through all the batches
            for i in range(self.dataset.batch_number):
                input_tensor, target_tensor = self.dataset.scale_and_parse_to_tensor(
                    batch_files=self.dataset.train_batch_chunks[i],
                    scale_params_low=scale_params_low,
                    scale_params_high=scale_params_high,
                    mean_image_low=mean_image_low,
                    mean_image_high=mean_image_high)

                if torch.cuda.is_available():
                    input_tensor = input_tensor.cuda()
                    target_tensor = target_tensor.cuda()

                self.model.train_model(input_tensor, target_tensor, current_epoch=e)

            # calculate the validation loss and add to validation history
            self.logger.get_val_loss(val_in=input_tensor_val, val_target=target_tensor_val)

            # save model every x epochs
            if e % self.log_period == 0 or e == self.epochs - 1:
                self.logger.log(save_appendix='_epoch_' + str(e),
                                current_epoch=e,
                                epochs=self.epochs,
                                mean_images=[mean_image_low, mean_image_high],
                                scale_params=[scale_params_low, scale_params_high],
                                learning_rates=self.learning_rates)

    def find_lr(self, init_value=1e-8, final_value=10., beta=0.98):
        """
        learning rate finder. goes through multiple learning rates and does one forward
        pass with each and tracks the loss.
        it returns the learning rates and losses so one can make an image of the loss from which one can find
        a suitable learning rate. Pick the highest one on which the loss is still decreasing (so not the minimum)
        :param train_in: training data input
        :param target: training data target
        :param init_value: inital value which the learning rate start with. init_value < final_value. Default: 1e-8
        :param final_value: last value of the learning rate. Default: 10.
        :param beta: used for smoothing the loss. Default: 0.98
        :return: log_learning_rates, losses
        """

        # get scale and center parameters
        scale_params_low, scale_params_high = self.dataset.load_params(param_type="scale_params")
        mean_image_low, mean_image_high = self.dataset.load_params(param_type="mean_images")

        import math
        self.dataset.batch_names(batch_size=2)
        print('number of files: ', len(self.dataset.train_file_names))
        num = self.dataset.batch_number-1
        mult = (final_value / init_value) ** (1 / num)
        lr = init_value
        self.model.optimizer.param_groups[0]['lr'] = lr
        avg_loss = 0.; best_loss = 0.; batch_num = 0; losses = []; log_lrs = []
        # in self.batch_number is the number of batches in the training set
        print('batch numbers: ', self.dataset.batch_number)

        for i in range(self.dataset.batch_number):
            sys.stdout.write('\r' + 'current iteration : ' + str(i))

            input_tensor, target_tensor = self.dataset.scale_and_parse_to_tensor(
                batch_files=self.dataset.val_file_names,
                scale_params_low=scale_params_low,
                scale_params_high=scale_params_high,
                mean_image_low=mean_image_low,
                mean_image_high=mean_image_high)

            if torch.cuda.is_available():
                input_tensor = input_tensor.cuda()
                target_tensor = target_tensor.cuda()

            batch_num += 1
            # As before, get the loss for this mini-batch of inputs/outputs
            self.model.optimizer.zero_grad()
            print('lr: ', lr)
            outputs = self.model.forward(input_tensor)
            loss = self.model.criterion(outputs, target_tensor)

            # Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            smoothed_loss = avg_loss / (1 - beta ** batch_num)
            # Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > 4 * best_loss:
                print('loss exploding')
                return log_lrs, losses
            # Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            # Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            # Do the SGD step
            loss.backward()
            self.model.optimizer.step()
            # Update the lr for the next step
            lr *= mult

            self.model.optimizer.param_groups[0]['lr'] = lr
        # Plot the results
        '''
        lrs = 10 ** np.array(log_lrs)
        fig, ax = plt.subplots(1)
        ax.plot(lrs, losses)
        ax.set_xscale('log')
        #ax.set_xlim((1e-8, 100))
        ax.figure.show()
        ax.figure.savefig('learning_rate_finder.png')
        '''
        return log_lrs, losses

    def get_learning_rate(self, learning_rate, epochs, method):
        """
        Method creating the learning rates corresponding to the corresponding adaptive-method.
        :param learning_rate: base learning rate. Used as max rate for one_cycle and cosine_annealing
        :param epochs: number of epochs the model is trained
        :param method: adaptive-method which should be used: standard, one_cycle, cosine_annealing
        :return: learning_rates as a list
        """
        lrs = []

        if method == 'standard' or method is None:
            lrs = [learning_rate for i in range(epochs)]

        elif method == 'one_cycle':
            higher_rate = learning_rate
            lower_rate = 1 / 10 * higher_rate

            ann_frac = min(50, int(epochs/3))

            up_num = int((epochs - ann_frac) / 2)
            down_num = up_num
            ann_num = epochs - up_num - down_num

            lr_up = np.linspace(lower_rate, higher_rate, num=up_num)
            lr_down = np.linspace(higher_rate, lower_rate, num=down_num)
            lr_anihilating = np.linspace(lower_rate, 0, num=ann_num)

            lrs = np.append(np.append(lr_up, lr_down), lr_anihilating)

        elif method=='cosine_annealing':
            pass

        return lrs
Esempio n. 2
0
class CNN_skipCo_trainer(object):
    def __init__(self):

        self.image_type = 'OA'

        self.dataset = ProcessData(train_ratio=0.9,
                                   process_raw_data=False,
                                   do_augment=False,
                                   add_augment=True,
                                   do_flip=True,
                                   do_blur=True,
                                   do_deform=True,
                                   do_crop=True,
                                   image_type=self.image_type,
                                   get_scale_center=False,
                                   single_sample=True)

        self.model = cnn_toy_model.cnn_skipC_model(criterion=nn.MSELoss(),
                                                   optimizer=torch.optim.Adam,
                                                   learning_rate=0.001,
                                                   weight_decay=0)

        if torch.cuda.is_available():
            torch.cuda.current_device()
            self.model.cuda()

        self.logger = Logger(model=self.model,
                             project_root_dir=self.dataset.project_root_dir,
                             image_type=self.image_type)
        self.epochs = 2

    def fit(self):
        # get scale and center parameters
        scale_params_low, scale_params_high = self.dataset.load_params(
            param_type="scale_params")
        mean_image_low, mean_image_high = self.dataset.load_params(
            param_type="mean_images")

        # currently for one image:
        '''
        self.dataset.batch_names(batch_size=5)
        X, Y = self.dataset.create_train_batches(self.dataset.train_batch_chunks[1])
        print(X.shape)
        X = X[0,:,:]
        Y = Y[0,:,:]
        '''
        # load validation set, normalize and parse into tensor

        input_tensor_val, target_tensor_val = self.dataset.scale_and_parse_to_tensor(
            batch_files=self.dataset.val_file_names,
            scale_params_low=scale_params_low,
            scale_params_high=scale_params_high,
            mean_image_low=mean_image_low,
            mean_image_high=mean_image_high)

        if torch.cuda.is_available():
            input_tensor_val = input_tensor_val.cuda()
            target_tensor_val = target_tensor_val.cuda()

        for e in range(0, self.epochs):
            # separate names into random batches and shuffle every epoch
            self.dataset.batch_names(batch_size=32)
            # in self.batch_number is the number of batches in the training set
            for i in range(self.dataset.batch_number):
                input_tensor, target_tensor = self.dataset.scale_and_parse_to_tensor(
                    batch_files=self.dataset.train_batch_chunks[i],
                    scale_params_low=scale_params_low,
                    scale_params_high=scale_params_high,
                    mean_image_low=mean_image_low,
                    mean_image_high=mean_image_high)

                if torch.cuda.is_available():

                    input_tensor = input_tensor.cuda()
                    target_tensor = target_tensor.cuda()

                self.model.train_model(input_tensor,
                                       target_tensor,
                                       current_epoch=e)

            # calculate the validation loss and add to validation history
            self.logger.get_val_loss(val_in=input_tensor_val,
                                     val_target=target_tensor_val)
            # save model every x epochs
            if e % 25 == 0 or e == self.epochs - 1:
                self.logger.log(
                    save_appendix='_epoch_' + str(e),
                    current_epoch=e,
                    epochs=self.epochs,
                    mean_images=[mean_image_low, mean_image_high],
                    scale_params=[scale_params_low, scale_params_high])

                # how to undo the scaling:
                # unscaled_X = utils.scale_and_center_reverse(scale_center_X,
                #  scale_params_low, mean_image_low, image_type = self.dataset.image_type)
                # unscaled_Y = utils.scale_and_center_reverse(scale_center_Y, scale_params_high,
                #  mean_image_high, image_type=self.dataset.image_type)
    def predict(self):
        # self.model.predict()

        # see self.dataset.X_val and self.dataset.Y_val
        pass
Esempio n. 3
0
class CNN_skipCo_trainer(object):
    def __init__(self):

        self.dataset = ProcessData(train_ratio=0.3,
                                   process_raw_data=False,
                                   do_augment=False,
                                   image_type='US',
                                   get_scale_center=False,
                                   single_sample=True)

        self.model = cnn_skipC_model.cnn_skipC_model(
            criterion=nn.MSELoss(),
            optimizer=torch.optim.Adam,
            learning_rate=0.01,
            weight_decay=0)

        self.model_2 = awesomeImageTranslator1000.AwesomeImageTranslator1000(
            criterion=nn.MSELoss(),
            optimizer=torch.optim.Adam,
            learning_rate=0.01,
            weight_decay=0)

        self.logger = Logger()

    def fit(self, epochs=10):
        # get scale and center parameters
        scale_params_low, scale_params_high = utils.load_params(
            image_type=self.dataset.image_type, param_type="scale_params")
        mean_image_low, mean_image_high = utils.load_params(
            image_type=self.dataset.image_type, param_type="mean_images")

        # currently for one image:
        '''
        self.dataset.batch_names(batch_size=5)
        X, Y = self.dataset.create_train_batches(self.dataset.train_batch_chunks[1])
        print(X.shape)
        X = X[0,:,:]
        Y = Y[0,:,:]
        '''
        for e in range(0, epochs):
            # separate names into random batches and shuffle every epoch
            self.dataset.batch_names(batch_size=5)
            # in self.batch_number is the number of batches in the training set
            for i in range(self.dataset.batch_number):
                X, Y = self.dataset.create_train_batches(
                    self.dataset.train_batch_chunks[i])

                # scale and center the batch
                scale_center_X = utils.scale_and_center(
                    X,
                    scale_params_low,
                    mean_image_low,
                    image_type=self.dataset.image_type)
                scale_center_Y = utils.scale_and_center(
                    Y,
                    scale_params_high,
                    mean_image_high,
                    image_type=self.dataset.image_type)

                scale_center_X = np.array([scale_center_X])
                scale_center_Y = np.array([scale_center_Y])

                # (C, N, H, W) to (N, C, H, W)
                scale_center_X = scale_center_X.reshape(
                    scale_center_X.shape[1], scale_center_X.shape[0],
                    scale_center_X.shape[2], scale_center_X.shape[3])
                scale_center_Y = scale_center_Y.reshape(
                    scale_center_Y.shape[1], scale_center_Y.shape[0],
                    scale_center_Y.shape[2], scale_center_Y.shape[3])

                input_tensor, target_tensor = torch.from_numpy(
                    scale_center_X), torch.from_numpy(scale_center_Y)

                if torch.cuda.is_available():
                    #print('CUDA available')
                    #print('current device ' + str(cur_dev))
                    #print('device count ' + str(torch.cuda.device_count()))
                    #print('device name ' + torch.cuda.get_device_name(cur_dev))

                    cur_dev = torch.cuda.current_device()
                    input_tensor.cuda()
                    target_tensor.cuda()

                #self.model.train_model(input_tensor, target_tensor, current_epoch=e)
                self.model.train_model(input_tensor,
                                       target_tensor,
                                       current_epoch=e)

                # save model every x epochs
                if e % 5 == 0:
                    self.logger.save_model(self.model,
                                           model_name=self.model.model_name +
                                           '_' + str(datetime.datetime.now()) +
                                           '_epoch_' + str(e))
                    self.logger.save_loss(self.model.model_name,
                                          self.model.train_loss,
                                          self.model.test_loss)

                    # save model every 50 epochs
                    #if e % 50 == 0:
                    #    self.logger.save_model(self.model, model_name=self.model.model_name + '_epoch_' + str(e))

                if e == 0:
                    self.logger.save_scale_center_params(
                        model_name=self.model.model_name,
                        mean_images=[mean_image_low, mean_image_high],
                        scale_params=[scale_params_low, scale_params_high])
                    self.logger.save_representation_of_model(
                        self.model.model_name, str(self.model))

                ## how to undo the scaling:
                #unscaled_X = utils.scale_and_center_reverse(scale_center_X, scale_params_low, mean_image_low, image_type = self.dataset.image_type)
                #unscaled_Y = utils.scale_and_center_reverse(scale_center_Y, scale_params_high, mean_image_high, image_type=self.dataset.image_type)

    def predict(self):
        #self.model.predict()

        # see self.dataset.X_val and self.dataset.Y_val
        pass

    def log_model(self, model_name=None):
        self.logger.log(self.model,
                        model_name=model_name,
                        train_loss=self.model.train_loss,
                        model_structure=str(self.model))
Esempio n. 4
0
class CNN_skipCo_trainer(object):
    def __init__(self):

        self.image_type = 'US'

        self.dataset = ProcessData(train_ratio=0.9,
                                   process_raw_data=False,
                                   do_augment=False,
                                   add_augment=False,
                                   do_flip=True,
                                   do_blur=True,
                                   do_deform=True,
                                   do_crop=True,
                                   image_type=self.image_type,
                                   get_scale_center=False,
                                   single_sample=True)

        self.model = deep_model.deep_model(criterion=nn.MSELoss(),
                                           optimizer=torch.optim.Adam,
                                           learning_rate=0.001,
                                           weight_decay=0)

        if torch.cuda.is_available():
            torch.cuda.current_device()
            self.model.cuda()

        self.logger = Logger(model=self.model,
                             project_root_dir=self.dataset.project_root_dir,
                             image_type=self.image_type)
        self.epochs = 250

    def fit(self, learning_rate, use_one_cycle=False):
        # get scale and center parameters
        scale_params_low, scale_params_high = self.dataset.load_params(
            param_type="scale_params")
        mean_image_low, mean_image_high = self.dataset.load_params(
            param_type="mean_images")

        # currently for one image:
        '''
        self.dataset.batch_names(batch_size=5)
        X, Y = self.dataset.create_train_batches(self.dataset.train_batch_chunks[1])
        print(X.shape)
        X = X[0,:,:]
        Y = Y[0,:,:]
        '''
        # load validation set, normalize and parse into tensor

        input_tensor_val, target_tensor_val = self.dataset.scale_and_parse_to_tensor(
            batch_files=self.dataset.val_file_names,
            scale_params_low=scale_params_low,
            scale_params_high=scale_params_high,
            mean_image_low=mean_image_low,
            mean_image_high=mean_image_high)

        if torch.cuda.is_available():
            input_tensor_val = input_tensor_val.cuda()
            target_tensor_val = target_tensor_val.cuda()

        if use_one_cycle:
            higher_rate = learning_rate
            lower_rate = 1 / 10 * higher_rate

            num_epochs = self.epochs
            up_num = int((num_epochs - 50) / 2)
            down_num = up_num
            ann_num = num_epochs - up_num - down_num
            lr_up = np.linspace(lower_rate, higher_rate, num=up_num)
            lr_down = np.linspace(higher_rate, lower_rate, num=down_num)
            lr_anihilating = np.linspace(lower_rate, 0, num=ann_num)
            learning_rates = np.append(np.append(lr_up, lr_down),
                                       lr_anihilating)

        else:
            self.model.set_learning_rate(learning_rate)

        for e in range(0, self.epochs):
            if use_one_cycle:
                lr = learning_rates[e]
                self.model.set_learning_rate(lr)
            # separate names into random batches and shuffle every epoch
            self.dataset.batch_names(batch_size=32)

            # in self.batch_number is the number of batches in the training set
            for i in range(self.dataset.batch_number):
                input_tensor, target_tensor = self.dataset.scale_and_parse_to_tensor(
                    batch_files=self.dataset.train_batch_chunks[i],
                    scale_params_low=scale_params_low,
                    scale_params_high=scale_params_high,
                    mean_image_low=mean_image_low,
                    mean_image_high=mean_image_high)

                if torch.cuda.is_available():

                    input_tensor = input_tensor.cuda()
                    target_tensor = target_tensor.cuda()

                self.model.train_model(input_tensor,
                                       target_tensor,
                                       current_epoch=e)

            # calculate the validation loss and add to validation history
            self.logger.get_val_loss(val_in=input_tensor_val,
                                     val_target=target_tensor_val)
            # save model every x epochs
            if e % 25 == 0 or e == self.epochs - 1:
                self.logger.log(
                    save_appendix='_epoch_' + str(e),
                    current_epoch=e,
                    epochs=self.epochs,
                    mean_images=[mean_image_low, mean_image_high],
                    scale_params=[scale_params_low, scale_params_high])

                # how to undo the scaling:
                # unscaled_X = utils.scale_and_center_reverse(scale_center_X,
                #  scale_params_low, mean_image_low, image_type = self.dataset.image_type)
                # unscaled_Y = utils.scale_and_center_reverse(scale_center_Y, scale_params_high,
                #  mean_image_high, image_type=self.dataset.image_type)
    def predict(self):
        # self.model.predict()

        # see self.dataset.X_val and self.dataset.Y_val
        pass

    def find_lr(self, init_value=1e-8, final_value=10., beta=0.98):
        """
        learning rate finder. goes through multiple learning rates and does one forward pass with each and tracks the loss.
        it returns the learning rates and losses so one can make an image of the loss from which one can find a suitable learning rate. Pick the highest one on which the loss is still decreasing (so not the minimum)
        :param train_in: training data input
        :param target: training data target
        :param init_value: inital value which the learning rate start with. init_value < final_value. Default: 1e-8
        :param final_value: last value of the learning rate. Default: 10.
        :param beta: used for smoothing the loss. Default: 0.98
        :return: log_learning_rates, losses
        """

        # get scale and center parameters
        scale_params_low, scale_params_high = self.dataset.load_params(
            param_type="scale_params")
        mean_image_low, mean_image_high = self.dataset.load_params(
            param_type="mean_images")

        import math
        print(len(self.dataset.train_file_names))
        num = len(self.dataset.train_file_names) - 1
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        self.model.optimizer.param_groups[0]['lr'] = lr
        avg_loss = 0.
        best_loss = 0.
        batch_num = 0
        losses = []
        log_lrs = []
        #for i in range(1, train_in.shape[0]):
        self.dataset.batch_names(batch_size=2)
        # in self.batch_number is the number of batches in the training set
        print(self.dataset.batch_number)
        for i in range(self.dataset.batch_number):
            sys.stdout.write('\r  current iteration : ' + str(i))

            input_tensor, target_tensor = self.dataset.scale_and_parse_to_tensor(
                batch_files=self.dataset.val_file_names,
                scale_params_low=scale_params_low,
                scale_params_high=scale_params_high,
                mean_image_low=mean_image_low,
                mean_image_high=mean_image_high)

            if torch.cuda.is_available():
                input_tensor = input_tensor.cuda()
                target_tensor = target_tensor.cuda()

            batch_num += 1
            # As before, get the loss for this mini-batch of inputs/outputs
            self.model.optimizer.zero_grad()
            outputs = self.model.forward(input_tensor)
            loss = self.model.criterion(outputs, target_tensor)
            # Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            # Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > 4 * best_loss:
                print('loss exploding')
                return log_lrs, losses
            # Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            # Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            # Do the SGD step
            loss.backward()
            self.model.optimizer.step()
            # Update the lr for the next step
            lr *= mult
            self.model.optimizer.param_groups[0]['lr'] = lr

        # Plot the results
        '''
        lrs = 10 ** np.array(log_lrs)
        fig, ax = plt.subplots(1)
        ax.plot(lrs, losses)
        ax.set_xscale('log')
        ax.set_xlim((1e-8, 1))
        ax.figure.show()
        ax.figure.savefig('learning_rate_finder.png')
        '''
        return log_lrs, losses