def basic_eval_all(args, loaders, net, criterion, epoch_str, logger, opt_config):
  args = deepcopy(args)
  logger.log('Basic-Eval-All evaluates {:} dataset'.format(len(loaders)))
  nmes = []
  for i, (loader, is_video) in enumerate(loaders):
    logger.log('==>>{:}, [{:}], evaluate the {:}/{:}-th dataset [{:}] : {:}'.format(time_string(), epoch_str, i, len(loaders), 'video' if is_video else 'image', loader.dataset))
    with torch.no_grad():
      eval_loss, eval_meta = basic_eval(args, loader, net, criterion, epoch_str+"::{:}/{:}".format(i,len(loaders)), logger, opt_config)
    nme, _, _ = eval_meta.compute_mse(logger)
    meta_path = logger.path('meta') / 'eval-{:}-{:02d}-{:02d}.pth'.format(epoch_str, i, len(loaders))
    eval_meta.save(meta_path)
    nmes.append(nme)
  return ', '.join(['{:.1f}'.format(x) for x in nmes])
Пример #2
0
def visualize_rank_over_time(meta_file, vis_save_dir):
    print('\n' + '-' * 150)
    vis_save_dir.mkdir(parents=True, exist_ok=True)
    print('{:} start to visualize rank-over-time into {:}'.format(
        time_string(), vis_save_dir))
    cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth'
    if not cache_file_path.exists():
        print('Do not find cache file : {:}'.format(cache_file_path))
        nas_bench = API(str(meta_file))
        print('{:} load nas_bench done'.format(time_string()))
        params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(
            list), defaultdict(list), defaultdict(list), defaultdict(list)
        #for iepoch in range(200): for index in range( len(nas_bench) ):
        for index in tqdm(range(len(nas_bench))):
            info = nas_bench.query_by_index(index, use_12epochs_result=False)
            for iepoch in range(200):
                res = info.get_metrics('cifar10', 'train', iepoch)
                train_acc = res['accuracy']
                res = info.get_metrics('cifar10-valid', 'x-valid', iepoch)
                valid_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test', iepoch)
                test_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test', iepoch)
                otest_acc = res['accuracy']
                train_accs[iepoch].append(train_acc)
                valid_accs[iepoch].append(valid_acc)
                test_accs[iepoch].append(test_acc)
                otest_accs[iepoch].append(otest_acc)
                if iepoch == 0:
                    res = info.get_comput_costs('cifar10')
                    flop, param = res['flops'], res['params']
                    flops.append(flop)
                    params.append(param)
        info = {
            'params': params,
            'flops': flops,
            'train_accs': train_accs,
            'valid_accs': valid_accs,
            'test_accs': test_accs,
            'otest_accs': otest_accs
        }
        torch.save(info, cache_file_path)
    else:
        print('Find cache file : {:}'.format(cache_file_path))
        info = torch.load(cache_file_path)
        params, flops, train_accs, valid_accs, test_accs, otest_accs = info[
            'params'], info['flops'], info['train_accs'], info[
                'valid_accs'], info['test_accs'], info['otest_accs']
    print('{:} collect data done.'.format(time_string()))
    #selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
    selected_epochs = list(range(200))
    x_xtests = test_accs[199]
    indexes = list(range(len(x_xtests)))
    ord_idxs = sorted(indexes, key=lambda i: x_xtests[i])
    for sepoch in selected_epochs:
        x_valids = valid_accs[sepoch]
        valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i])
        valid_ord_lbls = []
        for idx in ord_idxs:
            valid_ord_lbls.append(valid_ord_idxs.index(idx))
        # labeled data
        dpi, width, height = 300, 2600, 2600
        figsize = width / float(dpi), height / float(dpi)
        LabelSize, LegendFontsize = 18, 18

        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)
        plt.xlim(min(indexes), max(indexes))
        plt.ylim(min(indexes), max(indexes))
        plt.yticks(np.arange(min(indexes), max(indexes),
                             max(indexes) // 6),
                   fontsize=LegendFontsize,
                   rotation='vertical')
        plt.xticks(np.arange(min(indexes), max(indexes),
                             max(indexes) // 6),
                   fontsize=LegendFontsize)
        ax.scatter(indexes,
                   valid_ord_lbls,
                   marker='^',
                   s=0.5,
                   c='tab:green',
                   alpha=0.8)
        ax.scatter(indexes,
                   indexes,
                   marker='o',
                   s=0.5,
                   c='tab:blue',
                   alpha=0.8)
        ax.scatter([-1], [-1],
                   marker='^',
                   s=100,
                   c='tab:green',
                   label='CIFAR-10 validation')
        ax.scatter([-1], [-1],
                   marker='o',
                   s=100,
                   c='tab:blue',
                   label='CIFAR-10 test')
        plt.grid(zorder=0)
        ax.set_axisbelow(True)
        plt.legend(loc='upper left', fontsize=LegendFontsize)
        ax.set_xlabel('architecture ranking in the final test accuracy',
                      fontsize=LabelSize)
        ax.set_ylabel('architecture ranking in the validation set',
                      fontsize=LabelSize)
        save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve()
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
        save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve()
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
        print('{:} save into {:}'.format(time_string(), save_path))
        plt.close('all')
Пример #3
0
def main(args):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.set_num_threads(args.workers)
    print('Training Base Detector : prepare_seed : {:}'.format(args.rand_seed))
    prepare_seed(args.rand_seed)

    logger = prepare_logger(args)

    checkpoint = load_checkpoint(args.init_model)
    xargs = checkpoint['args']
    logger.log('Previous args : {:}'.format(xargs))

    # General Data Augmentation
    if xargs.use_gray == False:
        mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    else:
        mean_fill = (0.5, )
        normalize = transforms.Normalize(mean=[mean_fill[0]], std=[0.5])
    eval_transform  = transforms.Compose2V([transforms.ToTensor(), normalize, \
                                                transforms.PreCrop(xargs.pre_crop_expand), \
                                                transforms.CenterCrop(xargs.crop_max)])

    # Model Configure Load
    model_config = load_configure(xargs.model_config, logger)
    shape = (xargs.height, xargs.width)
    logger.log('--> {:}\n--> Sigma : {:}, Shape : {:}'.format(
        model_config, xargs.sigma, shape))

    # Evaluation Dataloader
    eval_loaders = []
    if args.eval_ilists is not None:
        for eval_ilist in args.eval_ilists:
            eval_idata = EvalDataset(eval_transform, xargs.sigma,
                                     model_config.downsample,
                                     xargs.heatmap_type, shape, xargs.use_gray,
                                     xargs.data_indicator)
            eval_idata.load_list(eval_ilist, args.num_pts, xargs.boxindicator,
                                 xargs.normalizeL, True)
            eval_iloader = torch.utils.data.DataLoader(
                eval_idata,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            eval_loaders.append((eval_iloader, False))
    if args.eval_vlists is not None:
        for eval_vlist in args.eval_vlists:
            eval_vdata = EvalDataset(eval_transform, xargs.sigma,
                                     model_config.downsample,
                                     xargs.heatmap_type, shape, xargs.use_gray,
                                     xargs.data_indicator)
            eval_vdata.load_list(eval_vlist, args.num_pts, xargs.boxindicator,
                                 xargs.normalizeL, True)
            eval_vloader = torch.utils.data.DataLoader(
                eval_vdata,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            eval_loaders.append((eval_vloader, True))

    # define the detector
    detector = obtain_pro_model(model_config, xargs.num_pts, xargs.sigma,
                                xargs.use_gray)
    assert model_config.downsample == detector.downsample, 'downsample is not correct : {:} vs {:}'.format(
        model_config.downsample, detector.downsample)
    logger.log("=> detector :\n {:}".format(detector))
    logger.log("=> Net-Parameters : {:} MB".format(
        count_parameters_in_MB(detector)))
    logger.log('=> Eval-Transform : {:}'.format(eval_transform))

    detector = detector.cuda()
    net = torch.nn.DataParallel(detector)
    net.eval()
    net.load_state_dict(checkpoint['detector'])
    cpu = torch.device('cpu')

    assert len(args.use_stable) == 2

    for iLOADER, (loader, is_video) in enumerate(eval_loaders):
        logger.log(
            '{:} The [{:2d}/{:2d}]-th test set [{:}] = {:} with {:} batches.'.
            format(time_string(), iLOADER, len(eval_loaders),
                   'video' if is_video else 'image', loader.dataset,
                   len(loader)))
        with torch.no_grad():
            all_points, all_results, all_image_ps = [], [], []
            for i, (inputs, targets, masks, normpoints, transthetas,
                    image_index, nopoints, shapes) in enumerate(loader):
                image_index = image_index.squeeze(1).tolist()
                (batch_size, C, H, W), num_pts = inputs.size(), xargs.num_pts
                # batch_heatmaps is a list for stage-predictions, each element should be [Batch, C, H, W]
                if xargs.procedure == 'heatmap':
                    batch_features, batch_heatmaps, batch_locs, batch_scos = net(
                        inputs)
                    batch_locs = batch_locs[:, :-1, :]
                else:
                    batch_locs = net(inputs)
                batch_locs = batch_locs.detach().to(cpu)
                # evaluate the training data
                for ibatch, (imgidx,
                             nopoint) in enumerate(zip(image_index, nopoints)):
                    if xargs.procedure == 'heatmap':
                        norm_locs = normalize_points(
                            (H, W), batch_locs[ibatch].transpose(1, 0))
                        norm_locs = torch.cat(
                            (norm_locs, torch.ones(1, num_pts)), dim=0)
                    else:
                        norm_locs = torch.cat((batch_locs[ibatch].permute(
                            1, 0), torch.ones(1, num_pts)),
                                              dim=0)
                    transtheta = transthetas[ibatch][:2, :]
                    norm_locs = torch.mm(transtheta, norm_locs)
                    real_locs = denormalize_points(shapes[ibatch].tolist(),
                                                   norm_locs)
                    #real_locs  = torch.cat((real_locs, batch_scos[ibatch].permute(1,0)), dim=0)
                    real_locs = torch.cat((real_locs, torch.ones(1, num_pts)),
                                          dim=0)
                    xpoints = loader.dataset.labels[imgidx].get_points().numpy(
                    )
                    image_path = loader.dataset.datas[imgidx]
                    # put into the list
                    all_points.append(torch.from_numpy(xpoints))
                    all_results.append(real_locs)
                    all_image_ps.append(image_path)
            total = len(all_points)
            logger.log(
                '{:} The [{:2d}/{:2d}]-th test set finishes evaluation : {:} frames/images'
                .format(time_string(), iLOADER, len(eval_loaders), total))
        """
    if args.use_stable[0] > 0:
      save_dir = Path( osp.join(args.save_path, '{:}-X-{:03d}'.format(args.model_name, iLOADER)) )
      save_dir.mkdir(parents=True, exist_ok=True)
      wrap_parallel = WrapParallel(save_dir, all_image_ps, all_results, all_points, 180, (255, 0, 0))
      wrap_loader   = torch.utils.data.DataLoader(wrap_parallel, batch_size=args.workers, shuffle=False, num_workers=args.workers, pin_memory=True)
      for iL, INDEXES in enumerate(wrap_loader): _ = INDEXES
      cmd = 'ffmpeg -y -i {:}/%06d.png -framerate 30 {:}.avi'.format(save_dir, save_dir)
      logger.log('{:} possible >>>>> : {:}'.format(time_string(), cmd))
      os.system( cmd )

    if args.use_stable[1] > 0:
      save_dir = Path( osp.join(args.save_path, '{:}-Y-{:03d}'.format(args.model_name, iLOADER)) )
      save_dir.mkdir(parents=True, exist_ok=True)
      Xpredictions, Xgts = torch.stack(all_results), torch.stack(all_points)
      new_preds = fc_solve(Xgts, Xpredictions, is_cuda=True)
      wrap_parallel = WrapParallel(save_dir, all_image_ps, new_preds, all_points, 180, (0, 0, 255))
      wrap_loader   = torch.utils.data.DataLoader(wrap_parallel, batch_size=args.workers, shuffle=False, num_workers=args.workers, pin_memory=True)
      for iL, INDEXES in enumerate(wrap_loader): _ = INDEXES
      cmd = 'ffmpeg -y -i {:}/%06d.png -framerate 30 {:}.avi'.format(save_dir, save_dir)
      logger.log('{:} possible >>>>> : {:}'.format(time_string(), cmd))
      os.system( cmd )
    """
        Xpredictions, Xgts = torch.stack(all_results), torch.stack(all_points)
        save_path = Path(
            osp.join(args.save_path,
                     '{:}-result-{:03d}.pth'.format(args.model_name, iLOADER)))
        torch.save(
            {
                'paths': all_image_ps,
                'ground-truths': Xgts,
                'predictions': all_results
            }, save_path)
        logger.log('{:} save into {:}'.format(time_string(), save_path))
        if False:
            new_preds = fc_solve_v2(Xgts, Xpredictions, is_cuda=True)
            # create the dir
            save_dir = Path(
                osp.join(args.save_path,
                         '{:}-T-{:03d}'.format(args.model_name, iLOADER)))
            save_dir.mkdir(parents=True, exist_ok=True)
            wrap_parallel = WrapParallelV2(save_dir, all_image_ps, Xgts,
                                           all_results, new_preds, all_points,
                                           180, [args.model_name, 'SRT'])
            wrap_parallel[0]
            wrap_loader = torch.utils.data.DataLoader(wrap_parallel,
                                                      batch_size=args.workers,
                                                      shuffle=False,
                                                      num_workers=args.workers,
                                                      pin_memory=True)
            for iL, INDEXES in enumerate(wrap_loader):
                _ = INDEXES
            cmd = 'ffmpeg -y -i {:}/%06d.png -vb 5000k {:}.avi'.format(
                save_dir, save_dir)
            logger.log('{:} possible >>>>> : {:}'.format(time_string(), cmd))
            os.system(cmd)

    logger.close()
    return
Пример #4
0
def main(xargs):
    cifar10 = tf.keras.datasets.cifar10

    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    x_train, x_test = x_train.astype('float32'), x_test.astype('float32')

    # Add a channels dimension
    all_indexes = list(range(x_train.shape[0]))
    random.shuffle(all_indexes)
    s_train_idxs, s_valid_idxs = all_indexes[::2], all_indexes[1::2]
    search_train_x, search_train_y = x_train[s_train_idxs], y_train[
        s_train_idxs]
    search_valid_x, search_valid_y = x_train[s_valid_idxs], y_train[
        s_valid_idxs]
    #x_train, x_test = x_train[..., tf.newaxis], x_test[..., tf.newaxis]

    # Use tf.data
    #train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64)
    search_ds = tf.data.Dataset.from_tensor_slices(
        (search_train_x, search_train_y, search_valid_x, search_valid_y))
    search_ds = search_ds.map(pre_process).shuffle(1000).batch(64)

    test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

    # Create an instance of the model
    config = dict2config(
        {
            'name': 'GDAS',
            'C': xargs.channel,
            'N': xargs.num_cells,
            'max_nodes': xargs.max_nodes,
            'num_classes': 10,
            'space': 'nas-bench-201',
            'affine': True
        }, None)
    model = get_cell_based_tiny_net(config)
    #import pdb; pdb.set_trace()
    #model.build(((64, 32, 32, 3), (1,)))
    #for x in model.trainable_variables:
    #  print('{:30s} : {:}'.format(x.name, x.shape))
    # Choose optimizer
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
    w_optimizer = SGDW(learning_rate=xargs.w_lr,
                       weight_decay=xargs.w_weight_decay,
                       momentum=xargs.w_momentum,
                       nesterov=True)
    a_optimizer = AdamW(learning_rate=xargs.arch_learning_rate,
                        weight_decay=xargs.arch_weight_decay,
                        beta_1=0.5,
                        beta_2=0.999,
                        epsilon=1e-07)
    #w_optimizer = tf.keras.optimizers.SGD(learning_rate=0.025, momentum=0.9, nesterov=True)
    #a_optimizer = tf.keras.optimizers.AdamW(learning_rate=xargs.arch_learning_rate, beta_1=0.5, beta_2=0.999, epsilon=1e-07)
    ####
    # metrics
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    valid_loss = tf.keras.metrics.Mean(name='valid_loss')
    valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='valid_accuracy')
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='test_accuracy')

    @tf.function
    def search_step(train_images, train_labels, valid_images, valid_labels,
                    tf_tau):
        # optimize weights
        with tf.GradientTape() as tape:
            predictions = model(train_images, tf_tau, True)
            w_loss = loss_object(train_labels, predictions)
        net_w_param = model.get_weights()
        gradients = tape.gradient(w_loss, net_w_param)
        w_optimizer.apply_gradients(zip(gradients, net_w_param))
        train_loss(w_loss)
        train_accuracy(train_labels, predictions)
        # optimize alphas
        with tf.GradientTape() as tape:
            predictions = model(valid_images, tf_tau, True)
            a_loss = loss_object(valid_labels, predictions)
        net_a_param = model.get_alphas()
        gradients = tape.gradient(a_loss, net_a_param)
        a_optimizer.apply_gradients(zip(gradients, net_a_param))
        valid_loss(a_loss)
        valid_accuracy(valid_labels, predictions)

    # TEST
    @tf.function
    def test_step(images, labels):
        predictions = model(images)
        t_loss = loss_object(labels, predictions)

        test_loss(t_loss)
        test_accuracy(labels, predictions)

    print(
        '{:} start searching with {:} epochs ({:} batches per epoch).'.format(
            time_string(), xargs.epochs,
            tf.data.experimental.cardinality(search_ds).numpy()))

    for epoch in range(xargs.epochs):
        # Reset the metrics at the start of the next epoch
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()
        cur_tau = xargs.tau_max - (xargs.tau_max -
                                   xargs.tau_min) * epoch / (xargs.epochs - 1)
        tf_tau = tf.cast(cur_tau, dtype=tf.float32, name='tau')

        for trn_imgs, trn_labels, val_imgs, val_labels in search_ds:
            search_step(trn_imgs, trn_labels, val_imgs, val_labels, tf_tau)
        genotype = model.genotype()
        genotype = CellStructure(genotype)

        #for test_images, test_labels in test_ds:
        #  test_step(test_images, test_labels)

        template = '{:} Epoch {:03d}/{:03d}, Train-Loss: {:.3f}, Train-Accuracy: {:.2f}%, Valid-Loss: {:.3f}, Valid-Accuracy: {:.2f}% | tau={:.3f}'
        print(
            template.format(time_string(), epoch + 1, xargs.epochs,
                            train_loss.result(),
                            train_accuracy.result() * 100, valid_loss.result(),
                            valid_accuracy.result() * 100, cur_tau))
        print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype,
                                                 model.get_np_alphas()))
Пример #5
0
def visualize_info(meta_file, dataset, vis_save_dir):
    print('{:} start to visualize {:} information'.format(
        time_string(), dataset))
    cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset)
    if not cache_file_path.exists():
        print('Do not find cache file : {:}'.format(cache_file_path))
        nas_bench = API(str(meta_file))
        params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], []
        for index in range(len(nas_bench)):
            info = nas_bench.query_by_index(index, use_12epochs_result=False)
            resx = info.get_comput_costs(dataset)
            flop, param = resx['flops'], resx['params']
            if dataset == 'cifar10':
                res = info.get_metrics('cifar10', 'train')
                train_acc = res['accuracy']
                res = info.get_metrics('cifar10-valid', 'x-valid')
                valid_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test')
                test_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test')
                otest_acc = res['accuracy']
            else:
                res = info.get_metrics(dataset, 'train')
                train_acc = res['accuracy']
                res = info.get_metrics(dataset, 'x-valid')
                valid_acc = res['accuracy']
                res = info.get_metrics(dataset, 'x-test')
                test_acc = res['accuracy']
                res = info.get_metrics(dataset, 'ori-test')
                otest_acc = res['accuracy']
            if index == 11472:  # resnet
                resnet = {
                    'params': param,
                    'flops': flop,
                    'index': 11472,
                    'train_acc': train_acc,
                    'valid_acc': valid_acc,
                    'test_acc': test_acc,
                    'otest_acc': otest_acc
                }
            flops.append(flop)
            params.append(param)
            train_accs.append(train_acc)
            valid_accs.append(valid_acc)
            test_accs.append(test_acc)
            otest_accs.append(otest_acc)
        #resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97}
        info = {
            'params': params,
            'flops': flops,
            'train_accs': train_accs,
            'valid_accs': valid_accs,
            'test_accs': test_accs,
            'otest_accs': otest_accs
        }
        info['resnet'] = resnet
        torch.save(info, cache_file_path)
    else:
        print('Find cache file : {:}'.format(cache_file_path))
        info = torch.load(cache_file_path)
        params, flops, train_accs, valid_accs, test_accs, otest_accs = info[
            'params'], info['flops'], info['train_accs'], info[
                'valid_accs'], info['test_accs'], info['otest_accs']
        resnet = info['resnet']
    print('{:} collect data done.'.format(time_string()))

    indexes = list(range(len(params)))
    dpi, width, height = 300, 2600, 2600
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 22, 22
    resnet_scale, resnet_alpha = 120, 0.5

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['valid_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=0.4)
    plt.grid(zorder=0)
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-valid.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-valid.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['test_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-test.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-test.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(20, 100)
        plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(25, 76)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['train_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-train.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-train.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xlim(0, max(indexes))
    plt.xticks(np.arange(min(indexes), max(indexes),
                         max(indexes) // 5),
               fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['index']], [resnet['test_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('architecture ID', fontsize=LabelSize)
    ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-test-over-ID.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-test-over-ID.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))
    plt.close('all')
Пример #6
0
def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
         splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any],
         to_evaluate_indexes: tuple, cover_mode: bool):

  log_dir = save_dir / 'logs'
  log_dir.mkdir(parents=True, exist_ok=True)
  logger = Logger(str(log_dir), os.getpid(), False)

  logger.log('xargs : seeds      = {:}'.format(seeds))
  logger.log('xargs : cover_mode = {:}'.format(cover_mode))
  logger.log('-' * 100)

  logger.log(
    'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes))
   +'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode))
  for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
    logger.log(
      '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
  logger.log('--->>> optimization config : {:}'.format(opt_config))
  #to_evaluate_indexes = list(range(srange[0], srange[1] + 1))

  start_time, epoch_time = time.time(), AverageMeter()
  for i, index in enumerate(to_evaluate_indexes):
    channelstr = nets[index]
    logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}'.format(time_string(), i,
                       len(to_evaluate_indexes), index, len(nets), seeds, '-' * 15))
    logger.log('{:} {:} {:}'.format('-' * 15, channelstr, '-' * 15))

    # test this arch on different datasets with different seeds
    has_continue = False
    for seed in seeds:
      to_save_name = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed)
      if to_save_name.exists():
        if cover_mode:
          logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name))
          os.remove(str(to_save_name))
        else:
          logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name))
          has_continue = True
          continue
      results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger)
      torch.save(results, to_save_name)
      logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}]  ===>>> {:}'.format(time_string(), i,
                    len(to_evaluate_indexes), index, len(nets), seeds, to_save_name))
    # measure elapsed time
    if not has_continue: epoch_time.update(time.time() - start_time)
    start_time = time.time()
    need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True))
    logger.log('This arch costs : {:}'.format(convert_secs2time(epoch_time.val, True)))
    logger.log('{:}'.format('*' * 100))
    logger.log('{:}   {:74s}   {:}'.format('*' * 10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(
      to_evaluate_indexes), index, len(nets), need_time), '*' * 10))
    logger.log('{:}'.format('*' * 100))

  logger.close()
Пример #7
0
def train_controller(xloader, network, criterion, optimizer, prev_baseline,
                     epoch_str, print_freq, logger):
    # config. (containing some necessary arg)
    #   baseline: The baseline score (i.e. average val_acc) from the previous epoch
    data_time, batch_time = AverageMeter(), AverageMeter()
    GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(
    ), AverageMeter(), time.time()

    controller_num_aggregate = 20
    controller_train_steps = 50
    controller_bl_dec = 0.99
    controller_entropy_weight = 0.0001

    network.eval()
    network.controller.train()
    network.controller.zero_grad()
    loader_iter = iter(xloader)
    for step in range(controller_train_steps * controller_num_aggregate):
        try:
            inputs, targets = next(loader_iter)
        except:
            loader_iter = iter(xloader)
            inputs, targets = next(loader_iter)
        inputs = inputs.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - xend)

        log_prob, entropy, sampled_arch = network.controller()
        with torch.no_grad():
            network.set_cal_mode('dynamic', sampled_arch)
            _, logits = network(inputs)
            val_top1, val_top5 = obtain_accuracy(logits.data,
                                                 targets.data,
                                                 topk=(1, 5))
            val_top1 = val_top1.view(-1) / 100
        reward = val_top1 + controller_entropy_weight * entropy
        if prev_baseline is None:
            baseline = val_top1
        else:
            baseline = prev_baseline - (1 - controller_bl_dec) * (
                prev_baseline - reward)

        loss = -1 * log_prob * (reward - baseline)

        # account
        RewardMeter.update(reward.item())
        BaselineMeter.update(baseline.item())
        ValAccMeter.update(val_top1.item() * 100)
        LossMeter.update(loss.item())
        EntropyMeter.update(entropy.item())

        # Average gradient over controller_num_aggregate samples
        loss = loss / controller_num_aggregate
        loss.backward(retain_graph=True)

        # measure elapsed time
        batch_time.update(time.time() - xend)
        xend = time.time()
        if (step + 1) % controller_num_aggregate == 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                network.controller.parameters(), 5.0)
            GradnormMeter.update(grad_norm)
            optimizer.step()
            network.controller.zero_grad()

        if step % print_freq == 0:
            Sstr = '*Train-Controller* ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(
                epoch_str, step,
                controller_train_steps * controller_num_aggregate)
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(
                loss=LossMeter,
                top1=ValAccMeter,
                reward=RewardMeter,
                basel=BaselineMeter)
            Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val,
                                                    EntropyMeter.avg)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr)

    return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg
def visualize_all_rank_info(api, vis_save_dir, indicator):
    vis_save_dir = vis_save_dir.resolve()
    # print ('{:} start to visualize {:} information'.format(time_string(), api))
    vis_save_dir.mkdir(parents=True, exist_ok=True)

    cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format(
        'cifar10', indicator)
    cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format(
        'cifar100', indicator)
    imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format(
        'ImageNet16-120', indicator)
    cifar010_info = torch.load(cifar010_cache_path)
    cifar100_info = torch.load(cifar100_cache_path)
    imagenet_info = torch.load(imagenet_cache_path)
    indexes = list(range(len(cifar010_info['params'])))

    print('{:} start to visualize relative ranking'.format(time_string()))

    dpi, width, height = 250, 3200, 1400
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 14, 14

    fig, axs = plt.subplots(1, 2, figsize=figsize)
    ax1, ax2 = axs

    sns_size = 15
    CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'],
                                        cifar010_info['test_accs'],
                                        cifar100_info['valid_accs'],
                                        cifar100_info['test_accs'],
                                        imagenet_info['valid_accs'],
                                        imagenet_info['test_accs'])

    sns.heatmap(
        CoRelMatrix,
        annot=True,
        annot_kws={'size': sns_size},
        fmt='.3f',
        linewidths=0.5,
        ax=ax1,
        xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'],
        yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'])

    selected_indexes, acc_bar = [], 92
    for i, acc in enumerate(cifar010_info['test_accs']):
        if acc > acc_bar: selected_indexes.append(i)
    cifar010_valid_accs = np.array(
        cifar010_info['valid_accs'])[selected_indexes]
    cifar010_test_accs = np.array(cifar010_info['test_accs'])[selected_indexes]
    cifar100_valid_accs = np.array(
        cifar100_info['valid_accs'])[selected_indexes]
    cifar100_test_accs = np.array(cifar100_info['test_accs'])[selected_indexes]
    imagenet_valid_accs = np.array(
        imagenet_info['valid_accs'])[selected_indexes]
    imagenet_test_accs = np.array(imagenet_info['test_accs'])[selected_indexes]
    CoRelMatrix = calculate_correlation(
        cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs,
        cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs)

    sns.heatmap(
        CoRelMatrix,
        annot=True,
        annot_kws={'size': sns_size},
        fmt='.3f',
        linewidths=0.5,
        ax=ax2,
        xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'],
        yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'])
    ax1.set_title('Correlation coefficient over ALL candidates')
    ax2.set_title(
        'Correlation coefficient over candidates with accuracy > {:}%'.format(
            acc_bar))
    save_path = (vis_save_dir /
                 '{:}-all-relative-rank.png'.format(indicator)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))
    plt.close('all')
Пример #9
0
def main(xargs):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = False
  torch.backends.cudnn.deterministic = True
  torch.set_num_threads( xargs.workers )
  prepare_seed(xargs.rand_seed)
  logger = prepare_logger(args)

  train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
  if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
    split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
    cifar_split = load_config(split_Fpath, None, None)
    train_split, valid_split = cifar_split.train, cifar_split.valid
    logger.log('Load split file from {:}'.format(split_Fpath))
  #elif xargs.dataset.startswith('ImageNet16'):
  #  # all_indexes = list(range(len(train_data))) ; random.seed(111) ; random.shuffle(all_indexes)
  #  # train_split, valid_split = sorted(all_indexes[: len(train_data)//2]), sorted(all_indexes[len(train_data)//2 :])
  #  # imagenet16_split = dict2config({'train': train_split, 'valid': valid_split}, None)
  #  # _ = configure2str(imagenet16_split, 'temp.txt')
  #  split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
  #  imagenet16_split = load_config(split_Fpath, None, None)
  #  train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
  #  logger.log('Load split file from {:}'.format(split_Fpath))
  else:
    raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
  config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
  logger.log('config : {:}'.format(config))
  # To split data
  train_data_v2 = deepcopy(train_data)
  train_data_v2.transform = valid_data.transform
  valid_data    = train_data_v2
  search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
  # data loader
  search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
  valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
  logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
  logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))

  search_space = get_search_spaces('cell', xargs.search_space_name)
  model_config = dict2config({'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells,
                              'max_nodes': xargs.max_nodes, 'num_classes': class_num,
                              'space'    : search_space}, None)
  search_model = get_cell_based_tiny_net(model_config)
  
  w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config)
  logger.log('w-optimizer : {:}'.format(w_optimizer))
  logger.log('w-scheduler : {:}'.format(w_scheduler))
  logger.log('criterion   : {:}'.format(criterion))
  if xargs.arch_nas_dataset is None: api = None
  else                             : api = API(xargs.arch_nas_dataset)
  logger.log('{:} create API = {:} done'.format(time_string(), api))

  last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
  network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()

  if last_info.exists(): # automatically resume from previous checkpoint
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info   = torch.load(last_info)
    start_epoch = last_info['epoch']
    checkpoint  = torch.load(last_info['last_checkpoint'])
    genotypes   = checkpoint['genotypes']
    valid_accuracies = checkpoint['valid_accuracies']
    search_model.load_state_dict( checkpoint['search_model'] )
    w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
    w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
    logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

  # start training
  start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
  for epoch in range(start_epoch, total_epoch):
    w_scheduler.update(epoch, 0.0)
    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
    epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
    logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))

    # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
    search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
    search_time.update(time.time() - start_time)
    logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
    valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
    logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
    cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
    genotypes[epoch] = cur_arch
    # check the best accuracy
    valid_accuracies[epoch] = valid_a_top1
    if valid_a_top1 > valid_accuracies['best']:
      valid_accuracies['best'] = valid_a_top1
      find_best = True
    else: find_best = False

    # save checkpoint
    save_path = save_checkpoint({'epoch' : epoch + 1,
                'args'  : deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer' : w_optimizer.state_dict(),
                'w_scheduler' : w_scheduler.state_dict(),
                'genotypes'   : genotypes,
                'valid_accuracies' : valid_accuracies},
                model_base_path, logger)
    last_info = save_checkpoint({
          'epoch': epoch + 1,
          'args' : deepcopy(args),
          'last_checkpoint': save_path,
          }, logger.path('info'), logger)
    if find_best:
      logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
      copy_checkpoint(model_base_path, model_best_path, logger)
    if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.log('\n' + '-'*200)
  logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
  start_time = time.time()
  best_arch, best_acc = None, -1
  for iarch in range(xargs.select_num):
    arch = search_model.random_genotype( True )
    valid_a_loss, valid_a_top1, valid_a_top5  = valid_func(valid_loader, network, criterion)
    logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss))
    if best_arch is None or best_acc < valid_a_top1:
      best_arch, best_acc = arch, valid_a_top1
  search_time.update(time.time() - start_time)
  logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
  if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
  logger.close()
def visualize_info(api, vis_save_dir, indicator):
    vis_save_dir = vis_save_dir.resolve()
    # print ('{:} start to visualize {:} information'.format(time_string(), api))
    vis_save_dir.mkdir(parents=True, exist_ok=True)

    cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format(
        'cifar10', indicator)
    cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format(
        'cifar100', indicator)
    imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format(
        'ImageNet16-120', indicator)
    cifar010_info = torch.load(cifar010_cache_path)
    cifar100_info = torch.load(cifar100_cache_path)
    imagenet_info = torch.load(imagenet_cache_path)
    indexes = list(range(len(cifar010_info['params'])))

    print('{:} start to visualize relative ranking'.format(time_string()))

    cifar010_ord_indexes = sorted(indexes,
                                  key=lambda i: cifar010_info['test_accs'][i])
    cifar100_ord_indexes = sorted(indexes,
                                  key=lambda i: cifar100_info['test_accs'][i])
    imagenet_ord_indexes = sorted(indexes,
                                  key=lambda i: imagenet_info['test_accs'][i])

    cifar100_labels, imagenet_labels = [], []
    for idx in cifar010_ord_indexes:
        cifar100_labels.append(cifar100_ord_indexes.index(idx))
        imagenet_labels.append(imagenet_ord_indexes.index(idx))
    print('{:} prepare data done.'.format(time_string()))

    dpi, width, height = 200, 1400, 800
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 18, 12
    resnet_scale, resnet_alpha = 120, 0.5

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xlim(min(indexes), max(indexes))
    plt.ylim(min(indexes), max(indexes))
    # plt.ylabel('y').set_rotation(30)
    plt.yticks(np.arange(min(indexes), max(indexes),
                         max(indexes) // 3),
               fontsize=LegendFontsize,
               rotation='vertical')
    plt.xticks(np.arange(min(indexes), max(indexes),
                         max(indexes) // 5),
               fontsize=LegendFontsize)
    ax.scatter(indexes,
               cifar100_labels,
               marker='^',
               s=0.5,
               c='tab:green',
               alpha=0.8)
    ax.scatter(indexes,
               imagenet_labels,
               marker='*',
               s=0.5,
               c='tab:red',
               alpha=0.8)
    ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue', alpha=0.8)
    ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue', label='CIFAR-10')
    ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
    ax.scatter([-1], [-1],
               marker='*',
               s=100,
               c='tab:red',
               label='ImageNet-16-120')
    plt.grid(zorder=0)
    ax.set_axisbelow(True)
    plt.legend(loc=0, fontsize=LegendFontsize)
    ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
    ax.set_ylabel('architecture ranking', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-relative-rank.pdf'.format(indicator)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-relative-rank.png'.format(indicator)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))
def visualize_rank_info(api, vis_save_dir, indicator):
    vis_save_dir = vis_save_dir.resolve()
    # print ('{:} start to visualize {:} information'.format(time_string(), api))
    vis_save_dir.mkdir(parents=True, exist_ok=True)

    cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format(
        'cifar10', indicator)
    cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format(
        'cifar100', indicator)
    imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format(
        'ImageNet16-120', indicator)
    cifar010_info = torch.load(cifar010_cache_path)
    cifar100_info = torch.load(cifar100_cache_path)
    imagenet_info = torch.load(imagenet_cache_path)
    indexes = list(range(len(cifar010_info['params'])))

    print('{:} start to visualize relative ranking'.format(time_string()))

    dpi, width, height = 250, 3800, 1200
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 14, 14

    fig, axs = plt.subplots(1, 3, figsize=figsize)
    ax1, ax2, ax3 = axs

    def get_labels(info):
        ord_test_indexes = sorted(indexes, key=lambda i: info['test_accs'][i])
        ord_valid_indexes = sorted(indexes,
                                   key=lambda i: info['valid_accs'][i])
        labels = []
        for idx in ord_test_indexes:
            labels.append(ord_valid_indexes.index(idx))
        return labels

    def plot_ax(labels, ax, name):
        for tick in ax.xaxis.get_major_ticks():
            tick.label.set_fontsize(LabelSize)
        for tick in ax.yaxis.get_major_ticks():
            tick.label.set_fontsize(LabelSize)
            tick.label.set_rotation(90)
        ax.set_xlim(min(indexes), max(indexes))
        ax.set_ylim(min(indexes), max(indexes))
        ax.yaxis.set_ticks(
            np.arange(min(indexes), max(indexes),
                      max(indexes) // 3))
        ax.xaxis.set_ticks(
            np.arange(min(indexes), max(indexes),
                      max(indexes) // 5))
        ax.scatter(indexes,
                   labels,
                   marker='^',
                   s=0.5,
                   c='tab:green',
                   alpha=0.8)
        ax.scatter(indexes,
                   indexes,
                   marker='o',
                   s=0.5,
                   c='tab:blue',
                   alpha=0.8)
        ax.scatter([-1], [-1],
                   marker='^',
                   s=100,
                   c='tab:green',
                   label='{:} test'.format(name))
        ax.scatter([-1], [-1],
                   marker='o',
                   s=100,
                   c='tab:blue',
                   label='{:} validation'.format(name))
        ax.legend(loc=4, fontsize=LegendFontsize)
        ax.set_xlabel('ranking on the {:} validation'.format(name),
                      fontsize=LabelSize)
        ax.set_ylabel('architecture ranking', fontsize=LabelSize)

    labels = get_labels(cifar010_info)
    plot_ax(labels, ax1, 'CIFAR-10')
    labels = get_labels(cifar100_info)
    plot_ax(labels, ax2, 'CIFAR-100')
    labels = get_labels(imagenet_info)
    plot_ax(labels, ax3, 'ImageNet-16-120')

    save_path = (vis_save_dir /
                 '{:}-same-relative-rank.pdf'.format(indicator)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-same-relative-rank.png'.format(indicator)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))
    plt.close('all')
def visualize_tss_info(api, dataset, vis_save_dir):
    vis_save_dir = vis_save_dir.resolve()
    print('{:} start to visualize {:} information'.format(
        time_string(), dataset))
    vis_save_dir.mkdir(parents=True, exist_ok=True)
    cache_file_path = vis_save_dir / '{:}-cache-tss-info.pth'.format(dataset)
    if not cache_file_path.exists():
        print('Do not find cache file : {:}'.format(cache_file_path))
        params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
        for index in range(len(api)):
            cost_info = api.get_cost_info(index, dataset, hp='12')
            params.append(cost_info['params'])
            flops.append(cost_info['flops'])
            # accuracy
            info = api.get_more_info(index, dataset, hp='200', is_random=False)
            train_accs.append(info['train-accuracy'])
            test_accs.append(info['test-accuracy'])
            if dataset == 'cifar10':
                info = api.get_more_info(index,
                                         'cifar10-valid',
                                         hp='200',
                                         is_random=False)
                valid_accs.append(info['valid-accuracy'])
            else:
                valid_accs.append(info['valid-accuracy'])
            print('')
        info = {
            'params': params,
            'flops': flops,
            'train_accs': train_accs,
            'valid_accs': valid_accs,
            'test_accs': test_accs
        }
        torch.save(info, cache_file_path)
    else:
        print('Find cache file : {:}'.format(cache_file_path))
        info = torch.load(cache_file_path)
        params, flops, train_accs, valid_accs, test_accs = info[
            'params'], info['flops'], info['train_accs'], info[
                'valid_accs'], info['test_accs']
    print('{:} collect data done.'.format(time_string()))

    resnet = [
        '|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|'
    ]
    resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
    largest_indexes = [
        api.query_index_by_arch(
            '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|'
        )
    ]

    indexes = list(range(len(params)))
    dpi, width, height = 250, 8500, 1300
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 24, 24
    # resnet_scale, resnet_alpha = 120, 0.5
    xscale, xalpha = 120, 0.8

    fig, axs = plt.subplots(1, 4, figsize=figsize)
    # ax1, ax2, ax3, ax4, ax5 = axs
    for ax in axs:
        for tick in ax.xaxis.get_major_ticks():
            tick.label.set_fontsize(LabelSize)
        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))
        for tick in ax.yaxis.get_major_ticks():
            tick.label.set_fontsize(LabelSize)
    ax2, ax3, ax4, ax5 = axs
    # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5))
    # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
    # ax1.set_xlabel('architecture ID', fontsize=LabelSize)
    # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize)

    ax2.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
    ax2.scatter([params[x] for x in resnet_indexes],
                [train_accs[x] for x in resnet_indexes],
                marker='*',
                s=xscale,
                c='tab:orange',
                label='ResNet',
                alpha=xalpha)
    ax2.scatter([params[x] for x in largest_indexes],
                [train_accs[x] for x in largest_indexes],
                marker='x',
                s=xscale,
                c='tab:green',
                label='Largest Candidate',
                alpha=xalpha)
    ax2.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
    ax2.legend(loc=4, fontsize=LegendFontsize)

    ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
    ax3.scatter([params[x] for x in resnet_indexes],
                [test_accs[x] for x in resnet_indexes],
                marker='*',
                s=xscale,
                c='tab:orange',
                label='ResNet',
                alpha=xalpha)
    ax3.scatter([params[x] for x in largest_indexes],
                [test_accs[x] for x in largest_indexes],
                marker='x',
                s=xscale,
                c='tab:green',
                label='Largest Candidate',
                alpha=xalpha)
    ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize)
    ax3.legend(loc=4, fontsize=LegendFontsize)

    ax4.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue')
    ax4.scatter([flops[x] for x in resnet_indexes],
                [train_accs[x] for x in resnet_indexes],
                marker='*',
                s=xscale,
                c='tab:orange',
                label='ResNet',
                alpha=xalpha)
    ax4.scatter([flops[x] for x in largest_indexes],
                [train_accs[x] for x in largest_indexes],
                marker='x',
                s=xscale,
                c='tab:green',
                label='Largest Candidate',
                alpha=xalpha)
    ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
    ax4.set_ylabel('train accuracy (%)', fontsize=LabelSize)
    ax4.legend(loc=4, fontsize=LegendFontsize)

    ax5.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue')
    ax5.scatter([flops[x] for x in resnet_indexes],
                [test_accs[x] for x in resnet_indexes],
                marker='*',
                s=xscale,
                c='tab:orange',
                label='ResNet',
                alpha=xalpha)
    ax5.scatter([flops[x] for x in largest_indexes],
                [test_accs[x] for x in largest_indexes],
                marker='x',
                s=xscale,
                c='tab:green',
                label='Largest Candidate',
                alpha=xalpha)
    ax5.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
    ax5.set_ylabel('test accuracy (%)', fontsize=LabelSize)
    ax5.legend(loc=4, fontsize=LegendFontsize)

    save_path = vis_save_dir / 'tss-{:}.png'.format(dataset)
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))
    plt.close('all')
Пример #13
0
        matrix = api.str2matrix(arch_str)
        print("Compute the adjacency matrix of {:}".format(arch_str))
        print(matrix)
    info = api.simulate_train_eval(123, "cifar10")
    print("simulate_train_eval : {:}\n\n".format(info))


if __name__ == "__main__":

    # api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
    for fast_mode in [True, False]:
        for verbose in [True, False]:
            api_nats_tss = create(None, "tss", fast_mode=fast_mode, verbose=True)
            print(
                "{:} create with fast_mode={:} and verbose={:}".format(
                    time_string(), fast_mode, verbose
                )
            )
            test_api(api_nats_tss, False)
            del api_nats_tss
            gc.collect()

    for fast_mode in [True, False]:
        for verbose in [True, False]:
            print(
                "{:} create with fast_mode={:} and verbose={:}".format(
                    time_string(), fast_mode, verbose
                )
            )
            api_nats_sss = create(None, "size", fast_mode=fast_mode, verbose=True)
            print("{:} --->>> {:}".format(time_string(), api_nats_sss))
Пример #14
0
def search_func(xloader, network, criterion, scheduler, w_optimizer,
                a_optimizer, epoch_str, print_freq, logger, bilevel):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    network.train()
    end = time.time()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_inputs = base_inputs.cuda()
        base_targets = base_targets.cuda(non_blocking=True)
        arch_inputs = arch_inputs.cuda()
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        w_optimizer.zero_grad()
        if not bilevel:
            a_optimizer.zero_grad()
        _, logits, cost = network(base_inputs)
        base_loss = criterion(logits, base_targets) + (cost / 1e9)
        base_loss.backward()
        torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        if not bilevel:
            a_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        try:
            base_losses.update(base_loss.item() - (cost.item() / 1e9),
                               base_inputs.size(0))
        except:
            base_losses.update(base_loss.item() - (cost / 1e9),
                               base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        if bilevel:
            # update the architecture-weight
            a_optimizer.zero_grad()
            _, logits, cost = network(arch_inputs)
            arch_loss = criterion(logits, arch_targets) + (cost / 1e9)
            arch_loss.backward()
            a_optimizer.step()
            # record
            arch_prec1, arch_prec5 = obtain_accuracy(logits.data,
                                                     arch_targets.data,
                                                     topk=(1, 5))
            arch_losses.update(arch_loss.item(), arch_inputs.size(0))
            arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
            arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
        else:
            arch_losses.update(0, arch_inputs.size(0))
            arch_top1.update(0, arch_inputs.size(0))
            arch_top5.update(0, arch_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = '*SEARCH* ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=arch_losses, top1=arch_top1, top5=arch_top5)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
    return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
Пример #15
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    #config_path = 'configs/nas-benchmark/algos/GDAS.config'
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/',
        config.batch_size, xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(
            xargs.dataset, len(search_loader), config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    if xargs.model_config is None and not args.constrain:
        model_config = dict2config(
            {
                'name': 'ProxylessNAS',
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'inp_size': 0,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    elif xargs.model_config is None:
        model_config = dict2config(
            {
                'name': 'ProxylessNAS',
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'inp_size': 32,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    else:
        model_config = load_config(
            xargs.model_config, {
                'num_classes': class_num,
                'space': search_space,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    search_model = get_cell_based_tiny_net(model_config)
    #logger.log('search-model :\n{:}'.format(search_model))
    logger.log('model-config : {:}'.format(model_config))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    #logger.log('{:}'.format(search_model))
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space [{:} ops] : {:}'.format(len(search_space),
                                                     search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()
    #network, criterion = search_model.cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: search_model.genotype()
        }

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    sampled_weights = []
    for epoch in range(start_epoch, total_epoch + config.t_epochs):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(
                epoch_time.val * (total_epoch - epoch + config.t_epochs),
                True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        search_model.set_tau(xargs.tau_max -
                             (xargs.tau_max - xargs.tau_min) * epoch /
                             (total_epoch - 1))
        logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format(
            epoch_str, need_time, search_model.get_tau(),
            min(w_scheduler.get_lr())))
        if epoch < total_epoch:
            search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
                      = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, xargs.bilevel)
        else:
            try:
                search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5, arch_iter \
                          = train_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, sampled_weights[0], arch_iter, logger)
            except IndexError:
                weights = search_model.sample_weights(100)
                sampled_weights.append(weights)
                arch_iter = iter(weights)
                search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5, arch_iter \
                          = train_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, sampled_weights[0], arch_iter, logger)

        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))

        if (epoch + 1) % 50 == 0 and not config.t_epochs:
            weights = search_model.sample_weights(100)
            sampled_weights.append(weights)
        elif (epoch + 1) == total_epoch and config.t_epochs:
            weights = search_model.sample_weights(100)
            sampled_weights.append(weights)
            arch_iter = iter(weights)
        # validate with single arch
        single_weight = search_model.sample_weights(1)[0]
        single_valid_acc = AverageMeter()
        network.eval()
        for i in range(10):
            try:
                val_input, val_target = next(valid_iter)
            except Exception as e:
                valid_iter = iter(valid_loader)
                val_input, val_target = next(valid_iter)
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=single_weight)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                single_valid_acc.update(val_acc1.item(), n_val)
        logger.log('[{:}] valid : accuracy = {:.2f}'.format(
            epoch_str, single_valid_acc.avg))

        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            genotypes['best'] = search_model.genotype()
            find_best = True
        else:
            find_best = False

        if epoch < total_epoch:
            genotypes[epoch] = search_model.genotype()
            logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
                epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, valid_a_top1))
            copy_checkpoint(model_base_path, model_best_path, logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None and epoch < total_epoch:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    network.eval()
    # Evaluate the architectures sampled throughout the search
    for i in range(len(sampled_weights) - 1):
        logger.log('Sample eval : epoch {}'.format((i + 1) * 50 - 1))
        for w in sampled_weights[i]:
            sample_valid_acc = AverageMeter()
            for i in range(10):
                try:
                    val_input, val_target = next(valid_iter)
                except Exception as e:
                    valid_iter = iter(valid_loader)
                    val_input, val_target = next(valid_iter)
                n_val = val_input.size(0)
                with torch.no_grad():
                    val_target = val_target.cuda(non_blocking=True)
                    _, logits, _ = network(val_input, weights=w)
                    val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                         val_target.data,
                                                         topk=(1, 5))
                    sample_valid_acc.update(val_acc1.item(), n_val)
            w_gene = search_model.genotype(w)
            if api is not None:
                ind = api.query_index_by_arch(w_gene)
                info = api.query_meta_info_by_index(ind)
                metrics = info.get_metrics('cifar10', 'ori-test')
                acc = metrics['accuracy']
            else:
                acc = 0.0
            logger.log(
                'sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format(
                    sample_valid_acc.avg, acc))
    # Evaluate the final sampling separately to find the top 10 architectures
    logger.log('Final sample eval')
    final_archs = []
    for w in sampled_weights[-1]:
        sample_valid_acc = AverageMeter()
        for i in range(10):
            try:
                val_input, val_target = next(valid_iter)
            except Exception as e:
                valid_iter = iter(valid_loader)
                val_input, val_target = next(valid_iter)
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=w)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                sample_valid_acc.update(val_acc1.item(), n_val)
        w_gene = search_model.genotype(w)
        if api is not None:
            ind = api.query_index_by_arch(w_gene)
            info = api.query_meta_info_by_index(ind)
            metrics = info.get_metrics('cifar10', 'ori-test')
            acc = metrics['accuracy']
        else:
            acc = 0.0
        logger.log('sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format(
            sample_valid_acc.avg, acc))
        final_archs.append((w, sample_valid_acc.avg))
    top_10 = sorted(final_archs, key=lambda x: x[1], reverse=True)[:10]
    # Evaluate the top 10 architectures on the entire validation set
    logger.log('Evaluating top archs')
    for w, prev_acc in top_10:
        full_valid_acc = AverageMeter()
        for val_input, val_target in valid_loader:
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=w)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                full_valid_acc.update(val_acc1.item(), n_val)
        w_gene = search_model.genotype(w)
        logger.log('genotype {}'.format(w_gene))
        if api is not None:
            ind = api.query_index_by_arch(w_gene)
            info = api.query_meta_info_by_index(ind)
            metrics = info.get_metrics('cifar10', 'ori-test')
            acc = metrics['accuracy']
        else:
            acc = 0.0
        logger.log(
            'full valid : val_acc = {:.2f} test_acc = {:.2f} pval_acc = {:.2f}'
            .format(full_valid_acc.avg, acc, prev_acc))

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'ProxylessNAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.
        format(total_epoch, search_time.sum, genotypes[total_epoch - 1]))
    if api is not None:
        logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch - 1])))
    logger.close()
def simplify(save_dir, save_name, nets, total, sup_config):
    hps, seeds = ['12', '200'], set()
    for hp in hps:
        sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
        ckps = sorted(list(sub_save_dir.glob('arch-*-seed-*.pth')))
        seed2names = defaultdict(list)
        for ckp in ckps:
            parts = re.split('-|\.', ckp.name)
            seed2names[parts[3]].append(ckp.name)
        print('DIR : {:}'.format(sub_save_dir))
        nums = []
        for seed, xlist in seed2names.items():
            seeds.add(seed)
            nums.append(len(xlist))
            print('  [seed={:}] there are {:} checkpoints.'.format(
                seed, len(xlist)))
        assert len(nets) == total == max(
            nums), 'there are some missed files : {:} vs {:}'.format(
                max(nums), total)
    print('{:} start simplify the checkpoint.'.format(time_string()))

    datasets = ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120')

    # Create the directory to save the processed data
    # full_save_dir contains all benchmark files with trained weights.
    # simplify_save_dir contains all benchmark files without trained weights.
    full_save_dir = save_dir / (save_name + '-FULL')
    simple_save_dir = save_dir / (save_name + '-SIMPLIFY')
    full_save_dir.mkdir(parents=True, exist_ok=True)
    simple_save_dir.mkdir(parents=True, exist_ok=True)
    # all data in memory
    arch2infos, evaluated_indexes = dict(), set()
    end_time, arch_time = time.time(), AverageMeter()
    # save the meta information
    for index in tqdm(range(total)):
        arch_str = nets[index]
        hp2info = OrderedDict()

        simple_save_path = simple_save_dir / '{:06d}.pickle'.format(index)

        arch2infos[index] = pickle_load(simple_save_path)
        evaluated_indexes.add(index)

        # measure elapsed time
        arch_time.update(time.time() - end_time)
        end_time = time.time()
        need_time = '{:}'.format(
            convert_secs2time(arch_time.avg * (total - index - 1), True))
        # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
    print('{:} {:} done.'.format(time_string(), save_name))
    final_infos = {
        'meta_archs': nets,
        'total_archs': total,
        'arch2infos': arch2infos,
        'evaluated_indexes': evaluated_indexes
    }
    save_file_name = save_dir / '{:}.pickle'.format(save_name)
    pickle_save(final_infos, str(save_file_name))
    # move the benchmark file to a new path
    hd5sum = get_md5_file(str(save_file_name) + '.pbz2')
    hd5_file_name = save_dir / '{:}-{:}.pickle.pbz2'.format(
        NATS_TSS_BASE_NAME, hd5sum)
    shutil.move(str(save_file_name) + '.pbz2', hd5_file_name)
    print('Save {:} / {:} architecture results into {:} -> {:}.'.format(
        len(evaluated_indexes), total, save_file_name, hd5_file_name))
    # move the directory to a new path
    hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(
        NATS_TSS_BASE_NAME, hd5sum)
    hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(
        NATS_TSS_BASE_NAME, hd5sum)
    shutil.move(full_save_dir, hd5_full_save_dir)
    shutil.move(simple_save_dir, hd5_simple_save_dir)
                        help='Folder to save checkpoints and log.')
    parser.add_argument('--rand_seed',
                        type=int,
                        default=-1,
                        help='manual seed')
    args = parser.parse_args()

    api = create(None, args.search_space, verbose=False)

    args.save_dir = os.path.join(
        '{:}-{:}'.format(args.save_dir, args.search_space), args.dataset,
        'RANDOM')
    print('save-dir : {:}'.format(args.save_dir))

    if args.rand_seed < 0:
        save_dir, all_info = None, collections.OrderedDict()
        for i in range(args.loops_if_rand):
            print('{:} : {:03d}/{:03d}'.format(time_string(), i,
                                               args.loops_if_rand))
            args.rand_seed = random.randint(1, 100000)
            save_dir, all_archs, all_total_times = main(args, api)
            all_info[i] = {
                'all_archs': all_archs,
                'all_total_times': all_total_times
            }
        save_path = save_dir / 'results.pth'
        print('save into {:}'.format(save_path))
        torch.save(all_info, save_path)
    else:
        main(args, api)
Пример #18
0
                        type=int,
                        default=-1,
                        help="manual seed")
    args = parser.parse_args()

    api = create(None, args.search_space, fast_mode=False, verbose=False)

    args.save_dir = os.path.join(
        "{:}-{:}".format(args.save_dir, args.search_space),
        "{:}-T{:}".format(args.dataset, args.time_budget),
        "BOHB",
    )
    print("save-dir : {:}".format(args.save_dir))

    if args.rand_seed < 0:
        save_dir, all_info = None, collections.OrderedDict()
        for i in range(args.loops_if_rand):
            print("{:} : {:03d}/{:03d}".format(time_string(), i,
                                               args.loops_if_rand))
            args.rand_seed = random.randint(1, 100000)
            save_dir, all_archs, all_total_times = main(args, api)
            all_info[i] = {
                "all_archs": all_archs,
                "all_total_times": all_total_times
            }
        save_path = save_dir / "results.pth"
        print("save into {:}".format(save_path))
        torch.save(all_info, save_path)
    else:
        main(args, api)
Пример #19
0
def search_func(xloader, network, criterion, scheduler, w_optimizer,
                a_optimizer, epoch_str, print_freq, algo, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    end = time.time()
    network.train()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_inputs = base_inputs.cuda(non_blocking=True)
        arch_inputs = arch_inputs.cuda(non_blocking=True)
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # Update the weights
        if algo == 'setn':
            sampled_arch = network.dync_genotype(True)
            network.set_cal_mode('dynamic', sampled_arch)
        elif algo == 'gdas':
            network.set_cal_mode('gdas', None)
        elif algo.startswith('darts'):
            network.set_cal_mode('joint', None)
        elif algo == 'random':
            network.set_cal_mode('urs', None)
        elif algo == 'enas':
            with torch.no_grad():
                network.controller.eval()
                _, _, sampled_arch = network.controller()
            network.set_cal_mode('dynamic', sampled_arch)
        else:
            raise ValueError('Invalid algo name : {:}'.format(algo))

        network.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # update the architecture-weight
        if algo == 'setn':
            network.set_cal_mode('joint')
        elif algo == 'gdas':
            network.set_cal_mode('gdas', None)
        elif algo.startswith('darts'):
            network.set_cal_mode('joint', None)
        elif algo == 'random':
            network.set_cal_mode('urs', None)
        elif algo != 'enas':
            raise ValueError('Invalid algo name : {:}'.format(algo))
        network.zero_grad()
        if algo == 'darts-v2':
            arch_loss, logits = backward_step_unrolled(
                network, criterion, base_inputs, base_targets, w_optimizer,
                arch_inputs, arch_targets)
            a_optimizer.step()
        elif algo == 'random' or algo == 'enas':
            with torch.no_grad():
                _, logits = network(arch_inputs)
                arch_loss = criterion(logits, arch_targets)
        else:
            _, logits = network(arch_inputs)
            arch_loss = criterion(logits, arch_targets)
            arch_loss.backward()
            a_optimizer.step()
        # record
        arch_prec1, arch_prec5 = obtain_accuracy(logits.data,
                                                 arch_targets.data,
                                                 topk=(1, 5))
        arch_losses.update(arch_loss.item(), arch_inputs.size(0))
        arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
        arch_top5.update(arch_prec5.item(), arch_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = '*SEARCH* ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=arch_losses, top1=arch_top1, top5=arch_top5)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
    return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
Пример #20
0
def main(xargs, api):
    torch.set_num_threads(4)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    logger.log("{:} use api : {:}".format(time_string(), api))
    api.reset_time()
    search_space = get_search_spaces(xargs.search_space, "nats-bench")
    if xargs.search_space == "tss":
        cs = get_topology_config_space(search_space)
        config2structure = config2topology_func()
    else:
        cs = get_size_config_space(search_space)
        config2structure = config2size_func(search_space)

    hb_run_id = "0"

    NS = hpns.NameServer(run_id=hb_run_id, host="localhost", port=0)
    ns_host, ns_port = NS.start()
    num_workers = 1

    workers = []
    for i in range(num_workers):
        w = MyWorker(
            nameserver=ns_host,
            nameserver_port=ns_port,
            convert_func=config2structure,
            dataset=xargs.dataset,
            api=api,
            run_id=hb_run_id,
            id=i,
        )
        w.run(background=True)
        workers.append(w)

    start_time = time.time()
    bohb = BOHB(
        configspace=cs,
        run_id=hb_run_id,
        eta=3,
        min_budget=1,
        max_budget=12,
        nameserver=ns_host,
        nameserver_port=ns_port,
        num_samples=xargs.num_samples,
        random_fraction=xargs.random_fraction,
        bandwidth_factor=xargs.bandwidth_factor,
        ping_interval=10,
        min_bandwidth=xargs.min_bandwidth,
    )

    results = bohb.run(xargs.n_iters, min_n_workers=num_workers)

    bohb.shutdown(shutdown_workers=True)
    NS.shutdown()

    # print('There are {:} runs.'.format(len(results.get_all_runs())))
    # workers[0].total_times
    # workers[0].trajectory
    current_best_index = []
    for idx in range(len(workers[0].trajectory)):
        trajectory = workers[0].trajectory[:idx + 1]
        arch = max(trajectory, key=lambda x: x[0])[1]
        current_best_index.append(api.query_index_by_arch(arch))

    best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1]
    logger.log("Best found configuration: {:} within {:.3f} s".format(
        best_arch, workers[0].total_times[-1]))
    info = api.query_info_str_by_arch(
        best_arch, "200" if xargs.search_space == "tss" else "90")
    logger.log("{:}".format(info))
    logger.log("-" * 100)
    logger.close()

    return logger.log_dir, current_best_index, workers[0].total_times
Пример #21
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    if xargs.overwite_epochs is None:
        extra_info = {'class_num': class_num, 'xshape': xshape}
    else:
        extra_info = {
            'class_num': class_num,
            'xshape': xshape,
            'epochs': xargs.overwite_epochs
        }
    config = load_config(xargs.config_path, extra_info, logger)
    search_loader, train_loader, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/',
        (config.batch_size, config.test_batch_size), xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces(xargs.search_space, 'nats-bench')

    model_config = dict2config(
        dict(name='generic',
             C=xargs.channel,
             N=xargs.num_cells,
             max_nodes=xargs.max_nodes,
             num_classes=class_num,
             space=search_space,
             affine=bool(xargs.affine),
             track_running_stats=bool(xargs.track_running_stats)), None)
    logger.log('search space : {:}'.format(search_space))
    logger.log('model config : {:}'.format(model_config))
    search_model = get_cell_based_tiny_net(model_config)
    search_model.set_algo(xargs.algo)
    logger.log('{:}'.format(search_model))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.weights, config)
    a_optimizer = torch.optim.Adam(search_model.alphas,
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay,
                                   eps=xargs.arch_eps)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    params = count_parameters_in_MB(search_model)
    logger.log('The parameters of the search model = {:.2f} MB'.format(params))
    logger.log('search-space : {:}'.format(search_space))
    if bool(xargs.use_api):
        api = create(None, 'topology', fast_mode=True, verbose=False)
    else:
        api = None
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = search_model.cuda(), criterion.cuda(
    )  # use a single GPU

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        baseline = checkpoint['baseline']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: network.return_topK(1, True)[0]
        }
        baseline = None

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        network.set_drop_path(
            float(epoch + 1) / total_epoch, xargs.drop_path_rate)
        if xargs.algo == 'gdas':
            network.set_tau(xargs.tau_max -
                            (xargs.tau_max - xargs.tau_min) * epoch /
                            (total_epoch - 1))
            logger.log('[RESET tau as : {:} and drop_path as {:}]'.format(
                network.tau, network.drop_path))
        search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
                    = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
        if xargs.algo == 'enas':
            ctl_loss, ctl_acc, baseline, ctl_reward \
                                       = train_controller(valid_loader, network, criterion, a_optimizer, baseline, epoch_str, xargs.print_freq, logger)
            logger.log(
                '[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}'
                .format(epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward))

        genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                                xargs.eval_candidate_num,
                                                xargs.algo)
        if xargs.algo == 'setn' or xargs.algo == 'enas':
            network.set_cal_mode('dynamic', genotype)
        elif xargs.algo == 'gdas':
            network.set_cal_mode('gdas', None)
        elif xargs.algo.startswith('darts'):
            network.set_cal_mode('joint', None)
        elif xargs.algo == 'random':
            network.set_cal_mode('urs', None)
        else:
            raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo))
        logger.log('[{:}] - [get_best_arch] : {:} -> {:}'.format(
            epoch_str, genotype, temp_accuracy))
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion, xargs.algo, logger)
        logger.log(
            '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5,
                    genotype))
        valid_accuracies[epoch] = valid_a_top1

        genotypes[epoch] = genotype
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'baseline': baseline,
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch],
                                                      '200')))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    # the final post procedure : count the time
    start_time = time.time()
    genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                            xargs.eval_candidate_num,
                                            xargs.algo)
    if xargs.algo == 'setn' or xargs.algo == 'enas':
        network.set_cal_mode('dynamic', genotype)
    elif xargs.algo == 'gdas':
        network.set_cal_mode('gdas', None)
    elif xargs.algo.startswith('darts'):
        network.set_cal_mode('joint', None)
    elif xargs.algo == 'random':
        network.set_cal_mode('urs', None)
    else:
        raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo))
    search_time.update(time.time() - start_time)

    valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
        valid_loader, network, criterion, xargs.algo, logger)
    logger.log(
        'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'
        .format(genotype, valid_a_top1))

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
        xargs.algo, total_epoch, search_time.sum, genotype))
    if api is not None:
        logger.log('{:}'.format(api.query_by_arch(genotype, '200')))
    logger.close()
Пример #22
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
                                          (config.batch_size, config.test_batch_size), xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')

    model_config = dict2config(
        dict(name='generic',
             C=xargs.channel,
             N=xargs.num_cells,
             max_nodes=xargs.max_nodes,
             num_classes=class_num,
             space=search_space,
             affine=bool(xargs.affine),
             track_running_stats=bool(xargs.track_running_stats)), None)
    logger.log('search space : {:}'.format(search_space))
    logger.log('model config : {:}'.format(model_config))
    search_model = get_cell_based_tiny_net(model_config)
    search_model.set_algo(xargs.algo)

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.weights, config)
    a_optimizer = torch.optim.Adam(search_model.alphas,
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    params = count_parameters_in_MB(search_model)
    logger.log('The parameters of the search model = {:.2f} MB'.format(params))
    logger.log('search-space : {:}'.format(search_space))
    api = API()
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    # network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
    network, criterion = search_model.cuda(), criterion.cuda(
    )  # use a single GPU

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        import pdb
        pdb.set_trace()

        search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
                    = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, search_a_loss, search_a_top1, search_a_top5))

        genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                                xargs.select_num)
        network.module.set_cal_mode('dynamic', genotype)
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)
        logger.log(
            '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5,
                    genotype))
        #search_model.set_cal_mode('urs')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        #search_model.set_cal_mode('joint')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        #search_model.set_cal_mode('select')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1

        genotypes[epoch] = genotype
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch],
                                                      '200')))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    # the final post procedure : count the time
    start_time = time.time()
    genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                            xargs.select_num)
    search_time.update(time.time() - start_time)
    network.module.set_cal_mode('dynamic', genotype)
    valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
        valid_loader, network, criterion)
    logger.log(
        'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'
        .format(genotype, valid_a_top1))

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, genotype))
    if api is not None:
        logger.log('{:}'.format(api.query_by_arch(genotype, '200')))
    logger.close()
Пример #23
0
def search_train(
    search_loader,
    network,
    criterion,
    scheduler,
    base_optimizer,
    arch_optimizer,
    optim_config,
    extra_info,
    print_freq,
    logger,
):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(
    ), AverageMeter(), AverageMeter()
    arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
    epoch_str, flop_need, flop_weight, flop_tolerant = (
        extra_info["epoch-str"],
        extra_info["FLOP-exp"],
        extra_info["FLOP-weight"],
        extra_info["FLOP-tolerant"],
    )

    network.train()
    logger.log(
        "[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
            epoch_str, flop_need, flop_weight))
    end = time.time()
    network.apply(change_key("search_mode", "search"))
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(search_loader):
        scheduler.update(None, 1.0 * step / len(search_loader))
        # calculate prediction and loss
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        base_optimizer.zero_grad()
        logits, expected_flop = network(base_inputs)
        # network.apply( change_key('search_mode', 'basic') )
        # features, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        base_optimizer.step()
        # record
        prec1, prec5 = obtain_accuracy(logits.data,
                                       base_targets.data,
                                       topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        top1.update(prec1.item(), base_inputs.size(0))
        top5.update(prec5.item(), base_inputs.size(0))

        # update the architecture
        arch_optimizer.zero_grad()
        logits, expected_flop = network(arch_inputs)
        flop_cur = network.module.get_flop("genotype", None, None)
        flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur,
                                                   flop_need, flop_tolerant)
        acls_loss = criterion(logits, arch_targets)
        arch_loss = acls_loss + flop_loss * flop_weight
        arch_loss.backward()
        arch_optimizer.step()

        # record
        arch_losses.update(arch_loss.item(), arch_inputs.size(0))
        arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
        arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % print_freq == 0 or (step + 1) == len(search_loader):
            Sstr = "**TRAIN** " + time_string(
            ) + " [{:}][{:03d}/{:03d}]".format(epoch_str, step,
                                               len(search_loader))
            Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
                batch_time=batch_time, data_time=data_time)
            Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
                loss=base_losses, top1=top1, top5=top5)
            Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format(
                aloss=arch_cls_losses,
                floss=arch_flop_losses,
                loss=arch_losses)
            logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr)
            # Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
            # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
            # print(network.module.get_arch_info())
            # print(network.module.width_attentions[0])
            # print(network.module.width_attentions[1])

    logger.log(
        " **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}"
        .format(
            top1=top1,
            top5=top5,
            error1=100 - top1.avg,
            error5=100 - top5.avg,
            baseloss=base_losses.avg,
            archloss=arch_losses.avg,
        ))
    return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
Пример #24
0
def search_func(xloader, network, criterion, scheduler, w_optimizer,
                a_optimizer, epoch_str, print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    end = time.time()
    network.train()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        sampled_arch = network.module.dync_genotype(True)
        network.module.set_cal_mode('dynamic', sampled_arch)
        #network.module.set_cal_mode( 'urs' )
        network.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # update the architecture-weight
        network.module.set_cal_mode('joint')
        network.zero_grad()
        _, logits = network(arch_inputs)
        arch_loss = criterion(logits, arch_targets)
        arch_loss.backward()
        a_optimizer.step()
        # record
        arch_prec1, arch_prec5 = obtain_accuracy(logits.data,
                                                 arch_targets.data,
                                                 topk=(1, 5))
        arch_losses.update(arch_loss.item(), arch_inputs.size(0))
        arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
        arch_top5.update(arch_prec5.item(), arch_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = '*SEARCH* ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=arch_losses, top1=arch_top1, top5=arch_top5)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
            #print (nn.functional.softmax(network.module.arch_parameters, dim=-1))
            #print (network.module.arch_parameters)
    return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
Пример #25
0
def visualize_relative_ranking(vis_save_dir):
    print('\n' + '-' * 100)
    cifar010_cache_path = vis_save_dir / '{:}-cache-info.pth'.format('cifar10')
    cifar100_cache_path = vis_save_dir / '{:}-cache-info.pth'.format(
        'cifar100')
    imagenet_cache_path = vis_save_dir / '{:}-cache-info.pth'.format(
        'ImageNet16-120')
    cifar010_info = torch.load(cifar010_cache_path)
    cifar100_info = torch.load(cifar100_cache_path)
    imagenet_info = torch.load(imagenet_cache_path)
    indexes = list(range(len(cifar010_info['params'])))

    print('{:} start to visualize relative ranking'.format(time_string()))
    # maximum accuracy with ResNet-level params 11472
    x_010_accs = [
        cifar010_info['test_accs'][i]
        if cifar010_info['params'][i] <= cifar010_info['params'][11472] else -1
        for i in indexes
    ]
    x_100_accs = [
        cifar100_info['test_accs'][i]
        if cifar100_info['params'][i] <= cifar100_info['params'][11472] else -1
        for i in indexes
    ]
    x_img_accs = [
        imagenet_info['test_accs'][i]
        if imagenet_info['params'][i] <= imagenet_info['params'][11472] else -1
        for i in indexes
    ]

    cifar010_ord_indexes = sorted(indexes,
                                  key=lambda i: cifar010_info['test_accs'][i])
    cifar100_ord_indexes = sorted(indexes,
                                  key=lambda i: cifar100_info['test_accs'][i])
    imagenet_ord_indexes = sorted(indexes,
                                  key=lambda i: imagenet_info['test_accs'][i])

    cifar100_labels, imagenet_labels = [], []
    for idx in cifar010_ord_indexes:
        cifar100_labels.append(cifar100_ord_indexes.index(idx))
        imagenet_labels.append(imagenet_ord_indexes.index(idx))
    print('{:} prepare data done.'.format(time_string()))

    dpi, width, height = 300, 2600, 2600
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 18, 18
    resnet_scale, resnet_alpha = 120, 0.5

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xlim(min(indexes), max(indexes))
    plt.ylim(min(indexes), max(indexes))
    #plt.ylabel('y').set_rotation(0)
    plt.yticks(np.arange(min(indexes), max(indexes),
                         max(indexes) // 6),
               fontsize=LegendFontsize,
               rotation='vertical')
    plt.xticks(np.arange(min(indexes), max(indexes),
                         max(indexes) // 6),
               fontsize=LegendFontsize)
    #ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8, label='CIFAR-100')
    #ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red'  , alpha=0.8, label='ImageNet-16-120')
    #ax.scatter(indexes, indexes        , marker='o', s=0.5, c='tab:blue' , alpha=0.8, label='CIFAR-10')
    ax.scatter(indexes,
               cifar100_labels,
               marker='^',
               s=0.5,
               c='tab:green',
               alpha=0.8)
    ax.scatter(indexes,
               imagenet_labels,
               marker='*',
               s=0.5,
               c='tab:red',
               alpha=0.8)
    ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue', alpha=0.8)
    ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue', label='CIFAR-10')
    ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
    ax.scatter([-1], [-1],
               marker='*',
               s=100,
               c='tab:red',
               label='ImageNet-16-120')
    plt.grid(zorder=0)
    ax.set_axisbelow(True)
    plt.legend(loc=0, fontsize=LegendFontsize)
    ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
    ax.set_ylabel('architecture ranking', fontsize=LabelSize)
    save_path = (vis_save_dir / 'relative-rank.pdf').resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir / 'relative-rank.png').resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    # calculate correlation
    sns_size = 15
    CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'],
                                        cifar010_info['test_accs'],
                                        cifar100_info['valid_accs'],
                                        cifar100_info['test_accs'],
                                        imagenet_info['valid_accs'],
                                        imagenet_info['test_accs'])
    fig = plt.figure(figsize=figsize)
    plt.axis('off')
    h = sns.heatmap(CoRelMatrix,
                    annot=True,
                    annot_kws={'size': sns_size},
                    fmt='.3f',
                    linewidths=0.5)
    save_path = (vis_save_dir / 'co-relation-all.pdf').resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    print('{:} save into {:}'.format(time_string(), save_path))

    # calculate correlation
    acc_bars = [92, 93]
    for acc_bar in acc_bars:
        selected_indexes = []
        for i, acc in enumerate(cifar010_info['test_accs']):
            if acc > acc_bar: selected_indexes.append(i)
        print('select {:} architectures'.format(len(selected_indexes)))
        cifar010_valid_accs = np.array(
            cifar010_info['valid_accs'])[selected_indexes]
        cifar010_test_accs = np.array(
            cifar010_info['test_accs'])[selected_indexes]
        cifar100_valid_accs = np.array(
            cifar100_info['valid_accs'])[selected_indexes]
        cifar100_test_accs = np.array(
            cifar100_info['test_accs'])[selected_indexes]
        imagenet_valid_accs = np.array(
            imagenet_info['valid_accs'])[selected_indexes]
        imagenet_test_accs = np.array(
            imagenet_info['test_accs'])[selected_indexes]
        CoRelMatrix = calculate_correlation(
            cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs,
            cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs)
        fig = plt.figure(figsize=figsize)
        plt.axis('off')
        h = sns.heatmap(CoRelMatrix,
                        annot=True,
                        annot_kws={'size': sns_size},
                        fmt='.3f',
                        linewidths=0.5)
        save_path = (
            vis_save_dir /
            'co-relation-top-{:}.pdf'.format(len(selected_indexes))).resolve()
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
        print('{:} save into {:}'.format(time_string(), save_path))
    plt.close('all')
Пример #26
0
    parser.add_argument(
        '--arch_nas_dataset',
        type=str,
        help='The path to load the architecture dataset (tiny-nas-benchmark).')
    parser.add_argument('--print_freq',
                        type=int,
                        help='print frequency (default: 200)')
    parser.add_argument('--rand_seed', type=int, help='manual seed')
    args = parser.parse_args()
    #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
    if args.arch_nas_dataset is None or not os.path.isfile(
            args.arch_nas_dataset):
        nas_bench = None
    else:
        print('{:} build NAS-Benchmark-API from {:}'.format(
            time_string(), args.arch_nas_dataset))
        nas_bench = API(args.arch_nas_dataset)
    if args.rand_seed < 0:
        save_dir, all_indexes, num, all_times = None, [], 500, []
        for i in range(num):
            print('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
            args.rand_seed = random.randint(1, 100000)
            save_dir, index, ctime = main(args, nas_bench)
            all_indexes.append(index)
            all_times.append(ctime)
        print('\n average time : {:.3f} s'.format(
            sum(all_times) / len(all_times)))
        torch.save(all_indexes, save_dir / 'results.pth')
    else:
        main(args, nas_bench)
Пример #27
0
def main(args):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    prepare_seed(args.rand_seed)

    logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file())
    logger = Logger(args.save_path, logstr)
    logger.log('Main Function with logger : {:}'.format(logger))
    logger.log('Arguments : -------------------------------')
    for name, value in args._get_kwargs():
        logger.log('{:16} : {:}'.format(name, value))
    logger.log("Python  version : {}".format(sys.version.replace('\n', ' ')))
    logger.log("Pillow  version : {}".format(PIL.__version__))
    logger.log("PyTorch version : {}".format(torch.__version__))
    logger.log("cuDNN   version : {}".format(torch.backends.cudnn.version()))

    # General Data Argumentation
    mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(
        args.arg_flip, args.rotate_max)
    train_transform = [transforms.PreCrop(args.pre_crop_expand)]
    train_transform += [
        transforms.TrainScale2WH((args.crop_width, args.crop_height))
    ]
    train_transform += [
        transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)
    ]
    #if args.arg_flip:
    #  train_transform += [transforms.AugHorizontalFlip()]
    if args.rotate_max:
        train_transform += [transforms.AugRotate(args.rotate_max)]
    train_transform += [
        transforms.AugCrop(args.crop_width, args.crop_height,
                           args.crop_perturb_max, mean_fill)
    ]
    train_transform += [transforms.ToTensor(), normalize]
    train_transform = transforms.Compose(train_transform)

    eval_transform = transforms.Compose([
        transforms.PreCrop(args.pre_crop_expand),
        transforms.TrainScale2WH((args.crop_width, args.crop_height)),
        transforms.ToTensor(), normalize
    ])
    assert (
        args.scale_min + args.scale_max
    ) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(
        args.scale_min, args.scale_max, args.scale_eval)

    # Model Configure Load
    model_config = load_configure(args.model_config, logger)
    args.sigma = args.sigma * args.scale_eval
    logger.log('Real Sigma : {:}'.format(args.sigma))

    # Training Dataset
    train_data = VDataset(train_transform, args.sigma, model_config.downsample,
                          args.heatmap_type, args.data_indicator,
                          args.video_parser)
    train_data.load_list(args.train_lists, args.num_pts, True)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    # Evaluation Dataloader
    eval_loaders = []
    if args.eval_vlists is not None:
        for eval_vlist in args.eval_vlists:
            eval_vdata = IDataset(eval_transform, args.sigma,
                                  model_config.downsample, args.heatmap_type,
                                  args.data_indicator)
            eval_vdata.load_list(eval_vlist, args.num_pts, True)
            eval_vloader = torch.utils.data.DataLoader(
                eval_vdata,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            eval_loaders.append((eval_vloader, True))

    if args.eval_ilists is not None:
        for eval_ilist in args.eval_ilists:
            eval_idata = IDataset(eval_transform, args.sigma,
                                  model_config.downsample, args.heatmap_type,
                                  args.data_indicator)
            eval_idata.load_list(eval_ilist, args.num_pts, True)
            eval_iloader = torch.utils.data.DataLoader(
                eval_idata,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            eval_loaders.append((eval_iloader, False))

    # Define network
    lk_config = load_configure(args.lk_config, logger)
    logger.log('model configure : {:}'.format(model_config))
    logger.log('LK configure : {:}'.format(lk_config))
    net = obtain_model(model_config, lk_config, args.num_pts + 1)
    assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(
        model_config.downsample, net.downsample)
    logger.log("=> network :\n {}".format(net))

    logger.log('Training-data : {:}'.format(train_data))
    for i, eval_loader in enumerate(eval_loaders):
        eval_loader, is_video = eval_loader
        logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(
            i, len(eval_loaders), 'video' if is_video else 'image',
            eval_loader.dataset))

    logger.log('arguments : {:}'.format(args))

    opt_config = load_configure(args.opt_config, logger)

    if hasattr(net, 'specify_parameter'):
        net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay)
    else:
        net_param_dict = net.parameters()

    optimizer, scheduler, criterion = obtain_optimizer(net_param_dict,
                                                       opt_config, logger)
    logger.log('criterion : {:}'.format(criterion))
    net, criterion = net.cuda(), criterion.cuda()
    net = torch.nn.DataParallel(net)

    last_info = logger.last_info()
    if last_info.exists():
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch'] + 1
        checkpoint = torch.load(last_info['last_checkpoint'])
        assert last_info['epoch'] == checkpoint[
            'epoch'], 'Last-Info is not right {:} vs {:}'.format(
                last_info, checkpoint['epoch'])
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done".format(
            logger.last_info(), checkpoint['epoch']))
    elif args.init_model is not None:
        init_model = Path(args.init_model)
        assert init_model.exists(), 'init-model {:} does not exist'.format(
            init_model)
        checkpoint = torch.load(init_model)
        checkpoint = remove_module_dict(checkpoint['state_dict'], True)
        net.module.detector.load_state_dict(checkpoint)
        logger.log("=> initialize the detector : {:}".format(init_model))
        start_epoch = 0
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch = 0

    detector = torch.nn.DataParallel(net.module.detector)

    eval_results = eval_all(args, eval_loaders, detector, criterion,
                            'start-eval', logger, opt_config)
    if args.eval_once:
        logger.log("=> only evaluate the model once")
        logger.close()
        return

    # Main Training and Evaluation Loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(start_epoch, opt_config.epochs):

        scheduler.step()
        need_time = convert_secs2time(
            epoch_time.avg * (opt_config.epochs - epoch), True)
        epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs)
        LRs = scheduler.get_lr()
        logger.log(
            '\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.
            format(time_string(), epoch_str, need_time, min(LRs), max(LRs),
                   opt_config))

        # train for one epoch
        train_loss = train(args, train_loader, net, criterion, optimizer,
                           epoch_str, logger, opt_config, lk_config,
                           epoch >= lk_config.start)
        # log the results
        logger.log('==>>{:s} Train [{:}] Average Loss = {:.6f}'.format(
            time_string(), epoch_str, train_loss))

        # remember best prec@1 and save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch,
                'args': deepcopy(args),
                'arch': model_config.arch,
                'state_dict': net.state_dict(),
                'detector': detector.state_dict(),
                'scheduler': scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            logger.path('model') /
            '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger)

        last_info = save_checkpoint(
            {
                'epoch': epoch,
                'last_checkpoint': save_path,
            }, logger.last_info(), logger)

        eval_results = eval_all(args, eval_loaders, detector, criterion,
                                epoch_str, logger, opt_config)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.close()
Пример #28
0
def main(args):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = True
  #torch.backends.cudnn.deterministic = True
  torch.set_num_threads( args.workers )
  
  prepare_seed(args.rand_seed)
  logger = prepare_logger(args)
  
  train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length)
  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True)
  valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
  # get configures
  model_config = load_config(args.model_config, {'class_num': class_num}, logger)
  optim_config = load_config(args.optim_config, {'class_num': class_num}, logger)

  if args.model_source == 'normal':
    base_model   = obtain_model(model_config)
  elif args.model_source == 'nas':
    base_model   = obtain_nas_infer_model(model_config)
  else:
    raise ValueError('invalid model-source : {:}'.format(args.model_source))
  flop, param  = get_model_infos(base_model, xshape)
  logger.log('model ====>>>>:\n{:}'.format(base_model))
  logger.log('model information : {:}'.format(base_model.get_message()))
  logger.log('-'*50)
  logger.log('Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G'.format(param, flop, flop/1e3))
  logger.log('-'*50)
  logger.log('train_data : {:}'.format(train_data))
  logger.log('valid_data : {:}'.format(valid_data))
  optimizer, scheduler, criterion = get_optim_scheduler(base_model.parameters(), optim_config)
  logger.log('optimizer  : {:}'.format(optimizer))
  logger.log('scheduler  : {:}'.format(scheduler))
  logger.log('criterion  : {:}'.format(criterion))
  
  last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
  network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda()

  if last_info.exists(): # automatically resume from previous checkpoint
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_infox  = torch.load(last_info)
    start_epoch = last_infox['epoch'] + 1
    last_checkpoint_path = last_infox['last_checkpoint']
    if not last_checkpoint_path.exists():
      logger.log('Does not find {:}, try another path'.format(last_checkpoint_path))
      last_checkpoint_path = last_info.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name
    checkpoint  = torch.load( last_checkpoint_path )
    base_model.load_state_dict( checkpoint['base-model'] )
    scheduler.load_state_dict ( checkpoint['scheduler'] )
    optimizer.load_state_dict ( checkpoint['optimizer'] )
    valid_accuracies = checkpoint['valid_accuracies']
    max_bytes        = checkpoint['max_bytes']
    logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
  elif args.resume is not None:
    assert Path(args.resume).exists(), 'Can not find the resume file : {:}'.format(args.resume)
    checkpoint  = torch.load( args.resume )
    start_epoch = checkpoint['epoch'] + 1
    base_model.load_state_dict( checkpoint['base-model'] )
    scheduler.load_state_dict ( checkpoint['scheduler'] )
    optimizer.load_state_dict ( checkpoint['optimizer'] )
    valid_accuracies = checkpoint['valid_accuracies']
    max_bytes        = checkpoint['max_bytes']
    logger.log("=> loading checkpoint from '{:}' start with {:}-th epoch.".format(args.resume, start_epoch))
  elif args.init_model is not None:
    assert Path(args.init_model).exists(), 'Can not find the initialization file : {:}'.format(args.init_model)
    checkpoint  = torch.load( args.init_model )
    base_model.load_state_dict( checkpoint['base-model'] )
    start_epoch, valid_accuracies, max_bytes = 0, {'best': -1}, {}
    logger.log('=> initialize the model from {:}'.format( args.init_model ))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch, valid_accuracies, max_bytes = 0, {'best': -1}, {}

  train_func, valid_func = get_procedures(args.procedure)
  
  total_epoch = optim_config.epochs + optim_config.warmup
  # Main Training and Evaluation Loop
  start_time  = time.time()
  epoch_time  = AverageMeter()
  for epoch in range(start_epoch, total_epoch):
    scheduler.update(epoch, 0.0)
    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch), True) )
    epoch_str = 'epoch={:03d}/{:03d}'.format(epoch, total_epoch)
    LRs       = scheduler.get_lr()
    find_best = False
    # set-up drop-out ratio
    if hasattr(base_model, 'update_drop_path'): base_model.update_drop_path(model_config.drop_path_prob * epoch / total_epoch)
    logger.log('\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler))
    
    # train for one epoch
    train_loss, train_acc1, train_acc5 = train_func(train_loader, network, criterion, scheduler, optimizer, optim_config, epoch_str, args.print_freq, logger)
    # log the results    
    logger.log('***{:s}*** TRAIN [{:}] loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}'.format(time_string(), epoch_str, train_loss, train_acc1, train_acc5))

    # evaluate the performance
    if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch):
      logger.log('-'*150)
      valid_loss, valid_acc1, valid_acc5 = valid_func(valid_loader, network, criterion, optim_config, epoch_str, args.print_freq_eval, logger)
      valid_accuracies[epoch] = valid_acc1
      logger.log('***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}'.format(time_string(), epoch_str, valid_loss, valid_acc1, valid_acc5, valid_accuracies['best'], 100-valid_accuracies['best']))
      if valid_acc1 > valid_accuracies['best']:
        valid_accuracies['best'] = valid_acc1
        find_best                = True
        logger.log('Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.'.format(epoch, valid_acc1, valid_acc5, 100-valid_acc1, 100-valid_acc5, model_best_path))
      num_bytes = torch.cuda.max_memory_cached( next(network.parameters()).device ) * 1.0
      logger.log('[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]'.format(next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9))
      max_bytes[epoch] = num_bytes
    if epoch % 10 == 0: torch.cuda.empty_cache()

    # save checkpoint
    save_path = save_checkpoint({
          'epoch'        : epoch,
          'args'         : deepcopy(args),
          'max_bytes'    : deepcopy(max_bytes),
          'FLOP'         : flop,
          'PARAM'        : param,
          'valid_accuracies': deepcopy(valid_accuracies),
          'model-config' : model_config._asdict(),
          'optim-config' : optim_config._asdict(),
          'base-model'   : base_model.state_dict(),
          'scheduler'    : scheduler.state_dict(),
          'optimizer'    : optimizer.state_dict(),
          }, model_base_path, logger)
    if find_best: copy_checkpoint(model_base_path, model_best_path, logger)
    last_info = save_checkpoint({
          'epoch': epoch,
          'args' : deepcopy(args),
          'last_checkpoint': save_path,
          }, logger.path('info'), logger)

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.log('\n' + '-'*200)
  logger.log('Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}'.format(convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e6, logger.path('info')))
  logger.log('-'*200 + '\n')
  logger.close()
Пример #29
0
  # channels and number-of-cells
  parser.add_argument('--ea_cycles',          type=int,   help='The number of cycles in EA.')
  parser.add_argument('--ea_population',      type=int,   help='The population size in EA.')
  parser.add_argument('--ea_sample_size',     type=int,   help='The sample size in EA.')
  parser.add_argument('--time_budget',        type=int,   default=20000, help='The total time cost budge for searching (in seconds).')
  parser.add_argument('--loops_if_rand',      type=int,   default=500,   help='The total runs for evaluation.')
  # log
  parser.add_argument('--save_dir',           type=str,   default='./output/search', help='Folder to save checkpoints and log.')
  parser.add_argument('--rand_seed',          type=int,   default=-1,    help='manual seed')
  args = parser.parse_args()

  api = create(None, args.search_space, fast_mode=True, verbose=False)

  args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'R-EA-SS{:}'.format(args.ea_sample_size))
  print('save-dir : {:}'.format(args.save_dir))
  print('xargs : {:}'.format(args))

  if args.rand_seed < 0:
    save_dir, all_info = None, collections.OrderedDict()
    for i in range(args.loops_if_rand):
      print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
      args.rand_seed = random.randint(1, 100000)
      save_dir, all_archs, all_total_times = main(args, api)
      all_info[i] = {'all_archs': all_archs,
                     'all_total_times': all_total_times}
    save_path = save_dir / 'results.pth'
    print('save into {:}'.format(save_path))
    torch.save(all_info, save_path)
  else:
    main(args, api)
Пример #30
0
def visualize_curve(api, vis_save_dir, search_space, suffix):
    vis_save_dir = vis_save_dir.resolve()
    vis_save_dir.mkdir(parents=True, exist_ok=True)

    dpi, width, height = 250, 5200, 1400
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 16, 16

    def sub_plot_fn(ax, dataset):
        print('{:} plot {:10s}'.format(time_string(), dataset))
        alg2data = fetch_data(search_space=search_space,
                              dataset=dataset,
                              suffix=name2suffix[(search_space, suffix)])
        alg2accuracies = OrderedDict()
        epochs = 100
        colors = ['b', 'g', 'c', 'm', 'y', 'r']
        ax.set_xlim(0, epochs)
        # ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
        for idx, (alg, data) in enumerate(alg2data.items()):
            xs, accuracies = [], []
            for iepoch in range(epochs + 1):
                try:
                    structures, accs = [_[iepoch - 1] for _ in data], []
                except:
                    raise ValueError(
                        'This alg {:} on {:} has invalid checkpoints.'.format(
                            alg, dataset))
                for structure in structures:
                    info = api.get_more_info(
                        structure,
                        dataset=dataset,
                        hp=90 if api.search_space_name == 'size' else 200,
                        is_random=False)
                    accs.append(info['test-accuracy'])
                accuracies.append(sum(accs) / len(accs))
                xs.append(iepoch)
            alg2accuracies[alg] = accuracies
            ax.plot(xs, accuracies, c=colors[idx], label='{:}'.format(alg))
            ax.set_xlabel('The searching epoch', fontsize=LabelSize)
            ax.set_ylabel('Test accuracy on {:}'.format(name2label[dataset]),
                          fontsize=LabelSize)
            ax.set_title('Searching results on {:}'.format(
                name2label[dataset]),
                         fontsize=LabelSize + 4)
            structures, valid_accs, test_accs = [_[epochs - 1]
                                                 for _ in data], [], []
            print('{:} plot alg : {:} -- final {:} architectures.'.format(
                time_string(), alg, len(structures)))
            for arch in structures:
                valid_acc, test_acc, _ = get_valid_test_acc(api, arch, dataset)
                test_accs.append(test_acc)
                valid_accs.append(valid_acc)
            print(
                '{:} plot alg : {:} -- validation: {:.2f}$\pm${:.2f} -- test: {:.2f}$\pm${:.2f}'
                .format(time_string(), alg, np.mean(valid_accs),
                        np.std(valid_accs), np.mean(test_accs),
                        np.std(test_accs)))
        ax.legend(loc=4, fontsize=LegendFontsize)

    fig, axs = plt.subplots(1, 3, figsize=figsize)
    datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
    for dataset, ax in zip(datasets, axs):
        sub_plot_fn(ax, dataset)
        print('sub-plot {:} on {:} done.'.format(dataset, search_space))
    save_path = (
        vis_save_dir /
        '{:}-ws-{:}-curve.png'.format(search_space, suffix)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))
    plt.close('all')
def main(args):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = True
  prepare_seed(args.rand_seed)

  logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file())
  logger = Logger(args.save_path, logstr)
  logger.log('Main Function with logger : {:}'.format(logger))
  logger.log('Arguments : -------------------------------')
  for name, value in args._get_kwargs():
    logger.log('{:16} : {:}'.format(name, value))
  logger.log("Python  version : {}".format(sys.version.replace('\n', ' ')))
  logger.log("Pillow  version : {}".format(PIL.__version__))
  logger.log("PyTorch version : {}".format(torch.__version__))
  logger.log("cuDNN   version : {}".format(torch.backends.cudnn.version()))

  # General Data Argumentation
  mean_fill   = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] )
  normalize   = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
  assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(args.arg_flip, args.rotate_max)
  train_transform  = [transforms.PreCrop(args.pre_crop_expand)]
  train_transform += [transforms.TrainScale2WH((args.crop_width, args.crop_height))]
  train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)]
  #if args.arg_flip:
  #  train_transform += [transforms.AugHorizontalFlip()]
  if args.rotate_max:
    train_transform += [transforms.AugRotate(args.rotate_max)]
  train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)]
  train_transform += [transforms.ToTensor(), normalize]
  train_transform  = transforms.Compose( train_transform )

  eval_transform  = transforms.Compose([transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)),  transforms.ToTensor(), normalize])
  assert (args.scale_min+args.scale_max) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(args.scale_min, args.scale_max, args.scale_eval)
  
  # Model Configure Load
  model_config = load_configure(args.model_config, logger)
  args.sigma   = args.sigma * args.scale_eval
  logger.log('Real Sigma : {:}'.format(args.sigma))

  # Training Dataset
  train_data   = Dataset(train_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
  train_data.load_list(args.train_lists, args.num_pts, True)
  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)


  # Evaluation Dataloader
  eval_loaders = []
  if args.eval_vlists is not None:
    for eval_vlist in args.eval_vlists:
      eval_vdata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
      eval_vdata.load_list(eval_vlist, args.num_pts, True)
      eval_vloader = torch.utils.data.DataLoader(eval_vdata, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
      eval_loaders.append((eval_vloader, True))

  if args.eval_ilists is not None:
    for eval_ilist in args.eval_ilists:
      eval_idata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
      eval_idata.load_list(eval_ilist, args.num_pts, True)
      eval_iloader = torch.utils.data.DataLoader(eval_idata, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
      eval_loaders.append((eval_iloader, False))

  # Define network
  logger.log('configure : {:}'.format(model_config))
  net = obtain_model(model_config, args.num_pts + 1)
  assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(model_config.downsample, net.downsample)
  logger.log("=> network :\n {}".format(net))

  logger.log('Training-data : {:}'.format(train_data))
  for i, eval_loader in enumerate(eval_loaders):
    eval_loader, is_video = eval_loader
    logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(i, len(eval_loaders), 'video' if is_video else 'image', eval_loader.dataset))
    
  logger.log('arguments : {:}'.format(args))

  opt_config = load_configure(args.opt_config, logger)

  if hasattr(net, 'specify_parameter'):
    net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay)
  else:
    net_param_dict = net.parameters()

  optimizer, scheduler, criterion = obtain_optimizer(net_param_dict, opt_config, logger)
  logger.log('criterion : {:}'.format(criterion))
  net, criterion = net.cuda(), criterion.cuda()
  net = torch.nn.DataParallel(net)

  last_info = logger.last_info()
  if last_info.exists():
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info = torch.load(last_info)
    start_epoch = last_info['epoch'] + 1
    checkpoint  = torch.load(last_info['last_checkpoint'])
    assert last_info['epoch'] == checkpoint['epoch'], 'Last-Info is not right {:} vs {:}'.format(last_info, checkpoint['epoch'])
    net.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done" .format(logger.last_info(), checkpoint['epoch']))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch = 0


  if args.eval_once:
    logger.log("=> only evaluate the model once")
    eval_results = eval_all(args, eval_loaders, net, criterion, 'eval-once', logger, opt_config)
    logger.close() ; return


  # Main Training and Evaluation Loop
  start_time = time.time()
  epoch_time = AverageMeter()
  for epoch in range(start_epoch, opt_config.epochs):

    scheduler.step()
    need_time = convert_secs2time(epoch_time.avg * (opt_config.epochs-epoch), True)
    epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs)
    LRs       = scheduler.get_lr()
    logger.log('\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), opt_config))

    # train for one epoch
    train_loss, train_nme = train(args, train_loader, net, criterion, optimizer, epoch_str, logger, opt_config)
    # log the results    
    logger.log('==>>{:s} Train [{:}] Average Loss = {:.6f}, NME = {:.2f}'.format(time_string(), epoch_str, train_loss, train_nme*100))

    # remember best prec@1 and save checkpoint
    save_path = save_checkpoint({
          'epoch': epoch,
          'args' : deepcopy(args),
          'arch' : model_config.arch,
          'state_dict': net.state_dict(),
          'scheduler' : scheduler.state_dict(),
          'optimizer' : optimizer.state_dict(),
          }, logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger)

    last_info = save_checkpoint({
          'epoch': epoch,
          'last_checkpoint': save_path,
          }, logger.last_info(), logger)

    eval_results = eval_all(args, eval_loaders, net, criterion, epoch_str, logger, opt_config)
    
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.close()
Пример #32
0
def train_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str,
               print_freq, archs, arch_iter, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    val_losses, val_top1, val_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    network.train()
    end = time.time()
    for step, (base_inputs, base_targets, val_inputs,
               val_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        try:
            arch = next(arch_iter)
        except:
            arch_iter = iter(archs)
            arch = next(arch_iter)
        base_inputs = base_inputs.cuda()
        base_targets = base_targets.cuda(non_blocking=True)
        val_inputs = val_inputs.cuda()
        val_targets = val_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        w_optimizer.zero_grad()
        _, logits, _ = network(base_inputs)  #, arch)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # validate arch
        _, logits, _ = network(val_inputs, arch)
        val_loss = criterion(logits, val_targets)
        # record
        val_prec1, val_prec5 = obtain_accuracy(logits.data,
                                               val_targets.data,
                                               topk=(1, 5))
        val_losses.update(val_loss.item(), val_inputs.size(0))
        val_top1.update(val_prec1.item(), val_inputs.size(0))
        val_top5.update(val_prec5.item(), val_inputs.size(0))
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = '*TRAIN* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(
                epoch_str, step, len(xloader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Val  [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=val_losses, top1=val_top1, top5=val_top5)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
    return base_losses.avg, base_top1.avg, base_top5.avg, val_losses.avg, val_top1.avg, val_top5.avg, arch_iter