def train(model, loader, criterion, optimizer, scheduler, args, epoch, device):
    logger.info('Current learning rate: %.6f', optimizer.param_groups[0]['lr'])
    model.train()
    meters = AverageMeterGroup()

    for step, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        logits = model(inputs)
        loss = criterion(logits, targets)
        loss.backward()
        if args.grad_clip > 0:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()
        meters.update({'acc': accuracy(logits, targets), 'loss': loss.item()})

        if step % args.log_frequency == 0 or step + 1 == len(loader):
            logger.info('Epoch [%d/%d] Step [%d/%d]  %s', epoch, args.epochs, step + 1, len(loader), meters)
        scheduler.step()
    return meters.acc.avg, meters.loss.avg
def test(model, loader, criterion, args, epoch, device):
    model.eval()
    meters = AverageMeterGroup()
    correct = loss = total = 0.
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            bs = targets.size(0)
            logits = model(inputs)
            loss += criterion(logits, targets).item() * bs
            correct += accuracy(logits, targets) * bs
            total += bs
    logger.info('Eval Epoch [%d/%d] Loss = %.6f Acc = %.6f',
                epoch, args.epochs, loss / total, correct / total)
    return correct / total, loss / total
Beispiel #3
0
def main():
    valid_splits = [
        "172", "334", "860", "91-172", "91-334", "91-860", "denoise-91",
        "denoise-80", "all"
    ]
    parser = ArgumentParser()
    parser.add_argument("--train_split", choices=valid_splits, default="172")
    parser.add_argument("--eval_split", choices=valid_splits, default="all")
    parser.add_argument("--gcn_hidden", type=int, default=144)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--train_batch_size", default=10, type=int)
    parser.add_argument("--eval_batch_size", default=1000, type=int)
    parser.add_argument("--epochs", default=300, type=int)
    parser.add_argument("--lr", "--learning_rate", default=1e-4, type=float)
    parser.add_argument("--wd", "--weight_decay", default=1e-3, type=float)
    parser.add_argument("--train_print_freq", default=None, type=int)
    parser.add_argument("--eval_print_freq", default=10, type=int)
    parser.add_argument("--visualize", default=False, action="store_true")
    args = parser.parse_args()

    reset_seed(args.seed)

    dataset = Nb101Dataset(split=args.train_split)
    dataset_test = Nb101Dataset(split=args.eval_split)
    data_loader = DataLoader(dataset,
                             batch_size=args.train_batch_size,
                             shuffle=True,
                             drop_last=True)
    test_data_loader = DataLoader(dataset_test,
                                  batch_size=args.eval_batch_size)
    net = NeuralPredictor(gcn_hidden=args.gcn_hidden)
    net.cuda()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.wd)
    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

    logger = get_logger()

    net.train()
    for epoch in range(args.epochs):
        meters = AverageMeterGroup()
        lr = optimizer.param_groups[0]["lr"]
        for step, batch in enumerate(data_loader):
            batch = to_cuda(batch)
            target = batch["val_acc"]
            predict = net(batch)
            loss = criterion(predict, target)
            loss.backward()
            optimizer.step()
            mse = accuracy_mse(predict, target)
            meters.update({
                "loss": loss.item(),
                "mse": mse.item()
            },
                          n=target.size(0))
            if (args.train_print_freq and step % args.train_print_freq == 0) or \
                    step + 1 == len(data_loader):
                logger.info("Epoch [%d/%d] Step [%d/%d] lr = %.3e  %s",
                            epoch + 1, args.epochs, step + 1, len(data_loader),
                            lr, meters)
        lr_scheduler.step()

    net.eval()
    meters = AverageMeterGroup()
    predict_, target_ = [], []
    with torch.no_grad():
        for step, batch in enumerate(test_data_loader):
            batch = to_cuda(batch)
            target = batch["val_acc"]
            predict = net(batch)
            predict_.append(predict.cpu().numpy())
            target_.append(target.cpu().numpy())
            meters.update(
                {
                    "loss": criterion(predict, target).item(),
                    "mse": accuracy_mse(predict, target).item()
                },
                n=target.size(0))

            if (args.eval_print_freq and step % args.eval_print_freq == 0) or \
                    step % 10 == 0 or step + 1 == len(test_data_loader):
                logger.info("Evaluation Step [%d/%d]  %s", step + 1,
                            len(test_data_loader), meters)
    predict_ = np.concatenate(predict_)
    target_ = np.concatenate(target_)
    if args.visualize:
        visualize_scatterplot(predict_, target_)
    def test_one_epoch(self, epoch, dataloader):
        config = self.cfg
        self.valid_meters = AverageMeterGroup()
        self.model.eval()
        y_true = []
        y_pred = []
        with torch.no_grad():
            for step, (X, y) in enumerate(dataloader):
                if self.debug and step > 1:
                    break
                X, y = X.to(self.device,
                            non_blocking=True), y.to(self.device,
                                                     non_blocking=True)
                bs = X.size(0)

                # forward
                logits = self.model(X)

                # loss
                if isinstance(logits, tuple):
                    logits, aux_logits = logits
                    aux_loss = self.loss_fn(aux_logits, y)
                else:
                    aux_loss = 0.
                loss = self.loss_fn(logits, y)
                if config.model.aux_weight > 0:
                    loss = loss + config.model.aux_weight * aux_loss

                # post-processing
                y_true.append(y.cpu().detach())
                y_pred.append(logits.cpu().detach())

                accuracy = metrics(logits, y, topk=(1, 3))
                self.valid_meters.update(accuracy)
                self.valid_meters.update({'valid_loss': loss.item()})
                if step % config.logger.log_frequency == 0 or step == len(
                        dataloader) - 1:
                    self.logger.info(
                        "Test: [{:3d}/{}] Step {:03d}/{:03d} {}".format(
                            epoch + 1, config.trainer.num_epochs, step,
                            len(dataloader) - 1, self.valid_meters))

        y_true = torch.cat(y_true)
        y_pred = torch.cat(y_pred)
        self.valid_report = parse_preds(
            np.array(y_true.detach().cpu().numpy()),
            np.array(y_pred.detach().cpu().numpy()))
        self.valid_report['acc1'] = self.valid_meters['acc1'].avg
        self.valid_report['epoch'] = epoch
        self.logger.info(self.valid_report['cls_report'])
        self.logger.info(self.valid_report['covid_auc'])
        torch.save(
            self.valid_report,
            os.path.join(config.logger.path,
                         f'best_epoch{epoch}_valid_report.pth'))
        # if self.enable_writter  and epoch > 0:
        #     self.writter.add_scalar("loss/valid", self.valid_meters['valid_loss'].avg, global_step=epoch)
        #     self.writter.add_scalar("acc1/valid", self.valid_meters['acc1'].avg, global_step=epoch)
        #     self.writter.add_scalar("acc3/valid", self.valid_meters['acc3'].avg, global_step=epoch)

        self.logger.info("Test: [{:3d}/{}] Final result {}".format(
            epoch + 1, config.trainer.num_epochs, self.valid_meters))

        return self.valid_meters
class DefaultEvaluator(BaseEvaluator):
    def __init__(self, cfg):
        super(DefaultEvaluator, self).__init__()
        self.cfg = cfg
        self.debug = cfg.debug
        self.callbacks = self.generate_callbacks()

        self.arcs = self.load_arcs(cfg.args.arc_path)
        self.writter = SummaryWriter(
            os.path.join(self.cfg.logger.path, 'summary_runs'))
        self.logger = MyLogger(__name__, cfg).getlogger()
        self.size_acc = {
        }  # {'epoch1': [model_size, acc], 'epoch2': [model_size, acc], ...}
        self.init_basic_settings()

    def init_basic_settings(self):
        '''init train_epochs, device, loss_fn, dataset, and dataloaders
        '''
        # train epochs
        try:
            self.train_epochs = self.cfg.args.train_epochs
        except:
            self.train_epochs = 1

        # device
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.logger.info(f"Using device: {self.device}")

        # loss_fn
        self.loss_fn = build_loss_fn(self.cfg)
        self.loss_fn.to(self.device)
        self.logger.info(f"Building loss function ...")

        # dataset
        self.train_dataset, self.test_dataset = build_dataset(self.cfg)

        # dataloader
        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.cfg.dataset.batch_size,
            shuffle=True,
            num_workers=self.cfg.dataset.workers,
            pin_memory=True)
        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.cfg.dataset.batch_size,
            shuffle=False,
            num_workers=self.cfg.dataset.workers,
            pin_memory=True)
        self.logger.info(f"Building dataset and dataloader ...")

    def load_arcs(self, arc_path):
        '''load arch json files
        Args:
            arc_path:
                (file): [arc_path]
                (dir): [arc_path/epoch_0.json, arc_path/epoch_1.json, ...]
        '''
        if os.path.isfile(arc_path):
            return [arc_path]
        else:
            arcs = os.listdir(arc_path)
            arcs = [
                os.path.join(arc_path, arc) for arc in arcs
                if arc.split('.')[-1] == 'json'
            ]
            arcs = sorted(
                arcs,
                key=lambda x: int(
                    os.path.splitext(os.path.basename(x))[0].split('_')[1]))
            return arcs

    def reset(self):
        '''mutable can be only initialized for once, hence it needs to
        reset model, optimizer, scheduler when run a new trial.
        '''
        # model
        self.model = build_model(self.cfg)
        self.model.to(self.device)
        self.logger.info(f"Building model {self.cfg.model.name} ...")

        # load teacher model if using knowledge distillation
        if hasattr(self.cfg, 'kd') and self.cfg.kd.enable:
            self.kd_model = load_kd_model(self.cfg).to(self.device)
            self.kd_model.eval()
            self.logger.info(
                f"Building teacher model {self.cfg.kd.model.name} ...")
        else:
            self.kd_model = None

        # optimizer
        self.optimizer = generate_optimizer(
            model=self.model,
            optim_name=self.cfg.optim.name,
            lr=self.cfg.optim.base_lr,
            momentum=self.cfg.optim.momentum,
            weight_decay=self.cfg.optim.weight_decay)
        self.logger.info(f"Building optimizer {self.cfg.optim.name} ...")

        # scheduler
        self.scheduler_params = parse_cfg_for_scheduler(
            self.cfg, self.cfg.optim.scheduler.name)
        self.lr_scheduler = generate_scheduler(self.optimizer,
                                               self.cfg.optim.scheduler.name,
                                               **self.scheduler_params)
        self.logger.info(
            f"Building optim.scheduler {self.cfg.optim.scheduler.name} ...")

    def compare(self):
        self.logger.info("=" * 20)
        self.logger.info("Selecting the best architecture ...")
        self.enable_writter = False
        # split train dataset into train and valid dataset
        train_size = int(0.8 * len(self.train_dataset))
        valid_size = len(self.train_dataset) - train_size
        self.train_dataset_part, self.valid_dataset_part = torch.utils.data.random_split(
            self.train_dataset, [train_size, valid_size])

        # dataloader
        self.train_loader_part = torch.utils.data.DataLoader(
            self.train_dataset_part,
            batch_size=self.cfg.dataset.batch_size,
            shuffle=True,
            num_workers=self.cfg.dataset.workers,
            pin_memory=True)
        self.valid_loader_part = torch.utils.data.DataLoader(
            self.valid_dataset_part,
            batch_size=self.cfg.dataset.batch_size,
            shuffle=True,
            num_workers=self.cfg.dataset.workers,
            pin_memory=True)

        # choose the best architecture
        for arc in self.arcs:
            self.reset()
            self.mutator = apply_fixed_architecture(self.model, arc)
            size = self.model_size()
            arc_name = os.path.basename(arc)
            self.logger.info(f"{arc} Model size={size*4/1024**2} MB")

            # train
            for epoch in range(self.train_epochs):
                self.train_one_epoch(epoch, self.train_loader_part)
            val_acc = self.valid_one_epoch(-1, self.valid_loader_part)
            self.size_acc[arc_name] = {
                'size': size,
                'val_acc': val_acc,
                'arc': arc
            }
        sorted_size_acc = sorted(
            self.size_acc.items(),
            key=lambda x: x[1]['val_acc']['save_metric'].avg,
            reverse=True)
        return sorted_size_acc[0][1]

    def run(self, arc, validate=True, test=False):
        '''retrain the best-performing arch from scratch
            arc: the json file path of the best-performing arch 
        '''
        self.logger.info("=" * 20)
        self.logger.info("Retraining the best architecture ...")
        self.enable_writter = True
        self.reset()

        # init model and mutator
        self.mutator = apply_fixed_architecture(self.model, arc)
        size = self.model_size()
        arc_name = os.path.basename(arc)
        self.logger.info(f"{arc_name} Model size={size*4/1024**2} MB")

        # callbacks
        for callback in self.callbacks:
            callback.build(self.model, self.mutator, self)

        # resume
        self.start_epoch = 0
        self.resume()

        # fintune
        # todo: improve robustness, bug of optimizer resume
        # if self.cfg.model.finetune:
        #     self.logger.info("Freezing params of conv part ...")
        #     for name, param in self.model.named_parameters():
        #         if 'dense' not in name:
        #             param.requires_grad = False

        # dataparallel
        if len(self.cfg.trainer.device_ids) > 1:
            device_ids = self.cfg.trainer.device_ids
            num_gpus_available = torch.cuda.device_count()
            assert num_gpus_available >= len(
                device_ids), "you can only use {} device(s)".format(
                    num_gpus_available)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=device_ids)
            if self.kd_model:
                self.kd_model = torch.nn.DataParallel(self.kd_model,
                                                      device_ids=device_ids)

        if test:
            meters = self.test_one_epoch(-1, self.test_loader)
            self.logger.info(f"Final test metrics= {meters}")
            return meters

        # start training
        for epoch in range(self.start_epoch, self.cfg.evaluator.num_epochs):
            for callback in self.callbacks:
                callback.on_epoch_begin(epoch)

            self.logger.info("Epoch %d Training", epoch)
            self.train_one_epoch(epoch, self.train_loader)

            if validate:
                self.logger.info("Epoch %d Validating", epoch)
                self.valid_one_epoch(epoch, self.test_loader)

            self.lr_scheduler.step()

            self.cur_meters = getattr(self, 'valid_meters', self.train_meters)
            for callback in self.callbacks:
                if isinstance(callback, CheckpointCallback):
                    callback.update_best_metric(
                        self.cur_meters.meters['save_metric'].avg)
                callback.on_epoch_end(epoch)

        self.logger.info("Final best Prec@1 = {:.4%}".format(self.best_metric))

    def train_one_epoch(self, epoch, dataloader):
        config = self.cfg
        self.train_meters = AverageMeterGroup()

        cur_lr = self.optimizer.param_groups[0]["lr"]
        self.logger.info("Epoch %d LR %.6f", epoch, cur_lr)
        if self.enable_writter:
            self.writter.add_scalar("lr", cur_lr, global_step=epoch)

        self.model.train()

        for step, (x, y) in enumerate(dataloader):
            if self.debug and step > 1:
                break
            for callback in self.callbacks:
                callback.on_batch_begin(epoch)
            x, y = x.to(self.device,
                        non_blocking=True), y.to(self.device,
                                                 non_blocking=True)
            bs = x.size(0)
            # mixup data
            if config.mixup.enable:
                x, y_a, y_b, lam = mixup_data(x, y, config.mixup.alpha)
                mixup_y = [y_a, y_b, lam]

            # forward
            logits = self.model(x)

            # loss
            if isinstance(logits, tuple):
                logits, aux_logits = logits
                if config.mixup.enable:
                    aux_loss = mixup_loss_fn(self.loss_fn, aux_logits,
                                             *mixup_y)
                else:
                    aux_loss = self.loss_fn(aux_logits, y)
            else:
                aux_loss = 0.
            if config.mixup.enable:
                loss = mixup_loss_fn(self.loss_fn, logits, *mixup_y)
            else:
                loss = self.loss_fn(logits, y)
            if config.model.aux_weight > 0:
                loss += config.model.aux_weight * aux_loss
            if self.kd_model:
                teacher_output = self.kd_model(x)
                loss += (1 - config.kd.loss.alpha) * loss + loss_fn_kd(
                    logits, teacher_output, self.cfg.kd.loss)

            # backward
            loss.backward()
            # gradient clipping
            # nn.utils.clip_grad_norm_(model.parameters(), 20)

            if (step + 1) % config.trainer.accumulate_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

            # post-processing
            accuracy = metrics(logits, y,
                               topk=(1, 3))  # e.g. {'acc1':0.65, 'acc3':0.86}
            self.train_meters.update(accuracy)
            self.train_meters.update({'train_loss': loss.item()})
            if step % config.logger.log_frequency == 0 or step == len(
                    dataloader) - 1:
                self.logger.info(
                    "Train: [{:3d}/{}] Step {:03d}/{:03d} {}".format(
                        epoch + 1, config.trainer.num_epochs, step,
                        len(dataloader) - 1, self.train_meters))

            for callback in self.callbacks:
                callback.on_batch_end(epoch)

        if self.enable_writter:
            self.writter.add_scalar("loss/train",
                                    self.train_meters['train_loss'].avg,
                                    global_step=epoch)
            self.writter.add_scalar("acc1/train",
                                    self.train_meters['acc1'].avg,
                                    global_step=epoch)
            self.writter.add_scalar("acc3/train",
                                    self.train_meters['acc3'].avg,
                                    global_step=epoch)

        self.logger.info("Train: [{:3d}/{}] Final result {}".format(
            epoch + 1, config.trainer.num_epochs, self.train_meters))

        return self.train_meters

    def valid_one_epoch(self, epoch, dataloader):
        config = self.cfg
        self.valid_meters = AverageMeterGroup()
        self.model.eval()
        y_true = []
        y_pred = []

        with torch.no_grad():
            for step, (X, y) in enumerate(dataloader):
                if self.debug and step > 1:
                    break
                X, y = X.to(self.device,
                            non_blocking=True), y.to(self.device,
                                                     non_blocking=True)
                bs = X.size(0)

                # forward
                logits = self.model(X)

                # loss
                if isinstance(logits, tuple):
                    logits, aux_logits = logits
                    aux_loss = self.loss_fn(aux_logits, y)
                else:
                    aux_loss = 0.
                loss = self.loss_fn(logits, y)
                if config.model.aux_weight > 0:
                    loss = loss + config.model.aux_weight * aux_loss

                # post-processing
                y_true.append(y.cpu().detach())
                y_pred.append(logits.cpu().detach())

                accuracy = metrics(logits, y, topk=(1, 3))
                self.valid_meters.update(accuracy)
                self.valid_meters.update({'valid_loss': loss.item()})
                if step % config.logger.log_frequency == 0 or step == len(
                        dataloader) - 1:
                    self.logger.info(
                        "Valid: [{:3d}/{}] Step {:03d}/{:03d} {}".format(
                            epoch + 1, config.trainer.num_epochs, step,
                            len(dataloader) - 1, self.valid_meters))

        y_true = torch.cat(y_true)
        y_pred = torch.cat(y_pred)
        self.valid_report = parse_preds(
            np.array(y_true.detach().cpu().numpy()),
            np.array(y_pred.detach().cpu().numpy()))
        self.valid_report['acc1'] = self.valid_meters['acc1'].avg
        self.valid_report['epoch'] = epoch
        self.logger.info(self.valid_report['cls_report'])
        self.logger.info(self.valid_report['covid_auc'])
        if self.enable_writter and epoch > 0:
            self.writter.add_scalar("loss/valid",
                                    self.valid_meters['valid_loss'].avg,
                                    global_step=epoch)
            self.writter.add_scalar("acc1/valid",
                                    self.valid_meters['acc1'].avg,
                                    global_step=epoch)
            self.writter.add_scalar("acc3/valid",
                                    self.valid_meters['acc3'].avg,
                                    global_step=epoch)

        self.logger.info("Valid: [{:3d}/{}] Final result {}".format(
            epoch + 1, config.trainer.num_epochs, self.valid_meters))

        return self.valid_meters
        # if self.cfg.callback.checkpoint.mode: # the more the better, e.g. acc
        #     return self.valid_meters['acc1'].avg
        # else: # the less, the better, e.g. epe
        #     return self.valid_meters['valid_loss'].avg

    def test_one_epoch(self, epoch, dataloader):
        config = self.cfg
        self.valid_meters = AverageMeterGroup()
        self.model.eval()
        y_true = []
        y_pred = []
        with torch.no_grad():
            for step, (X, y) in enumerate(dataloader):
                if self.debug and step > 1:
                    break
                X, y = X.to(self.device,
                            non_blocking=True), y.to(self.device,
                                                     non_blocking=True)
                bs = X.size(0)

                # forward
                logits = self.model(X)

                # loss
                if isinstance(logits, tuple):
                    logits, aux_logits = logits
                    aux_loss = self.loss_fn(aux_logits, y)
                else:
                    aux_loss = 0.
                loss = self.loss_fn(logits, y)
                if config.model.aux_weight > 0:
                    loss = loss + config.model.aux_weight * aux_loss

                # post-processing
                y_true.append(y.cpu().detach())
                y_pred.append(logits.cpu().detach())

                accuracy = metrics(logits, y, topk=(1, 3))
                self.valid_meters.update(accuracy)
                self.valid_meters.update({'valid_loss': loss.item()})
                if step % config.logger.log_frequency == 0 or step == len(
                        dataloader) - 1:
                    self.logger.info(
                        "Test: [{:3d}/{}] Step {:03d}/{:03d} {}".format(
                            epoch + 1, config.trainer.num_epochs, step,
                            len(dataloader) - 1, self.valid_meters))

        y_true = torch.cat(y_true)
        y_pred = torch.cat(y_pred)
        self.valid_report = parse_preds(
            np.array(y_true.detach().cpu().numpy()),
            np.array(y_pred.detach().cpu().numpy()))
        self.valid_report['acc1'] = self.valid_meters['acc1'].avg
        self.valid_report['epoch'] = epoch
        self.logger.info(self.valid_report['cls_report'])
        self.logger.info(self.valid_report['covid_auc'])
        torch.save(
            self.valid_report,
            os.path.join(config.logger.path,
                         f'best_epoch{epoch}_valid_report.pth'))
        # if self.enable_writter  and epoch > 0:
        #     self.writter.add_scalar("loss/valid", self.valid_meters['valid_loss'].avg, global_step=epoch)
        #     self.writter.add_scalar("acc1/valid", self.valid_meters['acc1'].avg, global_step=epoch)
        #     self.writter.add_scalar("acc3/valid", self.valid_meters['acc3'].avg, global_step=epoch)

        self.logger.info("Test: [{:3d}/{}] Final result {}".format(
            epoch + 1, config.trainer.num_epochs, self.valid_meters))

        return self.valid_meters

    def resume(self, mode=True):
        self.best_metric = -999
        path = self.cfg.model.resume_path
        if path:
            assert os.path.exists(path), "{} does not exist".format(path)
            ckpt = torch.load(path)
            try:
                self.model.load_state_dict(ckpt['model_state_dict'])
            except:
                self.logger.info('Loading from DataParallel model...')
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in ckpt['model_state_dict'].items():
                    name = k[7:]  # remove `module.`
                    new_state_dict[name] = v
                # load params
                self.model.load_state_dict(new_state_dict)
            self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
            self.lr_scheduler.load_state_dict(ckpt['lr_scheduler_state_dict'])
            self.logger.info('Resuming training from epoch {}'.format(
                self.start_epoch))
            self.best_metric = ckpt['best_metric']
            self.start_epoch = ckpt['epoch'] + 1

        for callback in self.callbacks:
            if isinstance(callback, CheckpointCallback):
                if self.best_metric == -999:
                    self.best_metric = callback.best_metric
                else:
                    callback.best_metric = self.best_metric

    def generate_callbacks(self):
        '''
        Args:
            func: a function to generate other callbacks, must return a list
        Return:
            a list of callbacks.
        '''
        self.ckpt_callback = CheckpointCallback(
            checkpoint_dir=self.cfg.logger.path,
            name='best_retrain.pth',
            mode=self.cfg.callback.checkpoint.mode)
        self.cam_callback = CAMCallback(self.cfg)
        callbacks = [self.ckpt_callback, self.cam_callback]
        return callbacks

    def model_size(self, name='size'):
        assert name in ['size', 'flops']
        size = self.cfg.input.size
        if self.cfg.dataset.is_3d:
            input_size = (1, 1, self.cfg.dataset.slice_num, *size)
        else:
            input_size = (1, 3, *size)
        return flops_size_counter(self.model, input_size)[name]
    def train_one_epoch(self, epoch, dataloader):
        config = self.cfg
        self.train_meters = AverageMeterGroup()

        cur_lr = self.optimizer.param_groups[0]["lr"]
        self.logger.info("Epoch %d LR %.6f", epoch, cur_lr)
        if self.enable_writter:
            self.writter.add_scalar("lr", cur_lr, global_step=epoch)

        self.model.train()

        for step, (x, y) in enumerate(dataloader):
            if self.debug and step > 1:
                break
            for callback in self.callbacks:
                callback.on_batch_begin(epoch)
            x, y = x.to(self.device,
                        non_blocking=True), y.to(self.device,
                                                 non_blocking=True)
            bs = x.size(0)
            # mixup data
            if config.mixup.enable:
                x, y_a, y_b, lam = mixup_data(x, y, config.mixup.alpha)
                mixup_y = [y_a, y_b, lam]

            # forward
            logits = self.model(x)

            # loss
            if isinstance(logits, tuple):
                logits, aux_logits = logits
                if config.mixup.enable:
                    aux_loss = mixup_loss_fn(self.loss_fn, aux_logits,
                                             *mixup_y)
                else:
                    aux_loss = self.loss_fn(aux_logits, y)
            else:
                aux_loss = 0.
            if config.mixup.enable:
                loss = mixup_loss_fn(self.loss_fn, logits, *mixup_y)
            else:
                loss = self.loss_fn(logits, y)
            if config.model.aux_weight > 0:
                loss += config.model.aux_weight * aux_loss
            if self.kd_model:
                teacher_output = self.kd_model(x)
                loss += (1 - config.kd.loss.alpha) * loss + loss_fn_kd(
                    logits, teacher_output, self.cfg.kd.loss)

            # backward
            loss.backward()
            # gradient clipping
            # nn.utils.clip_grad_norm_(model.parameters(), 20)

            if (step + 1) % config.trainer.accumulate_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

            # post-processing
            accuracy = metrics(logits, y,
                               topk=(1, 3))  # e.g. {'acc1':0.65, 'acc3':0.86}
            self.train_meters.update(accuracy)
            self.train_meters.update({'train_loss': loss.item()})
            if step % config.logger.log_frequency == 0 or step == len(
                    dataloader) - 1:
                self.logger.info(
                    "Train: [{:3d}/{}] Step {:03d}/{:03d} {}".format(
                        epoch + 1, config.trainer.num_epochs, step,
                        len(dataloader) - 1, self.train_meters))

            for callback in self.callbacks:
                callback.on_batch_end(epoch)

        if self.enable_writter:
            self.writter.add_scalar("loss/train",
                                    self.train_meters['train_loss'].avg,
                                    global_step=epoch)
            self.writter.add_scalar("acc1/train",
                                    self.train_meters['acc1'].avg,
                                    global_step=epoch)
            self.writter.add_scalar("acc3/train",
                                    self.train_meters['acc3'].avg,
                                    global_step=epoch)

        self.logger.info("Train: [{:3d}/{}] Final result {}".format(
            epoch + 1, config.trainer.num_epochs, self.train_meters))

        return self.train_meters