Ejemplo n.º 1
0
def train_model():
    """Trains the model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model, loss_fun, and optimizer
    model = setup_model()
    loss_fun = builders.build_loss_fun().cuda()
    optimizer = optim.construct_optimizer(model)
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model,
                                                      optimizer)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
        logger.info("Loaded initial weights from: {}".format(
            cfg.TRAIN.WEIGHTS))
    # Create data loaders and meters
    train_loader = loader.construct_train_loader()
    test_loader = loader.construct_test_loader()
    train_meter = meters.TrainMeter(len(train_loader))
    test_meter = meters.TestMeter(len(test_loader))
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train for one epoch
        train_epoch(train_loader, model, loss_fun, optimizer, train_meter,
                    cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            net.compute_precise_bn_stats(model, train_loader)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                model, optimizer, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            test_epoch(test_loader, model, test_meter, cur_epoch)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
Ejemplo n.º 2
0
def main():
    setup_env()
    # loadiong search space
    search_space = build_space()
    # init controller and architect
    loss_fun = nn.CrossEntropyLoss().cuda()
    darts_controller = DartsCNNController(search_space, loss_fun)
    darts_controller.cuda()
    architect = Architect(darts_controller, cfg.OPTIM.MOMENTUM,
                          cfg.OPTIM.WEIGHT_DECAY)
    # load dataset
    [train_, val_] = _construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                       cfg.SEARCH.BATCH_SIZE)
    # weights optimizer
    w_optim = torch.optim.SGD(darts_controller.weights(),
                              cfg.OPTIM.BASE_LR,
                              momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(darts_controller.alphas(),
                                   cfg.DARTS.ALPHA_LR,
                                   betas=(0.5, 0.999),
                                   weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, cfg.OPTIM.MAX_EPOCH, eta_min=cfg.OPTIM.MIN_LR)
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    start_epoch = 0
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        lr = lr_scheduler.get_last_lr()[0]
        train_epoch(train_, val_, darts_controller, architect, loss_fun,
                    w_optim, alpha_optim, lr, train_meter, cur_epoch)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                darts_controller, w_optim, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        lr_scheduler.step()
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            test_epoch(val_,
                       darts_controller,
                       val_meter,
                       cur_epoch,
                       tensorboard_writer=writer)
            logger.info(
                "###############Optimal genotype at epoch: {}############".
                format(cur_epoch))
            logger.info(darts_controller.genotype())
            logger.info(
                "########################################################")
            darts_controller.print_alphas(logger)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
Ejemplo n.º 3
0
def main():
    tic = time.time()
    setup_env()
    # loadiong search space
    search_space = build_space()
    # init controller and architect
    loss_fun = nn.CrossEntropyLoss().cuda()
    darts_controller = PCDartsCNNController(search_space, loss_fun)
    darts_controller.cuda()
    architect = Architect(
        darts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY)

    # load dataset
    #[train_, val_] = _construct_loader(
    #    cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE)

    train_transform, valid_transform = _data_transforms_cifar10()

    train_data = dset.CIFAR10(root=cfg.SEARCH.DATASET, train=True, download=True, transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(cfg.SEARCH.SPLIT[0] * num_train))

    train_ = torch.utils.data.DataLoader(
        train_data, batch_size=cfg.SEARCH.BATCH_SIZE,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=2)

    val_ = torch.utils.data.DataLoader(
        train_data, batch_size=cfg.SEARCH.BATCH_SIZE,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True, num_workers=2)

    # weights optimizer
    w_optim = torch.optim.SGD(darts_controller.weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(darts_controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999),
                                   weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, cfg.OPTIM.MAX_EPOCH, eta_min=cfg.OPTIM.MIN_LR)
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    start_epoch = 0
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        logger.info("###############Optimal genotype at epoch: {}############".format(cur_epoch))
        logger.info(darts_controller.genotype())
        logger.info("########################################################")
        darts_controller.print_alphas(logger)

        lr_scheduler.step()
        lr = lr_scheduler.get_last_lr()[0]
        train_epoch(train_, val_, darts_controller, architect, loss_fun,
                    w_optim, alpha_optim, lr, train_meter, cur_epoch)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                darts_controller, w_optim, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            test_epoch(val_, darts_controller, val_meter, cur_epoch, tensorboard_writer=writer)
            logger.info("###############Optimal genotype at epoch: {}############".format(cur_epoch))
            logger.info(darts_controller.genotype())
            logger.info("########################################################")
            darts_controller.print_alphas(logger)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache()  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()

    toc = time.time()
    logger.info("Search-time(GPUh): {}".format((toc - tic)/3600))
Ejemplo n.º 4
0
def train_model():
    """SNG search model training"""
    setup_env()
    # Load search space
    search_space = build_space()
    search_space.cuda()
    loss_fun = build_loss_fun().cuda()

    # Weights optimizer
    w_optim = torch.optim.SGD(search_space.parameters(),
                              cfg.OPTIM.BASE_LR,
                              momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)

    # Build distribution_optimizer
    if cfg.SPACE.NAME in ['darts', 'nasbench301']:
        distribution_optimizer = sng_builder([search_space.num_ops] *
                                             search_space.all_edges)
    elif cfg.SPACE.NAME in ['proxyless', 'google', 'ofa']:
        distribution_optimizer = sng_builder([search_space.num_ops] *
                                             search_space.all_edges)
    elif cfg.SPACE.NAME in [
            "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3"
    ]:
        category = []
        cs = search_space.search_space.get_configuration_space()
        for h in cs.get_hyperparameters():
            if type(h
                    ) == ConfigSpace.hyperparameters.CategoricalHyperparameter:
                category.append(len(h.choices))
        distribution_optimizer = sng_builder(category)
    else:
        raise NotImplementedError

    # Load dataset
    [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                      cfg.SEARCH.BATCH_SIZE)

    lr_scheduler = lr_scheduler_builder(w_optim)
    all_timer = Timer()
    _over_all_epoch = 0

    # ===== Warm up training =====
    logger.info("start warm up training")
    warm_train_meter = meters.TrainMeter(len(train_))
    warm_val_meter = meters.TestMeter(len(val_))
    all_timer.tic()
    for cur_epoch in range(cfg.OPTIM.WARMUP_EPOCHS):

        # Save a checkpoint
        if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                search_space, w_optim, _over_all_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))

        lr = lr_scheduler.get_last_lr()[0]
        if cfg.SNG.WARMUP_RANDOM_SAMPLE:
            sample = random_sampling(search_space,
                                     distribution_optimizer,
                                     epoch=cur_epoch)
            logger.info("Sampling: {}".format(one_hot_to_index(sample)))
            train_epoch(train_, val_, search_space, w_optim, lr,
                        _over_all_epoch, sample, loss_fun, warm_train_meter)
            top1 = test_epoch_with_sample(val_, search_space, warm_val_meter,
                                          _over_all_epoch, sample, writer)
            _over_all_epoch += 1
        else:
            num_ops, total_edges = search_space.num_ops, search_space.all_edges
            array_sample = [
                random.sample(list(range(num_ops)), num_ops)
                for i in range(total_edges)
            ]
            array_sample = np.array(array_sample)
            for i in range(num_ops):
                sample = np.transpose(array_sample[:, i])
                sample = index_to_one_hot(sample,
                                          distribution_optimizer.p_model.Cmax)
                train_epoch(train_, val_, search_space, w_optim, lr,
                            _over_all_epoch, sample, loss_fun,
                            warm_train_meter)
                top1 = test_epoch_with_sample(val_, search_space,
                                              warm_val_meter, _over_all_epoch,
                                              sample, writer)
                _over_all_epoch += 1
    all_timer.toc()
    logger.info("end warm up training")

    # ===== Training procedure =====
    logger.info("start one-shot training")
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    all_timer.tic()
    for cur_epoch in range(cfg.OPTIM.MAX_EPOCH):

        # Save a checkpoint
        if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                search_space, w_optim, _over_all_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))

        if hasattr(distribution_optimizer, 'training_finish'):
            if distribution_optimizer.training_finish:
                break
        lr = w_optim.param_groups[0]['lr']

        sample = random_sampling(search_space,
                                 distribution_optimizer,
                                 epoch=cur_epoch,
                                 _random=cfg.SNG.RANDOM_SAMPLE)
        logger.info("Sampling: {}".format(one_hot_to_index(sample)))
        train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch,
                    sample, loss_fun, train_meter)
        top1 = test_epoch_with_sample(val_, search_space, val_meter,
                                      _over_all_epoch, sample, writer)
        _over_all_epoch += 1

        lr_scheduler.step()
        distribution_optimizer.record_information(sample, top1)
        distribution_optimizer.update()

        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            logger.info(
                "###############Optimal genotype at epoch: {}############".
                format(cur_epoch))
            logger.info(
                search_space.genotype(distribution_optimizer.p_model.theta))
            logger.info(
                "########################################################")
            logger.info("####### ALPHA #######")
            for alpha in distribution_optimizer.p_model.theta:
                logger.info(alpha)
            logger.info("#####################")
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
    all_timer.toc()

    # ===== Final epoch =====
    lr = w_optim.param_groups[0]['lr']
    all_timer.tic()
    for cur_epoch in range(cfg.OPTIM.FINAL_EPOCH):

        # Save a checkpoint
        if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                search_space, w_optim, _over_all_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))

        if cfg.SPACE.NAME in ['darts', 'nasbench301']:
            genotype = search_space.genotype(
                distribution_optimizer.p_model.theta)
            sample = search_space.genotype_to_onehot_sample(genotype)
        else:
            sample = distribution_optimizer.sampling_best()
        _over_all_epoch += 1
        train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch,
                    sample, loss_fun, train_meter)
        test_epoch_with_sample(val_, search_space, val_meter, _over_all_epoch,
                               sample, writer)
    logger.info("Overall training time : {} hours".format(
        str((all_timer.total_time) / 3600.)))

    # Evaluate through nasbench
    if cfg.SPACE.NAME in [
            "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3",
            "nasbench201", "nasbench301"
    ]:
        logger.info("starting test using nasbench:{}".format(cfg.SPACE.NAME))
        theta = distribution_optimizer.p_model.theta
        EvaluateNasbench(theta, search_space, logger, cfg.SPACE.NAME)
Ejemplo n.º 5
0
def main():
    setup_env()
    # loadiong search space
    # init controller and architect
    loss_fun = nn.CrossEntropyLoss().cuda()
    # load dataset
    [train_, val_] = _construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                       cfg.SEARCH.BATCH_SIZE)

    num_to_keep = [5, 3, 1]
    eps_no_archs = [10, 10, 10]
    drop_rate = [0.1, 0.4, 0.7]
    add_layers = [0, 6, 12]
    add_width = [0, 0, 0]
    PRIMITIVES = cfg.SPACE.PRIMITIVES
    edgs_num = (cfg.SPACE.NODES + 3) * cfg.SPACE.NODES // 2
    basic_op = []
    for i in range(edgs_num * 2):
        basic_op.append(PRIMITIVES)
    for sp in range(len(num_to_keep)):
        # update the info of the supernet config
        cfg.defrost()
        cfg.SEARCH.add_layers = add_layers[sp]
        cfg.SEARCH.add_width = add_width[sp]
        cfg.SEARCH.dropout_rate = float(drop_rate[sp])
        cfg.SPACE.BASIC_OP = basic_op

        search_space = build_space()
        controller = PdartsCNNController(search_space, loss_fun)
        controller.cuda()
        architect = Architect(controller, cfg.OPTIM.MOMENTUM,
                              cfg.OPTIM.WEIGHT_DECAY)
        # weights optimizer
        w_optim = torch.optim.SGD(controller.subnet_weights(),
                                  cfg.OPTIM.BASE_LR,
                                  momentum=cfg.OPTIM.MOMENTUM,
                                  weight_decay=cfg.OPTIM.WEIGHT_DECAY)
        # alphas optimizer
        alpha_optim = torch.optim.Adam(
            controller.alphas(),
            cfg.DARTS.ALPHA_LR,
            betas=(0.5, 0.999),
            weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            w_optim, cfg.OPTIM.MAX_EPOCH, eta_min=cfg.OPTIM.MIN_LR)
        train_meter = meters.TrainMeter(len(train_))
        val_meter = meters.TestMeter(len(val_))
        start_epoch = 0
        # Perform the training loop
        logger.info("Start epoch: {}".format(start_epoch + 1))
        scale_factor = 0.2
        for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
            print('cur_epoch', cur_epoch)
            lr = lr_scheduler.get_last_lr()[0]
            if cur_epoch < eps_no_archs[sp]:
                controller.update_p(
                    float(drop_rate[sp]) *
                    (cfg.OPTIM.MAX_EPOCH - cur_epoch - 1) /
                    cfg.OPTIM.MAX_EPOCH)
                train_epoch(train_,
                            val_,
                            controller,
                            architect,
                            loss_fun,
                            w_optim,
                            alpha_optim,
                            lr,
                            train_meter,
                            cur_epoch,
                            train_arch=False)
            else:
                controller.update_p(
                    float(drop_rate[sp]) *
                    np.exp(-(cur_epoch - eps_no_archs[sp]) * scale_factor))
                train_epoch(train_,
                            val_,
                            controller,
                            architect,
                            loss_fun,
                            w_optim,
                            alpha_optim,
                            lr,
                            train_meter,
                            cur_epoch,
                            train_arch=True)
            # Save a checkpoint
            if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
                checkpoint_file = checkpoint.save_checkpoint(
                    controller, w_optim, cur_epoch)
                logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
            lr_scheduler.step()
            # Evaluate the model
            next_epoch = cur_epoch + 1
            if next_epoch >= cfg.OPTIM.MAX_EPOCH - 5:
                logger.info("Start testing")
                test_epoch(val_,
                           controller,
                           val_meter,
                           cur_epoch,
                           tensorboard_writer=writer)
                logger.info(
                    "###############Optimal genotype at epoch: {}############".
                    format(cur_epoch))
                logger.info(controller.genotype())
                logger.info(
                    "########################################################")
                controller.print_alphas(logger)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache(
                )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
            gc.collect()
            print("now top k primitive", num_to_keep[sp],
                  controller.get_topk_op(num_to_keep[sp]))

        if sp == len(num_to_keep) - 1:
            logger.info(
                "###############final Optimal genotype: {}############")
            logger.info(controller.genotype(final=True))
            logger.info(
                "########################################################")
            controller.print_alphas(logger)

            logger.info('Restricting skipconnect...')
            for sks in range(0, 9):
                max_sk = 8 - sks
                num_sk = controller.get_skip_number()
                if not num_sk > max_sk:
                    continue
                while num_sk > max_sk:
                    controller.delete_skip()
                    num_sk = controller.get_skip_number()

                logger.info('Number of skip-connect: %d', max_sk)
                logger.info(controller.genotype(final=True))
        else:
            basic_op = controller.get_topk_op(num_to_keep[sp])
        logger.info("###############final Optimal genotype: {}############")
        logger.info(controller.genotype(final=True))
        logger.info("########################################################")
        controller.print_alphas(logger)

        logger.info('Restricting skipconnect...')
        for sks in range(0, 9):
            max_sk = 8 - sks
            num_sk = controller.get_skip_number()
            if not num_sk > max_sk:
                continue
            while num_sk > max_sk:
                controller.delete_skip()
                num_sk = controller.get_skip_number()

            logger.info('Number of skip-connect: %d', max_sk)
            logger.info(controller.genotype(final=True))
Ejemplo n.º 6
0
def main():
    setup_env()

    # 32 3 10 === 32 16 10
    # print(input_size, input_channels, n_classes, '===', cfg.SEARCH.IM_SIZE, cfg.SPACE.CHANNEL, cfg.SEARCH.NUM_CLASSES)

    loss_fun = build_loss_fun().cuda()
    use_aux = cfg.TRAIN.AUX_WEIGHT > 0.

    # SEARCH.INIT_CHANNEL as 3 for rgb and TRAIN.CHANNELS as 32 by manual.
    # IM_SIZE, CHANNEL and NUM_CLASSES should be same with search period.
    model = AugmentCNN(cfg.SEARCH.IM_SIZE, cfg.SEARCH.INPUT_CHANNEL,
                       cfg.TRAIN.CHANNELS, cfg.SEARCH.NUM_CLASSES,
                       cfg.TRAIN.LAYERS, use_aux, cfg.TRAIN.GENOTYPE)

    # TODO: Parallel
    # model = nn.DataParallel(model, device_ids=cfg.NUM_GPUS).to(device)
    model.cuda()

    # weights optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                cfg.OPTIM.BASE_LR,
                                momentum=cfg.OPTIM.MOMENTUM,
                                weight_decay=cfg.OPTIM.WEIGHT_DECAY)

    # Get data loader
    [train_loader,
     valid_loader] = construct_loader(cfg.TRAIN.DATASET, cfg.TRAIN.SPLIT,
                                      cfg.TRAIN.BATCH_SIZE)

    lr_scheduler = lr_scheduler_builder(optimizer)

    best_top1err = 0.

    # TODO: DALI backend support
    # if config.data_loader_type == 'DALI':
    #     len_train_loader = get_train_loader_len(config.dataset.lower(), config.batch_size, is_train=True)
    # else:
    len_train_loader = len(train_loader)

    # Training loop
    # TODO: RESUME

    train_meter = meters.TrainMeter(len(train_loader))
    valid_meter = meters.TestMeter(len(valid_loader))

    start_epoch = 0
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):

        drop_prob = cfg.TRAIN.DROP_PATH_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH
        if cfg.NUM_GPUS > 1:
            model.module.drop_path_prob(drop_prob)
        else:
            model.drop_path_prob(drop_prob)

        # Training
        train_epoch(train_loader, model, optimizer, loss_fun, cur_epoch,
                    train_meter)

        # Save a checkpoint
        if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                model, optimizer, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))

        lr_scheduler.step()

        # Validation
        cur_step = (cur_epoch + 1) * len(train_loader)
        top1_err = valid_epoch(valid_loader, model, loss_fun, cur_epoch,
                               cur_step, valid_meter)
        logger.info("top1 error@epoch {}: {}".format(cur_epoch + 1, top1_err))
        best_top1err = min(best_top1err, top1_err)

    logger.info("Final best Prec@1 = {:.4%}".format(100 - best_top1err))