Esempio n. 1
0
    def __init__(self,
                 dataset,
                 weights_manager,
                 objective,
                 rollout_type="compare",
                 guide_evaluator_type="mepa",
                 guide_evaluator_cfg=None,
                 arch_network_type="pointwise_comparator",
                 arch_network_cfg=None,
                 guide_batch_size=16,
                 schedule_cfg=None):
        super(ArchNetworkEvaluator,
              self).__init__(dataset, weights_manager, objective, rollout_type,
                             schedule_cfg)

        # construct the evaluator that will be used to guide the learning of the predictor
        ge_cls = BaseEvaluator.get_class_(guide_evaluator_type)
        self.guide_evaluator = ge_cls(dataset,
                                      weights_manager,
                                      objective,
                                      rollout_type=rollout_type,
                                      **(guide_evaluator_cfg or {}))

        # construct the architecture network
        an_cls = ArchNetwork.get_class_(arch_network_type)
        self.arch_network = an_cls(
            search_space=self.weights_manager.search_space,
            **(arch_network_cfg or {}))

        # configurations
        self.guide_batch_size = guide_batch_size
Esempio n. 2
0
    def __init__(self,
                 search_space,
                 rollout_type="compare",
                 arch_network_type="pointwise_comparator",
                 arch_network_cfg=None,
                 schedule_cfg=None):
        super(BatchUpdateArchNetworkEvaluator,
              self).__init__(dataset=None,
                             weights_manager=None,
                             objective=None,
                             rollout_type=rollout_type,
                             schedule_cfg=schedule_cfg)

        # construct the architecture network
        an_cls = ArchNetwork.get_class_(arch_network_type)
        self.arch_network = an_cls(
            search_space=self.weights_manager.search_space,
            **(arch_network_cfg or {}))
Esempio n. 3
0
def main(argv):
    parser = argparse.ArgumentParser(prog="train_cellss_pkl.py")
    parser.add_argument("cfg_file")
    parser.add_argument("--gpu", type=int, default=0, help="gpu device id")
    parser.add_argument("--num-workers", default=4, type=int)
    parser.add_argument("--report-freq", default=200, type=int)
    parser.add_argument("--seed", default=None, type=int)
    parser.add_argument("--train-dir",
                        default=None,
                        help="Save train log/results into TRAIN_DIR")
    parser.add_argument("--save-every", default=None, type=int)
    parser.add_argument("--test-only", default=False, action="store_true")
    parser.add_argument("--test-funcs",
                        default=None,
                        help="comma-separated list of test funcs")
    parser.add_argument("--load",
                        default=None,
                        help="Load comparator from disk.")
    parser.add_argument("--sample", default=None, type=int)
    parser.add_argument("--sample-batchify-inner-sample-n",
                        default=None,
                        type=int)
    parser.add_argument("--sample-to-file", default=None, type=str)
    parser.add_argument("--sample-from-file", default=None, type=str)
    parser.add_argument("--sample-conflict-file",
                        default=None,
                        type=str,
                        action="append")
    parser.add_argument("--sample-ratio", default=10, type=float)
    parser.add_argument("--sample-output-dir", default="./sample_output/")
    # parser.add_argument("--data-fname", default="cellss_data.pkl")
    # parser.add_argument("--data-fname", default="cellss_data_round1_999.pkl")
    parser.add_argument("--data-fname", default="enas_data_round1_980.pkl")
    parser.add_argument("--addi-train",
                        default=[],
                        action="append",
                        help="additional train data")
    parser.add_argument("--addi-train-only",
                        action="store_true",
                        default=False)
    parser.add_argument("--addi-valid",
                        default=[],
                        action="append",
                        help="additional valid data")
    parser.add_argument("--addi-valid-only",
                        action="store_true",
                        default=False)
    parser.add_argument("--valid-true-split", default=None)
    parser.add_argument("--valid-score-split", default=None)
    parser.add_argument("--enas-ss", default=True, action="store_true")
    args = parser.parse_args(argv)

    setproctitle.setproctitle("python train_cellss_pkl.py config: {}; train_dir: {}; cwd: {}"\
                              .format(args.cfg_file, args.train_dir, os.getcwd()))

    # log
    log_format = "%(asctime)s %(message)s"
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt="%m/%d %I:%M:%S %p")

    if not args.test_only:
        assert args.train_dir is not None, "Must specificy `--train-dir` when training"
        # if training, setting up log file, backup config file
        if not os.path.exists(args.train_dir):
            os.makedirs(args.train_dir)
        log_file = os.path.join(args.train_dir, "train.log")
        logging.getLogger().addFile(log_file)

        # copy config file
        backup_cfg_file = os.path.join(args.train_dir, "config.yaml")
        shutil.copyfile(args.cfg_file, backup_cfg_file)
    else:
        backup_cfg_file = args.cfg_file

    # cuda
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu)
        cudnn.benchmark = True
        cudnn.enabled = True
        logging.info("GPU device = %d" % args.gpu)
    else:
        logging.info("no GPU available, use CPU!!")

    if args.seed is not None:
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info("Load pkl cache from cellss_data.pkl")
    data_fname = args.data_fname
    with open(data_fname, "rb") as rf:
        data = pickle.load(rf)
    with open(backup_cfg_file, "r") as cfg_f:
        cfg = yaml.load(cfg_f)

    logging.info("Config: %s", cfg)

    arch_network_type = cfg.get("arch_network_type", "pointwise_comparator")
    model_cls = ArchNetwork.get_class_(arch_network_type)
    # search space
    if args.enas_ss:
        ss_cfg_str = """
search_space_type: cnn
search_space_cfg:
  cell_layout: null
  num_cell_groups: 2
  num_init_nodes: 2
  num_layers: 8
  num_node_inputs: 2
  num_steps: 4
  reduce_cell_groups:
  - 1
  shared_primitives:
  - skip_connect
  - sep_conv_3x3
  - sep_conv_5x5
  - avg_pool_3x3
  - max_pool_3x3
        """
    else:
        ss_cfg_str = """
search_space_cfg:
  cell_layout: null
  num_cell_groups: 2
  num_init_nodes: 2
  num_layers: 8
  num_node_inputs: 2
  num_steps: 4
  reduce_cell_groups:
  - 1
  shared_primitives:
  - none
  - max_pool_3x3
  - avg_pool_3x3
  - skip_connect
  - sep_conv_3x3
  - sep_conv_5x5
  - dil_conv_3x3
  - dil_conv_5x5
search_space_type: cnn
        """

    ss_cfg = yaml.load(StringIO(ss_cfg_str))
    search_space = get_search_space(ss_cfg["search_space_type"],
                                    **ss_cfg["search_space_cfg"])
    model = model_cls(search_space, **cfg.pop("arch_network_cfg"))
    if args.load is not None:
        logging.info("Load %s from %s", arch_network_type, args.load)
        model.load(args.load)
    model.to(device)

    args.__dict__.update(cfg)
    logging.info("Combined args: %s", args)

    # init data loaders
    if hasattr(args, "train_size"):
        train_valid_split = args.train_size
    else:
        train_valid_split = int(
            getattr(args, "train_valid_split", 0.6) * len(data))
    train_data = data[:train_valid_split]
    valid_data = data[train_valid_split:]
    if hasattr(args, "train_ratio") and args.train_ratio is not None:
        _num = len(train_data)
        train_data = train_data[:int(_num * args.train_ratio)]
        logging.info("Train dataset ratio: %.3f", args.train_ratio)
    if args.addi_train:
        if args.addi_train_only:
            train_data = []
        for addi_train_fname in args.addi_train:
            with open(addi_train_fname, "rb") as rf:
                addi_train_data = pickle.load(rf)
            train_data += addi_train_data
    if args.addi_valid:
        if args.addi_valid_only:
            valid_data = []
        for addi_fname in args.addi_valid:
            with open(addi_fname, "rb") as rf:
                addi_valid_data = pickle.load(rf)
            valid_data += addi_valid_data
    num_train_archs = len(train_data)
    logging.info("Number of architectures: train: %d; valid: %d",
                 num_train_archs, len(valid_data))
    train_data = CellSSDataset(train_data,
                               minus=cfg.get("dataset_minus", None),
                               div=cfg.get("dataset_div", None))
    valid_data = CellSSDataset(valid_data,
                               minus=cfg.get("dataset_minus", None),
                               div=cfg.get("dataset_div", None))
    train_loader = DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers,
        collate_fn=lambda items: list([np.array(x) for x in zip(*items)]))
    val_loader = DataLoader(
        valid_data,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=args.num_workers,
        collate_fn=lambda items: list([np.array(x) for x in zip(*items)]))

    if args.test_funcs is not None:
        test_func_names = args.test_funcs.split(",")
        test_funcs = [globals()[func_name] for func_name in test_func_names]
    else:
        test_funcs = []

    # init test
    if not arch_network_type == "pairwise_comparator" or args.test_only:
        corr, func_res = valid(val_loader, model, args, funcs=test_funcs)

        if args.sample is not None:
            if args.sample_from_file:
                logging.info("Read genotypes from: {}".format(
                    args.sample_from_file))
                with open(args.sample_from_file, "r") as rf:
                    from_genotypes = yaml.load(rf)
                assert len(from_genotypes) == args.sample * int(
                    args.sample_ratio)
            else:
                from_genotypes = None
            if args.sample_conflict_file:
                conflict_archs = []
                for scf in args.sample_conflict_file:
                    conflict_archs += pickle.load(open(scf, "rb"))
            else:
                conflict_archs = None
            if args.sample_batchify_inner_sample_n is not None:
                # do not support multi-stage now
                genotypes = sample_batchify(search_space,
                                            model,
                                            args.sample_ratio,
                                            args.sample,
                                            args,
                                            conflict_archs=conflict_archs)
            else:
                genotypes = sample(search_space,
                                   model,
                                   args.sample * int(args.sample_ratio),
                                   args.sample,
                                   args,
                                   from_genotypes=from_genotypes,
                                   conflict_archs=conflict_archs)
            if args.sample_to_file:
                with open(args.sample_to_file, "w") as wf:
                    yaml.dump([str(geno) for geno in genotypes], wf)
            else:
                with open("./final_template.yaml", "r") as rf:
                    logging.info("Load final template config file!")
                    template_cfg = yaml.load(rf)
                for i, genotype in enumerate(genotypes):
                    sample_cfg = copy.deepcopy(template_cfg)
                    sample_cfg["final_model_cfg"]["genotypes"] = str(genotype)
                    if not os.path.exists(args.sample_output_dir):
                        os.makedirs(args.sample_output_dir)
                    with open(
                            os.path.join(args.sample_output_dir,
                                         "{}.yaml".format(i)), "w") as wf:
                        yaml.dump(sample_cfg, wf)

        if args.test_funcs is None:
            logging.info("INIT: kendall tau {:.4f}".format(corr))
        else:
            logging.info("INIT: kendall tau {:.4f};\n\t{}".format(
                corr, "\n\t".join([
                    "{}: {}".format(name, _get_float_format(res, "{:.4f}"))
                    for name, res in zip(test_func_names, func_res)
                ])))

    if args.test_only:
        return

    _multi_stage = getattr(args, "multi_stage", False)
    _multi_stage_pair_pool = getattr(args, "multi_stage_pair_pool", False)
    if _multi_stage:
        all_perfs = np.array([item[1] for item in train_data.data])
        all_inds = np.arange(all_perfs.shape[0])

        stage_epochs = getattr(args, "stage_epochs", [0, 1, 2, 3])
        num_stages = len(stage_epochs)

        default_stage_nums = [all_perfs.shape[0] // num_stages] * (num_stages - 1) + \
            [all_perfs.shape[0] - all_perfs.shape[0] // num_stages * (num_stages - 1)]
        stage_nums = getattr(args, "stage_nums", default_stage_nums)

        assert np.sum(stage_nums) == all_perfs.shape[0]
        logging.info("Stage nums: {}".format(stage_nums))

        stage_inds_lst = []
        for i_stage in range(num_stages):
            max_stage_ = np.max(all_perfs[all_inds, stage_epochs[i_stage]])
            min_stage_ = np.min(all_perfs[all_inds, stage_epochs[i_stage]])
            logging.info(
                "Stage {}, epoch {}: min {:.2f} %; max {:.2f}% (range {:.2f} %)"
                .format(i_stage, stage_epochs[i_stage], min_stage_ * 100,
                        max_stage_ * 100, (max_stage_ - min_stage_) * 100))
            sorted_inds = np.argsort(all_perfs[all_inds,
                                               stage_epochs[i_stage]])
            stage_inds, all_inds = all_inds[sorted_inds[:stage_nums[i_stage]]],\
                                            all_inds[sorted_inds[stage_nums[i_stage]:]]
            stage_inds_lst.append(stage_inds)
        train_stages = [[train_data.data[ind] for ind in _stage_inds]
                        for _stage_inds in stage_inds_lst]
        avg_score_stages = []
        for i_stage in range(num_stages - 1):
            avg_score_stages.append(
                (all_perfs[stage_inds_lst[i_stage],
                           stage_epochs[i_stage]].sum(),
                 np.sum([
                     all_perfs[stage_inds_lst[j_stage],
                               stage_epochs[i_stage]].sum()
                     for j_stage in range(i_stage + 1, num_stages)
                 ])))
        if _multi_stage_pair_pool:
            all_stages, pairs_list = make_pair_pool(train_stages, args,
                                                    stage_epochs)

        total_eval_time = all_perfs.shape[0] * all_perfs.shape[1]
        multi_stage_eval_time = sum([
            (stage_epochs[i_stage] + 1) * len(_stage_inds)
            for i_stage, _stage_inds in enumerate(stage_inds_lst)
        ])
        logging.info("Percentage of evaluation time: {:.2f} %".format(
            float(multi_stage_eval_time) / total_eval_time * 100))

    for i_epoch in range(1, args.epochs + 1):
        model.on_epoch_start(i_epoch)
        if _multi_stage:
            if _multi_stage_pair_pool:
                avg_loss = train_multi_stage_pair_pool(all_stages, pairs_list,
                                                       model, i_epoch, args)
            else:
                if getattr(args, "use_listwise", False):
                    avg_loss = train_multi_stage_listwise(
                        train_stages, model, i_epoch, args, avg_score_stages,
                        stage_epochs)
                else:
                    avg_loss = train_multi_stage(train_stages, model, i_epoch,
                                                 args, avg_score_stages,
                                                 stage_epochs)
        else:
            avg_loss = train(train_loader, model, i_epoch, args)
        logging.info("Train: Epoch {:3d}: train loss {:.4f}".format(
            i_epoch, avg_loss))
        train_corr, train_func_res = valid(train_loader,
                                           model,
                                           args,
                                           funcs=test_funcs)
        if args.test_funcs is not None:
            for name, res in zip(test_func_names, train_func_res):
                logging.info("Train: Epoch {:3d}: {}: {}".format(
                    i_epoch, name, _get_float_format(res, "{:.4f}")))
        logging.info("Train: Epoch {:3d}: train kd {:.4f}".format(
            i_epoch, train_corr))
        corr, func_res = valid(val_loader, model, args, funcs=test_funcs)
        if args.test_funcs is not None:
            for name, res in zip(test_func_names, func_res):
                logging.info("Valid: Epoch {:3d}: {}: {}".format(
                    i_epoch, name, _get_float_format(res, "{:.4f}")))
        logging.info("Valid: Epoch {:3d}: kendall tau {:.4f}".format(
            i_epoch, corr))
        if args.save_every is not None and i_epoch % args.save_every == 0:
            save_path = os.path.join(args.train_dir, "{}.ckpt".format(i_epoch))
            model.save(save_path)
            logging.info("Epoch {:3d}: Save checkpoint to {}".format(
                i_epoch, save_path))
Esempio n. 4
0
    def __init__(
            self,
            search_space,
            device,
            rollout_type,
            mode="eval",
            inner_controller_type=None,
            inner_controller_cfg=None,
            arch_network_type="pointwise_comparator",
            arch_network_cfg=None,

            # how to use the inner controller and arch network to sample new archs
            inner_sample_n=1,
            inner_samples=1,
            inner_steps=200,
            inner_report_freq=50,
            predict_batch_size=512,
            inner_random_init=True,
            inner_iter_random_init=True,
            inner_enumerate_search_space=False,  # DEPRECATED
            inner_enumerate_sample_ratio=None,  # DEPRECATED
            min_inner_sample_ratio=10,

            # how to train the arch network
            begin_train_num=0,
            predictor_train_cfg={
                "epochs": 200,
                "num_workers": 2,
                "batch_size": 128,
                "compare": True,
                "max_compare_ratio": 4,
                "compare_threshold": 0.,
                "report_freq": 50,
                "train_valid_split": None,
                "n_cross_valid": None,
            },
            training_on_load=False,  # force retraining on load
            schedule_cfg=None):
        super(PredictorBasedController,
              self).__init__(search_space, rollout_type, mode, schedule_cfg)

        expect(inner_controller_type is not None,
               "Must specificy inner controller type", ConfigException)

        self.device = device
        self.predictor_train_cfg = predictor_train_cfg
        self.inner_controller_reinit = True
        self.inner_sample_n = inner_sample_n
        self.inner_samples = inner_samples
        self.inner_steps = inner_steps
        self.inner_report_freq = inner_report_freq
        self.inner_random_init = inner_random_init
        self.inner_iter_random_init = inner_iter_random_init
        self.inner_enumerate_search_space = inner_enumerate_search_space
        if inner_enumerate_search_space:
            warnings.warn(
                "The `inner_enumerate_search_space` option is DEPRECATED. "
                "Use inner_controller, and set `inner_samples`, `inner_steps` "
                "accordingly", warnings.DeprecationWarning)
        self.inner_enumerate_sample_ratio = inner_enumerate_sample_ratio
        self.min_inner_sample_ratio = min_inner_sample_ratio
        self.predict_batch_size = predict_batch_size
        self.begin_train_num = begin_train_num
        self.training_on_load = training_on_load

        # initialize the inner controller
        inner_controller_cfg = inner_controller_cfg or {}
        tmp_r_type = inner_controller_cfg.get("rollout_type", None)
        if tmp_r_type is not None:
            expect(
                tmp_r_type == rollout_type,
                "If specified, inner_controller's `rollout_type` must match "
                "the outer `rollout_type`", ConfigException)
            inner_controller_cfg.pop("rollout_type", None)
            inner_controller_cfg.pop("mode", None)

        self.inner_controller_type = inner_controller_type
        self.inner_controller_cfg = inner_controller_cfg
        # if not self.inner_controller_reinit:
        self.inner_controller = BaseController.get_class_(
            self.inner_controller_type)(self.search_space,
                                        self.device,
                                        rollout_type=self.rollout_type,
                                        **self.inner_controller_cfg)
        # else:
        #     self.inner_controller = None
        # Currently, we do not use controller with parameters to be optimized (e.g. RL-learned RNN)
        self.inner_cont_optimizer = None

        # initialize the predictor
        arch_network_cfg = arch_network_cfg or {}
        expect(arch_network_type == "pointwise_comparator",
               "only support pointwise_comparator arch network for now",
               ConfigException)
        model_cls = ArchNetwork.get_class_(arch_network_type)
        self.model = model_cls(self.search_space, **arch_network_cfg)
        self.model.to(self.device)

        self.gt_rollouts = []
        self.gt_arch_scores = []
        self.num_gt_rollouts = 0
        # self.train_loader = None
        # self.val_loader = None
        self.is_predictor_trained = False
Esempio n. 5
0

if __name__ == "__main__":
    import sys
    base_dir = sys.argv[1]

    def print_with_fmt(container, fmt):
        if isinstance(container, (list, tuple)):
            return ", ".join([print_with_fmt(i, fmt) for i in container])
        return fmt.format(container)

    cfg_file = os.path.join(base_dir, "config.yaml")
    with open(cfg_file, "r") as rf:
        cfg = yaml.load(rf)
    ss = get_search_space(cfg["search_space_type"], **cfg["search_space_cfg"])
    pred = ArchNetwork.get_class_(cfg["controller_cfg"]["arch_network_type"])(
        ss, **cfg["controller_cfg"]["arch_network_cfg"])

    def calc_correlation(name_list, rollouts, corr_func=kendalltau):
        perfs = [[r.perf[name] for r in rollouts] for name in name_list]
        num_perf = len(perfs)
        corr_mat = np.zeros((num_perf, num_perf))
        for i in range(num_perf):
            for j in range(num_perf):
                corr_mat[i][j] = corr_func(perfs[i], perfs[j])
        return corr_mat

    def print_corr_mat(name_list, corr_mat, print_name):
        print(print_name)
        sz = len(name_list)
        name_list = [name[:12] for name in name_list]
        fmt_str = " ".join(["{:12}"] * (sz + 1))
Esempio n. 6
0
def main(argv):
    parser = argparse.ArgumentParser(prog="train_nasbench201_pkl.py")
    parser.add_argument("cfg_file")
    parser.add_argument("--gpu", type=int, default=0, help="gpu device id")
    parser.add_argument("--num-workers", default=4, type=int)
    parser.add_argument("--report_freq", default=200, type=int)
    parser.add_argument("--seed", default=None, type=int)
    parser.add_argument("--train-dir",
                        default=None,
                        help="Save train log/results into TRAIN_DIR")
    parser.add_argument("--save-every", default=10, type=int)
    parser.add_argument("--test-only", default=False, action="store_true")
    parser.add_argument("--test-funcs",
                        default=None,
                        help="comma-separated list of test funcs")
    parser.add_argument("--load",
                        default=None,
                        help="Load comparator from disk.")
    parser.add_argument("--eval-only-last",
                        default=None,
                        type=int,
                        help=("for pairwise compartor, the evaluation is slow,"
                              " only evaluate in the final epochs"))
    parser.add_argument("--save-predict",
                        default=None,
                        help="Save the predict scores")
    parser.add_argument("--train-pkl",
                        default="nasbench201_05.pkl",
                        help="Training Datasets pickle")
    parser.add_argument("--valid-pkl",
                        default="nasbench201_05_valid.pkl",
                        help="Evaluate Datasets pickle")
    args = parser.parse_args(argv)

    setproctitle.setproctitle("python train_nasbench201_pkl.py config: {}; train_dir: {}; cwd: {}"\
                              .format(args.cfg_file, args.train_dir, os.getcwd()))

    # log
    log_format = "%(asctime)s %(message)s"
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt="%m/%d %I:%M:%S %p")

    if not args.test_only:
        assert args.train_dir is not None, "Must specificy `--train-dir` when training"
        # if training, setting up log file, backup config file
        if not os.path.exists(args.train_dir):
            os.makedirs(args.train_dir)
        log_file = os.path.join(args.train_dir, "train.log")
        logging.getLogger().addFile(log_file)

        # copy config file
        backup_cfg_file = os.path.join(args.train_dir, "config.yaml")
        shutil.copyfile(args.cfg_file, backup_cfg_file)
    else:
        backup_cfg_file = args.cfg_file

    # cuda
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu)
        cudnn.benchmark = True
        cudnn.enabled = True
        logging.info("GPU device = %d" % args.gpu)
    else:
        logging.info("no GPU available, use CPU!!")

    if args.seed is not None:
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    search_space = get_search_space("nasbench-201", load_nasbench=False)
    logging.info(
        "Load pkl cache from nasbench201.pkl and nasbench201_valid.pkl")
    with open(args.train_pkl, "rb") as rf:
        train_data = pickle.load(rf)
    with open(args.valid_pkl, "rb") as rf:
        valid_data = pickle.load(rf)

    with open(backup_cfg_file, "r") as cfg_f:
        cfg = yaml.load(cfg_f)

    logging.info("Config: %s", cfg)

    arch_network_type = cfg.get("arch_network_type", "pointwise_comparator")
    model_cls = ArchNetwork.get_class_(arch_network_type)
    model = model_cls(search_space, **cfg.pop("arch_network_cfg"))
    if args.load is not None:
        logging.info("Load %s from %s", arch_network_type, args.load)
        model.load(args.load)
    model.to(device)

    args.__dict__.update(cfg)
    logging.info("Combined args: %s", args)

    # init nasbench data loaders
    if hasattr(args, "train_ratio") and args.train_ratio is not None:
        _num = len(train_data)
        train_data = train_data[:int(_num * args.train_ratio)]
        logging.info("Train dataset ratio: %.3f", args.train_ratio)
    num_train_archs = len(train_data)
    logging.info("Number of architectures: train: %d; valid: %d",
                 num_train_archs, len(valid_data))
    # decide how many archs would only train to halftime
    if hasattr(args, "ignore_quantile") and args.ignore_quantile is not None:
        half_accs = [item[2] for item in train_data]
        if not args.compare or getattr(args, "ignore_halftime", False):
            # just ignore halftime archs
            full_inds = np.argsort(half_accs)[int(num_train_archs *
                                                  args.ignore_quantile):]
            train_data = [train_data[ind] for ind in full_inds]
            logging.info(
                "#Train architectures after ignore half-time %.2f bad archs: %d",
                args.ignore_quantile, len(train_data))
        else:
            half_inds = np.argsort(half_accs)[:int(num_train_archs *
                                                   args.ignore_quantile)]
            logging.info(
                "#Architectures do not need to be trained to final: %.2f (%.2f %%)",
                len(half_inds), args.ignore_quantile * 100)
            for ind in half_inds:
                train_data[ind] = (train_data[ind][0], None,
                                   train_data[ind][2])
    train_data = NasBench201Dataset(train_data,
                                    minus=cfg.get("dataset_minus", None),
                                    div=cfg.get("dataset_div", None))
    valid_data = NasBench201Dataset(valid_data,
                                    minus=cfg.get("dataset_minus", None),
                                    div=cfg.get("dataset_div", None))
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              pin_memory=True,
                              num_workers=args.num_workers,
                              collate_fn=lambda items: list(zip(*items)))
    val_loader = DataLoader(valid_data,
                            batch_size=args.batch_size,
                            shuffle=False,
                            pin_memory=True,
                            num_workers=args.num_workers,
                            collate_fn=lambda items: list(zip(*items)))

    # init test
    if not arch_network_type in {"pairwise_comparator", "random_forest"
                                 } or args.test_only:
        if args.test_funcs is not None:
            test_func_names = args.test_funcs.split(",")
        corr, func_res = valid(
            val_loader,
            model,
            args,
            funcs=[globals()[func_name] for func_name in test_func_names]
            if args.test_funcs is not None else [])
        if args.test_funcs is None:
            logging.info("INIT: kendall tau {:.4f}".format(corr))
        else:
            logging.info("INIT: kendall tau {:.4f};\n\t{}".format(
                corr, "\n\t".join([
                    "{}: {}".format(name, res)
                    for name, res in zip(test_func_names, func_res)
                ])))

    if args.test_only:
        return

    for i_epoch in range(1, args.epochs + 1):
        model.on_epoch_start(i_epoch)
        if getattr(args, "use_listwise", False):
            avg_loss = train_listwise(train_data, model, i_epoch, args,
                                      arch_network_type)
        else:
            avg_loss = train(train_loader, model, i_epoch, args,
                             arch_network_type)
        logging.info("Train: Epoch {:3d}: train loss {:.4f}".format(
            i_epoch, avg_loss))
        if args.eval_only_last is None or (args.epochs - i_epoch <
                                           args.eval_only_last):
            corr, _ = valid(val_loader, model, args)
            logging.info("Valid: Epoch {:3d}: kendall tau {:.4f}".format(
                i_epoch, corr))
        if i_epoch % args.save_every == 0:
            save_path = os.path.join(args.train_dir, "{}.ckpt".format(i_epoch))
            model.save(save_path)
            logging.info("Epoch {:3d}: Save checkpoint to {}".format(
                i_epoch, save_path))