示例#1
0
文件: train.py 项目: IBM/adv-def-text
def create_models(args):
    model, model_dev, model_classifier = None, None, None
    if args.adv:
        # if args.copy:
        if args.classification_model == 'BERT':
            args.bert_config = modeling.BertConfig.from_json_file(
                args.bert_config_file)
        model = AdversarialModelCopy(args, mode='Train')
        model_dev = AdversarialModelCopy(args,
                                         mode="Infer",
                                         include_cls=False,
                                         embedding=model.get_embedding())
        model_classifier = AdversarialModelCopy(
            args,
            mode="Infer",
            include_ae=False,
            embedding=model.get_embedding())
    elif args.classification:
        if args.classification_model == 'RNN':
            utils.print_out('Initialise classification model: RNN')
            model = ClassificationModel(args, 'Train')
        elif args.classification_model == 'CNN':
            utils.print_out('Initialise classification model: CNN')
            model = CNNClassificationModel(args, mode='Train')
        elif args.classification_model == 'BERT':
            bert_config = modeling.BertConfig.from_json_file(
                args.bert_config_file)
            model = BertClassificationModel(args, bert_config, mode='Train')
            modeling.init_bert(args.bert_init_chk,
                               word_embedding_trainable=True)
    else:
        model = Seq2SeqModel(args, mode="Train")
        model_dev = Seq2SeqModel(args, mode="Infer")
    return model, model_dev, model_classifier
示例#2
0
def eval_ae(sess, dev_iter, model_dev, dev_next, vocab, step, demo_per_step):
    sess.run(dev_iter.initializer)
    decoder_reference_list = []
    decoder_prediction_list = []
    i = 0
    while True:
        try:
            dev_batch = sess.run(dev_next)
            decoder_predictions_inference = sess.run(
                model_dev.make_infer_outputs(),
                feed_dict=model_dev.make_train_inputs(dev_batch))
            decoder_reference_list.append(dev_batch[2])
            decoder_prediction_list.append(decoder_predictions_inference)
            if i >= 30:
                break
            i += 1
        except tf.errors.OutOfRangeError:
            break
    reference_list, translation_list = [], []
    for decoder_references in decoder_reference_list:
        references = getSentencesFromIDs(decoder_references, vocab)
        reference_list.extend(references)
    for decoder_predictions in decoder_prediction_list:
        translations = getSentencesFromIDs(decoder_predictions, vocab)
        translation_list.extend(translations)
    if demo_per_step == 1:
        for i in range(len(reference_list)):
            utils.print_out('Example ' + str(i) + ': src: ' +
                            ' '.join(reference_list[i]))
            utils.print_out('Example ' + str(i) + ': nmt: ' +
                            ' '.join(translation_list[i]))
    elif step % demo_per_step == 0:
        rand_ind = random.randint(1, len(reference_list) - 1)
        utils.print_out('Step: ' + str(step) + ', src: ' +
                        ' '.join(reference_list[rand_ind]))
        utils.print_out('Step: ' + str(step) + ', nmt: ' +
                        ' '.join(translation_list[rand_ind]))
        # utils.print_out('Step' + str(step) + ', copy: ' + ' '.join([str(a) for a in copy_mask_list[rand_ind]]))
        # utils.print_out('Step' + str(step) + ', logits: ' + ' '.join([str(a) for a in decoder_logits_list[rand_ind]]))

    acc = evaluate._accuracy(reference_list, translation_list)
    word_acc = evaluate._word_accuracy(reference_list, translation_list)
    rouge = evaluate._rouge(reference_list, translation_list)
    bleu = evaluate._bleu(reference_list, translation_list)
    return acc, word_acc, rouge, bleu
示例#3
0
def printSentence(tokenized_sentences, vocab):
    train_sentence = ''
    for word in tokenized_sentences:
        train_sentence += vocab[word] + ' '
    utils.print_out(train_sentence)
示例#4
0
def eval_classification(args, sess, dev_iter, dev_next, model, vocab):
    sess.run(dev_iter.initializer)
    dev_srcs = []
    dev_labels = []
    dev_logits = []
    alphas = []
    count = 0
    while True:
        try:
            dev_batch = sess.run(dev_next)
            results = sess.run(model.make_test_outputs(),
                               feed_dict=model.make_train_inputs(dev_batch))
            dev_srcs.append(results[-2])
            dev_labels.append(results[-1])
            dev_logits.append(results[1])
            if args.cls_attention:
                alphas.append(results[3])
            count += 1
        except tf.errors.OutOfRangeError:
            break
    reference_list = []
    alphas_list = []
    for ind, decoder_references in enumerate(dev_srcs):
        references = getSentencesFromIDs(decoder_references, vocab)
        reference_list.extend(references)
        if len(alphas) > 0:
            alphas_list.extend(alphas[ind])

    dev_logits = np.concatenate(dev_logits, axis=0)
    dev_labels = np.concatenate(dev_labels, axis=0)

    rand_ind = random.randint(1, len(reference_list) - 1)
    utils.print_out('src  : ' + str(reference_list[rand_ind]) + ', label: ' +
                    str(dev_labels[rand_ind]))
    if len(alphas_list) > 0:
        utils.print_out(
            'alpha: ' +
            ', '.join(["{:.3f}".format(a[0]) for a in alphas_list[rand_ind]]))

    # Print: False positives, false negatives
    fp_spl = open(args.output_dir + '/false_positive.txt', 'w')
    fn_spl = open(args.output_dir + '/false_negative.txt', 'w')
    for i in range(len(reference_list)):
        # utils.print_out('Example ' + str(i) + ': src:\t' + ' '.join(reference_list[i]) + '\t' + str(dev_labels[i]))
        # utils.print_out(' ')
        spl_predict = evaluate.max_index(dev_logits[i])
        if dev_labels[i][1] == 1:
            if spl_predict == 0:
                fn_spl.write('Example ' + str(i) + ': src:\t' +
                             ' '.join(reference_list[i]) + '\t: ' +
                             str(dev_labels[i]) + '\n')
                fn_spl.write('Example ' + str(i) + ': spl:\t' +
                             ' '.join(reference_list[i]) + '\t: ' +
                             str(dev_logits[i]) + '\n')
                fn_spl.write('\n')
        elif dev_labels[i][1] == 0:
            if spl_predict == 1:
                fp_spl.write('Example ' + str(i) + ': src:\t' +
                             ' '.join(reference_list[i]) + '\t: ' +
                             str(dev_labels[i]) + '\n')
                fp_spl.write('Example ' + str(i) + ': spl:\t' +
                             ' '.join(reference_list[i]) + '\t: ' +
                             str(dev_logits[i]) + '\n')
                fp_spl.write('\n')
    fp_spl.close()
    fn_spl.close()

    acc = evaluate._clss_accuracy(dev_labels, dev_logits)
    auc = evaluate._clss_auc(dev_labels, dev_logits)
    return auc, acc
示例#5
0
def eval_adv(args,
             sess,
             dev_iter,
             model_dev,
             model_classifier,
             dev_next,
             vocab,
             step,
             demo_per_step,
             is_train=True):
    sess.run(dev_iter.initializer)
    # sentiment_distances = []
    cls_labels = []
    cls_logits_def = []
    cls_origs_def = []
    cls_logits = []
    cls_orig_logits = []
    decoder_reference_list = []
    decoder_prediction_list = []
    sent_embs = []
    adv_sent_embs = []
    orig_alphas, trans_alphas = [], []
    trans_alphas_def = []
    copy_masks = []
    i = 0

    start = default_timer()
    while True:
        try:
            dev_batch = sess.run(dev_next)
            if args.copy:
                copy_mask = get_copy_mask(sess, model_dev, dev_batch,
                                          np.max(dev_batch[5]),
                                          args.top_k_attack)
                dev_batch = dev_batch + (copy_mask, )
            results = sess.run(
                model_dev.make_infer_outputs(),
                feed_dict=model_dev.make_train_inputs(dev_batch))
            cls_labels.extend(dev_batch[3])

            if args.beam_width > 0:
                decoder_outputs = results[0][0]
            else:
                decoder_outputs = results[0]

            if args.ae_vocab_file is not None:
                decoder_outputs = lookup(decoder_outputs, args.vocab_map)

            decoder_reference_list.extend(dev_batch[2])

            if args.copy:
                copy_masks.extend(results[1])

            decoder_prediction_list.extend(decoder_outputs)

            decoder_preds_lengths = get_lengths(
                decoder_outputs,
                eos_id=102
                if args.classification_model == 'BERT' else input_data.EOS_ID)
            decoder_outputs = padding(decoder_outputs,
                                      eos_id=102 if args.classification_model
                                      == 'BERT' else input_data.EOS_ID)

            if args.classification_model == 'BERT':
                decoder_outputs = np.concatenate([
                    np.array([[101]] * len(decoder_outputs)), decoder_outputs
                ],
                                                 axis=1)

            # classification based on decoder_predictions_inference
            cls_logit = sess.run(
                model_classifier.make_classifier_outputs(),
                feed_dict=model_classifier.make_classifier_input(
                    dev_batch, decoder_outputs, decoder_preds_lengths))
            cls_logits.extend(cls_logit[0])
            adv_sent_embs.extend(cls_logit[1])
            if args.cls_attention:
                trans_alphas.extend(cls_logit[2])

            # classification based on orginal input
            cls_orig_logit = sess.run(
                model_classifier.make_classifier_outputs(),
                feed_dict=model_classifier.make_classifier_input(
                    dev_batch, dev_batch[2], dev_batch[5]))
            cls_orig_logits.extend(cls_orig_logit[0])
            sent_embs.extend(cls_orig_logit[1])
            if args.cls_attention:
                orig_alphas.extend(cls_orig_logit[2])

            # defending classification based on decoder_predictions_inference
            if args.defending:
                cls_logit_def = sess.run(
                    model_classifier.make_def_classifier_outputs(),
                    feed_dict=model_classifier.make_classifier_input(
                        dev_batch, decoder_outputs, decoder_preds_lengths))
                cls_logits_def.extend(cls_logit_def[0])
                if args.cls_attention:
                    trans_alphas_def.extend(cls_logit_def[2])

                cls_orig_def = sess.run(
                    model_classifier.make_def_classifier_outputs(),
                    feed_dict=model_classifier.make_classifier_input(
                        dev_batch, dev_batch[2], dev_batch[5]))
                cls_origs_def.extend(cls_orig_def[0])

            if is_train and i >= 30:
                # if i >= 0:
                break
            i += 1
        except tf.errors.OutOfRangeError:
            break
    end = default_timer()
    if not is_train:
        utils.print_out('Adversarial attack elapsed:' +
                        '{0:.4f}'.format(end - start) + 's')

    cls_acc, cls_acc_pos, cls_acc_neg, changed_bleu = evaluate_attack(
        args,
        step,
        decoder_reference_list,
        decoder_prediction_list,
        cls_logits,
        cls_orig_logits,
        cls_labels,
        vocab,
        sent_embs,
        adv_sent_embs,
        is_test=(not is_train),
        orig_alphas=orig_alphas,
        trans_alphas=trans_alphas,
        cls_logits_def=cls_logits_def,
        cls_origs_def=cls_origs_def,
        copy_masks=copy_masks)

    return cls_acc, cls_acc_pos, cls_acc_neg, changed_bleu
示例#6
0
def evaluate_attack(args, step, decoder_reference_list, decoder_prediction_list,
                    cls_logits, cls_orig_logits, cls_labels, vocab,
                    sent_embs, adv_sent_embs,
                    is_test=False, X_adv_flip_num=None,
                    orig_alphas=None, trans_alphas=None,
                    cls_logits_def=None, cls_origs_def=None,
                    copy_masks=None):

    cls_orig_acc = general_evaluate._clss_accuracy(cls_labels, cls_orig_logits)
    cls_orig_auc = general_evaluate._clss_auc(cls_labels, cls_orig_logits)

    cls_acc = general_evaluate._clss_accuracy(cls_labels, cls_logits)
    cls_auc = general_evaluate._clss_auc(cls_labels, cls_logits)
    cls_acc_pos = general_evaluate._clss_accuracy_micro(cls_labels, cls_logits, orig_label=1)
    cls_acc_neg = general_evaluate._clss_accuracy_micro(cls_labels, cls_logits, orig_label=0)

    if cls_logits_def is not None and len(cls_logits_def) > 0:
        cls_def_acc = general_evaluate._clss_accuracy(cls_labels, cls_logits_def)
        cls_def_auc = general_evaluate._clss_auc(cls_labels, cls_logits_def)
        org_def_acc = general_evaluate._clss_accuracy(cls_labels, cls_origs_def)
        org_def_auc = general_evaluate._clss_auc(cls_labels, cls_origs_def)

    reference_list = getSentencesFromIDs(decoder_reference_list, vocab)
    translation_list = getSentencesFromIDs(decoder_prediction_list, vocab)

    ref_pos, ref_neg, trans_pos, trans_neg, ref_changed, trans_changed = [], [], [], [], [], []
    label_changed, logits_changed, flip_num_changed, ids_changed = [], [], [], []
    ref_emb_pos, trans_emb_pos, ref_emb_neg, trans_emb_neg, ref_emb_cha, trans_emb_cha = [], [], [], [], [], []

    for ind, references in enumerate(reference_list):
        ref_pos.append(references) if cls_labels[ind][1] > 0 else ref_neg.append(references)
        trans_pos.append(translation_list[ind]) if cls_labels[ind][1] > 0 else trans_neg.append(translation_list[ind])
        ref_emb_pos.append(sent_embs[ind]) if cls_labels[ind][1] > 0 else ref_emb_neg.append(sent_embs[ind])
        trans_emb_pos.append(adv_sent_embs[ind]) if cls_labels[ind][1] > 0 else trans_emb_neg.append(adv_sent_embs[ind])
        if np.argmax(cls_logits[ind]) != np.argmax(cls_orig_logits[ind]):
            ids_changed.append(ind)
            ref_changed.append(references)
            trans_changed.append(translation_list[ind])
            label_changed.append(cls_labels[ind])
            logits_changed.append(cls_logits[ind])
            ref_emb_cha.append(sent_embs[ind])
            trans_emb_cha.append(adv_sent_embs[ind])
            if X_adv_flip_num is not None:
                flip_num_changed.append(X_adv_flip_num[ind])

    ae_acc = general_evaluate._accuracy(reference_list, translation_list)
    word_acc = general_evaluate._word_accuracy(reference_list, translation_list)
    rouge = general_evaluate._rouge(reference_list, translation_list)
    bleu = general_evaluate._bleu(reference_list, translation_list)
    use = general_evaluate._use_scores(reference_list, translation_list, args.use_model)
    accept = general_evaluate._accept_score(reference_list, translation_list, args)

    # positive examples
    pos_rouge = general_evaluate._rouge(ref_pos, trans_pos)
    pos_bleu = general_evaluate._bleu(ref_pos, trans_pos)
    pos_accept = general_evaluate._accept_score(ref_pos, trans_pos, args)
    pos_semsim = avgcos(ref_emb_pos, trans_emb_pos)
    pos_use = general_evaluate._use_scores(ref_pos, trans_pos, args.use_model)

    # negative examples
    neg_rouge = general_evaluate._rouge(ref_neg, trans_neg)
    neg_bleu = general_evaluate._bleu(ref_neg, trans_neg)
    neg_accept = general_evaluate._accept_score(ref_neg, trans_neg, args)
    neg_semsim = avgcos(ref_emb_neg, trans_emb_neg)
    neg_use = general_evaluate._use_scores(ref_neg, trans_neg, args.use_model)


    # changed examples
    if len(ref_changed) == 0:
        changed_rouge = -1.0
        changed_bleu = -1.0
        changed_accept = -1.0
        changed_semsim = -1.0
        changed_use = -1.0
    else:
        changed_rouge = general_evaluate._rouge(ref_changed, trans_changed)
        changed_bleu = general_evaluate._bleu(ref_changed, trans_changed)
        changed_accept = general_evaluate._accept_score(ref_changed, trans_changed, args)
        changed_semsim = avgcos(ref_emb_cha, trans_emb_cha)
        changed_use = general_evaluate._use_scores(ref_changed, trans_changed, args.use_model)
        # changed_use = 0.0

    # print out src, spl, and nmt
    for i in range(len(ref_changed)):
        reference_changed = ref_changed[i]
        translation_changed = trans_changed[i]
        if orig_alphas is not None and len(orig_alphas) > 0:
            orig_alpha = orig_alphas[ids_changed[i]]
            reference_changed = [s + '('+'{:.3f}'.format(orig_alpha[ind][0])+')' for ind, s in enumerate(ref_changed[i])]
            trans_alpha = trans_alphas[ids_changed[i]]
            translation_changed = [s + '('+'{:.3f}'.format(trans_alpha[ind][0])+')' for ind, s in enumerate(trans_changed[i])]
        utils.print_out('Example ' + str(ids_changed[i]) + ': src:\t' + ' '.join(reference_changed) + '\t' + str(label_changed[i]))
        utils.print_out('Example ' + str(ids_changed[i]) + ': nmt:\t' + ' '.join(translation_changed) + '\t' + str(logits_changed[i]))
        if copy_masks is not None and len(copy_masks)>0:
            copy_mask = copy_masks[ids_changed[i]]
            copy_mask_str = [str(mask) for mask in copy_mask]
            utils.print_out('Example ' + str(ids_changed[i]) + ': msk:\t' + ' '.join(copy_mask_str))
        if X_adv_flip_num is not None:
            utils.print_out('Example ' + str(ids_changed[i]) + ' flipped tokens: ' + str(flip_num_changed[i]))
        utils.print_out(' ')

    if X_adv_flip_num is not None:
        lenght = 0
        for num in X_adv_flip_num:
            if num > 0:
                lenght += 1
        utils.print_out('Average flipped tokens: ' + str(sum(X_adv_flip_num) / lenght))

    utils.print_out('Step: ' + str(step) + ', cls_acc_pos=' + str(cls_acc_pos) + ', cls_acc_neg=' + str(cls_acc_neg))
    utils.print_out('Step: ' + str(step) + ', rouge_pos=' + str(pos_rouge) + ', rouge_neg=' + str(neg_rouge) + ', rouge_changed=' + str(changed_rouge))
    utils.print_out('Step: ' + str(step) + ', bleu_pos=' + str(pos_bleu) + ', bleu_neg=' + str(neg_bleu) + ', bleu_changed=' + str(changed_bleu))
    utils.print_out('Step: ' + str(step) + ', accept_pos=' + str(pos_accept) + ', accept_neg=' + str(neg_accept) + ', accept_changed=' + str(changed_accept))
    utils.print_out('Step: ' + str(step) + ', semsim_pos=' + str(pos_semsim) + ', semsim_neg=' + str(neg_semsim) + ', semsim_changed=' + str(changed_semsim))
    utils.print_out('Step: ' + str(step) + ', use_pos=' + str(pos_use) + ', use_neg=' + str(neg_use) + ', use_changed=' + str(changed_use))
    utils.print_out('Step: ' + str(step) + ', ae_acc=' + str(ae_acc) + ', word_acc=' + str(word_acc) + ', rouge=' + str(rouge) + ', bleu=' + str(bleu) +
                    ', accept=' + str(accept) + ', use=' + str(use) + ', semsim=' + str(avgcos(sent_embs, adv_sent_embs)))
    utils.print_out('Step: ' + str(step) + ', cls_orig_acc=' + str(cls_orig_acc) + ', cls_orig_auc=' + str(cls_orig_auc))
    utils.print_out('Step: ' + str(step) + ', cls_acc=' + str(cls_acc) + ', cls_auc=' + str(cls_auc))
    if cls_logits_def is not None and len(cls_logits_def) > 0:
        utils.print_out('Step: ' + str(step) + ', org_def_acc=' + str(org_def_acc) + ', org_def_auc=' + str(org_def_auc))
        utils.print_out('Step: ' + str(step) + ', cls_def_acc=' + str(cls_def_acc) + ', cls_def_auc=' + str(cls_def_auc))

    if is_test:
        with open(args.output_dir+'/src_changed.txt', 'w') as output_file:
            output_file.write('\n'.join([' '.join(a) for a in ref_changed]))
        with open(args.output_dir+'/adv_changed.txt', 'w') as output_file:
            output_file.write('\n'.join([' '.join(a) for a in trans_changed]))

        with open(args.output_dir+'/adv.txt', 'w') as output_file:
            output_file.write('\n'.join([' '.join(a) for a in translation_list]))
        with open(args.output_dir+'/adv_score.txt', 'w') as output_file:
            for score in cls_logits:
                output_file.write(' '.join([str(a) for a in score])+'\n')

    return cls_acc, cls_acc_pos, cls_acc_neg, changed_bleu
示例#7
0
文件: train.py 项目: IBM/adv-def-text
def test_adv_pos_neg(args):
    data_task = 'adv'
    if args.ae_vocab_file is not None: data_task = 'adv_counter_fitting'
    vocab, _ = input_data.load_vocab(args.vocab_file)
    ae_vocab, _ = (
        args.ae_vocab_file,
        None) if args.ae_vocab_file is None else input_data.load_vocab(
            args.ae_vocab_file)
    args.stop_words = setStopWord(
        vocab) if args.ae_vocab_file is None else setStopWord(ae_vocab)
    args.vocab_map = None if args.ae_vocab_file is None else (
        maping_vocabs_bert(ae_vocab, vocab) if args.classification_model
        == 'BERT' else maping_vocabs(ae_vocab, vocab))

    test_iter = input_data.get_dataset_iter(
        args,
        args.test_file,
        args.test_output,
        data_task,
        is_training=False,
        is_test=True,
        is_bert=(args.classification_model == 'BERT'))
    test_next = test_iter.get_next()

    step = 0
    model_test = AdversarialModelCopy(args, mode="Infer", include_cls=False)
    model_classifier = AdversarialModelCopy(
        args,
        mode="Infer",
        include_ae=False,
        embedding=model_test.get_embedding())

    utils.print_out('Testing model constructed.')

    saver = tf.train.Saver()
    sess_pos = tf.Session()
    sess_pos.run([tf.global_variables_initializer(), tf.tables_initializer()])
    sess_pos.run(test_iter.initializer)

    saver.restore(sess_pos, args.load_model_pos)

    _, _, dev_logits, dev_labels = eval_steps.run_classification(
        args, model_classifier, vocab, sess_pos, test_iter, test_next)

    target_predicts = np.argmax(dev_logits, axis=-1)
    pos_mask = (dev_labels[:, 1] == 1)
    neg_mask = (dev_labels[:, 1] == 0)
    correct_mask = (target_predicts == dev_labels[:, 1])
    pos_predict_mask = pos_mask & correct_mask
    neg_predict_mask = neg_mask & correct_mask
    keep_orig_mask = 1 - correct_mask

    sess_neg = tf.Session()
    sess_neg.run([tf.global_variables_initializer(), tf.tables_initializer()])
    saver.restore(sess_neg, args.load_model_neg)

    if args.use_model is not None:
        args.use_model.set_sess(sess_pos)

    start = default_timer()
    decoder_reference_list, decoder_prediction_list_pos, cls_labels, copy_masks_pos = \
        eval_steps.run_adv(args, model_test, sess_pos, test_iter, test_next)
    decoder_reference_list, decoder_prediction_list_neg, cls_labels, copy_masks_neg = \
        eval_steps.run_adv(args, model_test, sess_neg, test_iter, test_next)

    decoder_prediction_list = []
    for i in range((len(decoder_reference_list) // args.batch_size) + (1 if (
            len(decoder_reference_list) % args.batch_size > 0) else 0)):
        start, end = i * args.batch_size, (i + 1) * args.batch_size
        decoder_prediction_batch = np.array(decoder_reference_list[start:end]) * np.expand_dims(keep_orig_mask[start:end], axis=1) + \
                              np.array(decoder_prediction_list_pos[start:end]) * np.expand_dims(pos_predict_mask[start:end], axis=1) \
                              + np.array(decoder_prediction_list_neg[start:end]) * np.expand_dims(neg_predict_mask[start:end], axis=1)
        decoder_prediction_list.extend(decoder_prediction_batch)

    end = default_timer()
    utils.print_out('Adversarial attack elapsed:' +
                    '{0:.4f}'.format(end - start) + 's')

    cls_logits_def, cls_origs_def, cls_logits, cls_orig_logits, sent_embs, adv_sent_embs, \
           orig_alphas, trans_alphas, trans_alphas_def = \
        eval_steps.run_classifications(args, sess_pos, test_iter, decoder_prediction_list, model_classifier, test_next)

    evaluate_attack(args,
                    step,
                    decoder_reference_list,
                    decoder_prediction_list,
                    cls_logits,
                    cls_orig_logits,
                    cls_labels,
                    vocab,
                    sent_embs,
                    adv_sent_embs,
                    is_test=True,
                    orig_alphas=orig_alphas,
                    trans_alphas=trans_alphas,
                    cls_logits_def=cls_logits_def,
                    cls_origs_def=cls_origs_def)
示例#8
0
文件: train.py 项目: IBM/adv-def-text
def test(args):
    data_task = 'ae'
    if args.classification: data_task = 'clss'
    if args.adv: data_task = 'adv'
    if args.ae_vocab_file is not None: data_task = 'adv_counter_fitting'

    vocab, _ = input_data.load_vocab(args.vocab_file)
    ae_vocab, _ = (
        args.ae_vocab_file,
        None) if args.ae_vocab_file is None else input_data.load_vocab(
            args.ae_vocab_file)
    args.stop_words = setStopWord(
        vocab) if args.ae_vocab_file is None else setStopWord(ae_vocab)
    args.vocab_map = None if args.ae_vocab_file is None else (
        maping_vocabs_bert(ae_vocab, vocab) if args.classification_model
        == 'BERT' else maping_vocabs(ae_vocab, vocab))

    test_iter = input_data.get_dataset_iter(
        args,
        args.test_file,
        args.test_output,
        data_task,
        is_training=False,
        is_test=True,
        is_bert=(args.classification_model == 'BERT'))
    test_next = test_iter.get_next()

    step = 0
    if args.adv:
        model_test = AdversarialModelCopy(args,
                                          mode="Infer",
                                          include_cls=False)
        model_classifier = AdversarialModelCopy(
            args,
            mode="Infer",
            include_ae=False,
            embedding=model_test.get_embedding())
    elif args.classification:
        if args.classification_model == 'RNN':
            utils.print_out('Initialise classification model: RNN')
            model_test = ClassificationModel(args, mode='Train')
        elif args.classification_model == 'CNN':
            utils.print_out('Initialise classification model: CNN')
            model_test = CNNClassificationModel(args, mode='Train')
        elif args.classification_model == 'BERT':
            bert_config = modeling.BertConfig.from_json_file(
                args.bert_config_file)
            model_test = BertClassificationModel(args,
                                                 bert_config,
                                                 mode='Test')
    else:
        model_test = Seq2SeqModel(args, mode="Infer")

    utils.print_out('Testing model constructed.')

    saver = tf.train.Saver()

    with tf.Session() as sess:

        if args.classification and args.use_defending_as_target:
            vars = [i[0] for i in tf.train.list_variables(args.load_model)]
            def_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                             scope='my_classifier')
            map_def = {
                variable.op.name.replace('my_classifier',
                                         'defending_classifier'): variable
                for variable in def_var_list if variable.op.name.replace(
                    'my_classifier', 'defending_classifier') in vars
            }
            tf.train.init_from_checkpoint(args.load_model, map_def)

        if args.adv or (args.classification
                        and not args.use_defending_as_target):
            vars = [i[0] for i in tf.train.list_variables(args.load_model)]
            var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            map_all = {
                variable.op.name: variable
                for variable in var_list if variable.op.name in vars
            }
            tf.train.init_from_checkpoint(args.load_model, map_all)

        sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
        sess.run(test_iter.initializer)

        # if args.adv or (args.classification and not args.defending):
        #     saver.restore(sess, args.load_model)

        if args.use_model is not None:
            args.use_model.set_sess(sess)

        if args.adv:
            eval_steps.eval_adv(args,
                                sess,
                                test_iter,
                                model_test,
                                model_classifier,
                                test_next,
                                vocab,
                                step,
                                demo_per_step=1,
                                is_train=False)

        elif args.classification:
            auc, acc = eval_steps.eval_classification(args, sess, test_iter,
                                                      test_next, model_test,
                                                      vocab)
            utils.print_out('Test: acc=' + str(acc) + ', auc=' + str(auc))

        else:
            acc, word_acc, rouge, bleu = eval_steps.eval_ae(sess,
                                                            test_iter,
                                                            model_test,
                                                            test_next,
                                                            vocab,
                                                            step,
                                                            demo_per_step=1)
            utils.print_out('Test: acc=' + str(acc) + ', word_acc=' +
                            str(word_acc) + ', rouge=' + str(rouge) +
                            ', bleu=' + str(bleu))
示例#9
0
文件: train.py 项目: IBM/adv-def-text
def train(args):
    vocab, _ = input_data.load_vocab(args.vocab_file)
    ae_vocab, _ = (
        args.ae_vocab_file,
        None) if args.ae_vocab_file is None else input_data.load_vocab(
            args.ae_vocab_file)
    args.stop_words = setStopWord(
        vocab) if args.ae_vocab_file is None else setStopWord(ae_vocab)
    args.vocab_map = None if args.ae_vocab_file is None else (
        maping_vocabs_bert(ae_vocab, vocab) if args.classification_model
        == 'BERT' else maping_vocabs(ae_vocab, vocab))

    train_iter, train_next, dev_iter, dev_next = load_data_iters(args)

    model, model_dev, model_classifier = create_models(args)

    utils.print_out('Training model constructed.')
    saver = tf.train.Saver(max_to_keep=args.max_to_keep)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        # initialise models with pretrained weights
        init_model(args)

        if args.use_model is not None:
            args.use_model.set_sess(sess)

        sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
        sess.run(train_iter.initializer)

        if args.load_model_cls is not None and args.classification_model != 'BERT' and (
                not args.use_defending_as_target):
            saver_cls = tf.train.Saver(var_list=tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES, scope='my_classifier'))
            saver_cls.restore(sess, args.load_model_cls)

        if args.adv and args.load_model is not None:
            saver_all = tf.train.Saver(
                var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
            saver_all.restore(sess, args.load_model)

        utils.print_out('start training...')
        tf.get_default_graph().finalize()

        # init best infos
        step = 0
        last_improvement_step = -1
        best_loss = 1e6
        best_auc_max = -1
        best_auc_min = 1e6
        best_T = [100.0] * 9

        upper_bounds = [95.0, 88.0, 78.0, 64.0, 54.0, 44.0, 34.0, 24.0, 14.0]
        lower_bounds = [90.0, 84.0, 74.0, 58.0, 48.0, 38.0, 28.0, 18.0, 8.0]

        while True:

            try:
                batch = sess.run(train_next)
                if args.copy:
                    copy_mask = eval_steps.get_copy_mask(
                        sess,
                        model,
                        batch,
                        np.max(batch[5]),
                        n_top_k=args.top_k_attack)
                    batch = batch + (copy_mask, )

                results = sess.run(model.make_train_outputs(
                    full_loss_step=(step % args.at_steps == 0),
                    defence=args.defending),
                                   feed_dict=model.make_train_inputs(
                                       batch))  # Alternative training

            except tf.errors.OutOfRangeError:
                break

            if step % args.print_every_steps == 0:
                step_name = 'train'
                if args.defending and step % args.at_steps > 0:
                    step_name = 'defending'
                utils.print_out(
                    'Step: ' + str(step) + ', ' + step_name + ' loss=' +
                    str(results[1]) +
                    (', ae_loss=' + str(results[4]) + ', cls_loss=' +
                     str(results[5]) if len(results) > 5 else '') +
                    (', senti_loss=' + str(results[7]) + ', aux_loss=' +
                     str(results[6]) + ', def_loss=' +
                     str(results[8]) if len(results) > 6 else '') +
                    (' *' if (results[1] < best_loss) else ''))

                if (results[1] < best_loss):
                    best_loss = results[1]

                if step % (10 * args.print_every_steps) == 0:

                    if args.adv:
                        cls_acc, cls_acc_pos, cls_acc_neg, changed_bleu = eval_steps.eval_adv(
                            args, sess, dev_iter, model_dev, model_classifier,
                            dev_next, vocab, step, 10 * args.print_every_steps)

                        eval_score = cls_acc
                        eval_bleu = changed_bleu
                        if args.target_label is not None:
                            eval_score = cls_acc_neg if args.target_label == 1 else cls_acc_pos
                            # eval_bleu = neg_bleu if args.target_label == 1 else pos_bleu

                        # use accuracy as best selection measure in each threshold
                        thre_score = cls_acc

                        if args.save_checkpoints:
                            for i in range(9):
                                if eval_score >= lower_bounds[
                                        i] and eval_score < upper_bounds[i]:
                                    if thre_score < best_T[i]:
                                        best_T[i] = thre_score
                                        saver.save(
                                            sess, args.output_dir + '/' +
                                            'nmt-T' + str(i) + '.ckpt')
                                        utils.print_out('Step: ' + str(step) +
                                                        ' model saved for T' +
                                                        str(i) + ' *')

                        if (eval_score <=
                                0.0) or eval_score >= args.lowest_bound_score:
                            if eval_score < best_auc_min:
                                best_auc_min = eval_score
                                last_improvement_step = step
                        else:
                            break

                    elif args.classification:
                        auc, acc = eval_steps.eval_classification(
                            args, sess, dev_iter, dev_next, model, vocab)
                        utils.print_out('Step: ' + str(step) + ', test acc=' +
                                        str(acc) + ', auc=' + str(auc))
                        eval_score = acc
                        if args.output_classes > 2:
                            eval_score = acc
                        if eval_score > best_auc_max:
                            best_auc_max = eval_score
                            last_improvement_step = step
                            saver.save(sess,
                                       args.output_dir + '/' + 'nmt.ckpt')
                    else:
                        acc, word_acc, rouge, bleu = eval_steps.eval_ae(
                            sess, dev_iter, model_dev, dev_next, vocab, step,
                            50 * args.print_every_steps)
                        utils.print_out('Step: ' + str(step) + ', test acc=' +
                                        str(acc) + ', word_acc=' +
                                        str(word_acc) + ', rouge=' +
                                        str(rouge) + ', bleu=' + str(bleu))
                        if bleu > best_auc_max:
                            best_auc_max = bleu
                            last_improvement_step = step
                            saver.save(sess,
                                       args.output_dir + '/' + 'nmt.ckpt')

            if args.total_steps is None:
                if step - last_improvement_step > args.stop_steps:
                    break
            else:
                if step >= args.total_steps:
                    break

            step += 1

        utils.print_out('finish training')

        if args.do_test:
            args.load_model = args.output_dir + '/' + 'nmt.ckpt-' + str(step)