def train_and_eval(tag,
                   dataroot,
                   test_ratio=0.0,
                   cv_fold=0,
                   reporter=None,
                   metric='last',
                   save_path=None,
                   only_eval=False,
                   local_rank=-1,
                   evaluation_interval=5):
    total_batch = C.get()["batch"]
    if local_rank >= 0:
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=int(os.environ['WORLD_SIZE']))
        device = torch.device('cuda', local_rank)
        torch.cuda.set_device(device)

        C.get()['lr'] *= dist.get_world_size()
        logger.info(
            f'local batch={C.get()["batch"]} world_size={dist.get_world_size()} ----> total batch={C.get()["batch"] * dist.get_world_size()}'
        )
        total_batch = C.get()["batch"] * dist.get_world_size()

    is_master = local_rank < 0 or dist.get_rank() == 0
    if is_master:
        add_filehandler(logger, 'master' + '.log')

    if not reporter:
        reporter = lambda **kwargs: 0

    max_epoch = C.get()['epoch']
    trainsampler, trainloader, validloader, testloader_ = get_dataloaders(
        C.get()['dataset'],
        C.get()['batch'],
        dataroot,
        test_ratio,
        split_idx=cv_fold,
        multinode=(local_rank >= 0))

    # create a model & an optimizer
    model = get_model(C.get()['model'],
                      num_class(C.get()['dataset']),
                      local_rank=local_rank)
    model_ema = get_model(C.get()['model'],
                          num_class(C.get()['dataset']),
                          local_rank=-1)
    model_ema.eval()

    criterion_ce = criterion = CrossEntropyLabelSmooth(
        num_class(C.get()['dataset']),
        C.get().conf.get('lb_smooth', 0))
    if C.get().conf.get('mixup', 0.0) > 0.0:
        criterion = CrossEntropyMixUpLabelSmooth(
            num_class(C.get()['dataset']),
            C.get().conf.get('lb_smooth', 0))
    if C.get()['optimizer']['type'] == 'sgd':
        optimizer = optim.SGD(
            model.parameters(),
            lr=C.get()['lr'],
            momentum=C.get()['optimizer'].get('momentum', 0.9),
            weight_decay=0.0,
            nesterov=C.get()['optimizer'].get('nesterov', True))
    elif C.get()['optimizer']['type'] == 'rmsprop':
        optimizer = RMSpropTF(model.parameters(),
                              lr=C.get()['lr'],
                              weight_decay=0.0,
                              alpha=0.9,
                              momentum=0.9,
                              eps=0.001)
    else:
        raise ValueError('invalid optimizer type=%s' %
                         C.get()['optimizer']['type'])

    lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine')
    if lr_scheduler_type == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=C.get()['epoch'], eta_min=0.)
    elif lr_scheduler_type == 'resnet':
        scheduler = adjust_learning_rate_resnet(optimizer)
    elif lr_scheduler_type == 'efficientnet':
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda x: 0.97**int(
                (x + C.get()['lr_schedule']['warmup']['epoch']) / 2.4))
    else:
        raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type)

    if C.get()['lr_schedule'].get(
            'warmup', None) and C.get()['lr_schedule']['warmup']['epoch'] > 0:
        scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=C.get()['lr_schedule']['warmup']['multiplier'],
            total_epoch=C.get()['lr_schedule']['warmup']['epoch'],
            after_scheduler=scheduler)

    if not tag or not is_master:
        from FastAutoAugment.metrics import SummaryWriterDummy as SummaryWriter
        logger.warning('tag not provided, no tensorboard log.')
    else:
        from tensorboardX import SummaryWriter
    writers = [
        SummaryWriter(log_dir='./logs/%s/%s' % (tag, x))
        for x in ['train', 'valid', 'test']
    ]

    if C.get()['optimizer']['ema'] > 0.0 and is_master:
        # https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/4?u=ildoonet
        ema = EMA(C.get()['optimizer']['ema'])
    else:
        ema = None

    result = OrderedDict()
    epoch_start = 1
    #TODO: change only eval=False when without save_path ??
    if save_path != 'test.pth':  # and is_master: --> should load all data(not able to be broadcasted)
        if save_path and not os.path.exists(save_path):
            import torch.utils.model_zoo as model_zoo
            data = model_zoo.load_url(
                'https://download.pytorch.org/models/resnet50-19c8e357.pth',
                model_dir=os.path.join(os.getcwd(), 'FastAutoAugment/models'))
            if C.get()['dataset'] == 'cifar10':
                data.pop('fc.weight')
                data.pop('fc.bias')
                model_dict = model.state_dict()
                model_dict.update(data)
                model.load_state_dict(model_dict)
                torch.save(model_dict, save_path)

        logger.info('%s file found. loading...' % save_path)
        data = torch.load(save_path)
        key = 'model' if 'model' in data else 'state_dict'

        if 'epoch' not in data:
            model.load_state_dict(data)
        else:
            logger.info('checkpoint epoch@%d' % data['epoch'])
            if not isinstance(model, (DataParallel, DistributedDataParallel)):
                model.load_state_dict({
                    k.replace('module.', ''): v
                    for k, v in data[key].items()
                })
            else:
                model.load_state_dict({
                    k if 'module.' in k else 'module.' + k: v
                    for k, v in data[key].items()
                })
            logger.info('optimizer.load_state_dict+')
            optimizer.load_state_dict(data['optimizer'])
            if data['epoch'] < C.get()['epoch']:
                epoch_start = data['epoch']
            else:
                only_eval = True
            if ema is not None:
                ema.shadow = data.get('ema', {}) if isinstance(
                    data.get('ema', {}), dict) else data['ema'].state_dict()
        del data

    if local_rank >= 0:
        for name, x in model.state_dict().items():
            dist.broadcast(x, 0)
        logger.info(
            f'multinode init. local_rank={dist.get_rank()} is_master={is_master}'
        )
        torch.cuda.synchronize()

    tqdm_disabled = bool(os.environ.get(
        'TASK_NAME', '')) and local_rank != 0  # KakaoBrain Environment

    if only_eval:
        logger.info('evaluation only+')
        model.eval()
        rs = dict()
        rs['train'] = run_epoch(model,
                                trainloader,
                                criterion,
                                None,
                                desc_default='train',
                                epoch=0,
                                writer=writers[0],
                                is_master=is_master)

        with torch.no_grad():
            rs['valid'] = run_epoch(model,
                                    validloader,
                                    criterion,
                                    None,
                                    desc_default='valid',
                                    epoch=0,
                                    writer=writers[1],
                                    is_master=is_master)
            rs['test'] = run_epoch(model,
                                   testloader_,
                                   criterion,
                                   None,
                                   desc_default='*test',
                                   epoch=0,
                                   writer=writers[2],
                                   is_master=is_master)
            if ema is not None and len(ema) > 0:
                model_ema.load_state_dict({
                    k.replace('module.', ''): v
                    for k, v in ema.state_dict().items()
                })
                rs['valid'] = run_epoch(model_ema,
                                        validloader,
                                        criterion_ce,
                                        None,
                                        desc_default='valid(EMA)',
                                        epoch=0,
                                        writer=writers[1],
                                        verbose=is_master,
                                        tqdm_disabled=tqdm_disabled)
                rs['test'] = run_epoch(model_ema,
                                       testloader_,
                                       criterion_ce,
                                       None,
                                       desc_default='*test(EMA)',
                                       epoch=0,
                                       writer=writers[2],
                                       verbose=is_master,
                                       tqdm_disabled=tqdm_disabled)
        for key, setname in itertools.product(['loss', 'top1', 'top5'],
                                              ['train', 'valid', 'test']):
            if setname not in rs:
                continue
            result['%s_%s' % (key, setname)] = rs[setname][key]
        result['epoch'] = 0
        return result

    # train loop
    best_top1 = 0
    for epoch in range(epoch_start, max_epoch + 1):
        if local_rank >= 0:
            trainsampler.set_epoch(epoch)

        model.train()
        rs = dict()
        rs['train'] = run_epoch(model,
                                trainloader,
                                criterion,
                                optimizer,
                                desc_default='train',
                                epoch=epoch,
                                writer=writers[0],
                                verbose=(is_master and local_rank <= 0),
                                scheduler=scheduler,
                                ema=ema,
                                wd=C.get()['optimizer']['decay'],
                                tqdm_disabled=tqdm_disabled)
        model.eval()

        if math.isnan(rs['train']['loss']):
            raise Exception('train loss is NaN.')

        if ema is not None and C.get(
        )['optimizer']['ema_interval'] > 0 and epoch % C.get(
        )['optimizer']['ema_interval'] == 0:
            logger.info(f'ema synced+ rank={dist.get_rank()}')
            if ema is not None:
                model.load_state_dict(ema.state_dict())
            for name, x in model.state_dict().items():
                # print(name)
                dist.broadcast(x, 0)
            torch.cuda.synchronize()
            logger.info(f'ema synced- rank={dist.get_rank()}')

        if is_master and (epoch % evaluation_interval == 0
                          or epoch == max_epoch):
            with torch.no_grad():
                rs['valid'] = run_epoch(model,
                                        validloader,
                                        criterion_ce,
                                        None,
                                        desc_default='valid',
                                        epoch=epoch,
                                        writer=writers[1],
                                        verbose=is_master,
                                        tqdm_disabled=tqdm_disabled)
                rs['test'] = run_epoch(model,
                                       testloader_,
                                       criterion_ce,
                                       None,
                                       desc_default='*test',
                                       epoch=epoch,
                                       writer=writers[2],
                                       verbose=is_master,
                                       tqdm_disabled=tqdm_disabled)

                if ema is not None:
                    model_ema.load_state_dict({
                        k.replace('module.', ''): v
                        for k, v in ema.state_dict().items()
                    })
                    rs['valid'] = run_epoch(model_ema,
                                            validloader,
                                            criterion_ce,
                                            None,
                                            desc_default='valid(EMA)',
                                            epoch=epoch,
                                            writer=writers[1],
                                            verbose=is_master,
                                            tqdm_disabled=tqdm_disabled)
                    rs['test'] = run_epoch(model_ema,
                                           testloader_,
                                           criterion_ce,
                                           None,
                                           desc_default='*test(EMA)',
                                           epoch=epoch,
                                           writer=writers[2],
                                           verbose=is_master,
                                           tqdm_disabled=tqdm_disabled)

            logger.info(
                f'epoch={epoch} '
                f'[train] loss={rs["train"]["loss"]:.4f} top1={rs["train"]["top1"]:.4f} '
                f'[valid] loss={rs["valid"]["loss"]:.4f} top1={rs["valid"]["top1"]:.4f} '
                f'[test] loss={rs["test"]["loss"]:.4f} top1={rs["test"]["top1"]:.4f} '
            )

            if metric == 'last' or rs[metric]['top1'] > best_top1:
                if metric != 'last':
                    best_top1 = rs[metric]['top1']
                for key, setname in itertools.product(
                    ['loss', 'top1', 'top5'], ['train', 'valid', 'test']):
                    result['%s_%s' % (key, setname)] = rs[setname][key]
                result['epoch'] = epoch

                writers[1].add_scalar('valid_top1/best', rs['valid']['top1'],
                                      epoch)
                writers[2].add_scalar('test_top1/best', rs['test']['top1'],
                                      epoch)

                reporter(loss_valid=rs['valid']['loss'],
                         top1_valid=rs['valid']['top1'],
                         loss_test=rs['test']['loss'],
                         top1_test=rs['test']['top1'])

                # save checkpoint
                if is_master and save_path:
                    logger.info('save model@%d to %s, err=%.4f' %
                                (epoch, save_path, 1 - best_top1))
                    torch.save(
                        {
                            'epoch': epoch,
                            'log': {
                                'train': rs['train'].get_dict(),
                                'valid': rs['valid'].get_dict(),
                                'test': rs['test'].get_dict(),
                            },
                            'optimizer': optimizer.state_dict(),
                            'model': model.state_dict(),
                            'ema':
                            ema.state_dict() if ema is not None else None,
                        }, save_path)

    del model

    result['top1_test'] = best_top1
    return result
Beispiel #2
0
def train_and_eval(tag, dataroot, test_ratio=0.0, cv_fold=0, reporter=None, metric='last', save_path=None, only_eval=False, horovod=False):
    if horovod:
        import horovod.torch as hvd
        hvd.init()
        device = torch.device('cuda', hvd.local_rank())
        torch.cuda.set_device(device)

    if not reporter:
        reporter = lambda **kwargs: 0

    max_epoch = C.get()['epoch']
    trainsampler, trainloader, validloader, testloader_ = get_dataloaders(C.get()['dataset'], C.get()['batch'], dataroot, test_ratio, split_idx=cv_fold, horovod=horovod)

    # create a model & an optimizer
    model = get_model(C.get()['model'], num_class(C.get()['dataset']), data_parallel=(not horovod))

    criterion = nn.CrossEntropyLoss()
    if C.get()['optimizer']['type'] == 'sgd':
        optimizer = optim.SGD(
            model.parameters(),
            lr=C.get()['lr'],
            momentum=C.get()['optimizer'].get('momentum', 0.9),
            weight_decay=C.get()['optimizer']['decay'],
            nesterov=C.get()['optimizer']['nesterov']
        )
    else:
        raise ValueError('invalid optimizer type=%s' % C.get()['optimizer']['type'])

    is_master = True
    if horovod:
        optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
        optimizer._requires_update = set()  # issue : https://github.com/horovod/horovod/issues/1099
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)
        if hvd.rank() != 0:
            is_master = False
    logger.debug('is_master=%s' % is_master)

    lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine')
    if lr_scheduler_type == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=C.get()['epoch'], eta_min=0.)
    elif lr_scheduler_type == 'resnet':
        scheduler = adjust_learning_rate_resnet(optimizer)
    elif lr_scheduler_type == 'pyramid':
        scheduler = adjust_learning_rate_pyramid(optimizer, C.get()['epoch'])
    else:
        raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type)

    if C.get()['lr_schedule'].get('warmup', None):
        scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=C.get()['lr_schedule']['warmup']['multiplier'],
            total_epoch=C.get()['lr_schedule']['warmup']['epoch'],
            after_scheduler=scheduler
        )

    if not tag or not is_master:
        from FastAutoAugment.metrics import SummaryWriterDummy as SummaryWriter
        logger.warning('tag not provided, no tensorboard log.')
    else:
        from tensorboardX import SummaryWriter
    writers = [SummaryWriter(log_dir='./logs/%s/%s' % (tag, x)) for x in ['train', 'valid', 'test']]

    result = OrderedDict()
    epoch_start = 1
    if save_path and os.path.exists(save_path):
        logger.info('%s file found. loading...' % save_path)
        data = torch.load(save_path)
        if 'model' in data:
            logger.info('checkpoint epoch@%d' % data['epoch'])
            if not isinstance(model, DataParallel):
                model.load_state_dict({k.replace('module.', ''): v for k, v in data['model'].items()})
            else:
                model.load_state_dict({k if 'module.' in k else 'module.'+k: v for k, v in data['model'].items()})
            optimizer.load_state_dict(data['optimizer'])
            if data['epoch'] < C.get()['epoch']:
                epoch_start = data['epoch']
            else:
                only_eval = True
        else:
            model.load_state_dict({k: v for k, v in data.items()})
        del data
    else:
        logger.info('"%s" file not found. skip to pretrain weights...' % save_path)
        if only_eval:
            logger.warning('model checkpoint not found. only-evaluation mode is off.')
        only_eval = False

    if only_eval:
        logger.info('evaluation only+')
        model.eval()
        rs = dict()
        rs['train'] = run_epoch(model, trainloader, criterion, None, desc_default='train', epoch=0, writer=writers[0])
        rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=0, writer=writers[1])
        rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=0, writer=writers[2])
        for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']):
            if setname not in rs:
                continue
            result['%s_%s' % (key, setname)] = rs[setname][key]
        result['epoch'] = 0
        return result

    # train loop
    best_top1 = 0
    for epoch in range(epoch_start, max_epoch + 1):
        if horovod:
            trainsampler.set_epoch(epoch)

        model.train()
        rs = dict()
        rs['train'] = run_epoch(model, trainloader, criterion, optimizer, desc_default='train', epoch=epoch, writer=writers[0], verbose=is_master, scheduler=scheduler)
        model.eval()

        if math.isnan(rs['train']['loss']):
            raise Exception('train loss is NaN.')

        if epoch % 5 == 0 or epoch == max_epoch:
            rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=epoch, writer=writers[1], verbose=is_master)
            rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=epoch, writer=writers[2], verbose=is_master)

            if metric == 'last' or rs[metric]['top1'] > best_top1:
                if metric != 'last':
                    best_top1 = rs[metric]['top1']
                for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']):
                    result['%s_%s' % (key, setname)] = rs[setname][key]
                result['epoch'] = epoch

                writers[1].add_scalar('valid_top1/best', rs['valid']['top1'], epoch)
                writers[2].add_scalar('test_top1/best', rs['test']['top1'], epoch)

                reporter(
                    loss_valid=rs['valid']['loss'], top1_valid=rs['valid']['top1'],
                    loss_test=rs['test']['loss'], top1_test=rs['test']['top1']
                )

                # save checkpoint
                if is_master and save_path:
                    logger.info('save model@%d to %s' % (epoch, save_path))
                    torch.save({
                        'epoch': epoch,
                        'log': {
                            'train': rs['train'].get_dict(),
                            'valid': rs['valid'].get_dict(),
                            'test': rs['test'].get_dict(),
                        },
                        'optimizer': optimizer.state_dict(),
                        'model': model.state_dict()
                    }, save_path)
    del model

    result['top1_test'] = best_top1
    return result
Beispiel #3
0
def train_and_eval(
    args_save,
    tag,
    dataroot,
    test_ratio=0.0,
    cv_fold=0,
    reporter=None,
    metric="last",
    save_path=None,
    only_eval=False,
    local_rank=-1,
    evaluation_interval=5,
    get_dataloaders=None,
    num_class=None,
    get_model=None,
):
    assert get_dataloaders is not None
    assert num_class is not None
    assert get_model is not None
    total_batch = C.get()["batch"]
    if local_rank >= 0:
        dist.init_process_group(
            backend="nccl",
            init_method="env://",
            world_size=int(os.environ["WORLD_SIZE"]),
        )
        device = torch.device("cuda", local_rank)
        torch.cuda.set_device(device)

        C.get()["lr"] *= dist.get_world_size()
        logger.info(
            f'local batch={C.get()["batch"]} world_size={dist.get_world_size()} ----> total batch={C.get()["batch"] * dist.get_world_size()}'
        )
        total_batch = C.get()["batch"] * dist.get_world_size()

    is_master = local_rank < 0 or dist.get_rank() == 0
    if is_master:
        add_filehandler(logger, args_save + ".log")

    if not reporter:
        reporter = lambda **kwargs: 0

    max_epoch = C.get()["epoch"]
    trainsampler, trainloader, validloader, testloader_ = get_dataloaders(
        C.get()["dataset"],
        C.get()["batch"],
        dataroot,
        test_ratio,
        split_idx=cv_fold,
        multinode=(local_rank >= 0),
    )

    # create a model & an optimizer
    model = get_model(
        C.get()["model"], num_class(C.get()["dataset"]), local_rank=local_rank
    )
    model_ema = get_model(
        C.get()["model"], num_class(C.get()["dataset"]), local_rank=-1
    )
    model_ema.eval()

    criterion_ce = criterion = CrossEntropyLabelSmooth(
        num_class(C.get()["dataset"]), C.get().conf.get("lb_smooth", 0)
    )
    if C.get().conf.get("mixup", 0.0) > 0.0:
        criterion = CrossEntropyMixUpLabelSmooth(
            num_class(C.get()["dataset"]), C.get().conf.get("lb_smooth", 0)
        )
    if C.get()["optimizer"]["type"] == "sgd":
        optimizer = optim.SGD(
            model.parameters(),
            lr=C.get()["lr"],
            momentum=C.get()["optimizer"].get("momentum", 0.9),
            weight_decay=0.0,
            nesterov=C.get()["optimizer"].get("nesterov", True),
        )
    elif C.get()["optimizer"]["type"] == "rmsprop":
        optimizer = RMSpropTF(
            model.parameters(),
            lr=C.get()["lr"],
            weight_decay=0.0,
            alpha=0.9,
            momentum=0.9,
            eps=0.001,
        )
    else:
        raise ValueError("invalid optimizer type=%s" % C.get()["optimizer"]["type"])

    lr_scheduler_type = C.get()["lr_schedule"].get("type", "cosine")
    if lr_scheduler_type == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=C.get()["epoch"], eta_min=0.0
        )
    elif lr_scheduler_type == "resnet":
        scheduler = adjust_learning_rate_resnet(optimizer)
    elif lr_scheduler_type == "efficientnet":
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda x: 0.97
            ** int((x + C.get()["lr_schedule"]["warmup"]["epoch"]) / 2.4),
        )
    else:
        raise ValueError("invalid lr_schduler=%s" % lr_scheduler_type)

    if (
        C.get()["lr_schedule"].get("warmup", None)
        and C.get()["lr_schedule"]["warmup"]["epoch"] > 0
    ):
        scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=C.get()["lr_schedule"]["warmup"]["multiplier"],
            total_epoch=C.get()["lr_schedule"]["warmup"]["epoch"],
            after_scheduler=scheduler,
        )

    if not tag or not is_master:
        from FastAutoAugment.metrics import SummaryWriterDummy as SummaryWriter

        logger.warning("tag not provided, no tensorboard log.")
    else:
        from tensorboardX import SummaryWriter
    writers = [
        SummaryWriter(log_dir="./logs/%s/%s" % (tag, x))
        for x in ["train", "valid", "test"]
    ]

    if C.get()["optimizer"]["ema"] > 0.0 and is_master:
        # https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/4?u=ildoonet
        ema = EMA(C.get()["optimizer"]["ema"])
    else:
        ema = None

    result = OrderedDict()
    epoch_start = 1
    if (
        save_path != "test.pth"
    ):  # and is_master: --> should load all data(not able to be broadcasted)
        if save_path and os.path.exists(save_path):
            logger.info("%s file found. loading..." % save_path)
            data = torch.load(save_path)
            key = "model" if "model" in data else "state_dict"

            if "epoch" not in data:
                model.load_state_dict(data)
            else:
                logger.info("checkpoint epoch@%d" % data["epoch"])
                if not isinstance(model, (DataParallel, DistributedDataParallel)):
                    model.load_state_dict(
                        {k.replace("module.", ""): v for k, v in data[key].items()}
                    )
                else:
                    model.load_state_dict(
                        {
                            k if "module." in k else "module." + k: v
                            for k, v in data[key].items()
                        }
                    )
                logger.info("optimizer.load_state_dict+")
                optimizer.load_state_dict(data["optimizer"])
                if data["epoch"] < C.get()["epoch"]:
                    epoch_start = data["epoch"]
                else:
                    only_eval = True
                if ema is not None:
                    ema.shadow = (
                        data.get("ema", {})
                        if isinstance(data.get("ema", {}), dict)
                        else data["ema"].state_dict()
                    )
            del data
        else:
            logger.info('"%s" file not found. skip to pretrain weights...' % save_path)
            if only_eval:
                logger.warning(
                    "model checkpoint not found. only-evaluation mode is off."
                )
            only_eval = False

    if local_rank >= 0:
        for name, x in model.state_dict().items():
            dist.broadcast(x, 0)
        logger.info(
            f"multinode init. local_rank={dist.get_rank()} is_master={is_master}"
        )
        torch.cuda.synchronize()

    tqdm_disabled = (
        bool(os.environ.get("TASK_NAME", "")) and local_rank != 0
    )  # KakaoBrain Environment

    if only_eval:
        logger.info("evaluation only+")
        model.eval()
        rs = dict()
        rs["train"] = run_epoch(
            model,
            trainloader,
            criterion,
            None,
            desc_default="train",
            epoch=0,
            writer=writers[0],
            is_master=is_master,
        )

        with torch.no_grad():
            rs["valid"] = run_epoch(
                model,
                validloader,
                criterion,
                None,
                desc_default="valid",
                epoch=0,
                writer=writers[1],
                is_master=is_master,
            )
            rs["test"] = run_epoch(
                model,
                testloader_,
                criterion,
                None,
                desc_default="*test",
                epoch=0,
                writer=writers[2],
                is_master=is_master,
            )
            if ema is not None and len(ema) > 0:
                model_ema.load_state_dict(
                    {k.replace("module.", ""): v for k, v in ema.state_dict().items()}
                )
                rs["valid"] = run_epoch(
                    model_ema,
                    validloader,
                    criterion_ce,
                    None,
                    desc_default="valid(EMA)",
                    epoch=0,
                    writer=writers[1],
                    verbose=is_master,
                    tqdm_disabled=tqdm_disabled,
                )
                rs["test"] = run_epoch(
                    model_ema,
                    testloader_,
                    criterion_ce,
                    None,
                    desc_default="*test(EMA)",
                    epoch=0,
                    writer=writers[2],
                    verbose=is_master,
                    tqdm_disabled=tqdm_disabled,
                )
        for key, setname in itertools.product(
            ["loss", "top1", "top5"], ["train", "valid", "test"]
        ):
            if setname not in rs:
                continue
            result["%s_%s" % (key, setname)] = rs[setname][key]
        result["epoch"] = 0
        return result

    # train loop
    best_top1 = 0
    for epoch in range(epoch_start, max_epoch + 1):
        if local_rank >= 0:
            trainsampler.set_epoch(epoch)

        model.train()
        rs = dict()
        rs["train"] = run_epoch(
            model,
            trainloader,
            criterion,
            optimizer,
            desc_default="train",
            epoch=epoch,
            writer=writers[0],
            verbose=(is_master and local_rank <= 0),
            scheduler=scheduler,
            ema=ema,
            wd=C.get()["optimizer"]["decay"],
            tqdm_disabled=tqdm_disabled,
        )
        model.eval()

        if math.isnan(rs["train"]["loss"]):
            raise Exception("train loss is NaN.")

        if (
            ema is not None
            and C.get()["optimizer"]["ema_interval"] > 0
            and epoch % C.get()["optimizer"]["ema_interval"] == 0
        ):
            logger.info(f"ema synced+ rank={dist.get_rank()}")
            if ema is not None:
                model.load_state_dict(ema.state_dict())
            for name, x in model.state_dict().items():
                # print(name)
                dist.broadcast(x, 0)
            torch.cuda.synchronize()
            logger.info(f"ema synced- rank={dist.get_rank()}")

        if is_master and (epoch % evaluation_interval == 0 or epoch == max_epoch):
            with torch.no_grad():
                rs["valid"] = run_epoch(
                    model,
                    validloader,
                    criterion_ce,
                    None,
                    desc_default="valid",
                    epoch=epoch,
                    writer=writers[1],
                    verbose=is_master,
                    tqdm_disabled=tqdm_disabled,
                )
                rs["test"] = run_epoch(
                    model,
                    testloader_,
                    criterion_ce,
                    None,
                    desc_default="*test",
                    epoch=epoch,
                    writer=writers[2],
                    verbose=is_master,
                    tqdm_disabled=tqdm_disabled,
                )

                if ema is not None:
                    model_ema.load_state_dict(
                        {
                            k.replace("module.", ""): v
                            for k, v in ema.state_dict().items()
                        }
                    )
                    rs["valid"] = run_epoch(
                        model_ema,
                        validloader,
                        criterion_ce,
                        None,
                        desc_default="valid(EMA)",
                        epoch=epoch,
                        writer=writers[1],
                        verbose=is_master,
                        tqdm_disabled=tqdm_disabled,
                    )
                    rs["test"] = run_epoch(
                        model_ema,
                        testloader_,
                        criterion_ce,
                        None,
                        desc_default="*test(EMA)",
                        epoch=epoch,
                        writer=writers[2],
                        verbose=is_master,
                        tqdm_disabled=tqdm_disabled,
                    )

            logger.info(
                f"epoch={epoch} "
                f'[train] loss={rs["train"]["loss"]:.4f} top1={rs["train"]["top1"]:.4f} '
                f'[valid] loss={rs["valid"]["loss"]:.4f} top1={rs["valid"]["top1"]:.4f} '
                f'[test] loss={rs["test"]["loss"]:.4f} top1={rs["test"]["top1"]:.4f} '
            )

            if metric == "last" or rs[metric]["top1"] > best_top1:
                if metric != "last":
                    best_top1 = rs[metric]["top1"]
                for key, setname in itertools.product(
                    ["loss", "top1", "top5"], ["train", "valid", "test"]
                ):
                    result["%s_%s" % (key, setname)] = rs[setname][key]
                result["epoch"] = epoch

                writers[1].add_scalar("valid_top1/best", rs["valid"]["top1"], epoch)
                writers[2].add_scalar("test_top1/best", rs["test"]["top1"], epoch)

                reporter(
                    loss_valid=rs["valid"]["loss"],
                    top1_valid=rs["valid"]["top1"],
                    loss_test=rs["test"]["loss"],
                    top1_test=rs["test"]["top1"],
                )

                # save checkpoint
                if is_master and save_path:
                    logger.info(
                        "save model@%d to %s, err=%.4f"
                        % (epoch, save_path, 1 - best_top1)
                    )
                    torch.save(
                        {
                            "epoch": epoch,
                            "log": {
                                "train": rs["train"].get_dict(),
                                "valid": rs["valid"].get_dict(),
                                "test": rs["test"].get_dict(),
                            },
                            "optimizer": optimizer.state_dict(),
                            "model": model.state_dict(),
                            "ema": ema.state_dict() if ema is not None else None,
                        },
                        save_path,
                    )

    del model

    result["top1_test"] = best_top1
    return result
def train_and_eval(tag, dataroot, test_ratio=0.0, cv_fold=0, reporter=None, metric='last', save_path=None, only_eval=False, horovod=False):
    if horovod:
        import horovod.torch as hvd
        hvd.init()
        device = torch.device('cuda', hvd.local_rank())
        torch.cuda.set_device(device)

    if not reporter:
        reporter = lambda **kwargs: 0

    max_epoch = C.get()['epoch']
    # trainsampler, trainloader, validloader, testloader_ = get_dataloaders(dataroot, C.get()['batch'], horovod=horovod)
    trainsampler, trainloader, validloader, testloader_ = get_dataloaders(C.get()['dataset'], C.get()['batch'], dataroot, test_ratio, split_idx=cv_fold, horovod=horovod)

    # create a model & an optimizer
    model = get_model(C.get()['model'], num_class(C.get()['dataset']), data_parallel=(not horovod))

    criterion = nn.CrossEntropyLoss()
    if C.get()['optimizer']['type'] == 'sgd':
        optimizer = optim.SGD(
            model.parameters(),
            lr=C.get()['lr'],
            momentum=C.get()['optimizer'].get('momentum', 0.9),
            weight_decay=C.get()['optimizer']['decay'],
            nesterov=C.get()['optimizer']['nesterov']
        )
    else:
        raise ValueError('invalid optimizer type=%s' % C.get()['optimizer']['type'])

    is_master = True
    if horovod:
        optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
        optimizer._requires_update = set()  # issue : https://github.com/horovod/horovod/issues/1099
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)
        if hvd.rank() != 0:
            is_master = False
    logger.debug('is_master=%s' % is_master)

    lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine')
    if lr_scheduler_type == 'cosine':
        t_max = C.get()['epoch']
        if C.get()['lr_schedule'].get('warmup', None):
            t_max -= C.get()['lr_schedule']['warmup']['epoch']
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=0.)
    elif lr_scheduler_type == 'resnet':
        scheduler = adjust_learning_rate_resnet(optimizer)
    elif lr_scheduler_type == 'pyramid':
        scheduler = adjust_learning_rate_pyramid(optimizer, C.get()['epoch'])
    else:
        raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type)

    if C.get()['lr_schedule'].get('warmup', None):
        scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=C.get()['lr_schedule']['warmup']['multiplier'],
            total_epoch=C.get()['lr_schedule']['warmup']['epoch'],
            after_scheduler=scheduler
        )
    if not tag.strip() or not is_master:
        from FastAutoAugment.metrics import SummaryWriterDummy as SummaryWriter
        logger.warning('tag not provided, no tensorboard log.')
    else:
        from tensorboardX import SummaryWriter
    writers = [SummaryWriter(log_dir='/app/results/logs/%s/%s' % (tag, x)) for x in ['train', 'valid', 'test']]

    result = OrderedDict()
    epoch_start = 1
    # if save_path and os.path.exists(save_path):
    #     data = torch.load(save_path)
    #     if 'model' in data:
    #         # TODO : patch, horovod trained checkpoint
    #         new_state_dict = {}
    #         for k, v in data['model'].items():
    #             if not horovod and 'module.' not in k:
    #                 new_state_dict['module.' + k] = v
    #             else:
    #                 new_state_dict[k] = v
    #
    #         model.load_state_dict(new_state_dict)
    #         optimizer.load_state_dict(data['optimizer'])
    #         logger.info('ckpt epoch@%d' % data['epoch'])
    #         if data['epoch'] < C.get()['epoch']:
    #             epoch_start = data['epoch']
    #         else:
    #             only_eval = True
    #         logger.info('epoch=%d' % data['epoch'])
    #     else:
    #         model.load_state_dict(data)
    #     del data

    if only_eval:
        logger.info('evaluation only+')
        model.eval()
        rs = dict()
        rs['train'] = run_epoch(model, trainloader, criterion, None, desc_default='train', epoch=0, writer=writers[0])
        rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=0, writer=writers[1])
        rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=0, writer=writers[2])
        for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']):
            result['%s_%s' % (key, setname)] = rs[setname][key]
        result['epoch'] = 0
        return result
    # train loop
    best_valid_loss = 10e10

    for epoch in range(epoch_start, max_epoch + 1):
        if horovod:
            trainsampler.set_epoch(epoch)
        model.train()
        rs = dict()
        rs['train'] = run_epoch(model, trainloader, criterion, optimizer, desc_default='train', epoch=epoch, writer=writers[0], verbose=is_master, scheduler=scheduler)
        AugmentationPba.epoch += 1
        scheduler.step(epoch)
        model.eval()

        if math.isnan(rs['train']['loss']):
            raise Exception('train loss is NaN.')

        if epoch % (10 if 'cifar' in C.get()['dataset'] else 30) == 0 or epoch == max_epoch:
            rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=epoch, writer=writers[1], verbose=is_master)
            rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=epoch, writer=writers[2], verbose=is_master)

            if metric == 'last' or rs[metric]['loss'] < best_valid_loss:    # TODO
                if metric != 'last':
                    best_valid_loss = rs[metric]['loss']
                for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']):
                    result['%s_%s' % (key, setname)] = rs[setname][key]
                result['epoch'] = epoch

                writers[1].add_scalar('valid_top1/best', rs['valid']['top1'], epoch)
                writers[2].add_scalar('test_top1/best', rs['test']['top1'], epoch)

                reporter(
                    loss_valid=rs['valid']['loss'], top1_valid=rs['valid']['top1'],
                    loss_test=rs['test']['loss'], top1_test=rs['test']['top1']
                )

            # save checkpoint
            if is_master and save_path:
                model_name = C.get()['model']['type']
                if 'skip' in model_name:
                    alpha = int(np.log10(1/C.get()['alpha']))
                    filename = '{}/{}_last_epoch_alpha_{}.pth'.format(save_path, model_name, alpha)
                else:
                    filename = '{}/{}_last_epoch.pth'.format(save_path, model_name)
                logger.info('save model@%d to %s' % (epoch, filename))
                torch.save({
                    'epoch': epoch,
                    'log': {
                        'train': rs['train'].get_dict(),
                        'valid': rs['valid'].get_dict(),
                        'test': rs['test'].get_dict(),
                    },
                    'optimizer': optimizer.state_dict(),
                    'model': model.state_dict()
                }, filename)

    del model

    return result