Пример #1
0
def train_text_classifier():
    dataset = args.dataset
    x_train = y_train = x_test = y_test = None
    if dataset == 'imdb':
        train_texts, train_labels, test_texts, test_labels = split_imdb_files()
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
    elif dataset == 'agnews':
        train_texts, train_labels, test_texts, test_labels = split_agnews_files(
        )
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
    elif dataset == 'yahoo':
        train_texts, train_labels, test_texts, test_labels = split_yahoo_files(
        )
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
    x_train, y_train = shuffle(x_train, y_train, random_state=0)

    # Take a look at the shapes
    print('dataset:', dataset, '; model:', args.model, '; level:', args.level)
    print('X_train:', x_train.shape)
    print('y_train:', y_train.shape)
    print('X_test:', x_test.shape)
    print('y_test:', y_test.shape)

    log_dir = r'./logs/{}/{}/'.format(dataset, args.model)
    tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir,
                                              histogram_freq=0,
                                              write_graph=True)

    model_path = r'runs/{}/{}.dat'.format(dataset, args.model)
    model = batch_size = epochs = None
    assert args.model[:4] == args.level

    if args.model == "word_cnn":
        model = word_cnn(dataset)
        batch_size = config.wordCNN_batch_size[dataset]
        epochs = config.wordCNN_epochs[dataset]
    elif args.model == "word_bdlstm":
        model = bd_lstm(dataset)
        batch_size = config.bdLSTM_batch_size[dataset]
        epochs = config.bdLSTM_epochs[dataset]
    elif args.model == "char_cnn":
        model = char_cnn(dataset)
        batch_size = config.charCNN_batch_size[dataset]
        epochs = config.charCNN_epochs[dataset]
    elif args.model == "word_lstm":
        model = lstm(dataset)
        batch_size = config.LSTM_batch_size[dataset]
        epochs = config.LSTM_epochs[dataset]

    print('Train...')
    print('batch_size: ', batch_size, "; epochs: ", epochs)
    model.fit(x_train,
              y_train,
              batch_size=batch_size,
              epochs=epochs,
              validation_split=0.2,
              shuffle=True,
              callbacks=[tb_callback])
    scores = model.evaluate(x_test, y_test)
    print('test_loss: %f, accuracy: %f' % (scores[0], scores[1]))
    print('Saving model weights...')
    model.save_weights(model_path)
Пример #2
0
def fool_text_classifier():
    clean_samples_cap = args.clean_samples_cap  # 1000
    print('clean_samples_cap:', clean_samples_cap)

    # get tokenizer
    dataset = args.dataset
    tokenizer = get_tokenizer(dataset)

    # Read data set
    x_test = y_test = None
    test_texts = None
    if dataset == 'imdb':
        train_texts, train_labels, test_texts, test_labels = split_imdb_files()
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(train_texts, train_labels, test_texts, test_labels, dataset)
    elif dataset == 'agnews':
        train_texts, train_labels, test_texts, test_labels = split_agnews_files()
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(train_texts, train_labels, test_texts, test_labels, dataset)
    elif dataset == 'yahoo':
        train_texts, train_labels, test_texts, test_labels = split_yahoo_files()
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(train_texts, train_labels, test_texts, test_labels, dataset)
    elif dataset == 'yahoo_csv':
        train_texts, train_labels, test_texts, test_labels = split_yahoo_csv_files()
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(train_texts, train_labels, test_texts, test_labels, dataset)

    # Write clean examples into a txt file
    clean_texts_path = r'./fool_result/{}/clean_{}.txt'.format(dataset, str(clean_samples_cap))
    if not os.path.isfile(clean_texts_path):
        write_origin_input_texts(clean_texts_path, test_texts)

    # Select the model and load the trained weights
    assert args.model[:4] == args.level
    model = None
    if args.model == "word_cnn":
        model = word_cnn(dataset)
    elif args.model == "word_bdlstm":
        model = bd_lstm(dataset)
    elif args.model == "char_cnn":
        model = char_cnn(dataset)
    elif args.model == "word_lstm":
        model = lstm(dataset)
    model_path = r'./runs/{}/{}.dat'.format(dataset, args.model)
    model.load_weights(model_path)
    print('model path:', model_path)

    # evaluate classification accuracy of model on clean samples
    scores_origin = model.evaluate(x_test[:clean_samples_cap], y_test[:clean_samples_cap])
    print('clean samples origin test_loss: %f, accuracy: %f' % (scores_origin[0], scores_origin[1]))
    all_scores_origin = model.evaluate(x_test, y_test)
    print('all origin test_loss: %f, accuracy: %f' % (all_scores_origin[0], all_scores_origin[1]))

    grad_guide = ForwardGradWrapper(model)
    classes_prediction = grad_guide.predict_classes(x_test[: clean_samples_cap])

    # load HowNet candidates
    with open(os.path.join('hownet_candidates', dataset, 'dataset_50000.pkl'), 'rb') as f:
        dataset_dict = pickle.load(f)
    with open(os.path.join('hownet_candidates', dataset, 'word_candidates_sense_0515.pkl'), 'rb') as fp:
        word_candidate = pickle.load(fp)

    print('Crafting adversarial examples...')
    total_perturbated_text = 0
    successful_perturbations = 0
    failed_perturbations = 0
    sub_word_list = []
    sub_rate_list = []
    NE_rate_list = []

    start_cpu = time.clock()
    adv_text_path = r'./fool_result/{}/{}/adv_{}.txt'.format(dataset, args.model, str(clean_samples_cap))
    change_tuple_path = r'./fool_result/{}/{}/change_tuple_{}.txt'.format(dataset, args.model, str(clean_samples_cap))
    file_1 = open(adv_text_path, "a")
    file_2 = open(change_tuple_path, "a")
    for index, text in enumerate(test_texts[: clean_samples_cap]):
        if np.argmax(y_test[index]) == classes_prediction[index]:
            total_perturbated_text += 1
            # If the ground_true label is the same as the predicted label
            adv_doc, adv_y, sub_word, sub_rate, NE_rate, change_tuple_list = adversarial_paraphrase(input_text=text,
                                                                                          true_y=np.argmax(y_test[index]),
                                                                                          grad_guide=grad_guide,
                                                                                          tokenizer=tokenizer,
                                                                                          dataset=dataset,
                                                                                          dataset_dict=dataset_dict,
                                                                                          word_candidate=word_candidate,
                                                                                          level=args.level)
            if adv_y != np.argmax(y_test[index]):
                successful_perturbations += 1
                print('{}. Successful example crafted.'.format(index))
            else:
                failed_perturbations += 1
                print('{}. Failure.'.format(index))

            text = adv_doc
            sub_word_list.append(sub_word)
            sub_rate_list.append(sub_rate)
            NE_rate_list.append(NE_rate)
            file_2.write(str(index) + str(change_tuple_list) + '\n')
        file_1.write(text + "\n")
    end_cpu = time.clock()
    print('CPU second:', end_cpu - start_cpu)
    mean_sub_word = sum(sub_word_list) / len(sub_word_list)
    mean_fool_rate = successful_perturbations / total_perturbated_text
    mean_sub_rate = sum(sub_rate_list) / len(sub_rate_list)
    mean_NE_rate = sum(NE_rate_list) / len(NE_rate_list)
    print('mean substitution word:', mean_sub_word)
    print('mean fooling rate:', mean_fool_rate)
    print('mean substitution rate:', mean_sub_rate)
    print('mean NE rate:', mean_NE_rate)
    file_1.close()
    file_2.close()
Пример #3
0
def fool_text_classifier_pytorch(model, dataset='imdb'):
    clean_samples_cap = 100
    print('clean_samples_cap:', clean_samples_cap)

    # get tokenizer
    tokenizer = get_tokenizer(opt)

    # Read data set
    x_test = y_test = None
    test_texts = None
    if dataset == 'imdb':
        train_texts, train_labels, dev_texts, dev_labels, test_texts, test_labels = split_imdb_files(
            opt)
        x_train, y_train, x_test, y_test = word_process(
            train_texts, train_labels, test_texts, test_labels, dataset)

    elif dataset == 'agnews':
        train_texts, train_labels, test_texts, test_labels = split_agnews_files(
        )
        x_train, y_train, x_test, y_test = word_process(
            train_texts, train_labels, test_texts, test_labels, dataset)

    elif dataset == 'yahoo':
        train_texts, train_labels, test_texts, test_labels = split_yahoo_files(
        )
        x_train, y_train, x_test, y_test = word_process(
            train_texts, train_labels, test_texts, test_labels, dataset)

    grad_guide = ForwardGradWrapper_pytorch(model)
    classes_prediction = grad_guide.predict_classes(x_test[:clean_samples_cap])

    print('Crafting adversarial examples...')
    successful_perturbations = 0
    failed_perturbations = 0
    sub_rate_list = []
    NE_rate_list = []

    start_cpu = time.clock()
    adv_text_path = r'./fool_result/{}/adv_{}.txt'.format(
        dataset, str(clean_samples_cap))
    change_tuple_path = r'./fool_result/{}/change_tuple_{}.txt'.format(
        dataset, str(clean_samples_cap))
    file_1 = open(adv_text_path, "a")
    file_2 = open(change_tuple_path, "a")
    for index, text in enumerate(test_texts[:clean_samples_cap]):
        sub_rate = 0
        NE_rate = 0
        if np.argmax(y_test[index]) == classes_prediction[index]:
            # If the ground_true label is the same as the predicted label
            adv_doc, adv_y, sub_rate, NE_rate, change_tuple_list = adversarial_paraphrase(
                input_text=text,
                true_y=np.argmax(y_test[index]),
                grad_guide=grad_guide,
                tokenizer=tokenizer,
                dataset=dataset,
                level='word')
            if adv_y != np.argmax(y_test[index]):
                successful_perturbations += 1
                print('{}. Successful example crafted.'.format(index))
            else:
                failed_perturbations += 1
                print('{}. Failure.'.format(index))

            text = adv_doc
            sub_rate_list.append(sub_rate)
            NE_rate_list.append(NE_rate)
            file_2.write(str(index) + str(change_tuple_list) + '\n')
        file_1.write(text + " sub_rate: " + str(sub_rate) + "; NE_rate: " +
                     str(NE_rate) + "\n")
    end_cpu = time.clock()
    print('CPU second:', end_cpu - start_cpu)
    mean_sub_rate = sum(sub_rate_list) / len(sub_rate_list)
    mean_NE_rate = sum(NE_rate_list) / len(NE_rate_list)
    print('mean substitution rate:', mean_sub_rate)
    print('mean NE rate:', mean_NE_rate)
    file_1.close()
    file_2.close()
Пример #4
0

if __name__ == '__main__':
    args = parser.parse_args()
    clean_samples_cap = args.clean_samples_cap  # 1000

    # get tokenizer
    dataset = args.dataset
    tokenizer = get_tokenizer(dataset)

    # Read data set
    x_train = y_train = x_test = y_test = None
    test_texts = None
    first_get_dataset = False
    if dataset == 'imdb':
        train_texts, train_labels, test_texts, test_labels = split_imdb_files()
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
    elif dataset == 'agnews':
        train_texts, train_labels, test_texts, test_labels = split_agnews_files(
        )
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
Пример #5
0
def fool_text_classifier():
    clean_samples_cap = args.clean_samples_cap  # 1000
    print('clean_samples_cap:', clean_samples_cap)

    # get tokenizer
    dataset = args.dataset
    tokenizer = get_tokenizer(dataset)

    # Read data set
    x_test = y_test = None
    test_texts = None
    if dataset == 'imdb':
        train_texts, train_labels, test_texts, test_labels = split_imdb_files()
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
    elif dataset == 'agnews':
        train_texts, train_labels, test_texts, test_labels = split_agnews_files(
        )
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
    elif dataset == 'yahoo':
        train_texts, train_labels, test_texts, test_labels = split_yahoo_files(
        )
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(
                train_texts, train_labels, test_texts, test_labels, dataset)

    # Write clean examples into a txt file
    clean_path = '/content/drive/My Drive/PWWS_Pretrained'  #r./fool_result/{}'.format(dataset, str(clean_samples_cap))
    if not os.path.exists(clean_path):
        os.mkdir(clean_path)
        # if not os.path.isfile(clean_path):
        write_origin_input_texts(os.path.join(clean_path, 'clean_1000.txt'),
                                 test_texts)
        # write_origin_input_texts(os.path.join(clean_path, 'x_processed_1000'), x_test)
        # write_origin_input_texts(os.path.join(clean_path, 'y_1000'), y_test)
        np.save(os.path.join(clean_path, 'x_processed_1000.npy'), x_test)
        np.save(os.path.join(clean_path, 'y_1000.npy'), y_test)

    with open(os.path.join(clean_path, 'clean_1000.txt'), 'r') as f:
        test_texts = f.readlines()
    x_test = np.load(os.path.join(clean_path, 'x_processed_1000.npy'))
    y_test = np.load(os.path.join(clean_path, 'y_1000.npy'))

    # Select the model and load the trained weights
    assert args.model[:4] == args.level
    model = None
    if args.model == "word_cnn":
        model = word_cnn(dataset)
    elif args.model == "word_bdlstm":
        model = bd_lstm(dataset)
    elif args.model == "char_cnn":
        model = char_cnn(dataset)
    elif args.model == "word_lstm":
        model = lstm(dataset)
    model_path = r'runs/{}/{}.dat'.format(dataset, args.model)
    model.load_weights(model_path)
    print('model path:', model_path)

    # evaluate classification accuracy of model on clean samples
    scores_origin = model.evaluate(x_test, y_test)
    print('clean samples origin test_loss: %f, accuracy: %f' %
          (scores_origin[0], scores_origin[1]))
    all_scores_origin = model.evaluate(x_test, y_test)
    print('all origin test_loss: %f, accuracy: %f' %
          (all_scores_origin[0], all_scores_origin[1]))

    grad_guide = ForwardGradWrapper(model)
    classes_prediction = grad_guide.predict_classes(x_test)

    print('Crafting adversarial examples...')
    successful_perturbations = 0
    failed_perturbations = 0
    sub_rate_list = []
    NE_rate_list = []

    start_cpu = time.clock()
    adv_text_path = r'./fool_result/{}/{}/adv_{}.txt'.format(
        dataset, args.model, str(clean_samples_cap))
    change_tuple_path = r'./fool_result/{}/{}/change_tuple_{}.txt'.format(
        dataset, args.model, str(clean_samples_cap))
    file_1 = open(adv_text_path, "a")
    file_2 = open(change_tuple_path, "a")

    for index, text in enumerate(test_texts[:clean_samples_cap]):
        # print(text)
        sub_rate = 0
        NE_rate = 0
        if np.argmax(y_test[index]) == classes_prediction[index]:
            # If the ground_true label is the same as the predicted label
            adv_doc, adv_y, sub_rate, NE_rate, change_tuple_list = adversarial_paraphrase(
                input_text=text,
                true_y=np.argmax(y_test[index]),
                grad_guide=grad_guide,
                tokenizer=tokenizer,
                dataset=dataset,
                level=args.level)
            if adv_y != np.argmax(y_test[index]):
                successful_perturbations += 1
                print('{}. Successful example crafted.'.format(index))
            else:
                failed_perturbations += 1
                print('{}. Failure.'.format(index))

            text = adv_doc
            sub_rate_list.append(sub_rate)
            NE_rate_list.append(NE_rate)
            file_2.write(str(index) + str(change_tuple_list) + '\n')
        file_1.write(text + " sub_rate: " + str(sub_rate) + "; NE_rate: " +
                     str(NE_rate) + "\n")
    end_cpu = time.clock()
    print('CPU second:', end_cpu - start_cpu)
    mean_sub_rate = sum(sub_rate_list) / len(sub_rate_list)
    mean_NE_rate = sum(NE_rate_list) / len(NE_rate_list)
    print('mean substitution rate:', mean_sub_rate)
    print('mean NE rate:', mean_NE_rate)
    print('successful perturbations:', successful_perturbations)
    print('failed perturbations:', failed_perturbations)
    print(
        'success rate:', successful_perturbations /
        (failed_perturbations + successful_perturbations))
    file_1.close()
    file_2.close()
Пример #6
0
def train_text_classifier():
    dataset = args.dataset
    x_train = y_train = x_test = y_test = None
    if dataset == 'imdb':
        train_texts, train_labels, test_texts, test_labels = split_imdb_files()
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
    elif dataset == 'agnews':
        train_texts, train_labels, test_texts, test_labels = split_agnews_files(
        )
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
    elif dataset == 'yahoo':
        train_texts, train_labels, test_texts, test_labels = split_yahoo_files(
        )
        if args.level == 'word':
            x_train, y_train, x_test, y_test = word_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
        elif args.level == 'char':
            x_train, y_train, x_test, y_test = char_process(
                train_texts, train_labels, test_texts, test_labels, dataset)
    x_train, y_train = shuffle(x_train, y_train, random_state=0)

    # Take a look at the shapes
    print('dataset:', dataset, '; model:', args.model, '; level:', args.level)
    print('X_train:', x_train.shape)
    print('y_train:', y_train.shape)
    print('X_test:', x_test.shape)
    print('y_test:', y_test.shape)

    log_dir = r'./logs/{}/{}/'.format(dataset, args.model)
    tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir,
                                              histogram_freq=0,
                                              write_graph=True)

    model_path = r'./runs/{}/{}.dat_100embed_binary'.format(
        dataset, args.model)
    model = batch_size = epochs = None
    assert args.model[:4] == args.level

    if args.model == "word_cnn":
        model = word_cnn(dataset)
        batch_size = config.wordCNN_batch_size[dataset]
        epochs = config.wordCNN_epochs[dataset]
    elif args.model == "word_bdlstm":
        model = bd_lstm(dataset)
        batch_size = config.bdLSTM_batch_size[dataset]
        epochs = config.bdLSTM_epochs[dataset]
    elif args.model == "char_cnn":
        model = char_cnn(dataset)
        batch_size = config.charCNN_batch_size[dataset]
        epochs = config.charCNN_epochs[dataset]
    elif args.model == "word_lstm":
        model = lstm(dataset)
        batch_size = config.LSTM_batch_size[dataset]
        epochs = config.LSTM_epochs[dataset]

    print('Train...')
    print('batch_size: ', batch_size, "; epochs: ", epochs)

    # get threshold
    ## compute Gilbert–Varshamov Bound ##
    def _check_(L, M, d):
        left = 0
        for i in range(0, d):
            left = left + scipy.special.binom(L, i)
        tmp = float(pow(2, int(L))) / left
        return (tmp > M)

    def gvbound(L, M, tight=False):
        d = 1
        while (_check_(L, M, d)):
            d += 1
        if not tight: d -= 1
        print('Distance for %d classes with %d bits (GV bound): %d' %
              (M, L, d))
        return d

    ## compute Hamming Bound ##
    def check(L, M, d):
        right = float(pow(2, int(L))) / float(M)
        x = math.floor(float(d - 1) / 2.0)
        left = 0
        for i in range(0, x + 1):
            left = left + scipy.special.binom(L, i)
        return (left <= right)

    def hammingbound(L, M, tight=False):
        d = 0
        while (check(L, M, d)):
            d += 1
        if not tight: d -= 1
        print('Distance for %d classes with %d bits (Hamming bound): %d' %
              (M, L, d))
        return d

    embedding_dims = config.wordCNN_embedding_dims[dataset]
    num_words = config.num_words[dataset]
    ham_bound = hammingbound(embedding_dims, num_words)
    gv_bound = gvbound(embedding_dims, num_words)
    bound = ham_bound

    def custom_loss_wrapper(y_true, y_pred):
        # ce-loss
        xent_loss = losses.binary_crossentropy(y_true, y_pred)
        # quantization loss
        embedding_output = model.layers[0].output
        Bbatch = tf.sign(embedding_output)
        quan_l = K.mean(
            keras.losses.mean_absolute_error(Bbatch, embedding_output))
        # ecoc loss
        centers = model.layers[0].weights[0]
        centers_binary = tf.sign(centers)
        norm_centers = tf.math.l2_normalize(centers, axis=1)
        prod = tf.matmul(norm_centers,
                         tf.transpose(norm_centers))  #5000 x 5000
        ham_dist = 0.5 * (embedding_dims - prod)
        marg_loss = K.mean(ham_dist)
        #tmp = K.maximum(0.0, ham_dist - bound)
        #non_zero = K.sum(K.cast(tmp > 0, dtype='float32'))
        #marg_loss = K.sum(tmp) * 1.0/(non_zero+1.0)

        #tmp2 = K.maximum(0.0, gv_bound - ham_dist)
        #non_zero2 = K.sum(K.cast(tmp2 > 0, dtype='float32'))
        #marg_loss2 = K.sum(tmp2) * 1.0/(non_zero2+1.0)

        return xent_loss + 0.1 * quan_l

    #centers = model.layers[0].weights[0]
    #norm_centers = tf.math.l2_normalize(centers, axis=1)
    #prod = tf.matmul(norm_centers, tf.transpose(norm_centers))
    #ham_dist = 0.5 * (embedding_dims - prod)
    #marg_loss = K.mean(K.maximum(0.0, 26 - ham_dist))
    #get_output = K.function([model.layers[0].input],
    #[prod])
    #x = get_output([x_train])[0]

    #centers = centers[0]
    #prod = tf.matmul(centers, tf.transpose(centers))

    #pdb.set_trace()

    def quan_loss(y_true, y_pred):
        embedding_output = model.layers[0].output
        Bbatch = tf.sign(embedding_output)
        quan_l = K.mean(
            keras.losses.mean_absolute_error(Bbatch, embedding_output))
        return quan_l

    def ecoc_loss(y_true, y_pred):
        # ecoc loss
        centers = model.layers[0].weights[0]
        centers_binary = tf.sign(centers)
        norm_centers = tf.math.l2_normalize(centers, axis=1)
        prod = tf.matmul(norm_centers,
                         tf.transpose(norm_centers))  #5000 x 5000
        #ham_dist = 0.5 * (embedding_dims - prod)
        marg_loss = K.mean(prod)
        #tmp = K.maximum(0.0, bound - ham_dist)
        #non_zero = K.sum(K.cast(tmp > 0, dtype='float32'))
        #marg_loss = K.sum(tmp) * 1.0/(non_zero+1.0)
        #tmp = K.maximum(0.0, ham_dist - bound)
        #non_zero = K.sum(K.cast(tmp > 0, dtype='float32'))
        #marg_loss = K.sum(tmp) * 1.0/(non_zero+1.0)
        return marg_loss

    model.compile(loss=custom_loss_wrapper,
                  optimizer='adam',
                  metrics=['accuracy', quan_loss, ecoc_loss])

    #tmp = model.predict(x_train,batch_size=batch_size)

    #pdb.set_trace()
    model.fit(
        x_train,
        y_train,
        batch_size=batch_size,  # batch_size => 32
        epochs=10,  #epochs,
        validation_split=0.2,
        shuffle=True,
        callbacks=[tb_callback])
    scores = model.evaluate(x_test, y_test)
    print('test_loss: %f, accuracy: %f' % (scores[0], scores[1]))
    print('Saving model weights...')
    model.save_weights(model_path)