示例#1
0
def train_model():
    """Trains the model."""
    # Setup training/testing environment
    setup_env()
    # Construct the model, loss_fun, and optimizer
    model = setup_model()
    loss_fun = builders.build_loss_fun().cuda()
    optimizer = optim.construct_optimizer(model)
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model,
                                                      optimizer)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.TRAIN.WEIGHTS:
        checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
        logger.info("Loaded initial weights from: {}".format(
            cfg.TRAIN.WEIGHTS))
    # Create data loaders and meters
    train_loader = loader.construct_train_loader()
    test_loader = loader.construct_test_loader()
    train_meter = meters.TrainMeter(len(train_loader))
    test_meter = meters.TestMeter(len(test_loader))
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train for one epoch
        train_epoch(train_loader, model, loss_fun, optimizer, train_meter,
                    cur_epoch)
        # Compute precise BN stats
        if cfg.BN.USE_PRECISE_STATS:
            net.compute_precise_bn_stats(model, train_loader)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                model, optimizer, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            test_epoch(test_loader, model, test_meter, cur_epoch)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
示例#2
0
def darts_train_model():
    """train DARTS model"""
    setup_env()
    # Loading search space
    search_space = build_space()
    # TODO: fix the complexity function
    # search_space = setup_model()
    # Init controller and architect
    loss_fun = build_loss_fun().cuda()
    darts_controller = DartsCNNController(search_space, loss_fun)
    darts_controller.cuda()
    architect = Architect(darts_controller, cfg.OPTIM.MOMENTUM,
                          cfg.OPTIM.WEIGHT_DECAY)
    # Load dataset
    [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                      cfg.SEARCH.BATCH_SIZE)
    # weights optimizer
    w_optim = torch.optim.SGD(darts_controller.weights(),
                              cfg.OPTIM.BASE_LR,
                              momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)
    # alphas optimizer
    a_optim = torch.optim.Adam(darts_controller.alphas(),
                               cfg.DARTS.ALPHA_LR,
                               betas=(0.5, 0.999),
                               weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY)
    lr_scheduler = lr_scheduler_builder(w_optim)
    # Init meters
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = darts_load_checkpoint(last_checkpoint,
                                                 darts_controller, w_optim,
                                                 a_optim)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.SEARCH.WEIGHTS:
        darts_load_checkpoint(cfg.SEARCH.WEIGHTS, darts_controller)
        logger.info("Loaded initial weights from: {}".format(
            cfg.SEARCH.WEIGHTS))
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        benchmark.compute_time_full(darts_controller, loss_fun, train_, val_)
    # Setup timer
    train_timer = Timer()
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    train_timer.tic()
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):

        lr = lr_scheduler.get_last_lr()[0]
        train_epoch(train_, val_, darts_controller, architect, loss_fun,
                    w_optim, a_optim, lr, train_meter, cur_epoch)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = darts_save_checkpoint(darts_controller, w_optim,
                                                    a_optim, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        lr_scheduler.step()
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            test_epoch(val_, darts_controller, val_meter, cur_epoch, writer)
            logger.info(
                "###############Optimal genotype at epoch: {}############".
                format(cur_epoch))
            logger.info(darts_controller.genotype())
            logger.info(
                "########################################################")

            if cfg.SPACE.NAME == "nasbench301":
                logger.info("Evaluating with nasbench301")
                EvaluateNasbench(darts_controller.alpha, darts_controller.net,
                                 logger, "nasbench301")

            darts_controller.print_alphas(logger)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
    train_timer.toc()
    logger.info("Overall training time (hr) is:{}".format(
        str(train_timer.total_time)))
示例#3
0
def pcdarts_train_model():
    """train PC-DARTS model"""
    setup_env()
    # Loading search space
    search_space = build_space()
    # TODO: fix the complexity function
    # search_space = setup_model()
    # Init controller and architect
    loss_fun = build_loss_fun().cuda()
    pcdarts_controller = PCDartsCNNController(search_space, loss_fun)
    pcdarts_controller.cuda()
    architect = Architect(pcdarts_controller, cfg.OPTIM.MOMENTUM,
                          cfg.OPTIM.WEIGHT_DECAY)

    # Load dataset
    train_transform, valid_transform = data_transforms_cifar10(cutout_length=0)

    train_data = dset.CIFAR10(root=cfg.SEARCH.DATASET,
                              train=True,
                              download=True,
                              transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(cfg.SEARCH.SPLIT[0] * num_train))

    train_ = torch.utils.data.DataLoader(
        train_data,
        batch_size=cfg.SEARCH.BATCH_SIZE,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)
    val_ = torch.utils.data.DataLoader(
        train_data,
        batch_size=cfg.SEARCH.BATCH_SIZE,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=2)

    # weights optimizer
    w_optim = torch.optim.SGD(pcdarts_controller.weights(),
                              cfg.OPTIM.BASE_LR,
                              momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)
    # alphas optimizer
    a_optim = torch.optim.Adam(pcdarts_controller.alphas(),
                               cfg.DARTS.ALPHA_LR,
                               betas=(0.5, 0.999),
                               weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY)
    lr_scheduler = lr_scheduler_builder(w_optim)
    # Init meters
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = darts_load_checkpoint(last_checkpoint,
                                                 pcdarts_controller, w_optim,
                                                 a_optim)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.SEARCH.WEIGHTS:
        darts_load_checkpoint(cfg.SEARCH.WEIGHTS, pcdarts_controller)
        logger.info("Loaded initial weights from: {}".format(
            cfg.SEARCH.WEIGHTS))
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        benchmark.compute_time_full(pcdarts_controller, loss_fun, train_, val_)
    # Setup timer
    train_timer = Timer()
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    train_timer.tic()
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        lr = lr_scheduler.get_last_lr()[0]
        train_epoch(train_, val_, pcdarts_controller, architect, loss_fun,
                    w_optim, a_optim, lr, train_meter, cur_epoch)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = darts_save_checkpoint(pcdarts_controller,
                                                    w_optim, a_optim,
                                                    cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        lr_scheduler.step()
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            test_epoch(val_, pcdarts_controller, val_meter, cur_epoch, writer)
            logger.info(
                "###############Optimal genotype at epoch: {}############".
                format(cur_epoch))
            logger.info(pcdarts_controller.genotype())
            logger.info(
                "########################################################")
            pcdarts_controller.print_alphas(logger)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
    train_timer.toc()
    logger.info("Overall training time (hr) is:{}".format(
        str(train_timer.total_time)))
示例#4
0
def pdarts_train_model():
    """train PDARTS model"""

    num_to_keep = [5, 3, 1]
    eps_no_archs = [10, 10, 10]
    drop_rate = [0.1, 0.4, 0.7]
    add_layers = [0, 6, 12]
    add_width = [0, 0, 0]
    scale_factor = 0.2
    PRIMITIVES = cfg.SPACE.PRIMITIVES
    edgs_num = (cfg.SPACE.NODES + 3) * cfg.SPACE.NODES // 2
    basic_op = []
    for i in range(edgs_num * 2):
        basic_op.append(PRIMITIVES)

    setup_env()
    loss_fun = build_loss_fun().cuda()
    # Load dataset
    [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                      cfg.SEARCH.BATCH_SIZE)

    for sp in range(len(num_to_keep)):
        # Update info of supernet config
        cfg.defrost()
        cfg.SEARCH.add_layers = add_layers[sp]
        cfg.SEARCH.add_width = add_width[sp]
        cfg.SEARCH.dropout_rate = float(drop_rate[sp])
        cfg.SPACE.BASIC_OP = basic_op
        # Loading search space
        search_space = build_space()
        # TODO: fix the complexity function
        # search_space = setup_model()
        # Init controller and architect
        pdarts_controller = PDartsCNNController(search_space, loss_fun)
        pdarts_controller.cuda()
        architect = Architect(pdarts_controller, cfg.OPTIM.MOMENTUM,
                              cfg.OPTIM.WEIGHT_DECAY)
        # weights optimizer
        w_optim = torch.optim.SGD(pdarts_controller.subnet_weights(),
                                  cfg.OPTIM.BASE_LR,
                                  momentum=cfg.OPTIM.MOMENTUM,
                                  weight_decay=cfg.OPTIM.WEIGHT_DECAY)
        # alphas optimizer
        a_optim = torch.optim.Adam(pdarts_controller.alphas(),
                                   cfg.DARTS.ALPHA_LR,
                                   betas=(0.5, 0.999),
                                   weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY)
        lr_scheduler = lr_scheduler_builder(w_optim)
        # Init meters
        train_meter = meters.TrainMeter(len(train_))
        val_meter = meters.TestMeter(len(val_))
        # Load checkpoint or initial weights
        start_epoch = 0
        if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint():
            last_checkpoint = checkpoint.get_last_checkpoint()
            checkpoint_epoch = darts_load_checkpoint(last_checkpoint,
                                                     pdarts_controller,
                                                     w_optim, a_optim)
            logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
            start_epoch = checkpoint_epoch + 1
        elif cfg.SEARCH.WEIGHTS:
            darts_load_checkpoint(cfg.SEARCH.WEIGHTS, pdarts_controller)
            logger.info("Loaded initial weights from: {}".format(
                cfg.SEARCH.WEIGHTS))
        # Compute model and loader timings
        if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
            benchmark.compute_time_full(pdarts_controller, loss_fun, train_,
                                        val_)
        # TODO: Setup timer
        # train_timer = Timer()
        # Perform the training loop
        logger.info("Start epoch: {}".format(start_epoch + 1))
        for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
            print('cur_epoch', cur_epoch)
            lr = lr_scheduler.get_last_lr()[0]
            if cur_epoch < eps_no_archs[sp]:
                pdarts_controller.update_p(
                    float(drop_rate[sp]) *
                    (cfg.OPTIM.MAX_EPOCH - cur_epoch - 1) /
                    cfg.OPTIM.MAX_EPOCH)
                train_epoch(train_,
                            val_,
                            pdarts_controller,
                            architect,
                            loss_fun,
                            w_optim,
                            a_optim,
                            lr,
                            train_meter,
                            cur_epoch,
                            train_arch=False)
            else:
                pdarts_controller.update_p(
                    float(drop_rate[sp]) *
                    np.exp(-(cur_epoch - eps_no_archs[sp]) * scale_factor))
                train_epoch(train_,
                            val_,
                            pdarts_controller,
                            architect,
                            loss_fun,
                            w_optim,
                            a_optim,
                            lr,
                            train_meter,
                            cur_epoch,
                            train_arch=True)
            # Save a checkpoint
            if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
                checkpoint_file = darts_save_checkpoint(
                    pdarts_controller, w_optim, a_optim, cur_epoch)
                logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
            lr_scheduler.step()
            # Evaluate the model
            next_epoch = cur_epoch + 1
            if next_epoch >= cfg.OPTIM.MAX_EPOCH - 5:
                # if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
                logger.info("Start testing")
                test_epoch(val_, pdarts_controller, val_meter, cur_epoch,
                           writer)
                logger.info(
                    "###############Optimal genotype at epoch: {}############".
                    format(cur_epoch))
                logger.info(pdarts_controller.genotype())
                logger.info(
                    "########################################################")
                pdarts_controller.print_alphas(logger)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache(
                )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
            gc.collect()
            print("Top-{} primitive: {}".format(
                num_to_keep[sp],
                pdarts_controller.get_topk_op(num_to_keep[sp])))

        if sp == len(num_to_keep) - 1:
            logger.info(
                "###############final Optimal genotype: {}############")
            logger.info(pdarts_controller.genotype(final=True))
            logger.info(
                "########################################################")
            pdarts_controller.print_alphas(logger)

            logger.info('Restricting skipconnect...')
            for sks in range(0, 9):
                max_sk = 8 - sks
                num_sk = pdarts_controller.get_skip_number()
                if not num_sk > max_sk:
                    continue
                while num_sk > max_sk:
                    pdarts_controller.delete_skip()
                    num_sk = pdarts_controller.get_skip_number()

                logger.info('Number of skip-connect: %d', max_sk)
                logger.info(pdarts_controller.genotype(final=True))
        else:
            basic_op = pdarts_controller.get_topk_op(num_to_keep[sp])

        logger.info("###############final Optimal genotype: {}############")
        logger.info(pdarts_controller.genotype(final=True))
        logger.info("########################################################")
        pdarts_controller.print_alphas(logger)

        logger.info('Restricting skipconnect...')
        for sks in range(0, 9):
            max_sk = 8 - sks
            num_sk = pdarts_controller.get_skip_number()
            if not num_sk > max_sk:
                continue
            while num_sk > max_sk:
                pdarts_controller.delete_skip()
                num_sk = pdarts_controller.get_skip_number()

            logger.info('Number of skip-connect: %d', max_sk)
            logger.info(pdarts_controller.genotype(final=True))