Ejemplo n.º 1
0
def main():
    setup_env()
    # loadiong search space
    search_space = build_space()
    search_space.cuda()
    # init controller and architect
    loss_fun = nn.CrossEntropyLoss().cuda()

    # weights optimizer
    w_optim = torch.optim.SGD(search_space.parameters(),
                              cfg.OPTIM.BASE_LR,
                              momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)
    # load dataset
    [train_, val_] = _construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                       cfg.SEARCH.BATCH_SIZE)

    # build distribution_optimizer
    if cfg.SPACE.NAME in [
            "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3"
    ]:
        category = []
        cs = search_space.search_space.get_configuration_space()
        for h in cs.get_hyperparameters():
            if type(h
                    ) == ConfigSpace.hyperparameters.CategoricalHyperparameter:
                category.append(len(h.choices))
        distribution_optimizer = sng_builder(category)
    else:
        distribution_optimizer = sng_builder([search_space.num_ops] *
                                             search_space.all_edges)

    lr_scheduler = lr_scheduler_builder(w_optim)
    # training loop
    logger.info("start warm up training")
    warm_train_meter = meters.TrainMeter(len(train_))
    warm_val_meter = meters.TestMeter(len(val_))
    start_time = time.time()
    _over_all_epoch = 0
    for epoch in range(cfg.OPTIM.WARMUP_EPOCHS):
        # lr_scheduler.step()
        lr = lr_scheduler.get_last_lr()[0]
        # warm up training
        if cfg.SNG.WARMUP_RANDOM_SAMPLE:
            sample = random_sampling(search_space,
                                     distribution_optimizer,
                                     epoch=epoch)
            logger.info("The sample is: {}".format(one_hot_to_index(sample)))
            train(train_, val_, search_space, w_optim, lr, _over_all_epoch,
                  sample, loss_fun, warm_train_meter)
            top1 = test_epoch(val_, search_space, warm_val_meter,
                              _over_all_epoch, sample, writer)
            _over_all_epoch += 1
        else:
            num_ops, total_edges = search_space.num_ops, search_space.all_edges
            array_sample = [
                random.sample(list(range(num_ops)), num_ops)
                for i in range(total_edges)
            ]
            array_sample = np.array(array_sample)
            for i in range(num_ops):
                sample = np.transpose(array_sample[:, i])
                sample = index_to_one_hot(sample,
                                          distribution_optimizer.p_model.Cmax)
                train(train_, val_, search_space, w_optim, lr, _over_all_epoch,
                      sample, loss_fun, warm_train_meter)
                top1 = test_epoch(val_, search_space, warm_val_meter,
                                  _over_all_epoch, sample, writer)
                _over_all_epoch += 1
        # new version of warmup epoch
    logger.info("end warm up training")
    logger.info("start One shot searching")
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    for epoch in range(cfg.OPTIM.MAX_EPOCH):
        if hasattr(distribution_optimizer, 'training_finish'):
            if distribution_optimizer.training_finish:
                break
        lr = w_optim.param_groups[0]['lr']
        # sample by the distribution optimizer
        # _ = distribution_optimizer.sampling()
        # random sample
        sample = random_sampling(search_space,
                                 distribution_optimizer,
                                 epoch=epoch,
                                 _random=cfg.SNG.RANDOM_SAMPLE)
        logger.info("The sample is: {}".format(one_hot_to_index(sample)))

        # training
        train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample,
              loss_fun, train_meter)

        # validation
        top1 = test_epoch(val_, search_space, val_meter, _over_all_epoch,
                          sample, writer)
        _over_all_epoch += 1
        lr_scheduler.step()
        distribution_optimizer.record_information(sample, top1)
        distribution_optimizer.update()

        # Evaluate the model
        next_epoch = epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            logger.info(
                "###############Optimal genotype at epoch: {}############".
                format(epoch))
            logger.info(
                search_space.genotype(distribution_optimizer.p_model.theta))
            logger.info(
                "########################################################")
            logger.info("####### ALPHA #######")
            for alpha in distribution_optimizer.p_model.theta:
                logger.info(alpha)
            logger.info("#####################")
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
    end_time = time.time()
    lr = w_optim.param_groups[0]['lr']
    for epoch in range(cfg.OPTIM.FINAL_EPOCH):
        if cfg.SPACE.NAME == 'darts':
            genotype = search_space.genotype(
                distribution_optimizer.p_model.theta)
            sample = search_space.genotype_to_onehot_sample(genotype)
        else:
            sample = distribution_optimizer.sampling_best()
        _over_all_epoch += 1
        train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample,
              loss_fun, train_meter)
        test_epoch(val_, search_space, val_meter, _over_all_epoch, sample,
                   writer)
    logger.info("Overall training time (hr) is:{}".format(
        str((end_time - start_time) / 3600.)))

    # whether to evaluate through nasbench ;
    if cfg.SPACE.NAME in [
            "nasbench201", "nasbench1shot1_1", "nasbench1shot1_2",
            "nasbench1shot1_3"
    ]:
        logger.info("starting test using nasbench:{}".format(cfg.SPACE.NAME))
        theta = distribution_optimizer.p_model.theta
        EvaluateNasbench(theta, search_space, logger, cfg.SPACE.NAME)
Ejemplo n.º 2
0
def main():
    setup_env()
    # loadiong search space
    search_space = build_space()
    # init controller and architect
    loss_fun = nn.CrossEntropyLoss().cuda()
    darts_controller = DartsCNNController(search_space, loss_fun)
    darts_controller.cuda()
    architect = Architect(darts_controller, cfg.OPTIM.MOMENTUM,
                          cfg.OPTIM.WEIGHT_DECAY)
    # load dataset
    [train_, val_] = _construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                       cfg.SEARCH.BATCH_SIZE)
    # weights optimizer
    w_optim = torch.optim.SGD(darts_controller.weights(),
                              cfg.OPTIM.BASE_LR,
                              momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(darts_controller.alphas(),
                                   cfg.DARTS.ALPHA_LR,
                                   betas=(0.5, 0.999),
                                   weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, cfg.OPTIM.MAX_EPOCH, eta_min=cfg.OPTIM.MIN_LR)
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    start_epoch = 0
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        lr = lr_scheduler.get_last_lr()[0]
        train_epoch(train_, val_, darts_controller, architect, loss_fun,
                    w_optim, alpha_optim, lr, train_meter, cur_epoch)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                darts_controller, w_optim, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        lr_scheduler.step()
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            test_epoch(val_,
                       darts_controller,
                       val_meter,
                       cur_epoch,
                       tensorboard_writer=writer)
            logger.info(
                "###############Optimal genotype at epoch: {}############".
                format(cur_epoch))
            logger.info(darts_controller.genotype())
            logger.info(
                "########################################################")
            darts_controller.print_alphas(logger)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
Ejemplo n.º 3
0
def main():
    setup_env()
    # loadiong search space
    search_space = build_space()
    search_space.cuda()
    # init controller and architect
    loss_fun = nn.CrossEntropyLoss().cuda()

    # weights optimizer
    w_optim = torch.optim.SGD(search_space.parameters(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)
    # load dataset
    [train_, val_] = _construct_loader(
        cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE)

    distribution_optimizer = sng_builder([search_space.num_ops]*search_space.all_edges)
    lr_scheduler = lr_scheduler_builder(w_optim)
    num_ops, total_edges = search_space.num_ops, search_space.all_edges
    # training loop
    logger.info("start warm up training")
    warm_train_meter = meters.TrainMeter(len(train_))
    for epoch in range(cfg.OPTIM.WARMUP_EPOCHS):
        # lr_scheduler.step()
        lr = lr_scheduler.get_last_lr()[0]
        # warm up training
        array_sample = [random.sample(list(range(num_ops)), num_ops) for i in range(total_edges)]
        array_sample = np.array(array_sample)
        for i in range(num_ops):
            sample = np.transpose(array_sample[:, i])
            train(train_, val_, search_space, w_optim, lr, epoch, sample, loss_fun, warm_train_meter)
    logger.info("end warm up training")
    logger.info("start One shot searching")
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    for epoch in range(cfg.OPTIM.MAX_EPOCH):
        if hasattr(distribution_optimizer, 'training_finish'):
            if distribution_optimizer.training_finish:
                break
        lr = w_optim.param_groups[0]['lr']
        sample = distribution_optimizer.sampling()

        # training
        train(train_, val_, search_space, w_optim, lr, epoch, sample, loss_fun, train_meter)

        # validation
        top1 = test_epoch(val_, search_space, val_meter, epoch, sample, writer)
        lr_scheduler.step()
        distribution_optimizer.record_information(sample, top1)
        distribution_optimizer.update()

        # Evaluate the model
        next_epoch = epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            logger.info("###############Optimal genotype at epoch: {}############".format(epoch))
            logger.info(search_space.genotype(distribution_optimizer.p_model.theta))
            logger.info("########################################################")
            logger.info("####### ALPHA #######")
            for alpha in distribution_optimizer.p_model.theta:
                logger.info(alpha)
            logger.info("#####################")
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache()  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
Ejemplo n.º 4
0
def main():
    setup_env()
    # loadiong search space
    # init controller and architect
    loss_fun = nn.CrossEntropyLoss().cuda()
    # load dataset
    [train_, val_] = _construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                       cfg.SEARCH.BATCH_SIZE)

    num_to_keep = [5, 3, 1]
    eps_no_archs = [10, 10, 10]
    drop_rate = [0.1, 0.4, 0.7]
    add_layers = [0, 6, 12]
    add_width = [0, 0, 0]
    PRIMITIVES = cfg.SPACE.PRIMITIVES
    edgs_num = (cfg.SPACE.NODES + 3) * cfg.SPACE.NODES // 2
    basic_op = []
    for i in range(edgs_num * 2):
        basic_op.append(PRIMITIVES)
    for sp in range(len(num_to_keep)):
        # update the info of the supernet config
        cfg.defrost()
        cfg.SEARCH.add_layers = add_layers[sp]
        cfg.SEARCH.add_width = add_width[sp]
        cfg.SEARCH.dropout_rate = float(drop_rate[sp])
        cfg.SPACE.BASIC_OP = basic_op

        search_space = build_space()
        controller = PdartsCNNController(search_space, loss_fun)
        controller.cuda()
        architect = Architect(controller, cfg.OPTIM.MOMENTUM,
                              cfg.OPTIM.WEIGHT_DECAY)
        # weights optimizer
        w_optim = torch.optim.SGD(controller.subnet_weights(),
                                  cfg.OPTIM.BASE_LR,
                                  momentum=cfg.OPTIM.MOMENTUM,
                                  weight_decay=cfg.OPTIM.WEIGHT_DECAY)
        # alphas optimizer
        alpha_optim = torch.optim.Adam(
            controller.alphas(),
            cfg.DARTS.ALPHA_LR,
            betas=(0.5, 0.999),
            weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            w_optim, cfg.OPTIM.MAX_EPOCH, eta_min=cfg.OPTIM.MIN_LR)
        train_meter = meters.TrainMeter(len(train_))
        val_meter = meters.TestMeter(len(val_))
        start_epoch = 0
        # Perform the training loop
        logger.info("Start epoch: {}".format(start_epoch + 1))
        scale_factor = 0.2
        for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
            print('cur_epoch', cur_epoch)
            lr = lr_scheduler.get_last_lr()[0]
            if cur_epoch < eps_no_archs[sp]:
                controller.update_p(
                    float(drop_rate[sp]) *
                    (cfg.OPTIM.MAX_EPOCH - cur_epoch - 1) /
                    cfg.OPTIM.MAX_EPOCH)
                train_epoch(train_,
                            val_,
                            controller,
                            architect,
                            loss_fun,
                            w_optim,
                            alpha_optim,
                            lr,
                            train_meter,
                            cur_epoch,
                            train_arch=False)
            else:
                controller.update_p(
                    float(drop_rate[sp]) *
                    np.exp(-(cur_epoch - eps_no_archs[sp]) * scale_factor))
                train_epoch(train_,
                            val_,
                            controller,
                            architect,
                            loss_fun,
                            w_optim,
                            alpha_optim,
                            lr,
                            train_meter,
                            cur_epoch,
                            train_arch=True)
            # Save a checkpoint
            if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
                checkpoint_file = checkpoint.save_checkpoint(
                    controller, w_optim, cur_epoch)
                logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
            lr_scheduler.step()
            # Evaluate the model
            next_epoch = cur_epoch + 1
            if next_epoch >= cfg.OPTIM.MAX_EPOCH - 5:
                logger.info("Start testing")
                test_epoch(val_,
                           controller,
                           val_meter,
                           cur_epoch,
                           tensorboard_writer=writer)
                logger.info(
                    "###############Optimal genotype at epoch: {}############".
                    format(cur_epoch))
                logger.info(controller.genotype())
                logger.info(
                    "########################################################")
                controller.print_alphas(logger)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache(
                )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
            gc.collect()
            print("now top k primitive", num_to_keep[sp],
                  controller.get_topk_op(num_to_keep[sp]))

        if sp == len(num_to_keep) - 1:
            logger.info(
                "###############final Optimal genotype: {}############")
            logger.info(controller.genotype(final=True))
            logger.info(
                "########################################################")
            controller.print_alphas(logger)

            logger.info('Restricting skipconnect...')
            for sks in range(0, 9):
                max_sk = 8 - sks
                num_sk = controller.get_skip_number()
                if not num_sk > max_sk:
                    continue
                while num_sk > max_sk:
                    controller.delete_skip()
                    num_sk = controller.get_skip_number()

                logger.info('Number of skip-connect: %d', max_sk)
                logger.info(controller.genotype(final=True))
        else:
            basic_op = controller.get_topk_op(num_to_keep[sp])
        logger.info("###############final Optimal genotype: {}############")
        logger.info(controller.genotype(final=True))
        logger.info("########################################################")
        controller.print_alphas(logger)

        logger.info('Restricting skipconnect...')
        for sks in range(0, 9):
            max_sk = 8 - sks
            num_sk = controller.get_skip_number()
            if not num_sk > max_sk:
                continue
            while num_sk > max_sk:
                controller.delete_skip()
                num_sk = controller.get_skip_number()

            logger.info('Number of skip-connect: %d', max_sk)
            logger.info(controller.genotype(final=True))