Exemplo n.º 1
0
    def train(self, print_freq=10, fixbase_epoch=0, open_layers=None):
        losses = MetricMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        accuracy = AverageMeter()

        self.set_model_mode('train')

        self.two_stepped_transfer_learning(
            self.epoch, fixbase_epoch, open_layers
        )

        self.num_batches = len(self.train_loader)
        end = time.time()
        for self.batch_idx, data in enumerate(self.train_loader):
            data_time.update(time.time() - end)
            loss_summary, avg_acc = self.forward_backward(data)
            batch_time.update(time.time() - end)
            losses.update(loss_summary)
            accuracy.update(avg_acc)

            if (self.batch_idx + 1) % print_freq == 0:
                nb_this_epoch = self.num_batches - (self.batch_idx + 1)
                nb_future_epochs = (self.max_epoch - (self.epoch + 1)) * self.num_batches
                eta_seconds = batch_time.avg * (nb_this_epoch+nb_future_epochs)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'epoch: [{0}/{1}][{2}/{3}]\t'
                    'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'cls acc {accuracy.val:.3f} ({accuracy.avg:.3f})\t'
                    'eta {eta}\t'
                    '{losses}\t'
                    'lr {lr:.6f}'.format(
                        self.epoch + 1,
                        self.max_epoch,
                        self.batch_idx + 1,
                        self.num_batches,
                        batch_time=batch_time,
                        data_time=data_time,
                        accuracy=accuracy,
                        eta=eta_str,
                        losses=losses,
                        lr=self.get_current_lr()
                    )
                )

            if self.writer is not None:
                n_iter = self.epoch * self.num_batches + self.batch_idx
                self.writer.add_scalar('Train/time', batch_time.avg, n_iter)
                self.writer.add_scalar('Train/data', data_time.avg, n_iter)
                self.writer.add_scalar('Aux/lr', self.get_current_lr(), n_iter)
                self.writer.add_scalar('Accuracy/train', accuracy.avg, n_iter)
                for name, meter in losses.meters.items():
                    self.writer.add_scalar('Loss/' + name, meter.avg, n_iter)

            end = time.time()

        self.update_lr()
Exemplo n.º 2
0
    def train(self,
              print_freq=10,
              fixbase_epoch=0,
              open_layers=None,
              lr_finder=False,
              perf_monitor=None,
              stop_callback=None):
        losses = MetricMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        accuracy = AverageMeter()

        self.set_model_mode('train')

        if not self._should_freeze_aux_models(self.epoch):
            # NB: it should be done before `two_stepped_transfer_learning`
            # to give possibility to freeze some layers in the unlikely event
            # that `two_stepped_transfer_learning` is used together with nncf
            self._unfreeze_aux_models()

        self.two_stepped_transfer_learning(self.epoch, fixbase_epoch,
                                           open_layers)

        if self._should_freeze_aux_models(self.epoch):
            self._freeze_aux_models()

        self.num_batches = len(self.train_loader)
        end = time.time()
        for self.batch_idx, data in enumerate(self.train_loader):
            if perf_monitor and not lr_finder:
                perf_monitor.on_train_batch_begin(self.batch_idx)

            data_time.update(time.time() - end)

            if self.compression_ctrl:
                self.compression_ctrl.scheduler.step(self.batch_idx)

            loss_summary, avg_acc = self.forward_backward(data)
            batch_time.update(time.time() - end)

            losses.update(loss_summary)
            accuracy.update(avg_acc)
            if perf_monitor and not lr_finder:
                perf_monitor.on_train_batch_end(self.batch_idx)

            if not lr_finder and (((self.batch_idx + 1) % print_freq) == 0
                                  or self.batch_idx == self.num_batches - 1):
                nb_this_epoch = self.num_batches - (self.batch_idx + 1)
                nb_future_epochs = (self.max_epoch -
                                    (self.epoch + 1)) * self.num_batches
                eta_seconds = batch_time.avg * (nb_this_epoch +
                                                nb_future_epochs)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('epoch: [{0}/{1}][{2}/{3}]\t'
                      'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'cls acc {accuracy.val:.3f} ({accuracy.avg:.3f})\t'
                      'eta {eta}\t'
                      '{losses}\t'
                      'lr {lr:.6f}'.format(self.epoch + 1,
                                           self.max_epoch,
                                           self.batch_idx + 1,
                                           self.num_batches,
                                           batch_time=batch_time,
                                           data_time=data_time,
                                           accuracy=accuracy,
                                           eta=eta_str,
                                           losses=losses,
                                           lr=self.get_current_lr()))

            if self.writer is not None and not lr_finder:
                n_iter = self.epoch * self.num_batches + self.batch_idx
                self.writer.add_scalar('Train/time', batch_time.avg, n_iter)
                self.writer.add_scalar('Train/data', data_time.avg, n_iter)
                self.writer.add_scalar('Aux/lr', self.get_current_lr(), n_iter)
                self.writer.add_scalar('Accuracy/train', accuracy.avg, n_iter)
                for name, meter in losses.meters.items():
                    self.writer.add_scalar('Loss/' + name, meter.avg, n_iter)

            end = time.time()
            self.current_lr = self.get_current_lr()
            if stop_callback and stop_callback.check_stop():
                break
            if not lr_finder and self.use_ema_decay:
                self.ema_model.update(self.models[self.main_model_name])
            if self.per_batch_annealing:
                self.update_lr()

        return losses.meters['loss'].avg