コード例 #1
0
def correlation_between_one_shot_nb(model_path, config, epoch):
    if config['search_space'] == '1':
        search_space = SearchSpace1()
    elif config['search_space'] == '2':
        search_space = SearchSpace2()
    elif config['search_space'] == '3':
        search_space = SearchSpace3()
    else:
        raise ValueError('Unknown search space')
    model = DartsWrapper(
        save_path=model_path,
        seed=0,
        batch_size=128,
        grad_clip=5,
        epochs=200,
        num_intermediate_nodes=search_space.num_intermediate_nodes,
        search_space=search_space,
        cutout=False)
    if 'random_ws' in model_path:
        discrete = True
        normalize = False
    else:
        discrete = False
        normalize = True

    model.load(epoch=epoch)
    nb_test_errors = []
    nb_valid_errors = []
    one_shot_test_errors = []
    for adjacency_matrix, ops, model_spec in search_space.generate_search_space_without_loose_ends(
    ):
        if str(config['search_space']) == '1' or str(
                config['search_space']) == '2':
            adjacency_matrix_ss = np.delete(np.delete(adjacency_matrix, -2, 0),
                                            -2, 0)
            # Remove input, output and 5th node
            ops_ss = ops[1:-2]
        elif str(config['search_space']) == '3':
            adjacency_matrix_ss = adjacency_matrix
            # Remove input and output node
            ops_ss = ops[1:-1]
        else:
            raise ValueError('Unknown search space')

        one_shot_test_error = model.evaluate_test(
            (adjacency_matrix_ss, ops_ss),
            split='test',
            discrete=discrete,
            normalize=normalize)
        one_shot_test_errors.extend(np.repeat(one_shot_test_error, 3))
        # Query NASBench
        data = nasbench.query(model_spec)
        nb_test_errors.extend([1 - item['test_accuracy'] for item in data])
        nb_valid_errors.extend(
            [1 - item['validation_accuracy'] for item in data])
        print('NB', nb_test_errors[-1], 'OS', one_shot_test_errors[-1],
              'weights', model.model.arch_parameters())

    correlation = np.corrcoef(one_shot_test_errors, nb_test_errors)[0, -1]
    return correlation, nb_test_errors, nb_valid_errors, one_shot_test_errors
def correlation_between_one_shot_nb(model_path, config, epoch):
    if config['search_space'] == '1':
        search_space = SearchSpace1()
    elif config['search_space'] == '2':
        search_space = SearchSpace2()
    elif config['search_space'] == '3':
        search_space = SearchSpace3()
    else:
        raise ValueError('Unknown search space')
    model = DartsWrapper(save_path=model_path, seed=0, batch_size=128, grad_clip=5, epochs=200,
                         num_intermediate_nodes=search_space.num_intermediate_nodes, search_space=search_space,
                         cutout=False)
    discrete = True
    normalize = False

    model.load(epoch=epoch)
    controller = torch.load(os.path.join(model_path, 'controller_epoch_{}.pt'.format(epoch)))

    nb_test_errors = {'4': [], '12': [], '36': [], '108': []}
    nb_valid_errors = {'4': [], '12': [], '36': [], '108': []}
    one_shot_test_errors = []

    for idx in range(100):
        (adjacency_matrix_ss, ops_ss), _, _ = controller()

        print(adjacency_matrix_ss, ops_ss)

        one_shot_test_error = model.evaluate_test((adjacency_matrix_ss, ops_ss), split='test', discrete=discrete,
                                                  normalize=normalize)
        one_shot_test_errors.extend(np.repeat(one_shot_test_error, 3))

        # # Query NASBench
        # Create nested list from numpy matrix
        if str(config['search_space']) == '1' or str(config['search_space']) == '2':
            adjacency_matrix_ss = upscale_to_nasbench_format(adjacency_matrix_ss)
            # Remove input, output and 5th node
            ops_ss.append(CONV1X1)

        nasbench_adjacency_matrix = adjacency_matrix_ss.astype(np.int).tolist()

        ops_ss.insert(0, INPUT)
        ops_ss.append(OUTPUT)

        # Assemble the model spec
        model_spec = api.ModelSpec(
            # Adjacency matrix of the module
            matrix=nasbench_adjacency_matrix,
            # Operations at the vertices of the module, matches order of matrix
            ops=ops_ss)
        for nb_epoch_budget in [4, 12, 36, 108]:
            data = nasbench.query(model_spec=model_spec, epochs=nb_epoch_budget)
            nb_test_errors[str(nb_epoch_budget)].extend([1 - item['test_accuracy'] for item in data])
            nb_valid_errors[str(nb_epoch_budget)].extend([1 - item['validation_accuracy'] for item in data])
        # print('NB', nb_test_errors[-1], 'OS', one_shot_test_errors[-1], 'weights', model.model.arch_parameters())

    # correlation = np.corrcoef(one_shot_test_errors, nb_test_errors)[0, -1]
    return None, nb_test_errors, nb_valid_errors, one_shot_test_errors
コード例 #3
0
def main(args):
    # Fill in with root output path
    root_dir = os.getcwd()
    print('root_dir', root_dir)
    if args.save_dir is None:
        save_dir = os.path.join(
            root_dir, 'experiments/random_ws/ss_{}_{}_{}'.format(
                time.strftime("%Y%m%d-%H%M%S"), args.search_space, args.seed))
    else:
        save_dir = args.save_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if args.eval_only:
        assert args.save_dir is not None

    # Dump the config of the run folder
    with open(os.path.join(save_dir, 'config.json'), 'w') as fp:
        json.dump(args.__dict__, fp)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    logging.info(args)

    if args.search_space == '1':
        search_space = SearchSpace1()
    elif args.search_space == '2':
        search_space = SearchSpace2()
    elif args.search_space == '3':
        search_space = SearchSpace3()
    else:
        raise ValueError('Unknown search space')

    if args.benchmark == 'ptb':
        raise ValueError('PTB not supported.')
    else:
        data_size = 25000
        time_steps = 1

    B = int(args.epochs * data_size / args.batch_size / time_steps)
    if args.benchmark == 'cnn':
        from optimizers.random_search_with_weight_sharing.darts_wrapper_discrete import DartsWrapper
        model = DartsWrapper(
            save_dir,
            args.seed,
            args.batch_size,
            args.grad_clip,
            args.epochs,
            num_intermediate_nodes=search_space.num_intermediate_nodes,
            search_space=search_space,
            init_channels=args.init_channels,
            cutout=args.cutout)
    else:
        raise ValueError('Benchmarks other cnn on cifar are not available')

    searcher = Random_NAS(B, model, args.seed, save_dir)
    logging.info('budget: %d' % (searcher.B))
    if not args.eval_only:
        searcher.run()
        archs = searcher.get_eval_arch()
    else:
        np.random.seed(args.seed + 1)
        archs = searcher.get_eval_arch(2)
    logging.info(archs)
    arch = ' '.join([str(a) for a in archs[0][0]])
    with open('/tmp/arch', 'w') as f:
        f.write(arch)
    return arch
コード例 #4
0
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info(args)

    if args.benchmark == 'ptb':
        raise ValueError('PTB not supported.')
    else:
        data_size = 25000
        time_steps = 1

    B = int(args.epochs * data_size / args.batch_size / time_steps)
    if args.benchmark == 'cnn':
        from optimizers.random_search_with_weight_sharing.darts_wrapper_discrete import DartsWrapper
        model = DartsWrapper(save_dir, args.seed, args.batch_size, args.grad_clip, args.epochs, gpu=args.gpu,
                             num_intermediate_nodes=search_space.num_intermediate_nodes, search_space=search_space,
                             init_channels=args.init_channels, cutout=args.cutout)
    else:
        raise ValueError('Benchmarks other cnn on cifar are not available')

    searcher = Random_NAS(B, model, args.seed, save_dir)
    logging.info('budget: %d' % (searcher.B))
    if not args.eval_only:
        searcher.run()
        # archs = searcher.get_eval_arch()
    else:
        np.random.seed(args.seed + 1)
        archs = searcher.get_eval_arch(2)
    # logging.info(archs)
    # arch = ' '.join([str(a) for a in archs[0][0]])
    # with open('/tmp/arch', 'w') as f: