Ejemplo n.º 1
0
class BirdsBasicTrainerCV:
    '''
    classdocs
    '''
    # Number of intermediate models to save
    # during training:

    MODEL_ARCHIVE_SIZE = 20

    # For some tensorboard displays:
    # for how many epochs in the past
    # to display data:

    DISPLAY_HISTORY_LEN = 10

    #------------------------------------
    # Constructor
    #-------------------

    def __init__(self,
                 config_info,
                 device=0,
                 percentage=None,
                 debugging=False):
        '''
        
        :param config_info: all path and training parameters
        :type config_info: NeuralNetConfig
        :param debugging: output lots of debug info
        :type debugging: bool
        :param device: number of GPU to use; default is dev 0
            if any GPU is available
        :type device: {None | int}
        :param percentage: percentage of training data to 
            use
        :type percentage: {int | float}
        '''

        self.log = LoggingService()
        if debugging:
            self.log.logging_level = DEBUG

        if percentage is not None:
            # Integrity check:
            if type(percentage) not in [int, float]:
                raise TypeError(
                    f"Percentage must be int or float, not {type(percentage)}")
            if percentage < 1 or percentage > 100:
                raise ValueError(
                    f"Percentage must be between 1 and 100, not {percentage}")

        if device is None:
            device = 0
            torch.cuda.set_device(device)
        else:
            available_gpus = torch.cuda.device_count()
            if available_gpus == 0:
                self.log.info("No GPU available; running on CPU")
            else:
                if device > available_gpus - 1:
                    raise ValueError(
                        f"Asked to operate on device {device}, but only {available_gpus} are available"
                    )
                torch.cuda.set_device(device)

        self.curr_dir = os.path.dirname(os.path.abspath(__file__))

        try:
            self.config = self.initialize_config_struct(config_info)
        except Exception as e:
            msg = f"During config init: {repr(e)}"
            self.log.err(msg)
            raise RuntimeError(msg) from e

        try:
            self.root_train_test_data = self.config.getpath(
                'Paths', 'root_train_test_data', relative_to=self.curr_dir)
        except ValueError as e:
            raise ValueError(
                "Config file must contain an entry 'root_train_test_data' in section 'Paths'"
            ) from e

        self.batch_size = self.config.getint('Training', 'batch_size')
        self.kernel_size = self.config.getint('Training', 'kernel_size')
        self.min_epochs = self.config.Training.getint('min_epochs')
        self.max_epochs = self.config.Training.getint('max_epochs')
        self.lr = self.config.Training.getfloat('lr')
        self.net_name = self.config.Training.net_name
        self.pretrained = self.config.Training.getboolean('pretrained', False)
        self.num_folds = self.config.Training.getint('num_folds')
        self.freeze = self.config.Training.getint('freeze', 0)
        self.to_grayscale = self.config.Training.getboolean(
            'to_grayscale', True)

        self.set_seed(42)

        self.log.info("Parameter summary:")
        self.log.info(f"network     {self.net_name}")
        self.log.info(f"pretrained  {self.pretrained}")
        if self.pretrained:
            self.log.info(f"freeze      {self.freeze}")
        self.log.info(f"min epochs  {self.min_epochs}")
        self.log.info(f"max epochs  {self.max_epochs}")
        self.log.info(f"batch_size  {self.batch_size}")

        self.fastest_device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.device = self.fastest_device
        self.num_classes = self.find_num_classes(self.root_train_test_data)

        self.initialize_model()

        sample_width = self.config.getint('Training', 'sample_width', 400)
        sample_height = self.config.getint('Training', 'sample_height', 400)

        self.train_loader = self.get_dataloader(sample_width,
                                                sample_height,
                                                perc_data_to_use=percentage)
        self.log.info(f"Expecting {len(self.train_loader)} batches per epoch")
        num_train_samples = len(self.train_loader.dataset)
        num_classes = len(self.train_loader.dataset.class_names())
        self.log.info(
            f"Training set contains {num_train_samples} samples across {num_classes} classes"
        )

        self.class_names = self.train_loader.dataset.class_names()

        log_dir = os.path.join(self.curr_dir, 'runs')
        raw_data_dir = os.path.join(self.curr_dir, 'runs_raw_results')

        self.setup_tensorboard(log_dir, raw_data_dir=raw_data_dir)

        # Log a few example spectrograms to tensorboard;
        # one per class:
        TensorBoardPlotter.write_img_grid(
            self.writer,
            self.root_train_test_data,
            len(self.class_names),  # Num of train examples
        )

        # All ResultTally instances are
        # collected here: (num_folds * num-epochs)
        # each for training and validation steps.

        self.step_results = ResultCollection()

        self.log.debug(
            f"Just before train: \n{'none--on CPU' if self.fastest_device.type == 'cpu' else torch.cuda.memory_summary()}"
        )
        try:
            final_step = self.train()
            self.visualize_final_epoch_results(final_step)
        finally:
            self.close_tensorboard()

    #------------------------------------
    # train
    #-------------------

    def train(self):

        overall_start_time = datetime.datetime.now()
        # Just for sanity: keep track
        # of number of batches...
        total_batch_num = 0

        # Note: since we are cross validating, the
        # data loader's set_epoch() method is only
        # called once (automatically) during instantiation
        # of the associated sampler. Moving from split
        # to split includes shuffling if the caller
        # specified that.

        # Training
        for split_num in range(self.train_loader.num_folds):

            split_start_time = datetime.datetime.now()
            self.initialize_model()
            for epoch in range(self.max_epochs):

                # Set model to train mode:
                self.model.train()

                epoch_start_time = datetime.datetime.now()

                self.log.info(f"Starting epoch {epoch} training")

                # Sanity check record: will record
                # how many samples from each class were
                # used:
                self.class_coverage = {}

                # Sanity records: will record number
                # of samples of each class that are used
                # during training and validation:
                label_distrib = {}
                batch_num = 0

                self.log.info(
                    f"Train epoch {epoch}/{self.max_epochs} split {split_num}/{self.train_loader.num_folds}"
                )
                try:
                    for batch, targets in self.train_loader:
                        # Update the sanity check
                        # num of batches seen, and distribution
                        # of samples across classes:
                        batch_num += 1
                        total_batch_num += 1

                        # Update sanity check records:
                        for lbl in targets:
                            lbl = int(lbl)
                            try:
                                label_distrib[lbl] += 1
                            except KeyError:
                                label_distrib[lbl] = 1
                            try:
                                self.class_coverage[lbl]['train'] += 1
                            except KeyError:
                                self.class_coverage[lbl] = {
                                    'train': 1,
                                    'val': 0
                                }

                        self.log.debug(
                            f"Top of training loop: \n{'none--on CPU' if self.fastest_device.type == 'cpu' else torch.cuda.memory_summary()}"
                        )

                        images = FileUtils.to_device(batch, 'gpu')
                        labels = FileUtils.to_device(targets, 'gpu')

                        outputs = self.model(images)
                        loss = self.loss_fn(outputs, labels)
                        self.optimizer.zero_grad()
                        loss.backward()
                        self.optimizer.step()

                        # Remember the last batch's train result of this
                        # split (results for earlier batches of
                        # the same split will be overwritten). This statement
                        # must sit before deleting output and labels:

                        step_num = self.step_number(epoch, split_num,
                                                    self.num_folds)
                        self.remember_results(LearningPhase.TRAINING, step_num,
                                              outputs, labels, loss)

                        self.log.debug(
                            f"Just before clearing gpu: \n{'none--on CPU' if self.fastest_device.type == 'cpu' else torch.cuda.memory_summary()}"
                        )

                        images = FileUtils.to_device(images, 'cpu')
                        outputs = FileUtils.to_device(outputs, 'cpu')
                        labels = FileUtils.to_device(labels, 'cpu')
                        loss = FileUtils.to_device(loss, 'cpu')

                        del images
                        del outputs
                        del labels
                        del loss
                        torch.cuda.empty_cache()

                        self.log.debug(
                            f"Just after clearing gpu: \n{'none--on CPU' if self.fastest_device.type == 'cpu' else torch.cuda.memory_summary()}"
                        )
                except EndOfSplit:

                    end_time = datetime.datetime.now()
                    train_time_duration = end_time - epoch_start_time
                    # A human readable duration st down to minutes:
                    duration_str = FileUtils.time_delta_str(
                        train_time_duration, granularity=4)

                    self.log.info(
                        f"Done training epoch {epoch} of split {split_num} (duration: {duration_str})"
                    )

                    #***********
                    #print(f"****** num_batches in split: {batch_num}" )
                    #print(f"****** LblDist: {label_distrib}")
                    #***********
                    self.validate_split(step_num)
                    self.visualize_step(step_num)
                    # Save model, keeping self.model_archive_size models:
                    self.model_archive.save_model(self.model, epoch)

                    self.log.debug(
                        f"After eval: \n{'none--on CPU' if self.fastest_device.type == 'cpu' else torch.cuda.memory_summary()}"
                    )

                    # Next Epoch
                    continue

            end_time = datetime.datetime.now()
            train_time_duration = end_time - split_start_time
            # A human readable duration st down to minutes:
            duration_str = FileUtils.time_delta_str(train_time_duration,
                                                    granularity=4)

            self.log.info(
                f"Done training split {split_num} (duration: {duration_str})")

            # Next split
            continue

        end_time = datetime.datetime.now()
        epoch_duration = end_time - epoch_start_time
        epoch_dur_str = FileUtils.time_delta_str(epoch_duration, granularity=4)

        cumulative_dur = end_time - overall_start_time
        cum_dur_str = FileUtils.time_delta_str(cumulative_dur, granularity=4)

        msg = f"Done epoch {epoch}  (epoch duration: {epoch_dur_str}; cumulative: {cum_dur_str})"
        self.log.info(msg)

        #******self.scheduler.step()

        # Fresh results tallying
        #self.results.clear()

        self.log.info(
            f"Training complete after {self.train_loader.num_folds} splits")

        # Report the sanity checks:
        self.log.info(f"Total batches processed: {total_batch_num}")
        for cid in self.class_coverage.keys():
            train_use, val_use = self.class_coverage[cid].items()
            self.log.info(
                f"{self.class_names[cid]} Training: {train_use}, Validation: {val_use}"
            )

        # All seems to have gone well. Report the
        # overall result of the final epoch for the
        # hparms config used in this process:

        self.report_hparams_summary(self.latest_result)

        # The final epoch number:
        return epoch

    #------------------------------------
    # validate_split
    #-------------------

    def validate_split(self, step):
        '''
        Validate one split, using that split's 
        validation fold. Return time taken. Record
        results for tensorboard and other record keeping.
        
        :param step: current combination of epoch and 
            split
        :type step: int
        :return: number of epoch seconds needed for the validation
        :rtype: int
        '''
        # Validation

        self.log.debug(
            f"Start of validation: \n{'none--on CPU' if self.fastest_device.type == 'cpu' else torch.cuda.memory_summary()}"
        )

        start_time = datetime.datetime.now()
        self.log.info(f"Starting validation for step {step}")

        self.model.eval()
        with torch.no_grad():
            for img_tensor, target in self.train_loader.validation_samples():
                expanded_img_tensor = unsqueeze(img_tensor, dim=0)
                expanded_target = unsqueeze(target, dim=0)

                # Update sanity record:
                self.class_coverage[int(target)]['val'] += 1

                images = FileUtils.to_device(expanded_img_tensor, 'gpu')
                label = FileUtils.to_device(expanded_target, 'gpu')

                outputs = self.model(images)
                loss = self.loss_fn(outputs, label)

                images = FileUtils.to_device(images, 'cpu')
                outputs = FileUtils.to_device(outputs, 'cpu')
                label = FileUtils.to_device(label, 'cpu')
                loss = FileUtils.to_device(loss, 'cpu')

                self.remember_results(LearningPhase.VALIDATING, step, outputs,
                                      label, loss)
                del images
                del outputs
                del label
                del loss
                torch.cuda.empty_cache()

        end_time = datetime.datetime.now()
        val_time_duration = end_time - start_time
        # A human readable duration st down to minues:
        duration_str = FileUtils.time_delta_str(val_time_duration,
                                                granularity=4)
        self.log.info(f"Done validation (duration: {duration_str})")

        return val_time_duration

    # ------------- Utils -----------

    #------------------------------------
    # report_acc_loss
    #-------------------

    def report_acc_loss(self, phase, epoch, accumulated_loss):

        self.writer.add_scalar(f"loss/{phase}", accumulated_loss, epoch)

    #------------------------------------
    # remember_results
    #-------------------

    def remember_results(
        self,
        phase,
        step,
        outputs,
        labels,
        loss,
    ):

        # Add the results
        tally = ResultTally(step, phase, outputs, labels, loss,
                            self.num_classes, self.batch_size)
        # Add result to intermediate results collection of
        # tallies:
        self.results[step] = tally

        # Same with the session-wide
        # collection:

        self.step_results.add(tally)

    #------------------------------------
    # visualize_step
    #-------------------

    def visualize_step(self, step):
        '''
        Take the ResultTally instances
        in the train and val ResultCollections
        in self.results, and report appropriate
        aggregates to tensorboard. Computes
        f1 scores, accuracies, etc. for given
        step.

        Separately for train and validation
        results: build one long array 
        of predictions, and a corresponding
        array of labels. Also, average the
        loss across all instances.
        
        The preds and labels as rows to csv 
        files.

        '''

        val_tally = self.results[(step, str(LearningPhase.VALIDATING))]
        train_tally = self.results[(step, str(LearningPhase.TRAINING))]

        result_coll = ResultCollection()
        result_coll.add(val_tally, step)
        result_coll.add(train_tally, step)

        self.latest_result = {'train': train_tally, 'val': val_tally}

        # If we are to write preds and labels to
        # .csv for later additional processing:

        if self.csv_writer is not None:
            self.csv_writer.writerow([
                step, train_tally.preds, train_tally.labels, val_tally.preds,
                val_tally.labels
            ])

        TensorBoardPlotter.visualize_step(
            result_coll, self.writer,
            [LearningPhase.TRAINING, LearningPhase.VALIDATING], step,
            self.class_names)
        # History of learning rate adjustments:
        lr_this_step = self.optimizer.param_groups[0]['lr']
        self.writer.add_scalar('learning_rate', lr_this_step, global_step=step)

    #------------------------------------
    # visualize_final_epoch_results
    #-------------------

    def visualize_final_epoch_results(self, epoch):
        '''
        Reports to tensorboard just for the
        final epoch.
 
        Expect self.latest_result to be the latest
        ResultTally.
        '''
        # DISPLAY_HISTORY_LEN holds the number
        # of historic epochs we will show. Two
        # results per epochs --> need
        # 2*DISPLAY_HISTORY_LEN results. But check
        # that there are that many, and show fewer
        # if needed:

        num_res_to_show = min(len(self.step_results),
                              2 * self.DISPLAY_HISTORY_LEN)

        f1_hist = self.step_results[-num_res_to_show:]

        # First: the table of train and val f1-macro
        # scores for the past few epochs:
        #
        #      |phase|ep0  |ep1 |ep2 |
        #      |-----|-----|----|----|
        #      |train| f1_0|f1_1|f1_2|
        #      |  val| f1_0|f1_1|f1_2|

        f1_macro_tbl = TensorBoardPlotter.make_f1_train_val_table(f1_hist)
        self.writer.add_text('f1/history', f1_macro_tbl)

        # Now, in the same tensorboard row: the
        # per_class train/val f1 scores for each
        # class separately:
        #
        # |class|weighted mean f1 train|weighted mean f1 val|
        # |-----|----------------------|--------------------|
        # |  c1 |0.1                   |0.6                 |
        # |  c2 |0.1                   |0.6                 |
        # |  c3 |0.1                   |0.6                 |
        # ------|----------------------|--------------------|

        f1_all_classes = TensorBoardPlotter.make_all_classes_f1_table(
            self.latest_result, self.class_names)
        self.writer.add_text('f1/per-class', f1_all_classes)

    #------------------------------------
    # report_hparams_summary
    #-------------------

    def report_hparams_summary(self, latest_result):
        '''
        Called at the end of training. Constructs
        a summary to report for the hyperparameters
        used in this process. Reports to the tensorboard.
         
        Hyperparameters reported:
         
           o lr
           o optimizer
           o batch_size
           o kernel_size
         
        Included in the measures are:
         
           o balanced_accuracy      (train and val)
           o mean_accuracy_train    (train and val)
           o epoch_prec_weighted
           o epoch_recall_weighted
           o epoch_mean_loss        (train and val)
           
         
        :param latest_result: dict with keys 'train' and
            'val', holding the respective most recent
            (i.e. last-epoch) ResultTally
        :type latest_result: {'train' : ResultTally,
                               'val'   : ResultTally
                               }
        '''

        # Get the latest validation tally:
        train_tally = latest_result['train']
        val_tally = latest_result['val']

        hparms_vals = OrderedDict({
            'net':
            self.net_name,
            'pretrained':
            f"{self.pretrained}",
            'lr_initial':
            self.config.Training.lr,
            'optimizer':
            self.config.Training.opt_name,
            'batch_size':
            self.config.getint('Training', 'batch_size'),
            'kernel_size':
            self.config.getint('Training', 'kernel_size'),
            'to_grayscale':
            self.to_grayscale
        })

        metric_results = {
            'zz_balanced_adj_acc_train': train_tally.balanced_acc,
            'zz_balanced_adj_acc_val': val_tally.balanced_acc,
            'zz_acc_train': train_tally.accuracy,
            'zz_acc_val': val_tally.accuracy,
            'zz_epoch_weighted_prec': val_tally.prec_weighted,
            'zz_epoch_weighted_recall': val_tally.recall_weighted,
            'zz_epoch_mean_loss_train': train_tally.mean_loss,
            'zz_epoch_mean_loss_val': val_tally.mean_loss
        }

        self.writer.add_hparams(hparms_vals, metric_results)

    #------------------------------------
    # get_dataloader
    #-------------------

    def get_dataloader(self,
                       sample_width,
                       sample_height,
                       perc_data_to_use=None):
        '''
        Returns a cross validating dataloader. 
        If perc_data_to_use is None, all samples
        under self.root_train_test_data will be
        used for training. Else percentage indicates
        the percentage of those samples to use. The
        selection is random.
        
        :param sample_width: pixel width of returned images
        :type sample_width: int
        :param sample_height: pixel height of returned images
        :type sample_height: int
        :param perc_data_to_use: amount of available training
            data to use.
        :type perc_data_to_use: {None | int | float}
        :return: a data loader that serves batches of
            images and their assiated labels
        :rtype: CrossValidatingDataLoader
        '''

        data_root = self.root_train_test_data

        train_dataset = SingleRootImageDataset(data_root,
                                               sample_width=sample_width,
                                               sample_height=sample_height,
                                               percentage=perc_data_to_use,
                                               to_grayscale=True)

        sampler = SKFSampler(train_dataset,
                             num_folds=self.num_folds,
                             seed=42,
                             shuffle=True,
                             drop_last=True)

        train_loader = CrossValidatingDataLoader(train_dataset,
                                                 batch_size=self.batch_size,
                                                 shuffle=True,
                                                 drop_last=True,
                                                 sampler=sampler,
                                                 num_folds=self.num_folds)
        return train_loader

    #------------------------------------
    # initialize_model
    #-------------------

    def initialize_model(self):
        self.model = NetUtils.get_net(self.net_name,
                                      num_classes=self.num_classes,
                                      pretrained=self.pretrained,
                                      freeze=self.freeze,
                                      to_grayscale=self.to_grayscale)
        self.log.debug(
            f"Before any gpu push: \n{'none--on CPU' if self.fastest_device.type == 'cpu' else torch.cuda.memory_summary()}"
        )

        FileUtils.to_device(self.model, 'gpu')

        self.log.debug(
            f"Before after model push: \n{'none--on CPU' if self.fastest_device.type == 'cpu' else torch.cuda.memory_summary()}"
        )

        self.opt_name = self.config.Training.get('optimizer',
                                                 'Adam')  # Default
        self.optimizer = self.get_optimizer(self.opt_name, self.model, self.lr)

        self.loss_fn = nn.CrossEntropyLoss()
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, self.min_epochs)

    #------------------------------------
    # find_num_classes
    #-------------------

    def find_num_classes(self, data_root):
        '''
        Expect two subdirectories under data_root:
        train and validation. Underneath each are 
        further subdirectories whose names are the
        classes:
        
                train               validation
        class1 class2 class3     class1 class2 class3
          imgs   imgs   imgs       imgs   imgs   imgs
        
        No error checking to confirm this structure
        
        :param data_root: path to parent of train/validation
        :type data_root: str
        :return: number of unique classes as obtained
            from the directory names
        :rtype: int
        '''
        self.classes = FileUtils.find_class_names(data_root)
        return len(self.classes)

    #------------------------------------
    # setup_tensorboard
    #-------------------

    def setup_tensorboard(self, logdir, raw_data_dir=True):
        '''
        Initialize tensorboard. To easily compare experiments,
        use runs/exp1, runs/exp2, etc.
        
        Method creates the dir if needed.
        
        Additionally, sets self.csv_pred_writer and self.csv_label_writer
        to None, or open CSV writers, depending on the value of raw_data_dir,
        see create_csv_writer()
        
        :param logdir: root for tensorboard events
        :type logdir: str
        '''

        if not os.path.isdir(logdir):
            os.makedirs(logdir)

        # For storing train/val preds/labels
        # for every epoch. Used to create charts
        # after run is finished:
        self.csv_writer = self.create_csv_writer(raw_data_dir)

        # Place to store intermediate models:
        self.model_archive = \
            self.create_model_archive(self.config,
                                      self.num_classes
                                      )

        # Use SummaryWriterPlus to avoid confusing
        # directory creations when calling add_hparams()
        # on the writer:

        self.writer = SummaryWriterPlus(log_dir=logdir)

        # Intermediate storage for train and val results:
        self.results = ResultCollection()

        self.log.info(
            f"To view tensorboard charts: in shell: tensorboard --logdir {logdir}; then browser: localhost:6006"
        )

    #------------------------------------
    # create_csv_writer
    #-------------------

    def create_csv_writer(self, raw_data_dir):
        '''
        Create a csv_writer that will fill a csv
        file during training/validation as follows:
        
            epoch  train_preds   train_labels  val_preds  val_labels
            
        Cols after the integer 'epoch' col will each be
        an array of ints:
        
                  train_preds    train_lbls   val_preds  val_lbls
                2,"[2,5,1,2,3]","[2,6,1,2,1]","[1,2]",    "[1,3]" 
        
        If raw_data_dir is provided as a str, it is
        taken as the directory where csv file with predictions
        and labels are to be written. The dir is created if necessary.
         
        If the arg is instead set to True, a dir 'runs_raw_results' is
        created under this script's directory if it does not
        exist. Then a subdirectory is created for this run,
        using the hparam settings to build a file name. The dir
        is created if needed. Result ex.:
        
              <script_dir>
                   runs_raw_results
                       Run_lr_0.001_br_32
                           run_2021_05_ ... _lr_0.001_br_32.csv
        
        
        Then file name is created, again from the run
        hparam settings. If this file exists, user is asked whether
        to remove or append. The inst var self.csv_writer is
        initialized to:
        
           o None if csv file exists, but is not to 
             be overwritten nor appended-to
           o A filed descriptor for a file open for either
             'write' or 'append.
        
        :param raw_data_dir: If simply True, create dir and file names
            from hparams, and create as needed. If a string, it is 
            assumed to be the directory where a .csv file is to be
            created. If None, self.csv_writer is set to None.
        :type raw_data_dir: {None | True | str|
        :return: CSV writer ready for action. Set either to
            write a fresh file, or append to an existing file.
            Unless file exists, and user decided not to overwrite
        :rtype: {None | csv.writer}
        '''

        # Ensure the csv file root dir exists if
        # we'll do a csv dir and run-file below it:

        if type(raw_data_dir) == str:
            raw_data_root = raw_data_dir
        else:
            raw_data_root = os.path.join(self.curr_dir, 'runs_raw_results')

        if not os.path.exists(raw_data_root):
            os.mkdir(raw_data_root)

        # Can rely on raw_data_root being defined and existing:

        if raw_data_dir is None:
            return None

        # Create both a raw dir sub-directory and a .csv file
        # for this run:
        csv_subdir_name = FileUtils.construct_filename(self.config.Training,
                                                       prefix='Run',
                                                       incl_date=True)
        os.makedirs(csv_subdir_name)

        # Create a csv file name:
        csv_file_nm = FileUtils.construct_filename(self.config.Training,
                                                   prefix='run',
                                                   suffix='.csv',
                                                   incl_date=True)

        csv_path = os.path.join(raw_data_root, csv_file_nm)

        # Get csv_raw_fd appropriately:

        if os.path.exists(csv_path):
            do_overwrite = FileUtils.user_confirm(
                f"File {csv_path} exists; overwrite?", default='N')
            if not do_overwrite:
                do_append = FileUtils.user_confirm(f"Append instead?",
                                                   default='N')
                if not do_append:
                    return None
                else:
                    mode = 'a'
        else:
            mode = 'w'

        csv_writer = CSVWriterCloseable(csv_path, mode=mode, delimiter=',')

        header = [
            'epoch', 'train_preds', 'train_labels', 'val_preds', 'val_labels'
        ]
        csv_writer.writerow(header)

        return csv_writer

    #------------------------------------
    # create_model_archive
    #-------------------

    def create_model_archive(self, config, num_classes):
        '''
        Creates facility for saving partially trained
        models along the way.
        
        :param config:
        :type config:
        :param num_classes:
        :type num_classes:
        :return: ModelArchive instance ready
            for calls to save_model()
        :rtype: ModelArchive
        '''
        model_archive = ModelArchive(config,
                                     num_classes,
                                     history_len=self.MODEL_ARCHIVE_SIZE,
                                     log=self.log)
        return model_archive

    #------------------------------------
    # close_tensorboard
    #-------------------

    def close_tensorboard(self):
        if self.csv_writer is not None:
            try:
                self.csv_writer.close()
            except Exception as e:
                self.log.warn(f"Could not close csv file: {repr(e)}")
        try:
            self.writer.close()
        except AttributeError:
            self.log.warn(
                "Method close_tensorboard() called before setup_tensorboard()?"
            )
        except Exception as e:
            raise RuntimeError(
                f"Problem closing tensorboard: {repr(e)}") from e

    #------------------------------------
    # get_optimizer
    #-------------------

    def get_optimizer(self, optimizer_name, model, lr):

        optimizer_name = optimizer_name.lower()
        if optimizer_name == 'adam':
            optimizer = optim.Adam(model.parameters(),
                                   lr=lr,
                                   eps=1e-3,
                                   amsgrad=True)
            return optimizer

        if optimizer_name == 'sgd':
            optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
            return optimizer

        if optimizer_name == 'rmsprop':
            optimizer = optim.RMSprop(model.parameters(), lr=lr, momentum=0.9)
            return optimizer

        raise ValueError(f"Optimizer {optimizer_name} not supported")

    #------------------------------------
    # initialize_config_struct
    #-------------------

    def initialize_config_struct(self, config_info):
        '''
        Initialize a config dict of dict with
        the application's configurations. Sections
        will be:
        
          config['Paths']       -> dict[attr : val]
          config['Training']    -> dict[attr : val]
          config['Parallelism'] -> dict[attr : val]
        
        The config read method will handle config_info
        being None. 
        
        If config_info is a string, it is assumed either 
        to be a file containing the configuration, or
        a JSON string that defines the config.
         
        Else config_info is assumed to be a NeuralNetConfig.
        The latter is relevant only if using this file
        as a library, rather than a command line tool.
        
        If given a NeuralNetConfig instance, it is returned
        unchanged. 
        
        :param config_info: the information needed to construct
            the structure
        :type config_info: {NeuralNetConfig | str}
        :return a NeuralNetConfig instance with all parms
            initialized
        :rtype NeuralNetConfig
        '''

        if isinstance(config_info, str):
            # Is it a JSON str? Should have a better test!
            if config_info.startswith('{'):
                # JSON String:
                config = NeuralNetConfig.from_json(config_info)
            else:
                config = self.read_configuration(config_info)
        elif isinstance(config_info, NeuralNetConfig):
            config = config_info
        else:
            msg = f"Error: must have a config file, not {config_info}. See config.cfg.Example in project root"
            # Since logdir may be in config, need to use print here:
            print(msg)
            raise ConfigError(msg)

        return config

    #------------------------------------
    # read_configuration
    #-------------------

    def read_configuration(self, conf_file):
        '''
        Parses config file that describes training parameters,
        various file paths, and how many GPUs different machines have.
        Syntax follows Python's configfile package, which includes
        sections, and attr/val pairs in each section.
        
        Expected sections:

           o Paths: various file paths for the application
           o Training: holds batch sizes, number of epochs, etc.
           o Parallelism: holds number of GPUs on different machines
        
        For Parallelism, expect entries like:
        
           foo.bar.com  = 4
           127.0.0.1    = 5
           localhost    = 3
           172.12.145.1 = 6
           
        Method identifies which of the entries is
        'localhost' by comparing against local hostname.
        Though 'localhost' or '127.0.0.1' may be provided.
        
        Returns a dict of dicts: 
            config[section-names][attr-names-within-section]
            
        Types of standard entries, such as epochs, batch_size,
        etc. are coerced, so that, e.g. config['Training']['epochs']
        will be an int. Clients may add non-standard entries.
        For those the client must convert values from string
        (the type in which values are stored by default) to the
        required type. This can be done the usual way: int(...),
        or using one of the configparser's retrieval methods
        getboolean(), getint(), and getfloat():
        
            config['Training'].getfloat('learning_rate')
        
        :param other_gpu_config_file: path to configuration file
        :type other_gpu_config_file: str
        :return: a dict of dicts mirroring the config file sections/entries
        :rtype: dict[dict]
        :raises ValueErr
        :raises TypeError
        '''

        if conf_file is None:
            return self.init_defaults()

        config = DottableConfigParser(conf_file)

        if len(config.sections()) == 0:
            # Config file exists, but empty:
            return (self.init_defaults(config))

        # Do type conversion also in other entries that
        # are standard:

        types = {
            'epochs': int,
            'batch_size': int,
            'kernel_size': int,
            'sample_width': int,
            'sample_height': int,
            'seed': int,
            'pytorch_comm_port': int,
            'num_pretrained_layers': int,
            'root_train_test_data': str,
            'net_name': str,
        }
        for section in config.sections():
            for attr_name in config[section].keys():
                try:
                    str_val = config[section][attr_name]
                    required_type = types[attr_name]
                    config[section][attr_name] = required_type(str_val)
                except KeyError:
                    # Current attribute is not standard;
                    # users of the corresponding value need
                    # to do their own type conversion when
                    # accessing this configuration entry:
                    continue
                except TypeError:
                    raise ValueError(
                        f"Config file error: {section}.{attr_name} should be convertible to {required_type}"
                    )

        return config

    #------------------------------------
    # set_seed
    #-------------------

    def set_seed(self, seed):
        '''
        Set the seed across all different necessary platforms
        to allow for comparison of different models and runs
        
        :param seed: random seed to set for all random num generators
        :type seed: int
        '''
        torch.manual_seed(seed)
        cuda.manual_seed_all(seed)
        # Not totally sure what these two do!
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        random.seed(seed)

    #------------------------------------
    # time_delta_str
    #-------------------

    def time_delta_str(self, epoch_delta, granularity=2):
        '''
        Takes the difference between two datetime times:
        
               start_time = datetime.datetime.now()
               <some time elapses>
               end_time = datetime.datetime.now()
               
               delta = end_time - start_time
               time_delta_str(delta
        
        Depending on granularity, returns a string like:
        
            Granularity:
                      1  '160.0 weeks'
                      2  '160.0 weeks, 4.0 days'
                      3  '160.0 weeks, 4.0 days, 6.0 hours'
                      4  '160.0 weeks, 4.0 days, 6.0 hours, 42.0 minutes'
                      5  '160.0 weeks, 4.0 days, 6.0 hours, 42.0 minutes, 13.0 seconds'
        
            For smaller time deltas, such as 10 seconds,
            does not include leading zero times. For
            any granularity:
            
                          '10.0 seconds'

            If duration is less than second, returns '< 1sec>'
            
        :param epoch_delta:
        :type epoch_delta:
        :param granularity:
        :type granularity:
        '''
        intervals = (
            ('weeks', 604800),  # 60 * 60 * 24 * 7
            ('days', 86400),  # 60 * 60 * 24
            ('hours', 3600),  # 60 * 60
            ('minutes', 60),
            ('seconds', 1),
        )
        secs = epoch_delta.total_seconds()
        result = []
        for name, count in intervals:
            value = secs // count
            if value:
                secs -= value * count
                if value == 1:
                    name = name.rstrip('s')
                result.append("{} {}".format(value, name))
        dur_str = ', '.join(result[:granularity])
        if len(dur_str) == 0:
            dur_str = '< 1sec>'
        return dur_str

    #------------------------------------
    # step_number
    #-------------------

    def step_number(self, epoch, split_num, num_folds):
        '''
        Combines an epoch with a split number into 
        a single integer series as epochs increase,
        and split_num cycles from 0 to num_folds.
        
        :param epoch: epoch to encode
        :type epoch: int
        :param split_num: split number to encode
        :type split_num: int
        :param num_folds: number of folds for CV splitting
            must be contant!
        :type num_folds: int
        :return: an integer the combines epoch and split-num
        :rtype: int
        '''

        step_num = epoch * num_folds + split_num
        return step_num

    #------------------------------------
    # cleanup
    #-------------------

    def cleanup(self):
        '''
        Recover resources taken by collaborating
        processes. OK to call multiple times.
        '''
        # self.clear_gpu()

        try:
            self.writer.close()
        except Exception as e:
            self.log.err(f"Could not close tensorboard writer: {repr(e)}")
Ejemplo n.º 2
0
class Inferencer:
    '''
    classdocs
    '''

    #------------------------------------
    # Constructor
    #-------------------

    def __init__(self,
                 model_paths,
                 samples_path,
                 batch_size=1,
                 labels_path=None,
                 gpu_ids=0):
        '''
        Given the path to a trained model,
        and the path to the root of a set
        of data, compute predictions.
        
        If labels_path is None, the subdirectory
        names between the samples_path root,
        and the samples themselves are used as
        the ground truth labels.
        
        By default: run batches of size 1,
        because we always have drop_last set
        to True. For small test sets leaving
        out any data at all isn't good. Caller
        can still set batch_size higher to gain
        speed if the testset is very large, so that
        not inferencing on up to batch_size - 1 
        samples is OK
        
        :param model_paths:
        :type model_paths:
        :param samples_path:
        :type samples_path:
        :param batch_size:
        :type batch_size:
        :param labels_path:
        :type labels_path:
        :param gpu_ids: Device number of GPU, in case 
            one is available
        :type gpu_ids: {int | [int]} 
        '''

        self.model_paths = model_paths
        self.samples_path = samples_path
        self.labels_path = labels_path
        self.gpu_ids = gpu_ids if type(gpu_ids) == list else [gpu_ids]
        if batch_size is not None:
            self.batch_size = batch_size
        else:
            self.batch_size = 1

        self.IMG_EXTENSIONS = FileUtils.IMG_EXTENSIONS
        self.log = LoggingService()
        self.curr_dir = os.path.dirname(__file__)

    #------------------------------------
    # prep_model_inference
    #-------------------

    def prep_model_inference(self, model_path):
        '''
        1. Parses model_path into its components, and 
            creates a dict: self.model_props, which 
            contains the network type, grayscale or not,
            whether pretrained, etc.
        2. Creates self.csv_writer to write results measures
            into csv files. The destination file is determined
            as follows:
                <script_dir>/runs_raw_inferences/inf_csv_results_<datetime>/<model-props-derived-fname>.csv
        3. Creates self.writer(), a tensorboard writer with destination dir:
                <script_dir>/runs_inferences/inf_results_<datetime>
        4. Creates an ImageFolder classed dataset to self.samples_path
        5. Creates a shuffling DataLoader
        6. Initializes self.num_classes and self.class_names
        7. Creates self.model from the passed-in model_path name
        
        :param model_path: path to model that will be used for
            inference by this instance of Inferencer
        :type model_path: str
        '''

        model_fname = os.path.basename(model_path)

        # Extract model properties
        # from the model filename:
        self.model_props = FileUtils.parse_filename(model_fname)

        csv_results_root = os.path.join(self.curr_dir, 'runs_raw_inferences')
        #self.csv_dir = os.path.join(csv_results_root, f"inf_csv_results_{uuid.uuid4().hex}")
        ts = FileUtils.file_timestamp()
        self.csv_dir = os.path.join(csv_results_root, f"inf_csv_results_{ts}")
        os.makedirs(self.csv_dir, exist_ok=True)

        csv_file_nm = FileUtils.construct_filename(self.model_props,
                                                   prefix='inf',
                                                   suffix='.csv',
                                                   incl_date=True)
        csv_path = os.path.join(self.csv_dir, csv_file_nm)

        self.csv_writer = CSVWriterCloseable(csv_path)

        ts = FileUtils.file_timestamp()
        tensorboard_root = os.path.join(self.curr_dir, 'runs_inferences')
        tensorboard_dest = os.path.join(tensorboard_root, f"inf_results_{ts}")
        #f"inf_results_{ts}{uuid.uuid4().hex}")
        os.makedirs(tensorboard_dest, exist_ok=True)

        self.writer = SummaryWriterPlus(log_dir=tensorboard_dest)

        dataset = SingleRootImageDataset(
            self.samples_path, to_grayscale=self.model_props['to_grayscale'])

        # Make reproducible:
        Utils.set_seed(42)
        #********Utils.set_seed(56)
        self.loader = DataLoader(dataset,
                                 batch_size=self.batch_size,
                                 shuffle=True,
                                 drop_last=True)
        self.class_names = dataset.class_names()
        self.num_classes = len(self.class_names)

        # Get the right type of model,
        # Don't bother getting it pretrained,
        # of freezing it, b/c we will overwrite
        # the weights:

        self.model = NetUtils.get_net(
            self.model_props['net_name'],
            num_classes=self.num_classes,
            pretrained=False,
            freeze=0,
            to_grayscale=self.model_props['to_grayscale'])

        self.log.info(f"Tensorboard info written to {tensorboard_dest}")
        self.log.info(f"Result measurement CSV file(s) written to {csv_path}")

    #------------------------------------
    # __call__
    #-------------------

    def __call__(self, gpu_id_model_path_pair):
        gpu_id, self.model_path = gpu_id_model_path_pair
        self.prep_model_inference(self.model_path)
        self.log.info(
            f"Begining inference with model {FileUtils.ellipsed_file_path(self.model_path)} on gpu_id {gpu_id}"
        )
        #****************
        #return self.run_inference(gpu_to_use=gpu_id)
        dicts_from_runs = []
        for i in range(3):
            self.curr_dict = {}
            dicts_from_runs.append(self.curr_dict)
            self.run_inference(gpu_to_use=gpu_id)
        print(dicts_from_runs)
        #****************

    #------------------------------------
    # go
    #-------------------

    def go(self):
        # Pair models to GPUs; example for
        # self.gpu_ids == [0,4], and three models:
        #    [(gpu0, model0) (gpu4, model1), (gpu0, model3)]

        repeats = int(np.ceil(len(self.model_paths) / len(self.gpu_ids)))
        gpu_model_pairings = list(zip(self.gpu_ids * repeats,
                                      self.model_paths))

        #************* No parallelism for debugging
        self(gpu_model_pairings[0])
        return
        #************* END No parallelism for debugging

        with Pool(len(self.gpu_ids)) as inf_pool:
            # Run as many inferences in parallel as
            # there are models to try. The first arg,
            # (self): means to invoke the __call__() method
            # on self.
            result_it = inf_pool.imap(self,
                                      gpu_model_pairings,
                                      chunksize=len(self.gpu_ids))
            results = [res.get() for res in result_it]
            print(f"******Results: {results}")

    #------------------------------------
    # run_inferencer
    #-------------------

    def run_inference(self, gpu_to_use=0):
        '''
        Runs model over dataloader. Along
        the way: creates ResultTally for each
        batch, and maintains dict instance variable
        self.raw_results for later conversion of
        logits to class IDs under different threshold
        assumptions. 
        
        self.raw_results: 
                {'all_outputs' : <arr>,
                 'all_labels'  : <arr>
                 }
        
        Returns a ResultCollection with the
        ResultTally instances of each batch.

        :param gpu_to_use: which GPU to deploy to (if it is available)
        :type gpu_to_use: int
        :return: collection of tallies, one for each batch,
            or None if something went wrong.
        :rtype: {None | ResultCollection}
        '''
        # Just in case the loop never runs:
        batch_num = -1
        overall_start_time = datetime.datetime.now()

        try:
            try:
                if torch.cuda.is_available():
                    self.model.load_state_dict(torch.load(self.model_path))
                    FileUtils.to_device(self.model, 'gpu', gpu_to_use)
                else:
                    self.model.load_state_dict(
                        torch.load(self.model_path,
                                   map_location=torch.device('cpu')))
            except RuntimeError as e:
                emsg = repr(e)
                if emsg.find("size mismatch for conv1") > -1:
                    emsg += " Maybe model was trained with to_grayscale=False, but local net created for grayscale?"
                    raise RuntimeError(emsg) from e

            loss_fn = nn.CrossEntropyLoss()

            result_coll = ResultCollection()

            # Save all per-class logits for ability
            # later to use different thresholds for
            # conversion to class IDs:

            all_outputs = []
            all_labels = []

            self.model.eval()
            num_test_samples = len(self.loader.dataset)
            self.log.info(
                f"Begin inference ({num_test_samples} test samples)...")

            samples_processed = 0

            loop_start_time = overall_start_time
            with torch.no_grad():

                for batch_num, (batch, targets) in enumerate(self.loader):
                    if torch.cuda.is_available():
                        images = FileUtils.to_device(batch, 'gpu')
                        labels = FileUtils.to_device(targets, 'gpu')
                    else:
                        images = batch
                        labels = targets

                    outputs = self.model(images)
                    loss = loss_fn(outputs, labels)

                    images = FileUtils.to_device(images, 'cpu')
                    outputs = FileUtils.to_device(outputs, 'cpu')
                    labels = FileUtils.to_device(labels, 'cpu')
                    loss = FileUtils.to_device(loss, 'cpu')

                    #**********
                    max_logit = outputs[0].max().item()
                    max_idx = (outputs.squeeze() == max_logit).nonzero(
                        as_tuple=False).item()
                    smpl_id = torch.utils.data.dataloader.sample_id_seq[-1]
                    lbl = labels[0].item()
                    pred_cl = max_idx

                    self.curr_dict[smpl_id] = (smpl_id, lbl, pred_cl)
                    #**********

                    # Specify the batch_num in place
                    # of an epoch, which is not applicatble
                    # during testing:
                    tally = ResultTally(batch_num, LearningPhase.TESTING,
                                        outputs, labels, loss,
                                        self.num_classes, self.batch_size)
                    result_coll.add(tally, step=None)

                    all_outputs.append(outputs)
                    all_labels.append(labels)

                    samples_processed += len(labels)

                    del images
                    del outputs
                    del labels
                    del loss

                    torch.cuda.empty_cache()

                    time_now = datetime.datetime.now()
                    # Sign of life every 6 seconds:
                    if (time_now - loop_start_time).seconds >= 5:
                        self.log.info(
                            f"GPU{gpu_to_use} processed {samples_processed}/{num_test_samples} samples"
                        )
                        loop_start_time = time_now
        finally:

            #*********
            print(f"Sample seq: {torch.utils.data.dataloader.sample_id_seq}")
            torch.utils.data.dataloader.sample_id_seq = []
            #*********
            time_now = datetime.datetime.now()
            test_time_duration = time_now - overall_start_time
            # A human readable duration st down to minutes:
            duration_str = FileUtils.time_delta_str(test_time_duration,
                                                    granularity=4)
            self.log.info(
                f"Done with inference: {samples_processed} test samples; {duration_str}"
            )
            # Total number of batches we ran:
            num_batches = 1 + batch_num  # b/c of zero-base

            # If loader delivered nothing, the loop
            # never ran; warn, and get out:
            if num_batches == 0:
                self.log.warn(
                    f"Dataloader delivered no data from {self.samples_path}")
                self.close()
                return None

            # Var all_outputs is now:
            #  [tensor([pred_cl0, pred_cl1, pred_cl<num_classes - 1>], # For sample0
            #   tensor([pred_cl0, pred_cl1, pred_cl<num_classes - 1>], # For sample1
            #                     ...
            #   ]
            # Make into one tensor: (num_batches, batch_size, num_classes),
            # unless an exception was raised at some point,
            # throwing us into this finally clause:
            if len(all_outputs) == 0:
                self.log.info(
                    f"No outputs were produced; thus no results to report")
                return None

            self.all_outputs_tn = torch.stack(all_outputs)
            # Be afraid...be very afraid:
            assert(self.all_outputs_tn.shape == \
                   torch.Size([num_batches,
                               self.batch_size,
                               self.num_classes])
                   )

            # Var all_labels is now num-batches tensors,
            # each containing batch_size labels:
            assert (len(all_labels) == num_batches)

            # list of single-number tensors. Make
            # into one tensor:
            self.all_labels_tn = torch.stack(all_labels)
            assert(self.all_labels_tn.shape == \
                   torch.Size([num_batches, self.batch_size])
                   )
            # And equivalently:
            assert(self.all_labels_tn.shape == \
                   (self.all_outputs_tn.shape[0],
                    self.all_outputs_tn.shape[1]
                    )
                   )

            self.report_results(result_coll)
            self.close()

        return result_coll

    #------------------------------------
    # report_results
    #-------------------

    def report_results(self, tally_coll):
        self._report_textual_results(tally_coll, self.csv_dir)
        self._report_conf_matrix(tally_coll, show_in_tensorboard=True)
        self._report_charted_results()

    #------------------------------------
    # _report_conf_matrix
    #-------------------

    def _report_conf_matrix(self,
                            tally_coll,
                            show=True,
                            show_in_tensorboard=None):
        '''
        Computes the confusion matrix CM from tally collection.
        Creates an image from CM, and displays it via matplotlib, 
        if show arg is True. If show_in_tensorboard is a Tensorboard
        SummaryWriter instance, the figure is posted to tensorboard,
        no matter the value of the show arg.  
        
        Returns the Figure object.
        
        :param tally_coll: all ResultTally instances to be included
            in the confusion matrix
        :type tally_coll: result_tallying.ResultCollection
        :param show: whether or not to call show() on the
            confusion matrix figure, or only return the Figure instance
        :type show: bool
        :param show_in_tensorboard: whether or not to post the image
            to tensorboard
        :type show_in_tensorboard: bool
        :return: Figure instance containing confusion matrix heatmap
            with color legend.
        :rtype: matplotlib.pyplot.Figure
        '''

        all_preds = []
        all_labels = []

        for tally in tally_coll.tallies(phase=LearningPhase.TESTING):
            all_preds.extend(tally.preds)
            all_labels.extend(tally.labels)

        conf_matrix = Charter.compute_confusion_matrix(all_labels,
                                                       all_preds,
                                                       self.class_names,
                                                       normalize=True)

        # Normalization in compute_confusion_matrix() is
        # to 0-1. Turn those values into percentages:
        conf_matrix_perc = (100 * conf_matrix).astype(int)

        # Decide whether or not to write
        # confusion cell values into the cells.
        # The decision depends on how many species
        # are represented in the conf matrix; too many,
        # and having numbers in all cells is too cluttered:

        if len(self.class_names
               ) > CELL_LABELING.CONF_MATRIX_CELL_LABEL_LIMIT.value:
            write_in_fields = CELL_LABELING.DIAGONAL
        else:
            write_in_fields = CELL_LABELING.ALWAYS

        fig = Charter.fig_from_conf_matrix(
            conf_matrix_perc,
            supertitle='Confusion Matrix\n',
            subtitle='Normalized to percentages',
            write_in_fields=write_in_fields)
        if show_in_tensorboard:
            self.writer.add_figure('Inference Confusion Matrix',
                                   fig,
                                   global_step=0)

        if show:
            # Something above makes fig lose its
            # canvas manager. Add that back in:
            Utils.add_pyplot_manager_to_fig(fig)
            fig.show()
        return fig

    #------------------------------------
    # _report_charted_results
    #-------------------

    def _report_charted_results(self, thresholds=None):
        '''
        Computes and (pyplot-)shows a set of precision-recall
        curves in one plot. If precision and/or recall are 
        undefined (b/c of division by zero) for all curves, then
        returns False, else True. If no curves are defined,
        logs a warning.
        
        :param thresholds: list of cutoff thresholds
            for turning logits into class ID predictions.
            If None, the default at Charters.compute_multiclass_pr_curves()
            is used.
        :type thresholds: [float]
        :return: True if curves were computed and show. Else False
        :rtype: bool
        '''

        # Obtain a dict of CurveSpecification instances,
        # one for each class, plus the mean Average Precision
        # across all curves. The dict will be keyed
        # by class ID:

        (all_curves_info, mAP) = \
          Charter.compute_multiclass_pr_curves(
              self.all_labels_tn,
              self.all_outputs_tn,
              thresholds
              )

        # Separate out the curves without
        # ill defined prec, rec, or f1:
        well_defined_curves = list(filter(
                    lambda crv_obj: not(crv_obj['undef_prec'] or\
                                        crv_obj['undef_rec'] or\
                                        crv_obj['undef_f1']
                                        ),
                    all_curves_info.values()
                    )
            )

        if len(well_defined_curves) == 0:
            self.log.warn(
                f"For all thresholds, one or more of precision, recall or f1 are undefined. No p/r curves to show"
            )
            return False

        # Too many curves are clutter. Only
        # show the best and worst by optimal f1:
        f1_sorted = sorted(well_defined_curves,
                           key=lambda obj: obj['best_op_pt']['f1'])
        curves_to_show = {
            crv_obj['class_id']: crv_obj
            for crv_obj in (f1_sorted[0], f1_sorted[-1])
        }
        #********** Mixup with objs blurring together

        (_num_classes, fig) = \
          ClassificationPlotter.chart_pr_curves(curves_to_show)

        fig.show()
        return True

    #------------------------------------
    # _report_textual_results
    #-------------------

    def _report_textual_results(self, tally_coll, res_dir):
        '''
        Give a sequence of tallies with results
        from a series of batches, create long
        outputs, and inputs lists from all tallies
        
        Computes information retrieval type values:
             precision (macro/micro/weighted/by-class)
             recall    (macro/micro/weighted/by-class)
             f1        (macro/micro/weighted/by-class)
             acuracy
             balanced_accuracy
        
        Combines these results into a Pandas series, 
        and writes them to a csv file. That file is constructed
        from the passed-in res_dir, appended with 'ir_results.csv'.
        
        Finally, constructs Github flavored tables from the
        above results, and posts them to the 'text' tab of 
        tensorboard.
        
        Returns the results measures Series 
        
        :param tally_coll: collect of tallies from batches
        :type tally_coll: ResultCollection
        :param res_dir: directory where all .csv and other 
            result files are to be written
        :type res_dir: str
        :return results of information retrieval-like measures
        :rtype: pandas.Series
        '''

        all_preds = []
        all_labels = []

        for tally in tally_coll.tallies(phase=LearningPhase.TESTING):
            all_preds.extend(tally.preds)
            all_labels.extend(tally.labels)

        res = OrderedDict({})
        res['prec_macro'] = precision_score(all_labels,
                                            all_preds,
                                            average='macro',
                                            zero_division=0)
        res['prec_micro'] = precision_score(all_labels,
                                            all_preds,
                                            average='micro',
                                            zero_division=0)
        res['prec_weighted'] = precision_score(all_labels,
                                               all_preds,
                                               average='weighted',
                                               zero_division=0)
        res['prec_by_class'] = precision_score(all_labels,
                                               all_preds,
                                               average=None,
                                               zero_division=0)

        res['recall_macro'] = recall_score(all_labels,
                                           all_preds,
                                           average='macro',
                                           zero_division=0)
        res['recall_micro'] = recall_score(all_labels,
                                           all_preds,
                                           average='micro',
                                           zero_division=0)
        res['recall_weighted'] = recall_score(all_labels,
                                              all_preds,
                                              average='weighted',
                                              zero_division=0)
        res['recall_by_class'] = recall_score(all_labels,
                                              all_preds,
                                              average=None,
                                              zero_division=0)

        res['f1_macro'] = f1_score(all_labels,
                                   all_preds,
                                   average='macro',
                                   zero_division=0)
        res['f1_micro'] = f1_score(all_labels,
                                   all_preds,
                                   average='micro',
                                   zero_division=0)
        res['f1_weighted'] = f1_score(all_labels,
                                      all_preds,
                                      average='weighted',
                                      zero_division=0)
        res['f1_by_class'] = f1_score(all_labels,
                                      all_preds,
                                      average=None,
                                      zero_division=0)

        res['accuracy'] = accuracy_score(all_labels, all_preds)
        res['balanced_accuracy'] = balanced_accuracy_score(
            all_labels, all_preds)

        res_series = pd.Series(list(res.values()), index=list(res.keys()))

        # Write information retrieval type results
        # to a one-line .csv file, using pandas Series
        # as convenient intermediary:
        res_csv_path = os.path.join(res_dir, 'ir_results.csv')
        res_series.to_csv(res_csv_path)

        res_rnd = {}
        for meas_nm, meas_val in res.items():

            # Measure results are either floats (precision, recall, etc.),
            # or np arrays (e.g. precision-per-class). For both
            # cases, round each measure to one digit:

            res_rnd[meas_nm] = round(meas_val,1) if type(meas_val) == float \
                                                 else meas_val.round(1)

        ir_measures_skel = {
            'col_header': ['precision', 'recall', 'f1'],
            'row_labels': ['macro', 'micro', 'weighted'],
            'rows': [[
                res_rnd['prec_macro'], res_rnd['recall_macro'],
                res_rnd['f1_macro']
            ],
                     [
                         res_rnd['prec_micro'], res_rnd['recall_micro'],
                         res_rnd['f1_micro']
                     ],
                     [
                         res_rnd['prec_weighted'], res_rnd['recall_weighted'],
                         res_rnd['f1_weighted']
                     ]]
        }

        ir_per_class_rows = [[
            prec_class, recall_class, f1_class
        ] for prec_class, recall_class, f1_class in zip(
            res_rnd['prec_by_class'], res_rnd['recall_by_class'],
            res_rnd['f1_by_class'])]
        ir_by_class_skel = {
            'col_header': ['precision', 'recall', 'f1'],
            'row_labels': self.class_names,
            'rows': ir_per_class_rows
        }

        accuracy_skel = {
            'col_header': ['accuracy', 'balanced_accuracy'],
            'row_labels': ['Overall'],
            'rows': [[res_rnd['accuracy'], res_rnd['balanced_accuracy']]]
        }

        ir_measures_tbl = GithubTableMaker.make_table(ir_measures_skel,
                                                      sep_lines=False)
        ir_by_class_tbl = GithubTableMaker.make_table(ir_by_class_skel,
                                                      sep_lines=False)
        accuracy_tbl = GithubTableMaker.make_table(accuracy_skel,
                                                   sep_lines=False)

        # Write the markup tables to Tensorboard:
        self.writer.add_text('Information retrieval measures',
                             ir_measures_tbl,
                             global_step=0)
        self.writer.add_text('Per class measures',
                             ir_by_class_tbl,
                             global_step=0)
        self.writer.add_text('Accuracy', accuracy_tbl, global_step=0)

        return res_series

    #------------------------------------
    # close
    #-------------------

    def close(self):
        try:
            self.writer.close()
        except Exception as e:
            self.log.err(f"Could not close tensorboard writer: {repr(e)}")
Ejemplo n.º 3
0
class ModelArchive:
    '''
    classdocs
    '''

    #------------------------------------
    # Constructor 
    #-------------------

    def __init__(self, 
                 config, 
                 num_classes,
                 history_len=8,
                 model_root=None,
                 log=None):
        '''
        Constructor:
        
        :param config: configuration structure
        :type config: NeuralNetConfig
        :param num_classes: number of target classes
        :type num_classes: int
        :param history_len: number of model snapshots to 
            maintain
        :type history_len: int
        :param model_root: path to where models
            will be deposited
        :type model_root: str
        :param log: logging service to use. If
            None, create new one for display output
        :type log: LoggingService
        '''

        self.curr_dir = os.path.dirname(os.path.abspath(__file__))
        
        # Model root directory:
        if model_root is None:
            self.model_root = os.path.abspath(
                os.path.join(self.curr_dir, 
                             '../runs_models')
                )
        else:
            self.model_root = model_root

        if os.path.exists(self.model_root) and \
                not os.path.isdir(self.model_root):
            raise FileExistsError(f"{self.model_root} exists but is not a directory")

        # Ensure that intermediate dirs exist:
        try:
            os.makedirs(self.model_root)
        except FileExistsError:
            pass

        if log is None:
            self.log = LoggingService()
        else:
            self.log = log
            
        self.history_len = history_len

        # Create a subdirectory of model_root
        # where this archive keeps its models.
        # The subdir is guaranteed to be unique
        # among model_root's siblings, and it will
        # be created:
        
        self.run_subdir = self._construct_run_subdir(config, 
                                                    num_classes,
                                                    self.model_root)

        # Queue to track models, keeping the 
        # number of saved models to history_len:
        
        self.model_fnames = deque(maxlen=self.history_len)
        
    #------------------------------------
    # save_model 
    #-------------------
    
    def save_model(self, model, epoch):
        '''
        Saves and retains trained models
        on disk. 
        
        Within a subdir the method maintains a queue
        of files of len history_len: 
        
                 fname_1_ep_0.pth
                 fname_2_ep_1.pth
                      ...
                 fname_<history_len>.pth
        
        where ep_<n> is the epoch during training
        where the model of that moment is being 
        saved.
        
        When history_len model files are already present, 
        removes the oldest.
        
        Assumptions: 
            o self.fname_els_dict contains prop/value
              pairs for use in FileUtils.construct_filename()
                 {'bs' : 32,
                  'lr' : 0.001,
                     ...
                 }
            o self model_fnames is a deque the size of
              which indicates how many models to save
              before discarding the oldest one as new
              ones are added
                 
        :param model: model to save
        :type model: nn.module
        :param epoch: the epoch that created the model
        :type epoch: int
        :param history_len: number of snapshot to retain
        :type history_len: int
        '''
        
        deque_len = len(self.model_fnames)
        if deque_len >= self.history_len:
            # Pushing a new model fname to the
            # front will pop the oldest from the
            # end. That file needs to be deleted:
            oldest_model_path = self.model_fnames[-1]
        else:
            # No file will need to be deleted.
            # Still filling our allotment:
            oldest_model_path = None
            
        model_fname = FileUtils.construct_filename(self.fname_els_dict,
                                                   prefix='mod', 
                                                   suffix=f"_ep{epoch}.pth", 
                                                   incl_date=True)
        
        model_path = os.path.join(self.run_subdir, model_fname)
        
        # As recommended by pytorch, save the
        # state_dict for portability:
        torch.save(model.state_dict(), model_path)

        self.model_fnames.appendleft(model_path)
        
        if oldest_model_path is not None:
            try:
                os.remove(oldest_model_path)
            except Exception as e:
                self.log.warn(f"Could not remove old model: {repr(e)}")


    #------------------------------------
    # restore_model 
    #-------------------
    
    def restore_model(self, model_path, config=None):
        '''
        Given the path to a saved model, 
        load and return it. The saved file
        is the saved model's state_dict. 
        So, the method must first create a
        model instance of the correct type.
        Then the state is loaded into that
        instance.
        
        :param model_path:
        :type model_path:
        :param config: a config structure that will be
            use to decide which model class to instantiate.
            If None, attempts to reconstruct the 
            information from the model_path.
        :type config: NeuralNetConfig
        :return: loaded model
        :rtype: torch.nn.module
        '''
        
        if config is None:
            model = self._instantiate_model(config=config)
        else:
            model = self._instantiate_model(run_path_str=model_path)
         
        model.load_state_dict(torch.load(model_path))
        return model

    #------------------------------------
    # _instantiate_model 
    #-------------------
    
    def _instantiate_model(self, run_path_str=None, config=None):
        '''
        Returns a model based on information in 
        the config structure, or the info encoded
        in the run_path_str file name. 
        
        One of run_path_str or config must be non-None.
        If both are non-None, uses config.
        
        File paths that encode run parameters look like
        this horror:
        
        model_2021-03-11T10_59_02_net_resnet18_pretrain_0_lr_0.01_opt_SGD_bs_64_ks_7_folds_0_gray_True_classes_10.pth 
        
        :param run_path_str: a path name associated with
            a model. 
        :type run_path_str:
        :param config: run configuration structure 
        :type config: NeuralNetConfig
        :return: a model 
        :rtype: torch.nn.module
        '''
        if config is None:
            # Get a dict with info 
            # in a standard (horrible) file name:
            fname_props = FileUtils.parse_filename(run_path_str)
        else:
            fname_props = config.Training
            data_root   = config.Paths.root_train_test_data
            class_names = FileUtils.find_class_names(data_root)
            fname_props['classes'] = len(class_names)
            fname_props['pretrain'] = config.Training.getint('freeze', 0)
        
        model = NetUtils.get_net(net_name=fname_props['net_name'],
                                 num_classes=fname_props['classes'],
                                 freeze=fname_props['pretrain'],
                                 to_grayscale=fname_props['to_grayscale']
                                 )
        return model

# ---------------- Utils -------------

    #------------------------------------
    # _construct_run_subdir 
    #-------------------
    
    def _construct_run_subdir(self, 
                             config, 
                             num_classes, 
                             model_root):
        '''
        Constructs a directory name composed of
        elements specified in utility.py's 
        FileUtils file/config info dicts.
        
        Ensures that <model_root>/subdir_name does
        not exist. If it does, keeps adding '_r<n>'
        to the end of the dir name.
        
        Final str will look like this:
        
        model_2021-03-23T15_38_39_net_resnet18_pre_True_frz_6_bs_2_folds_5_opt_SGD_ks_7_lr_0.01_gray_False
            
        Details will depend on the passed in 
        configuration.

        Instance var fname_els_dict will contain 
        all run attr/values needed for calls to 
        FileUtils.construct_filename() 
        
        :param config: run configuration
        :type config: NeuralNetConfig
        :param num_classes: number of target classes 
        :type num_classes: int
        :param model_root: full path to dir where the
            subdir is to be created
        :type model_root: str
        :return: unique subdir name of self.model_root,
            which has been created
        :rtype: str
        '''

        # Using config, gather run-property/value 
        # pairs to include in the dir name:
         
        fname_els_dict = {}
        
        section_dict   = config.Training 
        
        for el_name, el_abbr in FileUtils.fname_long_2_short.items():
            
            el_type = FileUtils.fname_el_types[el_abbr]
            
            if el_type == int:
                fname_els_dict[el_name] = section_dict.getint(el_name)
            elif el_type == str:
                fname_els_dict[el_name] = section_dict.get(el_name)
            elif el_type == float:
                fname_els_dict[el_name] = section_dict.getfloat(el_name)
            elif el_type == bool:
                fname_els_dict[el_name] = section_dict.getboolean(el_name)
            elif callable(el_type):
                # A lambda or func. Apply it:
                fname_els_dict[el_name] = el_type(section_dict[el_name])

        fname_els_dict['num_classes'] = num_classes

        # Save this root name:
        self.fname_els_dict = fname_els_dict

        # Get the subdir name (without leading path):
        dir_basename = FileUtils.construct_filename(
            fname_els_dict,
            prefix='models',
            suffix=None, 
            incl_date=True)
        
        final_dir_path = os.path.join(model_root, dir_basename)
        
        # Disambiguate by appending '_r<n>' as needed: 
        disambiguation = 1
        while os.path.exists(final_dir_path):
            new_basename = f"{dir_basename}_r{disambiguation}"
            final_dir_path = os.path.join(model_root, new_basename)
            disambiguation += 1

        os.makedirs(final_dir_path)
        
        return final_dir_path 
Ejemplo n.º 4
0
class CrossValidatingDataLoader(DataLoader):
    '''
    
    Subclass of torch.utils.data.DataLoader. Provides
    stratified k-fold crossvalidation in single-machine,
    (optionally) single-GPU context.
    
    Instantiate this class if running only on a
    single machine, optionally using a single GPU. Else,
    instantiate the MultiprocessingDataLoader subclass 
    instead.
    
    An instance of this class wraps any dict-API dataset instance, 
    which provides tuples , for instance (<img-tensor>, class-label-int) 
    from the file system when given a sample ID.
    
    This subclass of torch.utils.data.DataLoader specilizes
    the default by using a stratified k-fold cross validation
    sampler. That underlying sampler manages partitioning of
    samples into folds, and successively feeding samples from
    the training folds. The sampler also manages the 'switching out'
    of folds to take the role of test fold in round robin fashion.
        
    This DataLoader instance also managing combination of 
    samples into batches.
    
    An instance of this class presents an iterator API, additionally
    serving the test samples whenever one set of train folds are 
    exhausted. Example: assume 
          
          o k-fold cross validation k = 5
        
          for split in range(k):
          
              for batch in my_dataloader:
                  try:
                      <feed training batch to emerging model>
                  except EndOfSplit as e:
                      print(e.message) # Just for debugging
                      break
                  
              # Exhausted all train folds of one split
              # Now test current state of the 
              # model using this split's test samples,
              # which are available as an iterator from the
              # dataloader:
              
              for (img_tensor, label) in my_ataloader.validation_samples():
                  <test model on img_tensor>
         
              # next split
              
    The validation_samples() method is a generator that provides the content of 
    the just exhausted split's validation samples.
    
    NOTE: when re-setting an instance of this class
          for a new epoch, client must call set_epoch()
          with the new epoch number to ensure proper
          shuffling randomness. Such a reset occurs implicitly
          with the often used idiom:
               
                for i,res = enumerate(dataloader)
        
          The enumerate() starts the same dataloader instance
          from the beginning. 
          
          If shuffle is False, set_epoch() needs not be called.
          But doing so does no harm.
          
    '''

    #------------------------------------
    # Constructor
    #-------------------

    def __init__(self,
                 dataset,
                 batch_size=32,
                 shuffle=True,
                 seed=42,
                 num_workers=0,
                 pin_memory=False,
                 prefetch_factor=2,
                 drop_last=True,
                 num_folds=10,
                 sampler=None,
                 logger=None):
        '''
        This instance will use cross validation
        as it serves out samples. The client determines
        the number of folds to use. Example for 
        num_folds of 2:
        
         Split1:
           TrainFold1    TrainFold2   ValidationFold  
            sample1      sample2        sample3
            sample4      sample5        sample6

         Split2:
           TrainFold1    TrainFold2   ValidationFold  
            sample3      sample4        sample2
            sample1      sample6        sample5
            
        This dataloader will create two sequences,
        like this:
        
           For use with training:   [sample1, sample4, sample2, sample5]
           For use with validation: [sample4, sample6]
             after the training 
             sequence is used up

        Assuming batch_size of two, this dataloader's
        client will receive one row from each 
        call to next():
        
            [[sample1, sample4],
             [sample2, sample5],
             [None   , None]
             ]
             
        The None tuple indicates that this split has
        been exhausted, and it is time to validate.
        
        The client then calls validation_samples() on
        this dataloader instance to receive one validation
        sample at a time. The client will predict the
        (target) class for each of these validation samples,
        and tally successes and failures. The client should
        then compute the compute validation accuracy from
        that series of successes and failures. 

        Calling next() again will create a new split,
        and again feed out the samples in the respective
        new folds.
        
        The feed terminates after as many splits as there
        are folds. Any following call to next() will raise
        a StopIteration exception.

        :param dataset: underlying map-store that 
                supplies(img_torch, label) tuples
        :type dataset: BirdDataset
        :param batch_size: number of samples to combine into 
            a batch to feed model during training
        :type batch_size: int
        :param shuffle: whether or not to shuffle the
            dataset once, initially.
        :type shuffle: bool
        :param seed: random seed to use if shuffle is True
        :type shuffle: int
        :param num_workers: number of threads used to preload
        :type num_workers: int
        :param pin_memory: set to True if using a GPU. Speeds
            transfer of tensors from CPU to GPU
        :type pin_memory: bool
        :param prefetch_factor: how many samples to prefetch from
            underlying database to speed access to file system
        :type prefetch_factor: int
        :param drop_last: whether or not to serve only partially 
            filled batches. Those occur when samples cannot be
            evenly packed into batches. 
        :type drop_last: bool
        :param num_folds: the 'k' in k-fold cross validation
        :type num_folds: int
        :param sampler: Only used when MultiprocessingDataLoader
            is being instantiated, and that class's __init__()
            calls super(). Leave out for singleprocess/single-GPU
            use
        :type sampler: {None | DistributedSKFSampler}
        :param logger: the LoggingService instance to use
            for logging info/warnings/errors. If None, fetches
            the LoggingService singleton.
        :type logger: LoggingService
        '''

        if len(dataset) == 0:
            raise ValueError("Dataset is empty, nothing to load")

        self.drop_last = drop_last

        if logger is None:
            self.log = LoggingService()
        else:
            self.log = logger

        # Sampler will only be set if a subclass instance
        # of MultiprocessingDataLoader is being initialized.
        # Else, running single process:

        if sampler is None:
            self.sampler = SKFSampler(dataset,
                                      num_folds=num_folds,
                                      shuffle=shuffle,
                                      drop_last=drop_last,
                                      seed=seed)
        else:
            self.sampler = sampler

        if not isinstance(batch_size, int) or batch_size <= 0:
            msg = f"Batch size must be a positive int, not "

            # Complete the error msg according which of
            # the two failure conditions occurred:
            msg += type(batch_size).__name__ if not isinstance(batch_size, int)\
                                           else f"{batch_size}"

            raise ValueError(msg)

        self.batch_size = batch_size

        self.num_folds = num_folds

        # Total num of batches served when
        # rotating through all folds is computed
        # the first time __len__() is called:

        self.num_batches = None
        self.curr_split_idx = -1

        super().__init__(dataset,
                         batch_size=batch_size,
                         sampler=self.sampler,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         prefetch_factor=prefetch_factor,
                         drop_last=drop_last)

    #------------------------------------
    # __len__
    #-------------------

    def __len__(self):
        '''
        Number of batches this loader will
        feed out. Example:
            o 12 samples total
            o  3 folds
            o  2 batch size
            o  4 samples in each split (12/3)
            o  2 batches per split (samples-each-split / batch-size)
            o  3 number of trips through folds
            o  2 number of folds in each of the 3
                 trips (num-folds - hold-out-fold)
            o 12 batches total: batches-per-fold * folds-per-trip * num-folds 
                   2*2*3 = 12
        '''

        # Compute number of batches only once:
        if self.num_batches is None:

            # This computation can surely be more
            # concise and direct. But it happens only
            # once, and this step by step is easier
            # on the eyes than one minimal expression:
            num_samples = len(self.sampler)

            if num_samples == 0:
                raise ValueError("No samples to serve.")

            # Rounded-down number of samples that fit into each fold.
            # Having 34 samples with 3 folds, that is 34/3 == ~11

            samples_per_fold = num_samples // self.num_folds

            # For training we get 2 folds worth of samples,
            # with one fold held out: 11*2 = 22

            samples_per_split = samples_per_fold * (self.num_folds - 1)

            # As many permutations as there are folds: 3 * 22: 66

            total_train_samples = self.num_folds * samples_per_split

            # Convert to batches. Assume batch_size of 2:
            # 66 // 2 = 33

            self.total_num_batches = total_train_samples // self.batch_size
            if self.total_num_batches == 0:
                self.log.warn(
                    f"Not enough data ({total_train_samples}) for even one batch (of size {self.batch_size})"
                )

            remainder_samples = total_train_samples % self.batch_size
            if not self.drop_last and remainder_samples > 0:
                # Add the final partially filled batch,
                # if num_samples not a multiple of batches += 1
                self.total_num_batches += 1

        return self.total_num_batches

    # #------------------------------------
    # # __iter__
    # #-------------------
    #
    def __iter__(self):
        # Call to __next__() returns
        # a generator, which does the
        # right thing with next(), list(),
        # and for loops. Return that iterator:

        return (self.__next__())

    #------------------------------------
    # __next__
    #-------------------

    def __next__(self):

        # Loop over all splits (i.e. over all
        # configurations of which fold is for
        # validation.

        # Get one list of sample IDs that
        # covers all train samples in one split.
        # And one list of sample IDs that
        # are to be used for validation in this
        # split.

        # Raise EndOfSplit exception at the end of
        # each split, i.e. when client is to validate.
        # When all splits are exhausted, raise StopIteration.

        for split_train_idxs, split_test_idxs in self.sampler:

            # Keep track of which split we are working
            # on. Needed only as info for client; not
            # used for logic in this method:

            self.curr_split_idx += 1

            # split_train_idxs has all sample IDs
            # to use for training in this split.
            # The split_test_idxs holds the left-out
            # sample IDs to use for testing once
            # the split_train_idxs have been served out
            # one batch at a time.

            # Set this split's test sample ids aside for client
            # to retrieve later via: get_split_test_sample_ids()
            # once they pulled all the batches of this
            # split:
            self.curr_test_sample_ids = []
            for sample_idx in split_test_idxs:
                self.curr_test_sample_ids.append(
                    self.dataset.sample_id_by_sample_idx(sample_idx))

            # Create one batch:

            num_train_sample_ids = len(split_train_idxs)
            num_batches = num_train_sample_ids // self.batch_size
            num_remainder_samples = num_train_sample_ids % self.batch_size
            batch_start_idx = 0

            # Create num_batches batches from the
            # training data of this split:

            for _batch_count in range(num_batches):

                batch = None
                # Truth labels for each sample in
                # the current batch:
                y = []
                batch_end_idx = batch_start_idx + self.batch_size
                curr_batch_range = range(batch_start_idx, batch_end_idx)

                for train_sample_idx in curr_batch_range:

                    # Index into the current split's list
                    # of training sample ids:
                    sample_idx = split_train_idxs[train_sample_idx]
                    # Get one pair: <img-tensor>, class_id_int:
                    (img_tensor,
                     label) = self.dataset.sample_by_idx(sample_idx)
                    expanded_img_tensor = unsqueeze(img_tensor, dim=0)
                    batch = (cat((batch, expanded_img_tensor), dim=0)
                             if batch is not None else expanded_img_tensor)
                    y.append(label)

                # Got one batch ready:
                yield (batch, torch.tensor(y))

                # Client consumed one batch in current split.
                # Next batch: Starts another batch size
                # samples onwards in the train split:

                batch_start_idx += self.batch_size

                # Put together next batch:
                continue

            # Done all full batches. Any partial batch
            # left over that we should include?

            if num_remainder_samples > 0 and not self.drop_last:
                batch = None
                y = []

                for sample_id in range(batch_start_idx, batch_start_idx +
                                       num_remainder_samples):
                    (img_tensor, label) = self.dataset[sample_id]
                    expanded_img_tensor = unsqueeze(img_tensor, dim=0)
                    batch = (cat((batch, expanded_img_tensor))
                             if batch is not None else expanded_img_tensor)
                    y.append(label)
                yield (batch, torch.tensor(y))

            # Let client know that all batches for one split
            # have been delivered by a None/None pair:

            raise EndOfSplit()

            # Next split:
            continue

    #------------------------------------
    # get_curr_fold_idx
    #-------------------

    def get_curr_fold_idx(self):
        return self.curr_split_idx

    #------------------------------------
    # get_split_test_sample_ids
    #-------------------

    def get_split_test_sample_ids(self):
        try:
            return self.curr_test_sample_ids
        except:
            return None

    #------------------------------------
    # validation_samples
    #-------------------

    def validation_samples(self):
        '''
        Generator that runs through every
        test sample_id of the current fold, 
        and feeds (<img_tensor, label) pairs.
        
           for (img_tensor, label) in my_bird_dataloader.validation_samples():
               <test model>
        '''

        for sample_id in self.get_split_test_sample_ids():
            yield self.dataset[sample_id]

    #------------------------------------
    # file_from_sample_id
    #-------------------

    def file_from_sample_id(self, sample_id):
        '''
        Given a sample_id, return the absolute
        file path of the corresponding sample
        in the file system.
        
        We use the public dataset method.
        
        :param sample_id: sample ID to look up
        :type sample_id: int
        '''
        return self.dataset.file_from_sample_id(sample_id)

    #------------------------------------
    # class_from_sample_id
    #-------------------

    def class_from_sample_id(self, sample_id):
        '''
        Given a sample ID, return its class index.
        
        :param sample_id: ID to look up
        :type sample_id: int
        :return: given sample's class ID
        :rtype: int
        '''
        return self.dataset.sample_id_to_class[sample_id]

    #------------------------------------
    # set_epoch
    #-------------------

    def set_epoch(self, new_epoch):
        '''
        Must be called by client every time
        a new epoch starts. The epoch number
        is used by the sampler to shuffle
        the dataset before beginning to draw
        samples.

        :param new_epoch: the epoch under which the dataloader
            is (re)started
        :type new_epoch: int
        '''
        self.sampler.set_epoch(new_epoch)