Esempio n. 1
0
 def sampling_best(self):
     if not self.dynamic_sampling:
         return one_hot_to_index(np.array(self.sampling()))
     sample = []
     for i in range(self.p_model.d):
         sample.append(np.argmax(self.p_model.theta[i]))
     sample = np.array(sample)
     return index_to_one_hot(sample, self.p_model.Cmax)
Esempio n. 2
0
    def sampling(self):
        if not self.dynamic_sampling:
            rand = np.random.rand(self.p_model.d,
                                  1)  # range of random number is [0, 1)
            cum_theta = self.p_model.theta.cumsum(axis=1)  # (d, Cmax)

            # x[i, j] becomes 1 if cum_theta[i, j] - theta[i, j] <= rand[i] < cum_theta[i, j]
            c = (cum_theta - self.p_model.theta <= rand) & (rand < cum_theta)
            return c
        if self.sampling_number_per_edge == 1:
            return index_to_one_hot(self.sampling_index(), self.p_model.Cmax)
        else:
            sample = []
            sample_one_hot_like = np.zeros([self.p_model.d, self.p_model.Cmax])
            for i in range(self.p_model.d):
                # get the prob
                if self.sample_with_prob:
                    prob = copy.deepcopy(
                        self.p_model.theta[i, self.sample_index[i]])
                    prob = prob / prob.sum()
                    sample.append(
                        np.random.choice(self.sample_index[i],
                                         size=self.sampling_number_per_edge,
                                         p=prob,
                                         replace=False))
                else:
                    sample.append(
                        np.random.choice(self.sample_index[i],
                                         size=self.sampling_number_per_edge,
                                         replace=False))
                if len(self.sample_index[i]) > 0:
                    for j in sample[i]:
                        self.sample_index[i].remove(int(j))
                for j in range(self.sampling_number_per_edge):
                    sample_one_hot_like[i, int(sample[i][j])] = 1
        return sample_one_hot_like
Esempio n. 3
0
def random_sampling(search_space,
                    distribution_optimizer,
                    epoch=-1000,
                    _random=False):

    if _random:
        num_ops, total_edges = search_space.num_ops, search_space.all_edges
        # edge importance
        non_edge_idx = []
        if cfg.SNG.EDGE_SAMPLING and epoch > cfg.SNG.EDGE_SAMPLING_EPOCH:

            assert cfg.SPACE.NAME == 'darts', "only support darts for now!"
            norm_indexes = search_space.norm_node_index
            non_edge_idx = []
            for node in norm_indexes:
                edge_non_prob = distribution_optimizer.p_model.theta[
                    np.array(node), 7]
                edge_non_prob = edge_non_prob / np.sum(edge_non_prob)
                if len(node) == 2:
                    pass
                else:
                    non_edge_sampling_num = len(node) - 2
                    non_edge_idx += list(
                        np.random.choice(node,
                                         non_edge_sampling_num,
                                         p=edge_non_prob,
                                         replace=False))
        if random.random() < cfg.SNG.BIGMODEL_SAMPLE_PROB:

            # sample the network with high complexity
            _num = 100
            while _num > cfg.SNG.BIGMODEL_NON_PARA:

                _error = False
                if cfg.SNG.PROB_SAMPLING:
                    sample = np.array([
                        np.random.choice(
                            num_ops,
                            1,
                            p=distribution_optimizer.p_model.theta[i, :])[0]
                        for i in range(total_edges)
                    ])
                else:
                    sample = np.array([
                        np.random.choice(num_ops, 1)[0]
                        for i in range(total_edges)
                    ])
                _num = 0
                for i in sample[0:search_space.num_edges]:
                    if i in non_edge_idx:
                        pass
                    elif i in search_space.non_op_idx:
                        if i == 7:
                            _error = True
                        _num = _num + 1
                if _error:
                    _num = 100
        else:
            if cfg.SNG.PROB_SAMPLING:
                sample = np.array([
                    np.random.choice(
                        num_ops,
                        1,
                        p=distribution_optimizer.p_model.theta[i, :])[0]
                    for i in range(total_edges)
                ])
            else:
                sample = np.array([
                    np.random.choice(num_ops, 1)[0] for i in range(total_edges)
                ])
        if cfg.SNG.EDGE_SAMPLING and epoch > cfg.SNG.EDGE_SAMPLING_EPOCH:
            for i in non_edge_idx:
                sample[i] = 7
        sample = index_to_one_hot(sample, distribution_optimizer.p_model.Cmax)
        # in the pruning method we have to sampling anyway
        distribution_optimizer.sampling()
        return sample
    else:
        return distribution_optimizer.sampling()
Esempio n. 4
0
def main():
    setup_env()
    # loadiong search space
    search_space = build_space()
    search_space.cuda()
    # init controller and architect
    loss_fun = nn.CrossEntropyLoss().cuda()

    # weights optimizer
    w_optim = torch.optim.SGD(search_space.parameters(),
                              cfg.OPTIM.BASE_LR,
                              momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)
    # load dataset
    [train_, val_] = _construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                       cfg.SEARCH.BATCH_SIZE)

    # build distribution_optimizer
    if cfg.SPACE.NAME in [
            "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3"
    ]:
        category = []
        cs = search_space.search_space.get_configuration_space()
        for h in cs.get_hyperparameters():
            if type(h
                    ) == ConfigSpace.hyperparameters.CategoricalHyperparameter:
                category.append(len(h.choices))
        distribution_optimizer = sng_builder(category)
    else:
        distribution_optimizer = sng_builder([search_space.num_ops] *
                                             search_space.all_edges)

    lr_scheduler = lr_scheduler_builder(w_optim)
    # training loop
    logger.info("start warm up training")
    warm_train_meter = meters.TrainMeter(len(train_))
    warm_val_meter = meters.TestMeter(len(val_))
    start_time = time.time()
    _over_all_epoch = 0
    for epoch in range(cfg.OPTIM.WARMUP_EPOCHS):
        # lr_scheduler.step()
        lr = lr_scheduler.get_last_lr()[0]
        # warm up training
        if cfg.SNG.WARMUP_RANDOM_SAMPLE:
            sample = random_sampling(search_space,
                                     distribution_optimizer,
                                     epoch=epoch)
            logger.info("The sample is: {}".format(one_hot_to_index(sample)))
            train(train_, val_, search_space, w_optim, lr, _over_all_epoch,
                  sample, loss_fun, warm_train_meter)
            top1 = test_epoch(val_, search_space, warm_val_meter,
                              _over_all_epoch, sample, writer)
            _over_all_epoch += 1
        else:
            num_ops, total_edges = search_space.num_ops, search_space.all_edges
            array_sample = [
                random.sample(list(range(num_ops)), num_ops)
                for i in range(total_edges)
            ]
            array_sample = np.array(array_sample)
            for i in range(num_ops):
                sample = np.transpose(array_sample[:, i])
                sample = index_to_one_hot(sample,
                                          distribution_optimizer.p_model.Cmax)
                train(train_, val_, search_space, w_optim, lr, _over_all_epoch,
                      sample, loss_fun, warm_train_meter)
                top1 = test_epoch(val_, search_space, warm_val_meter,
                                  _over_all_epoch, sample, writer)
                _over_all_epoch += 1
        # new version of warmup epoch
    logger.info("end warm up training")
    logger.info("start One shot searching")
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    for epoch in range(cfg.OPTIM.MAX_EPOCH):
        if hasattr(distribution_optimizer, 'training_finish'):
            if distribution_optimizer.training_finish:
                break
        lr = w_optim.param_groups[0]['lr']
        # sample by the distribution optimizer
        # _ = distribution_optimizer.sampling()
        # random sample
        sample = random_sampling(search_space,
                                 distribution_optimizer,
                                 epoch=epoch,
                                 _random=cfg.SNG.RANDOM_SAMPLE)
        logger.info("The sample is: {}".format(one_hot_to_index(sample)))

        # training
        train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample,
              loss_fun, train_meter)

        # validation
        top1 = test_epoch(val_, search_space, val_meter, _over_all_epoch,
                          sample, writer)
        _over_all_epoch += 1
        lr_scheduler.step()
        distribution_optimizer.record_information(sample, top1)
        distribution_optimizer.update()

        # Evaluate the model
        next_epoch = epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            logger.info(
                "###############Optimal genotype at epoch: {}############".
                format(epoch))
            logger.info(
                search_space.genotype(distribution_optimizer.p_model.theta))
            logger.info(
                "########################################################")
            logger.info("####### ALPHA #######")
            for alpha in distribution_optimizer.p_model.theta:
                logger.info(alpha)
            logger.info("#####################")
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
    end_time = time.time()
    lr = w_optim.param_groups[0]['lr']
    for epoch in range(cfg.OPTIM.FINAL_EPOCH):
        if cfg.SPACE.NAME == 'darts':
            genotype = search_space.genotype(
                distribution_optimizer.p_model.theta)
            sample = search_space.genotype_to_onehot_sample(genotype)
        else:
            sample = distribution_optimizer.sampling_best()
        _over_all_epoch += 1
        train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample,
              loss_fun, train_meter)
        test_epoch(val_, search_space, val_meter, _over_all_epoch, sample,
                   writer)
    logger.info("Overall training time (hr) is:{}".format(
        str((end_time - start_time) / 3600.)))

    # whether to evaluate through nasbench ;
    if cfg.SPACE.NAME in [
            "nasbench201", "nasbench1shot1_1", "nasbench1shot1_2",
            "nasbench1shot1_3"
    ]:
        logger.info("starting test using nasbench:{}".format(cfg.SPACE.NAME))
        theta = distribution_optimizer.p_model.theta
        EvaluateNasbench(theta, search_space, logger, cfg.SPACE.NAME)
Esempio n. 5
0
 def sampling(self):
     # return self.sampling_index()
     self.sample = self.sampling_index()
     return index_to_one_hot(self.sample, self.p_model.Cmax)
Esempio n. 6
0
 def sampling(self):
     return index_to_one_hot(self.sampling_index(), self.p_model.Cmax)
Esempio n. 7
0
def train_model():
    """SNG search model training"""
    setup_env()
    # Load search space
    search_space = build_space()
    search_space.cuda()
    loss_fun = build_loss_fun().cuda()

    # Weights optimizer
    w_optim = torch.optim.SGD(search_space.parameters(),
                              cfg.OPTIM.BASE_LR,
                              momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)

    # Build distribution_optimizer
    if cfg.SPACE.NAME in ['darts', 'nasbench301']:
        distribution_optimizer = sng_builder([search_space.num_ops] *
                                             search_space.all_edges)
    elif cfg.SPACE.NAME in ['proxyless', 'google', 'ofa']:
        distribution_optimizer = sng_builder([search_space.num_ops] *
                                             search_space.all_edges)
    elif cfg.SPACE.NAME in [
            "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3"
    ]:
        category = []
        cs = search_space.search_space.get_configuration_space()
        for h in cs.get_hyperparameters():
            if type(h
                    ) == ConfigSpace.hyperparameters.CategoricalHyperparameter:
                category.append(len(h.choices))
        distribution_optimizer = sng_builder(category)
    else:
        raise NotImplementedError

    # Load dataset
    [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                      cfg.SEARCH.BATCH_SIZE)

    lr_scheduler = lr_scheduler_builder(w_optim)
    all_timer = Timer()
    _over_all_epoch = 0

    # ===== Warm up training =====
    logger.info("start warm up training")
    warm_train_meter = meters.TrainMeter(len(train_))
    warm_val_meter = meters.TestMeter(len(val_))
    all_timer.tic()
    for cur_epoch in range(cfg.OPTIM.WARMUP_EPOCHS):

        # Save a checkpoint
        if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                search_space, w_optim, _over_all_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))

        lr = lr_scheduler.get_last_lr()[0]
        if cfg.SNG.WARMUP_RANDOM_SAMPLE:
            sample = random_sampling(search_space,
                                     distribution_optimizer,
                                     epoch=cur_epoch)
            logger.info("Sampling: {}".format(one_hot_to_index(sample)))
            train_epoch(train_, val_, search_space, w_optim, lr,
                        _over_all_epoch, sample, loss_fun, warm_train_meter)
            top1 = test_epoch_with_sample(val_, search_space, warm_val_meter,
                                          _over_all_epoch, sample, writer)
            _over_all_epoch += 1
        else:
            num_ops, total_edges = search_space.num_ops, search_space.all_edges
            array_sample = [
                random.sample(list(range(num_ops)), num_ops)
                for i in range(total_edges)
            ]
            array_sample = np.array(array_sample)
            for i in range(num_ops):
                sample = np.transpose(array_sample[:, i])
                sample = index_to_one_hot(sample,
                                          distribution_optimizer.p_model.Cmax)
                train_epoch(train_, val_, search_space, w_optim, lr,
                            _over_all_epoch, sample, loss_fun,
                            warm_train_meter)
                top1 = test_epoch_with_sample(val_, search_space,
                                              warm_val_meter, _over_all_epoch,
                                              sample, writer)
                _over_all_epoch += 1
    all_timer.toc()
    logger.info("end warm up training")

    # ===== Training procedure =====
    logger.info("start one-shot training")
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    all_timer.tic()
    for cur_epoch in range(cfg.OPTIM.MAX_EPOCH):

        # Save a checkpoint
        if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                search_space, w_optim, _over_all_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))

        if hasattr(distribution_optimizer, 'training_finish'):
            if distribution_optimizer.training_finish:
                break
        lr = w_optim.param_groups[0]['lr']

        sample = random_sampling(search_space,
                                 distribution_optimizer,
                                 epoch=cur_epoch,
                                 _random=cfg.SNG.RANDOM_SAMPLE)
        logger.info("Sampling: {}".format(one_hot_to_index(sample)))
        train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch,
                    sample, loss_fun, train_meter)
        top1 = test_epoch_with_sample(val_, search_space, val_meter,
                                      _over_all_epoch, sample, writer)
        _over_all_epoch += 1

        lr_scheduler.step()
        distribution_optimizer.record_information(sample, top1)
        distribution_optimizer.update()

        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            logger.info(
                "###############Optimal genotype at epoch: {}############".
                format(cur_epoch))
            logger.info(
                search_space.genotype(distribution_optimizer.p_model.theta))
            logger.info(
                "########################################################")
            logger.info("####### ALPHA #######")
            for alpha in distribution_optimizer.p_model.theta:
                logger.info(alpha)
            logger.info("#####################")
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
    all_timer.toc()

    # ===== Final epoch =====
    lr = w_optim.param_groups[0]['lr']
    all_timer.tic()
    for cur_epoch in range(cfg.OPTIM.FINAL_EPOCH):

        # Save a checkpoint
        if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                search_space, w_optim, _over_all_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))

        if cfg.SPACE.NAME in ['darts', 'nasbench301']:
            genotype = search_space.genotype(
                distribution_optimizer.p_model.theta)
            sample = search_space.genotype_to_onehot_sample(genotype)
        else:
            sample = distribution_optimizer.sampling_best()
        _over_all_epoch += 1
        train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch,
                    sample, loss_fun, train_meter)
        test_epoch_with_sample(val_, search_space, val_meter, _over_all_epoch,
                               sample, writer)
    logger.info("Overall training time : {} hours".format(
        str((all_timer.total_time) / 3600.)))

    # Evaluate through nasbench
    if cfg.SPACE.NAME in [
            "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3",
            "nasbench201", "nasbench301"
    ]:
        logger.info("starting test using nasbench:{}".format(cfg.SPACE.NAME))
        theta = distribution_optimizer.p_model.theta
        EvaluateNasbench(theta, search_space, logger, cfg.SPACE.NAME)