Exemplo n.º 1
0
    def train_one_epoch(self, epoch, verbose=True):
        """Train for one epoch"""
        self.load_datasets()

        with collectors_context(
                self.activations_collectors["train"]) as collectors:
            top1, top5, loss = train(self.train_loader,
                                     self.model,
                                     self.criterion,
                                     self.optimizer,
                                     epoch,
                                     self.compression_scheduler,
                                     loggers=[self.tflogger, self.pylogger],
                                     args=self.args)
            if verbose:
                distiller.log_weights_sparsity(self.model, epoch,
                                               [self.tflogger, self.pylogger])
            distiller.log_activation_statistics(
                epoch,
                "train",
                loggers=[self.tflogger],
                collector=collectors["sparsity"])
            if self.args.masks_sparsity:
                msglogger.info(
                    distiller.masks_sparsity_tbl_summary(
                        self.model, self.compression_scheduler))
        return top1, top5, loss
Exemplo n.º 2
0
    def validate_one_epoch(self, epoch, verbose=True):
        """Evaluate on validation set"""
        self.load_datasets()
        with collectors_context(
                self.activations_collectors["valid"]) as collectors:
            top1, top5, vloss = validate(self.val_loader, self.model,
                                         self.criterion, [self.pylogger],
                                         self.args, epoch)
            distiller.log_activation_statistics(
                epoch,
                "valid",
                loggers=[self.tflogger],
                collector=collectors["sparsity"])
            save_collectors_data(collectors, msglogger.logdir)

        if verbose:
            stats = ('Performance/Validation/',
                     OrderedDict([('Loss', vloss), ('Top1', top1),
                                  ('Top5', top5)]))
            distiller.log_training_progress(stats,
                                            None,
                                            epoch,
                                            steps_completed=0,
                                            total_steps=1,
                                            log_freq=1,
                                            loggers=[self.tflogger])
        return top1, top5, vloss
Exemplo n.º 3
0
def test(test_loader, model, criterion, loggers, activations_collectors, args):
    """Model Test"""
    msglogger.info('--- test ---------------------')
    if activations_collectors is None:
        activations_collectors = create_activation_stats_collectors(model, None)
    with collectors_context(activations_collectors["test"]) as collectors:
        top1, top5, lossses = _validate(test_loader, model, criterion, loggers, args)
        distiller.log_activation_statistics(-1, "test", loggers, collector=collectors['sparsity'])
        save_collectors_data(collectors, msglogger.logdir)
    return top1, top5, lossses