def shallow_finetune_labelflip(feat_net='resnet50', double_seeds=True, outlier_model=False):

    if cfg.use_toy_dataset:
        trainset_name = cfg.dataset + '_train'
        validset_name = None
        valid_split = 0.3
    else:
        trainset_name = cfg.dataset + '_train' + ('_ds' if double_seeds else '')
        validset_name = cfg.dataset + '_valid'
        valid_split = 0

    print("Shallow train on features from net: " + feat_net)
    print("Trainset: " + trainset_name)

    in_shape = cfg.feat_shape_dict[feat_net]
    out_shape = feat_dataset_n_classes(trainset_name, feat_net)




    SNB = ShallowNetBuilder(in_shape, out_shape)
    SL = ShallowLoader(trainset_name, feat_net)

    load_iter = ['10']
    decay = [0.1, 0.01, 0.001, 0.0001]
    lrs = [0.01, 0.001, 0.0001, 0.00001]

    for li in load_iter:
        for dc in decay:
            for lr in lrs:
                extr_n = '_ft@{}_dc-{}_lr-{}'.format(li, dc, lr)
                opt = SGD(lr=0.001, momentum=0.9, decay=1e-6, nesterov=True)
                snet = [SNB.H8K(extr_n, lf_decay=0.01).init(lf=True).load(SL, li, SNB.H8K())]
                ST = ShallowTrainer(feat_net, trainset_name, validset_name, valid_split, batch_size=BATCH, loss=LOSS, metric=METRIC)
                ST.train(snet, opt, epochs=20, chk_period=1)
def extract_shallow_features():
    feat_net = 'resnet50'
    cfg.init(include_nets=[feat_net])

    old_trainset_name = cfg.dataset + '_train_ds'
    #old_testset_name = cfg.dataset + '_test'
    dataset_name =  cfg.dataset + '_train_ds'
    dataset_name = cfg.dataset + '_test'
    #crop, size = cfg.crop_size(net=feat_net)


    print("\nloading dataset: " + dataset_name)
    try:
        dataset = common.feat_dataset(dataset_name, feat_net)
    except IOError:
        print("Can't open dataset.")
        return
    print("dataset loaded.")

    in_shape = cfg.feat_shape_dict[feat_net]
    out_shape = feat_dataset_n_classes(dataset_name, feat_net)


    B = ShallowNetBuilder(in_shape, out_shape)
    SL = ShallowLoader(old_trainset_name, feat_net)

    pretrain_weight_epoch = '10'
    labelflip_finetune_epoch = '00'
    out_layer = 'additional_hidden_0'


    extr_n = '_ft@' + pretrain_weight_epoch
    model = B.H8K(extr_n, lf_decay=0.01).init(lf=False).load(SL, labelflip_finetune_epoch).model()
    #model.summary()
    feature_vectors = net_utils.extract_features(model, dataset, out_layer, batch_size, True)
    feature_vectors.save_hdf5("shallow_extracted_features/shallow_feat_" + dataset_name + ".h5")