Exemple #1
0
    def train(self):
        self._init_params()

        for epoch in range(1, self.epochs + 1):
            self.monitor.update()
            self.model_adapter.set_epoch(epoch)

            train_loss = self._run_epoch(epoch)
            val_loss, metrics, batch_sample = self._validate()
            self.scheduler.step(epoch=epoch)

            if self.monitor.should_save_checkpoint():
                self.monitor.reset()
                self._save_checkpoint(
                    file_prefix="model_epoch_{}".format(epoch))
            self._set_checkpoint(val_loss)

            logger.info(
                "\nEpoch: {}; train loss = {}; validation loss = {}".format(
                    epoch, train_loss, val_loss))

            self.model_adapter.write_to_tensorboard(epoch, train_loss,
                                                    val_loss, batch_sample)

        self.model_adapter.on_training_end()
Exemple #2
0
    def train(self):
        self._init_params()

        for epoch in range(1, self.epochs + 1):
            self.monitor.update()
            self.model_adapter.set_epoch(epoch)

            train_loss = self._run_epoch(epoch)
            val_loss, metrics, batch_sample = self._validate()
            self.scheduler.step(epoch=epoch)
            # self.scheduler.step(metrics[self.model_adapter.main_metric] if self.model_adapter.main_metric != 'loss'
            #                     else val_loss)
            if self.monitor.should_save_checkpoint():
                self.monitor.reset()
                self._save_checkpoint(file_prefix=f"model_epoch_{epoch}")
            self._set_checkpoint(val_loss)

            logger.info(
                f"\nEpoch: {epoch}; train loss = {train_loss}; validation loss = {val_loss}"
            )

            self.model_adapter.write_to_tensorboard(epoch, train_loss,
                                                    val_loss, batch_sample)

        self.model_adapter.on_training_end()
Exemple #3
0
    def _validate(self, epoch: int):
        # switch model to eval mode
        self.model.eval()

        with torch.no_grad():
            self.metric_counter.clear()

            for i, data in enumerate(
                    tqdm(self.val_dataset,
                         postfix=self.metric_counter.loss_message())):
                images, targets = self.model_adapter.get_input(data)
                outputs = self.model(images)
                loss = self.criterion(outputs, targets)
                _, loss_dict = self.model_adapter.get_loss(loss)
                self.metric_counter.add_losses(loss_dict)

                # calculate metrics
                metrics = self.model_adapter.get_metrics(outputs, targets)
                self.metric_counter.add_metrics(metrics)

                if i >= self.validation_steps:
                    logger.info(
                        "Validation steps reach max={}, breaking validation.".
                        format(self.validation_steps))
                    break

            self.metric_counter.write_to_tensorboard(epoch, validation=True)
Exemple #4
0
 def __init__(self,
              imgs: Sequence[str],
              ):
     self.imgs = imgs
     self.normalize_fn = get_normalize()
     self.approx_img_size = 384
     logger.info(f'Dataset has been created with {len(self.imgs)} samples')
Exemple #5
0
    def _train_epoch(self, epoch: int):
        # switch model to train mode
        self.model.train()

        self.metric_counter.clear()
        lr = self.optimizer.param_groups[0]["lr"]
        for i, data in enumerate(
                tqdm(
                    self.train_dataset,
                    desc="Epoch: {}, lr: {}".format(epoch, lr),
                    postfix=self.metric_counter.loss_message(),
                )):
            images, targets = self.model_adapter.get_input(data)
            outputs = self.model(images)

            self.optimizer.zero_grad()
            loss = self.criterion(outputs, targets)

            total_loss, loss_dict = self.model_adapter.get_loss(loss)
            total_loss.backward()
            self.optimizer.step()
            self.metric_counter.add_losses(loss_dict)

            if i >= self.steps_per_epoch:
                logger.info(
                    "Steps per epoch reach max={}, breaking training.".format(
                        self.steps_per_epoch))
                break

        self.metric_counter.write_to_tensorboard(epoch)
        logger.info("Mean loss: {} for epoch: {}".format(
            self.metric_counter.get_loss(), epoch))
Exemple #6
0
    def from_config(config):
        files_a = sorted(glob(config["files_a"], recursive=True))
        files_b = sorted(glob(config["files_b"], recursive=True))

        logger.info("files_a read: {} files., files_b read: {} files.".format(
            len(files_a), len(files_b)))

        names = list(map(lambda path: splitext(basename(path))[0], files_a))

        transform = get_transforms(config["transform"])

        # ToDo: make augmentations more customizible via transform

        hash_fn = hash_from_paths
        # ToDo: add more hash functions
        verbose = config.get("verbose", True)
        data = dataset.subsample(
            data=zip(files_a, files_b, names),
            bounds=config.get("bounds", (0, 1)),
            hash_fn=hash_fn,
            verbose=verbose,
        )

        files_a, files_b, names = map(list, zip(*data))

        return PairedDataset(
            files_a=files_a,
            files_b=files_b,
            names=names,
            preload=config["preload"],
            preload_size=config["preload_size"],
            transform=transform,
            verbose=verbose,
        )
Exemple #7
0
    def __init__(self,
                 files_a: Tuple[str],
                 files_b: Tuple[str],
                 transform_fn: Callable,
                 normalize_fn: Callable,
                 corrupt_fn: Optional[Callable] = None,
                 preload: bool = True,
                 preload_size: Optional[int] = 0,
                 verbose=True):

        assert len(files_a) == len(files_b)

        self.preload = preload
        self.data_a = files_a
        self.data_b = files_b
        self.verbose = verbose
        self.corrupt_fn = corrupt_fn
        self.transform_fn = transform_fn
        self.normalize_fn = normalize_fn
        logger.info(
            f'Dataset has been created with {len(self.data_a)} samples')

        if preload:
            preload_fn = partial(self._bulk_preload, preload_size=preload_size)
            if files_a == files_b:
                self.data_a = self.data_b = preload_fn(self.data_a)
            else:
                self.data_a, self.data_b = map(preload_fn,
                                               (self.data_a, self.data_b))
            self.preload = True
    def __init__(self,
                 files: Tuple[str],
                 transform_fn: Callable,
                 normalize_fn: Callable,
                 corrupt_fn: Optional[Callable] = None,
                 soften_fn: Optional[Callable] = None,
                 preload: bool = True,
                 preload_size: Optional[int] = 0,
                 mixup: float = 0,
                 verbose=True):

        self.size = preload_size
        self.preload = preload
        self.imgs = files
        self.labels = [self._get_label(f) for f in files]
        self.verbose = verbose
        self.corrupt_fn = corrupt_fn
        self.transform_fn = transform_fn
        self.normalize_fn = normalize_fn
        self.soften_fn = soften_fn
        self.mixup_proba = mixup
        logger.info(f'Dataset has been created with {len(self.imgs)} samples')

        if preload:
            preload_fn = partial(self._bulk_preload, preload_size=preload_size)
            self.imgs = preload_fn(self.imgs)
Exemple #9
0
    def fit_one_epoch(self, n_epoch):
        losses, accs = self._train_epoch(n_epoch)
        val_losses, val_accs = self._val_epoch(n_epoch)

        train_loss = np.mean(losses)
        val_loss = np.mean(val_losses)
        train_acc = np.mean(accs)
        val_acc = np.mean(val_accs)
        msg = f'Epoch {n_epoch}: train loss is {train_loss:.3f}, train accuracy {train_acc:.3f}, ' \
              f'val loss {val_loss:.3f}, val accuracy {val_acc:.3f}'
        logger.info(msg)

        self.scheduler.step(metrics=val_loss, epoch=n_epoch)

        metric = val_acc
        if metric > self.current_metric:
            self.current_metric = metric
            self.last_improvement = n_epoch
            save(self.model, f=self.checkpoint)
            logger.info(f'Best model has been saved at {n_epoch}, accuracy is {metric:.4f}')
        else:
            if self.last_improvement + self.early_stop < n_epoch:
                return True, (train_loss, val_loss, train_acc, val_acc)

        return False, (train_loss, val_loss, train_acc, val_acc)
Exemple #10
0
    def fit_one_epoch(self, n_epoch):
        self.model.train(True)
        losses, reg_losses = [], []

        for i, (x, y) in enumerate(self.train):
            x, y = x.to(self.device), y.to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(x)

            loss = self.loss_fn(outputs, y)
            losses.append(loss.item())

            for param in self.model.model.parameters():
                loss += self.reg_lambda * torch.norm(param, p=self.reg_norm)

            reg_loss = loss.item()
            reg_losses.append(reg_loss)

            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), 1)

            self.optimizer.step()

        self.model.train(False)

        val_losses = []
        y_pred_acc, y_true_acc = [], []

        with torch.no_grad():
            for i, (x, y) in enumerate(self.val):
                x, y = x.to(self.device), y.to(self.device)
                outputs = self.model(x)

                loss = self.loss_fn(outputs, y)
                val_losses.append(loss.item())

                y_pred_acc.append(outputs.detach().cpu().numpy())
                y_true_acc.append(y.detach().cpu().numpy())

        train_loss = np.mean(losses)
        train_reg_loss = np.mean(reg_losses)
        val_loss = np.mean(val_losses)
        msg = f'Epoch {n_epoch}: train loss is {train_loss:.5f} (raw), {train_reg_loss:.5f} (reg); val loss is {val_loss:.5f}'
        logger.info(msg)

        self.scheduler.step(metrics=val_loss, epoch=n_epoch)
        y_true_acc, y_pred_acc = map(np.vstack, (y_true_acc, y_pred_acc))

        metric = self.evaluate(y_pred=y_pred_acc, y_true=y_true_acc)

        if metric > self.current_metric:
            self.current_metric = metric
            self.last_improvement = n_epoch
            save(self.model, f=self.checkpoint)
            logger.info(
                f'Best model has been saved at {n_epoch}, accuracy is {metric:.4f}'
            )

        return train_loss, val_loss, metric
Exemple #11
0
def map_classes(y_full):
    classes = sorted(list(set(y_full)))
    mapping = {}
    for i, y in enumerate(classes):
        mapping[y] = i

    logger.info(f'Mapping is {mapping}')
    return np.array([mapping[y] for y in y_full])
 def inner_f(*args, **options):
     try:
         return f(*args, **options)
     except KeyboardInterrupt:
         logger.info('Parsing stopped with KeyboardInterrupt')
         sys.exit()
     except Exception:
         logger.exception('Parsing failed at {}'.format(args))
Exemple #13
0
def update_config(config, params):
    for k, v in params.items():
        *path, key = k.split('.')
        conf = config
        for p in path:
            if p not in conf:
                logger.error(f'Overwriting non-existing attribute {k} = {v}')
            conf = conf[p]
        logger.info(f'Overwriting {k} = {v} (was {conf.get(key)})')
        conf[key] = v
Exemple #14
0
 def _set_checkpoint(self):
     """ Saves model weights in the last checkpoint.
     Also, model is saved as the best model if model has the best metric
     """
     if self.metric_counter.update_best_model():
         torch.save(
             {'model': self.model_adapter.get_model_export(self.model)},
             osp.join(self.config['experiment']['folder'],
                      self.config['experiment']['name'], 'best.h5'))
     torch.save({'model': self.model_adapter.get_model_export(self.model)},
                osp.join(self.config['experiment']['folder'],
                         self.config['experiment']['name'], 'last.h5'))
     logger.info(self.metric_counter.loss_message())
def run():
    for c in cameras:
        logger.info(f'Working with {c}')
        photos = []
        for i in range(10):
            url = cameras[c].replace('ID', str(i + 1))
            photos += parse_page(url)
        logger.info(f'Page URLs parsed for {c}: {len(photos)}')

        for img in tqdm(photos):
            parsed = parse_photo(img)
            if parsed:
                src, soft = parsed
                yield {'camera': c, 'url': src, 'soft': soft}
Exemple #16
0
def main(preload=False, parallel=True):
    models = ('densenet121',)
    dropouts = (25, 50)
    folds = (0, 1, 2)

    for d in dropouts:
        for name in models:
            for fold in folds:
                model = f'{name}.{d}'
                command = f"python train.py --name {model} --model {model} --batch_size 128 --n_fold {fold}"
                if preload:
                    command += " --train.preload --val.preload"
                if parallel:
                    command += " --parallel"
                logger.info(command)
                subprocess.call(command.split(' '))
Exemple #17
0
    def _save_checkpoint(self, epoch: int):
        # update checkpoint
        if self.monitor.should_save_checkpoint():
            self._write_checkpoint("checkpoint_{}".format(epoch), epoch)
            self.monitor.reset()

        # update best model
        if self.metric_counter.update_best_model():
            self._write_checkpoint("best", epoch)
            logger.info("Best model updated. Loss: {}".format(
                self.metric_counter.loss_message()))

        # save last model
        self._write_checkpoint("last", epoch)
        logger.info("Last model saved. Loss: {}".format(
            self.metric_counter.loss_message()))
Exemple #18
0
 def __init__(self,
              imgs,
              labels,
              size: int,
              transform_fn: Callable,
              normalize_fn: Callable,
              corrupt_fn: Optional[Callable] = None,
              verbose=True):
     self.imgs = imgs
     self.labels = labels
     self.size = size
     self.verbose = verbose
     self.corrupt_fn = corrupt_fn
     self.transform_fn = transform_fn
     self.normalize_fn = normalize_fn
     logger.info(f'Dataset has been created with {len(self.imgs)} samples')
Exemple #19
0
    def __init__(self, classes=20, s=1, pretrained=None, gpus=1):
        super().__init__()
        classificationNet = EESPNet(classes=1000, s=s)
        if gpus >= 1:
            classificationNet = nn.DataParallel(classificationNet)
        # load the pretrained weights
        if pretrained:
            if not os.path.isfile(pretrained):
                logger.info(
                    "Weight file does not exist. Training without pre-trained weights"
                )
            logger.info("Model initialized with pretrained weights")
            classificationNet.load_state_dict(torch.load(pretrained))

        self.net = classificationNet.module

        del classificationNet
        # delete last few layers
        del self.net.classifier
        del self.net.level5
        del self.net.level5_0
        if s <= 0.5:
            p = 0.1
        else:
            p = 0.2

        self.proj_L4_C = CBR(
            self.net.level4[-1].module_act.num_parameters,
            self.net.level3[-1].module_act.num_parameters,
            1,
            1,
        )
        pspSize = 2 * self.net.level3[-1].module_act.num_parameters
        self.pspMod = nn.Sequential(
            EESP(pspSize, pspSize // 2, stride=1, k=4, r_lim=7),
            PSPModule(pspSize // 2, pspSize // 2),
        )
        self.project_l3 = nn.Sequential(nn.Dropout2d(p=p),
                                        C(pspSize // 2, classes, 1, 1))
        self.act_l3 = BR(classes)
        self.project_l2 = CBR(self.net.level2_0.act.num_parameters + classes,
                              classes, 1, 1)
        self.project_l1 = nn.Sequential(
            nn.Dropout2d(p=p),
            C(self.net.level1.act.num_parameters + classes, classes, 1, 1),
        )
Exemple #20
0
def fit(parallel=False, **kwargs):
    with open('config.yaml') as cfg:
        config = yaml.load(cfg)
    update_config(config, kwargs)
    work_dir = config['name']
    os.makedirs(work_dir, exist_ok=True)
    with open(os.path.join(work_dir, 'config.yaml'), 'w') as out:
        yaml.dump(config, out)

    config['train']['salt'] = config['val']['salt'] = config['name']
    config['train']['n_fold'] = config['val']['n_fold'] = config['n_fold']

    train, val = make_dataloaders(config['train'], config['val'], config['batch_size'], multiprocessing=parallel)
    model = DataParallel(get_baseline(config['model']))
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])

    trainer = Trainer(model=model,
                      train=train,
                      val=val,
                      work_dir=work_dir,
                      loss_fn=None,
                      optimizer=optimizer,
                      scheduler=ReduceLROnPlateau(factor=.2, patience=5, optimizer=optimizer),
                      device='cuda:0',
                      )

    stages = config['stages']
    epochs_completed = 0
    for i, stage in enumerate(stages):
        logger.info(f'Starting stage {i}')
        # ToDo: update train properties: mixup, crop type
        trainer.train.dataset.update_config(stage['train'])
        trainer.epochs = stage['epochs']
        weights = torch.from_numpy(np.array(stage['loss_weights'], dtype='float32')).to('cuda:0')
        trainer.loss_fn = partial(soft_cross_entropy,
                                  weights=weights)
        epochs_completed = trainer.fit(epochs_completed)

    convert_model(model_path=os.path.join(work_dir, 'model.pt'),
                  out_name=os.path.join(work_dir, f'{config["name"]}_{config["n_fold"]}.trcd'),
                  name=config['model']
                  )
Exemple #21
0
    def __init__(self,
                 files_a: Tuple[str],
                 files_b: Tuple[str],
                 transform_fn: Callable,
                 normalize_fn: Callable,
                 corrupt_fn: Optional[Callable] = None,
                 preload: bool = True,
                 preload_size: Optional[int] = 0,
                 verbose=True):

        assert len(files_a) == len(files_b)

        self.preload = preload
        self.data_a = files_a
        self.data_b = files_b
        self.verbose = verbose
        self.corrupt_fn = corrupt_fn
        self.transform_fn = transform_fn
        self.normalize_fn = normalize_fn

        # Train AidedDeblur #
        f_train = open("./dataset/AidedDeblur/train_instance_names.txt", "r")
        train_data = f_train.readlines()
        train_data = [line.rstrip() for line in train_data]
        f_test.close()

        self.data_a = train_data
        self.data_b = train_data

        logger.info(
            f'Dataset has been created with {len(self.data_a)} samples')

        if preload:
            preload_fn = partial(self._bulk_preload, preload_size=preload_size)
            if files_a == files_b:
                self.data_a = self.data_b = preload_fn(self.data_a)
            else:
                self.data_a, self.data_b = map(preload_fn,
                                               (self.data_a, self.data_b))
            self.preload = True
Exemple #22
0
def calculate_distances(name):
    index_ids, index_vectors = get_data('index', name)
    test_ids, test_vectors = get_data('test', name)
    logger.info('data is read')

    index_vectors, test_vectors = map(arr, (index_vectors, test_vectors))
    logger.info('tensors are ready')

    index_ids = index_ids
    test_ids = test_ids

    shape = len(test_ids), len(index_ids)

    file = File('data/distances.h5', 'w')
    result = file.create_dataset('result', shape=shape, dtype=np.uint8)
    logger.info('h5 file is ready')

    index_vectors = index_vectors.view(-1, SHAPE).cuda()
    for i in tqdm(np.arange(shape[0]), desc='calculating cosine'):
        c = cosine(test_vectors[i].view(-1, SHAPE), index_vectors)
        result[i, :] = c

    for i, v in tqdm(zip(index_ids, index_vectors),
                     desc='removing empty pics'):
        if v is None:
            result[:, i] = 255

    file.close()
Exemple #23
0
def subsample(data: Iterable,
              bounds: Tuple[float, float],
              hash_fn: Callable,
              n_buckets=100,
              salt='',
              verbose=True):
    data = list(data)
    buckets = split_into_buckets(data,
                                 n_buckets=n_buckets,
                                 salt=salt,
                                 hash_fn=hash_fn)

    lower_bound, upper_bound = [x * n_buckets for x in bounds]
    msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
    if salt:
        msg += f'; salt is {salt}'
    if verbose:
        logger.info(msg)
    return np.array([
        sample for bucket, sample in zip(buckets, data)
        if lower_bound <= bucket < upper_bound
    ])
Exemple #24
0
    def fit_one_epoch(self, n_epoch):
        segm_losses, clf_losses, scores = self._train_epoch(n_epoch)
        val_segm_losses, val_clf_losses, val_scores = self._val_epoch(n_epoch)

        train_segm_loss = np.mean(segm_losses)
        val_segm_loss = np.mean(val_segm_losses)
        train_clf_loss = np.mean(clf_losses)
        val_clf_loss = np.mean(val_clf_losses)
        scores = np.mean(scores)
        val_scores = np.mean(val_scores)

        msg = f'Epoch {n_epoch}: ' \
              f'train segm loss is {train_segm_loss:.3f}, ' \
              f'train clf loss  {train_clf_loss:.3f}, ' \
              f'train score {scores:.3f}, ' \
              f'val segm loss is {val_segm_loss:.3f}, ' \
              f'val clf loss  {val_clf_loss:.3f}, ' \
              f'val score  {val_scores:.3f}, '
        logger.info(msg)

        self.scheduler.step(metrics=val_segm_loss + val_clf_loss,
                            epoch=n_epoch)

        metric = -val_segm_loss - val_clf_loss
        if metric > self.current_metric:
            self.current_metric = metric
            self.last_improvement = n_epoch
            save(self.model, f=self.checkpoint)
            logger.info(
                f'Best model has been saved at {n_epoch}, metric is {metric:.4f}'
            )
        else:
            if self.last_improvement + self.early_stop < n_epoch:
                return True, (train_segm_loss, train_clf_loss, scores,
                              val_segm_loss, val_clf_loss, val_scores)

        return False, (train_segm_loss, train_clf_loss, scores, val_segm_loss,
                       val_clf_loss, val_scores)
Exemple #25
0
def fit(parallel=False, **kwargs):
    with open('config.yaml') as cfg:
        config = yaml.load(cfg)
    update_config(config, kwargs)
    work_dir = config['name']
    os.makedirs(work_dir, exist_ok=True)
    with open(os.path.join(work_dir, 'config.yaml'), 'w') as out:
        yaml.dump(config, out)

    train, val = make_dataloaders(config['train'],
                                  config['val'],
                                  config['batch_size'],
                                  multiprocessing=parallel)

    checkpoint = config.get('checkpoint')
    if checkpoint is not None:
        logger.info(f'Restoring model from {checkpoint}')
        model = load(checkpoint)
    else:
        model = TigerFPN()
        model = DataParallel(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])

    trainer = Trainer(model=model,
                      train=train,
                      val=val,
                      clf_loss_fn=F.binary_cross_entropy_with_logits,
                      segm_loss_fn=iou_continuous_loss_with_logits,
                      work_dir=work_dir,
                      optimizer=optimizer,
                      scheduler=ReduceLROnPlateau(factor=.2,
                                                  patience=10,
                                                  optimizer=optimizer),
                      device='cuda:0',
                      epochs=config['n_epochs'],
                      early_stop=config['early_stop'])
    epochs_used = trainer.fit(start_epoch=0)
    logger.info(f'The model trained for {epochs_used}')

    if config['finetune']:
        trainer.train.dataset.corrupt_fn = None
        trainer.optimizer = torch.optim.Adam(model.parameters(),
                                             lr=config['lr'] / 10)
        trainer.checkpoint = os.path.join(trainer.work_dir, 'model_ft.pt')
        trainer.last_improvement = epochs_used
        epochs_used = trainer.fit(start_epoch=epochs_used)
        logger.info(f'The model fine-tuned for {epochs_used}')
Exemple #26
0
def check_fonts(char_list, font):
    """
    check fonts
    :param char_list:
    :param font:
    :return: True or False
    """

    for char in reversed(char_list):
        unicode_char = char.encode("unicode_escape")
        utf_8_char = unicode_char.decode('utf-8').split('\\')[-1]
        utf_8_char = utf_8_char if len(utf_8_char) == 1 else utf_8_char[1:]
        if utf_8_char in char_dict.keys():
            utf_8_char_check = char_dict[utf_8_char]

        elif char != '¥':
            utf_8_char = unicode_char.decode('utf-8').split('\\')[-1].strip(
                'u')
            utf_8_char_check = 'uni' + utf_8_char.upper()
        else:
            continue
        try:
            ttf = TTFont(font)
            lower = ttf.getGlyphSet().get(utf_8_char_check)
            if lower is None:
                logger.info('1char {} is not in font'.format(char))
                return False
            else:
                if lower._glyph.numberOfContours == 0:
                    logger.info('2char {} is not in font'.format(char))
                    return False
                else:
                    continue
        except:
            logger.info('3char {} is not in font'.format(char))
            return False
    return True
Exemple #27
0
 def __init__(self, x_data: np.array, y_data: np.array, folds: tuple):
     data = zip(x_data, y_data)
     self.data = [x for i, x in enumerate(data) if i % 5 in folds]
     logger.info(f'There are {len(self.data)} records in the dataset')
     self.features_shape = x_data.shape
Exemple #28
0
def describe_model(m):
    logger.info(
        f'File {m} created: {pd.to_datetime(round(path.getctime(m)), unit="s")}'
    )
    return m
Exemple #29
0
    def train(self):
        self.monitor.reset()
        for epoch in range(0, self.config["num_epochs"]):
            self.monitor.update()
            # if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
            #     self.model.module.unfreeze()
            #     self.optimizer = self._get_optim(self.model.parameters())
            #     self.scheduler = self._get_scheduler(self.optimizer)

            self._train_epoch(epoch)

            logger.info("Validation ...")
            self._validate(epoch)
            logger.info("Validation finished.")

            logger.info("Updating scheduler ...")
            self._update_scheduler()
            logger.info("Scheduler updated.")

            logger.info("Saving checkpoint ...")
            self._save_checkpoint(epoch)
            logger.info("Checkpoint saved.")

            self.early_stopping(val_metric=self.metric_counter.get_metric())
            if self.early_stopping.early_stop:
                logger.info("Early stopping executed.")
                break