コード例 #1
0
ファイル: shot.py プロジェクト: zhaoxin94/Dassl.pytorch
    def run_epoch(self):
        self.set_model_mode('train')
        losses = MetricMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        # Decide to iterate over labeled or unlabeled dataset
        len_train_loader_u = len(self.train_loader_u)
        self.num_batches = len_train_loader_u
        train_loader_u_iter = iter(self.train_loader_u)

        # self.fix_model.load_state_dict(self.model.state_dict())
        # self.fix_model.eval()
        self.netB = copy.deepcopy(self.model.backbone)
        self.netB.eval()

        if self.model.head is not None:
            self.netH = copy.deepcopy(self.model.head)
            self.netH.eval()

        self.center = self.obtain_center()

        end = time.time()
        for self.batch_idx in range(self.num_batches):
            batch_u = next(train_loader_u_iter)
            data_time.update(time.time() - end)
            loss_summary = self.forward_backward(batch_u)
            batch_time.update(time.time() - end)
            losses.update(loss_summary)

            if (self.batch_idx + 1) % self.cfg.TRAIN.PRINT_FREQ == 0:
                nb_this_epoch = self.num_batches - (self.batch_idx + 1)
                nb_future_epochs = (self.max_epoch -
                                    (self.epoch + 1)) * self.num_batches
                eta_seconds = batch_time.avg * (nb_this_epoch +
                                                nb_future_epochs)
                eta = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('epoch [{0}/{1}][{2}/{3}]\t'
                      'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'eta {eta}\t'
                      '{losses}\t'
                      'lr {lr}'.format(self.epoch + 1,
                                       self.max_epoch,
                                       self.batch_idx + 1,
                                       self.num_batches,
                                       batch_time=batch_time,
                                       data_time=data_time,
                                       eta=eta,
                                       losses=losses,
                                       lr=self.get_current_lr()))

            n_iter = self.epoch * self.num_batches + self.batch_idx
            for name, meter in losses.meters.items():
                self.write_scalar('train/' + name, meter.avg, n_iter)

            end = time.time()
コード例 #2
0
    def run_epoch(self):
        self.set_model_mode('train')
        losses = MetricMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        self.num_batches = len(self.train_loader_x)

        end = time.time()
        for self.batch_idx, batch in enumerate(self.train_loader_x):
            data_time.update(time.time() - end)
            loss_summary = self.forward_backward(batch)
            batch_time.update(time.time() - end)
            losses.update(loss_summary)

            if (self.batch_idx + 1) % self.cfg.TRAIN.PRINT_FREQ == 0:
                nb_this_epoch = self.num_batches - (self.batch_idx + 1)
                nb_future_epochs = (
                    self.max_epoch - (self.epoch + 1)
                ) * self.num_batches
                eta_seconds = batch_time.avg * (nb_this_epoch+nb_future_epochs)
                eta = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'epoch [{0}/{1}][{2}/{3}]\t'
                    'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'eta {eta}\t'
                    '{losses}\t'
                    'lr {lr}'.format(
                        self.epoch + 1,
                        self.max_epoch,
                        self.batch_idx + 1,
                        self.num_batches,
                        batch_time=batch_time,
                        data_time=data_time,
                        eta=eta,
                        losses=losses,
                        lr=self.get_current_lr()
                    )
                )

            n_iter = self.epoch * self.num_batches + self.batch_idx
            for name, meter in losses.meters.items():
                self.write_scalar('train/' + name, meter.avg, n_iter)

            end = time.time()
コード例 #3
0
    def run_epoch(self):
        self.set_model_mode('train')
        losses = MetricMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        # Decide to iterate over labeled or unlabeled dataset
        len_train_loader_x = len(self.train_loader_x)
        len_train_loader_u = len(self.train_loader_u)
        if self.cfg.TRAIN.COUNT_ITER == 'train_x':
            self.num_batches = len_train_loader_x
        elif self.cfg.TRAIN.COUNT_ITER == 'train_u':
            self.num_batches = len_train_loader_u
        elif self.cfg.TRAIN.COUNT_ITER == 'smaller_one':
            self.num_batches = min(len_train_loader_x, len_train_loader_u)
        else:
            raise ValueError

        train_loader_x_iter = iter(self.train_loader_x)
        train_loader_u_iter = iter(self.train_loader_u)

        end = time.time()
        for self.batch_idx in range(self.num_batches):
            try:
                batch_x = next(train_loader_x_iter)
            except StopIteration:
                train_loader_x_iter = iter(self.train_loader_x)
                batch_x = next(train_loader_x_iter)

            try:
                batch_u = next(train_loader_u_iter)
            except StopIteration:
                train_loader_u_iter = iter(self.train_loader_u)
                batch_u = next(train_loader_u_iter)

            data_time.update(time.time() - end)
            loss_summary = self.forward_backward(batch_x, batch_u)
            batch_time.update(time.time() - end)
            losses.update(loss_summary)

            if (self.batch_idx + 1) % self.cfg.TRAIN.PRINT_FREQ == 0:
                nb_this_epoch = self.num_batches - (self.batch_idx + 1)
                nb_future_epochs = (self.max_epoch -
                                    (self.epoch + 1)) * self.num_batches
                eta_seconds = batch_time.avg * (nb_this_epoch +
                                                nb_future_epochs)
                eta = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('epoch [{0}/{1}][{2}/{3}]\t'
                      'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'eta {eta}\t'
                      '{losses}\t'
                      'lr {lr}'.format(self.epoch + 1,
                                       self.max_epoch,
                                       self.batch_idx + 1,
                                       self.num_batches,
                                       batch_time=batch_time,
                                       data_time=data_time,
                                       eta=eta,
                                       losses=losses,
                                       lr=self.get_current_lr()))

            n_iter = self.epoch * self.num_batches + self.batch_idx
            for name, meter in losses.meters.items():
                self.write_scalar('train/' + name, meter.avg, n_iter)
            self.write_scalar('train/lr', self.get_current_lr(), n_iter)

            end = time.time()