예제 #1
0
    def train(self, num_epochs, start_epoch=0):
        """
        starts the training, logs loss and metrics in logging file and prints progress
        in the console, including an ETA. Also stores snapshots of current model each epoch.

        :param num_epochs:
            number of epochs to train
        :param start_epoch:
            the first epoch, default to 0.
            Can be set higher for finetuning, etc.
        """
        try:
            rem = ETATimer(num_epochs - start_epoch)
            for epoch in range(start_epoch+1, num_epochs+1):
                self.state['epoch'] = epoch
                if not self.state['quiet']:
                    print('Epoch', epoch)
                self.print_info(epoch)
                train_metrics = self.train_epoch()
                self.logger.info(epoch=epoch, phase='train', metrics=train_metrics)
                if self.state['val_iter'] is not None:
                    val_metrics = self.validate_epoch(epoch)
                    self.logger.info(epoch=epoch, phase='val', metrics=val_metrics)
                self.snapshot(epoch)
                if not self.state['quiet']:
                    print('ETA:', rem())
            return self.state
        except Exception as e:
            logging.exception(e)
            raise
        finally:
            self.logger.info(msg='Finished!')
예제 #2
0
    def print_info(self, epoch):
        """
        prints and logs current learning rates as well as the epoch.

        :param epoch: the current epoch.
        """
        if not self.state['quiet']:
            s = 'learning rates ' + (', '.join(map(str, self._lrs())))
            print(s)
            self.logger.info(epoch=epoch, lrs=self._lrs())
    def train(self, num_epochs, start_epoch=0):
        """
        starts the training, logs loss and metrics in logging file and prints progress
        in the console, including an ETA. Also stores snapshots of current model each epoch.

        """
        try:
            rem = ETATimer(num_epochs - start_epoch)
            for epoch in range(start_epoch + 1, num_epochs + 1):
                self.state['epoch'] = epoch
                if not self.state['quiet']:
                    print('Epoch', epoch)
                self.print_info(epoch)
                train_metrics = self.train_epoch()
                self.logger.info(epoch=epoch,
                                 phase='train',
                                 metrics=train_metrics)
                if self.state['val_iter'] is not None:
                    val_metrics = self.validate_epoch(epoch)
                    self.logger.info(epoch=epoch,
                                     phase='val',
                                     metrics=val_metrics)
                self.snapshot(epoch)
                if not self.state['quiet']:
                    print('ETA:', rem())
            return self.state
        except Exception as e:
            logging.exception(e)
            raise
        finally:
            # save loss plot
            path = pt.join(ROOT, '../res/plots/',
                           pt.basename(self.state['outdir']) + '/')
            if not pt.exists(path):
                os.makedirs(path)
            losses = [k['loss'] for k in self.state['train_metric_values']]
            fig = plt.figure()
            plt.plot(losses)
            plt.legend('loss')
            plt.savefig(path + 'loss.png')
            plt.close(fig)

            values = [k['mse'] for k in self.state['train_metric_values']]

            fig = plt.figure()
            plt.plot(values)
            plt.legend(('MSE metric'))
            plt.savefig(path + 'AE_tiles_metric.png')
            plt.close(fig)
            self.logger.info(msg='Finished!')
예제 #4
0
    def train(self, num_epochs, start_epoch=0):
        """
        starts the training, logs loss and metrics in logging file and prints progress
        in the console, including an ETA. Also stores snapshots of current model each epoch.

        :param num_epochs:
            number of epochs to train
        :param start_epoch:
            the first epoch, default to 0.
            Can be set higher for finetuning, etc.
        """
        try:
            rem = ETATimer(num_epochs - start_epoch)
            for epoch in range(start_epoch + 1, num_epochs + 1):
                self.state['epoch'] = epoch
                if not self.state['quiet']:
                    print('Epoch', epoch)
                self.print_info(epoch)
                train_metrics = self.train_epoch()
                self.logger.info(epoch=epoch,
                                 phase='train',
                                 metrics=train_metrics)
                if self.state['val_iter'] is not None:
                    val_metrics = self.validate_epoch(epoch)
                    self.logger.info(epoch=epoch,
                                     phase='val',
                                     metrics=val_metrics)
                self.snapshot(epoch)
                if not self.state['quiet']:
                    print('ETA:', rem())
            return self.state
        except Exception as e:
            logging.exception(e)
            raise
        finally:
            # save loss plot
            path = pt.join(ROOT, '../res/plots/')
            fig = plt.figure()
            losses = [
                k['total_loss'] for k in self.state['train_metric_values']
            ]
            ae_loss = [k['AE_loss'] for k in self.state['train_metric_values']]
            cl_loss = [
                k['classifier_loss'] for k in self.state['train_metric_values']
            ]
            plt.plot((losses, ae_loss, cl_loss))
            plt.legend(('loss', 'AE_loss', 'classifier_loss'))
            plt.savefig(path + 'jigsaw_loss.png')
            self.logger.info(msg='Finished!')
예제 #5
0
    def train(self, num_epochs, start_epoch=0):
        """
        starts the training, logs loss and metrics in logging file and prints progress
        in the console, including an ETA. Also stores snapshots of current model each epoch.

        :param num_epochs:
            number of epochs to train
        :param start_epoch:
            the first epoch, default to 0.
            Can be set higher for finetuning, etc.
        """
        try:
            rem = ETATimer(num_epochs - start_epoch)
            for epoch in range(start_epoch + 1, num_epochs + 1):
                self.state['epoch'] = epoch
                if not self.state['quiet']:
                    print('Epoch', epoch)
                self.print_info(epoch)
                train_metrics = self.train_epoch()
                self.logger.info(epoch=epoch,
                                 phase='train',
                                 metrics=train_metrics)
                if self.state['val_iter'] is not None:
                    val_metrics = self.validate_epoch(epoch)
                    self.logger.info(epoch=epoch,
                                     phase='val',
                                     metrics=val_metrics)
                self.snapshot(epoch)
                if not self.state['quiet']:
                    print('ETA:', rem())
            return self.state
        except Exception as e:
            logging.exception(e)
            raise
        finally:
            # here visualize the loss
            path = pt.join(ROOT, '../res/plots/')

            loss_train = [k['loss'] for k in self.state['train_metric_values']]
            loss_val = [k['loss'] for k in self.state['val_metric_values']]
            fig = plt.figure()
            plt.plot(loss_train)
            plt.plot(loss_val)
            plt.title('Imagenet Training')
            plt.xlabel('epochs')
            plt.ylabel('CrossEntropyLoss')
            plt.legend(('train_loss', 'val_loss'), loc='upper right')
            plt.savefig(path + 'loss.png')
            plt.close(fig)

            acc_train = [
                k['top-1 acc'] for k in self.state['train_metric_values']
            ]
            acc_val = [k['top-1 acc'] for k in self.state['val_metric_values']]
            fig = plt.figure()
            plt.plot(acc_train)
            plt.plot(acc_val)
            plt.title('Imagenet Training')
            plt.xlabel('epochs')
            plt.ylabel('top-1 acc')
            plt.legend(('train_acc', 'val_acc'), loc='upper left')
            plt.savefig(path + 'metric.png')
            plt.close(fig)

            self.logger.info(msg='Finished!')
예제 #6
0
    def train(self, num_epochs, start_epoch=0):
        """
        starts the training, logs loss and metrics in logging file and prints progress
        in the console, including an ETA. Also stores snapshots of current model each epoch.

        """
        try:
            rem = ETATimer(num_epochs - start_epoch)
            for epoch in range(start_epoch + 1, num_epochs + 1):
                self.state['epoch'] = epoch
                if not self.state['quiet']:
                    print('Epoch', epoch)
                self.print_info(epoch)
                train_metrics = self.train_epoch()
                self.logger.info(epoch=epoch,
                                 phase='train',
                                 metrics=train_metrics)
                if self.state['val_iter'] is not None:
                    val_metrics = self.validate_epoch(epoch)
                    self.logger.info(epoch=epoch,
                                     phase='val',
                                     metrics=val_metrics)
                self.snapshot(epoch)
                if not self.state['quiet']:
                    print('ETA:', rem())
            return self.state
        except Exception as e:
            logging.exception(e)
            raise
        finally:
            # save loss plot
            path = pt.join(ROOT, '../../res/plots/',
                           pt.basename(self.state['outdir']) + '/')
            if not pt.exists(path):
                os.makedirs(path)

            total_loss = [k['loss'] for k in self.state['train_metric_values']]
            ae_loss = [k['ae_loss'] for k in self.state['train_metric_values']]
            cl_loss = [k['cl_loss'] for k in self.state['train_metric_values']]
            accuracy = [
                k['top-1 acc'] for k in self.state['train_metric_values']
            ]
            sigma_cl = [
                k['sigma_cl'] for k in self.state['train_metric_values']
            ]
            sigma_ae = [
                k['sigma_ae'] for k in self.state['train_metric_values']
            ]
            names = [
                'total_loss', 'ae_loss', 'cl_loss', 'accuracy', 'sigma_cl',
                'sigma_ae'
            ]
            lists = [
                total_loss, ae_loss, cl_loss, accuracy, sigma_cl, sigma_ae
            ]
            for name, value in zip(names, lists):
                fig = plt.figure()
                plt.plot(value)
                plt.xlabel('epochs')
                plt.ylabel(name)
                plt.savefig(path + '{}.png'.format(name))
                plt.close(fig)

            self.logger.info(msg='Finished!')
    def train(self, num_epochs, start_epoch=0):
        """
        starts the training, logs loss and metrics in logging file and prints progress
        in the console, including an ETA. Also stores snapshots of current model each epoch.

        :param num_epochs:
            number of epochs to train
        :param start_epoch:
            the first epoch, default to 0.
            Can be set higher for finetuning, etc.
        """
        try:
            rem = ETATimer(num_epochs - start_epoch)
            for epoch in range(start_epoch + 1, num_epochs + 1):
                self.state['epoch'] = epoch
                if not self.state['quiet']:
                    print('Epoch', epoch)
                self.print_info(epoch)
                train_metrics = self.train_epoch()
                self.logger.info(epoch=epoch,
                                 phase='train',
                                 metrics=train_metrics)
                if self.state['val_iter'] is not None:
                    val_metrics = self.validate_epoch(epoch)
                    self.logger.info(epoch=epoch,
                                     phase='val',
                                     metrics=val_metrics)
                self.snapshot(epoch)
                if not self.state['quiet']:
                    print('ETA:', rem())
            return self.state
        except Exception as e:
            logging.exception(e)
            raise
        finally:
            # save loss plot
            path = pt.join(ROOT, '../res/plots/',
                           pt.basename(self.state['outdir']) + '/')
            if not pt.exists(path):
                os.makedirs(path)
            losses = [
                k['total_loss'] for k in self.state['train_metric_values']
            ]

            fig = plt.figure()
            plt.plot(losses, label='total_loss')
            plt.legend(loc='upper right')
            plt.savefig(path + 'total_loss.png')
            plt.close(fig)

            ae_loss = [k['AE_loss'] for k in self.state['train_metric_values']]
            fig = plt.figure()
            plt.plot(ae_loss, label='ae_loss')
            plt.legend(loc='upper right')
            plt.savefig(path + 'ae_loss.png')
            plt.close(fig)

            cl_loss = [
                k['classifier_loss'] for k in self.state['train_metric_values']
            ]
            fig = plt.figure()
            plt.plot(cl_loss, label='classifier_loss')
            plt.legend(loc='upper right')
            plt.savefig(path + 'cl_loss.png')
            plt.close(fig)

            metric_ = [
                k['top-1 acc'] for k in self.state['train_metric_values']
            ]
            fig = plt.figure()
            plt.plot(metric_, label='metric')
            plt.legend(loc='upper right')
            plt.savefig(path + 'accuracy.png')
            plt.close(fig)

            self.logger.info(msg='Finished!')