def create_models(args): model, model_dev, model_classifier = None, None, None if args.adv: # if args.copy: if args.classification_model == 'BERT': args.bert_config = modeling.BertConfig.from_json_file( args.bert_config_file) model = AdversarialModelCopy(args, mode='Train') model_dev = AdversarialModelCopy(args, mode="Infer", include_cls=False, embedding=model.get_embedding()) model_classifier = AdversarialModelCopy( args, mode="Infer", include_ae=False, embedding=model.get_embedding()) elif args.classification: if args.classification_model == 'RNN': utils.print_out('Initialise classification model: RNN') model = ClassificationModel(args, 'Train') elif args.classification_model == 'CNN': utils.print_out('Initialise classification model: CNN') model = CNNClassificationModel(args, mode='Train') elif args.classification_model == 'BERT': bert_config = modeling.BertConfig.from_json_file( args.bert_config_file) model = BertClassificationModel(args, bert_config, mode='Train') modeling.init_bert(args.bert_init_chk, word_embedding_trainable=True) else: model = Seq2SeqModel(args, mode="Train") model_dev = Seq2SeqModel(args, mode="Infer") return model, model_dev, model_classifier
def eval_ae(sess, dev_iter, model_dev, dev_next, vocab, step, demo_per_step): sess.run(dev_iter.initializer) decoder_reference_list = [] decoder_prediction_list = [] i = 0 while True: try: dev_batch = sess.run(dev_next) decoder_predictions_inference = sess.run( model_dev.make_infer_outputs(), feed_dict=model_dev.make_train_inputs(dev_batch)) decoder_reference_list.append(dev_batch[2]) decoder_prediction_list.append(decoder_predictions_inference) if i >= 30: break i += 1 except tf.errors.OutOfRangeError: break reference_list, translation_list = [], [] for decoder_references in decoder_reference_list: references = getSentencesFromIDs(decoder_references, vocab) reference_list.extend(references) for decoder_predictions in decoder_prediction_list: translations = getSentencesFromIDs(decoder_predictions, vocab) translation_list.extend(translations) if demo_per_step == 1: for i in range(len(reference_list)): utils.print_out('Example ' + str(i) + ': src: ' + ' '.join(reference_list[i])) utils.print_out('Example ' + str(i) + ': nmt: ' + ' '.join(translation_list[i])) elif step % demo_per_step == 0: rand_ind = random.randint(1, len(reference_list) - 1) utils.print_out('Step: ' + str(step) + ', src: ' + ' '.join(reference_list[rand_ind])) utils.print_out('Step: ' + str(step) + ', nmt: ' + ' '.join(translation_list[rand_ind])) # utils.print_out('Step' + str(step) + ', copy: ' + ' '.join([str(a) for a in copy_mask_list[rand_ind]])) # utils.print_out('Step' + str(step) + ', logits: ' + ' '.join([str(a) for a in decoder_logits_list[rand_ind]])) acc = evaluate._accuracy(reference_list, translation_list) word_acc = evaluate._word_accuracy(reference_list, translation_list) rouge = evaluate._rouge(reference_list, translation_list) bleu = evaluate._bleu(reference_list, translation_list) return acc, word_acc, rouge, bleu
def printSentence(tokenized_sentences, vocab): train_sentence = '' for word in tokenized_sentences: train_sentence += vocab[word] + ' ' utils.print_out(train_sentence)
def eval_classification(args, sess, dev_iter, dev_next, model, vocab): sess.run(dev_iter.initializer) dev_srcs = [] dev_labels = [] dev_logits = [] alphas = [] count = 0 while True: try: dev_batch = sess.run(dev_next) results = sess.run(model.make_test_outputs(), feed_dict=model.make_train_inputs(dev_batch)) dev_srcs.append(results[-2]) dev_labels.append(results[-1]) dev_logits.append(results[1]) if args.cls_attention: alphas.append(results[3]) count += 1 except tf.errors.OutOfRangeError: break reference_list = [] alphas_list = [] for ind, decoder_references in enumerate(dev_srcs): references = getSentencesFromIDs(decoder_references, vocab) reference_list.extend(references) if len(alphas) > 0: alphas_list.extend(alphas[ind]) dev_logits = np.concatenate(dev_logits, axis=0) dev_labels = np.concatenate(dev_labels, axis=0) rand_ind = random.randint(1, len(reference_list) - 1) utils.print_out('src : ' + str(reference_list[rand_ind]) + ', label: ' + str(dev_labels[rand_ind])) if len(alphas_list) > 0: utils.print_out( 'alpha: ' + ', '.join(["{:.3f}".format(a[0]) for a in alphas_list[rand_ind]])) # Print: False positives, false negatives fp_spl = open(args.output_dir + '/false_positive.txt', 'w') fn_spl = open(args.output_dir + '/false_negative.txt', 'w') for i in range(len(reference_list)): # utils.print_out('Example ' + str(i) + ': src:\t' + ' '.join(reference_list[i]) + '\t' + str(dev_labels[i])) # utils.print_out(' ') spl_predict = evaluate.max_index(dev_logits[i]) if dev_labels[i][1] == 1: if spl_predict == 0: fn_spl.write('Example ' + str(i) + ': src:\t' + ' '.join(reference_list[i]) + '\t: ' + str(dev_labels[i]) + '\n') fn_spl.write('Example ' + str(i) + ': spl:\t' + ' '.join(reference_list[i]) + '\t: ' + str(dev_logits[i]) + '\n') fn_spl.write('\n') elif dev_labels[i][1] == 0: if spl_predict == 1: fp_spl.write('Example ' + str(i) + ': src:\t' + ' '.join(reference_list[i]) + '\t: ' + str(dev_labels[i]) + '\n') fp_spl.write('Example ' + str(i) + ': spl:\t' + ' '.join(reference_list[i]) + '\t: ' + str(dev_logits[i]) + '\n') fp_spl.write('\n') fp_spl.close() fn_spl.close() acc = evaluate._clss_accuracy(dev_labels, dev_logits) auc = evaluate._clss_auc(dev_labels, dev_logits) return auc, acc
def eval_adv(args, sess, dev_iter, model_dev, model_classifier, dev_next, vocab, step, demo_per_step, is_train=True): sess.run(dev_iter.initializer) # sentiment_distances = [] cls_labels = [] cls_logits_def = [] cls_origs_def = [] cls_logits = [] cls_orig_logits = [] decoder_reference_list = [] decoder_prediction_list = [] sent_embs = [] adv_sent_embs = [] orig_alphas, trans_alphas = [], [] trans_alphas_def = [] copy_masks = [] i = 0 start = default_timer() while True: try: dev_batch = sess.run(dev_next) if args.copy: copy_mask = get_copy_mask(sess, model_dev, dev_batch, np.max(dev_batch[5]), args.top_k_attack) dev_batch = dev_batch + (copy_mask, ) results = sess.run( model_dev.make_infer_outputs(), feed_dict=model_dev.make_train_inputs(dev_batch)) cls_labels.extend(dev_batch[3]) if args.beam_width > 0: decoder_outputs = results[0][0] else: decoder_outputs = results[0] if args.ae_vocab_file is not None: decoder_outputs = lookup(decoder_outputs, args.vocab_map) decoder_reference_list.extend(dev_batch[2]) if args.copy: copy_masks.extend(results[1]) decoder_prediction_list.extend(decoder_outputs) decoder_preds_lengths = get_lengths( decoder_outputs, eos_id=102 if args.classification_model == 'BERT' else input_data.EOS_ID) decoder_outputs = padding(decoder_outputs, eos_id=102 if args.classification_model == 'BERT' else input_data.EOS_ID) if args.classification_model == 'BERT': decoder_outputs = np.concatenate([ np.array([[101]] * len(decoder_outputs)), decoder_outputs ], axis=1) # classification based on decoder_predictions_inference cls_logit = sess.run( model_classifier.make_classifier_outputs(), feed_dict=model_classifier.make_classifier_input( dev_batch, decoder_outputs, decoder_preds_lengths)) cls_logits.extend(cls_logit[0]) adv_sent_embs.extend(cls_logit[1]) if args.cls_attention: trans_alphas.extend(cls_logit[2]) # classification based on orginal input cls_orig_logit = sess.run( model_classifier.make_classifier_outputs(), feed_dict=model_classifier.make_classifier_input( dev_batch, dev_batch[2], dev_batch[5])) cls_orig_logits.extend(cls_orig_logit[0]) sent_embs.extend(cls_orig_logit[1]) if args.cls_attention: orig_alphas.extend(cls_orig_logit[2]) # defending classification based on decoder_predictions_inference if args.defending: cls_logit_def = sess.run( model_classifier.make_def_classifier_outputs(), feed_dict=model_classifier.make_classifier_input( dev_batch, decoder_outputs, decoder_preds_lengths)) cls_logits_def.extend(cls_logit_def[0]) if args.cls_attention: trans_alphas_def.extend(cls_logit_def[2]) cls_orig_def = sess.run( model_classifier.make_def_classifier_outputs(), feed_dict=model_classifier.make_classifier_input( dev_batch, dev_batch[2], dev_batch[5])) cls_origs_def.extend(cls_orig_def[0]) if is_train and i >= 30: # if i >= 0: break i += 1 except tf.errors.OutOfRangeError: break end = default_timer() if not is_train: utils.print_out('Adversarial attack elapsed:' + '{0:.4f}'.format(end - start) + 's') cls_acc, cls_acc_pos, cls_acc_neg, changed_bleu = evaluate_attack( args, step, decoder_reference_list, decoder_prediction_list, cls_logits, cls_orig_logits, cls_labels, vocab, sent_embs, adv_sent_embs, is_test=(not is_train), orig_alphas=orig_alphas, trans_alphas=trans_alphas, cls_logits_def=cls_logits_def, cls_origs_def=cls_origs_def, copy_masks=copy_masks) return cls_acc, cls_acc_pos, cls_acc_neg, changed_bleu
def evaluate_attack(args, step, decoder_reference_list, decoder_prediction_list, cls_logits, cls_orig_logits, cls_labels, vocab, sent_embs, adv_sent_embs, is_test=False, X_adv_flip_num=None, orig_alphas=None, trans_alphas=None, cls_logits_def=None, cls_origs_def=None, copy_masks=None): cls_orig_acc = general_evaluate._clss_accuracy(cls_labels, cls_orig_logits) cls_orig_auc = general_evaluate._clss_auc(cls_labels, cls_orig_logits) cls_acc = general_evaluate._clss_accuracy(cls_labels, cls_logits) cls_auc = general_evaluate._clss_auc(cls_labels, cls_logits) cls_acc_pos = general_evaluate._clss_accuracy_micro(cls_labels, cls_logits, orig_label=1) cls_acc_neg = general_evaluate._clss_accuracy_micro(cls_labels, cls_logits, orig_label=0) if cls_logits_def is not None and len(cls_logits_def) > 0: cls_def_acc = general_evaluate._clss_accuracy(cls_labels, cls_logits_def) cls_def_auc = general_evaluate._clss_auc(cls_labels, cls_logits_def) org_def_acc = general_evaluate._clss_accuracy(cls_labels, cls_origs_def) org_def_auc = general_evaluate._clss_auc(cls_labels, cls_origs_def) reference_list = getSentencesFromIDs(decoder_reference_list, vocab) translation_list = getSentencesFromIDs(decoder_prediction_list, vocab) ref_pos, ref_neg, trans_pos, trans_neg, ref_changed, trans_changed = [], [], [], [], [], [] label_changed, logits_changed, flip_num_changed, ids_changed = [], [], [], [] ref_emb_pos, trans_emb_pos, ref_emb_neg, trans_emb_neg, ref_emb_cha, trans_emb_cha = [], [], [], [], [], [] for ind, references in enumerate(reference_list): ref_pos.append(references) if cls_labels[ind][1] > 0 else ref_neg.append(references) trans_pos.append(translation_list[ind]) if cls_labels[ind][1] > 0 else trans_neg.append(translation_list[ind]) ref_emb_pos.append(sent_embs[ind]) if cls_labels[ind][1] > 0 else ref_emb_neg.append(sent_embs[ind]) trans_emb_pos.append(adv_sent_embs[ind]) if cls_labels[ind][1] > 0 else trans_emb_neg.append(adv_sent_embs[ind]) if np.argmax(cls_logits[ind]) != np.argmax(cls_orig_logits[ind]): ids_changed.append(ind) ref_changed.append(references) trans_changed.append(translation_list[ind]) label_changed.append(cls_labels[ind]) logits_changed.append(cls_logits[ind]) ref_emb_cha.append(sent_embs[ind]) trans_emb_cha.append(adv_sent_embs[ind]) if X_adv_flip_num is not None: flip_num_changed.append(X_adv_flip_num[ind]) ae_acc = general_evaluate._accuracy(reference_list, translation_list) word_acc = general_evaluate._word_accuracy(reference_list, translation_list) rouge = general_evaluate._rouge(reference_list, translation_list) bleu = general_evaluate._bleu(reference_list, translation_list) use = general_evaluate._use_scores(reference_list, translation_list, args.use_model) accept = general_evaluate._accept_score(reference_list, translation_list, args) # positive examples pos_rouge = general_evaluate._rouge(ref_pos, trans_pos) pos_bleu = general_evaluate._bleu(ref_pos, trans_pos) pos_accept = general_evaluate._accept_score(ref_pos, trans_pos, args) pos_semsim = avgcos(ref_emb_pos, trans_emb_pos) pos_use = general_evaluate._use_scores(ref_pos, trans_pos, args.use_model) # negative examples neg_rouge = general_evaluate._rouge(ref_neg, trans_neg) neg_bleu = general_evaluate._bleu(ref_neg, trans_neg) neg_accept = general_evaluate._accept_score(ref_neg, trans_neg, args) neg_semsim = avgcos(ref_emb_neg, trans_emb_neg) neg_use = general_evaluate._use_scores(ref_neg, trans_neg, args.use_model) # changed examples if len(ref_changed) == 0: changed_rouge = -1.0 changed_bleu = -1.0 changed_accept = -1.0 changed_semsim = -1.0 changed_use = -1.0 else: changed_rouge = general_evaluate._rouge(ref_changed, trans_changed) changed_bleu = general_evaluate._bleu(ref_changed, trans_changed) changed_accept = general_evaluate._accept_score(ref_changed, trans_changed, args) changed_semsim = avgcos(ref_emb_cha, trans_emb_cha) changed_use = general_evaluate._use_scores(ref_changed, trans_changed, args.use_model) # changed_use = 0.0 # print out src, spl, and nmt for i in range(len(ref_changed)): reference_changed = ref_changed[i] translation_changed = trans_changed[i] if orig_alphas is not None and len(orig_alphas) > 0: orig_alpha = orig_alphas[ids_changed[i]] reference_changed = [s + '('+'{:.3f}'.format(orig_alpha[ind][0])+')' for ind, s in enumerate(ref_changed[i])] trans_alpha = trans_alphas[ids_changed[i]] translation_changed = [s + '('+'{:.3f}'.format(trans_alpha[ind][0])+')' for ind, s in enumerate(trans_changed[i])] utils.print_out('Example ' + str(ids_changed[i]) + ': src:\t' + ' '.join(reference_changed) + '\t' + str(label_changed[i])) utils.print_out('Example ' + str(ids_changed[i]) + ': nmt:\t' + ' '.join(translation_changed) + '\t' + str(logits_changed[i])) if copy_masks is not None and len(copy_masks)>0: copy_mask = copy_masks[ids_changed[i]] copy_mask_str = [str(mask) for mask in copy_mask] utils.print_out('Example ' + str(ids_changed[i]) + ': msk:\t' + ' '.join(copy_mask_str)) if X_adv_flip_num is not None: utils.print_out('Example ' + str(ids_changed[i]) + ' flipped tokens: ' + str(flip_num_changed[i])) utils.print_out(' ') if X_adv_flip_num is not None: lenght = 0 for num in X_adv_flip_num: if num > 0: lenght += 1 utils.print_out('Average flipped tokens: ' + str(sum(X_adv_flip_num) / lenght)) utils.print_out('Step: ' + str(step) + ', cls_acc_pos=' + str(cls_acc_pos) + ', cls_acc_neg=' + str(cls_acc_neg)) utils.print_out('Step: ' + str(step) + ', rouge_pos=' + str(pos_rouge) + ', rouge_neg=' + str(neg_rouge) + ', rouge_changed=' + str(changed_rouge)) utils.print_out('Step: ' + str(step) + ', bleu_pos=' + str(pos_bleu) + ', bleu_neg=' + str(neg_bleu) + ', bleu_changed=' + str(changed_bleu)) utils.print_out('Step: ' + str(step) + ', accept_pos=' + str(pos_accept) + ', accept_neg=' + str(neg_accept) + ', accept_changed=' + str(changed_accept)) utils.print_out('Step: ' + str(step) + ', semsim_pos=' + str(pos_semsim) + ', semsim_neg=' + str(neg_semsim) + ', semsim_changed=' + str(changed_semsim)) utils.print_out('Step: ' + str(step) + ', use_pos=' + str(pos_use) + ', use_neg=' + str(neg_use) + ', use_changed=' + str(changed_use)) utils.print_out('Step: ' + str(step) + ', ae_acc=' + str(ae_acc) + ', word_acc=' + str(word_acc) + ', rouge=' + str(rouge) + ', bleu=' + str(bleu) + ', accept=' + str(accept) + ', use=' + str(use) + ', semsim=' + str(avgcos(sent_embs, adv_sent_embs))) utils.print_out('Step: ' + str(step) + ', cls_orig_acc=' + str(cls_orig_acc) + ', cls_orig_auc=' + str(cls_orig_auc)) utils.print_out('Step: ' + str(step) + ', cls_acc=' + str(cls_acc) + ', cls_auc=' + str(cls_auc)) if cls_logits_def is not None and len(cls_logits_def) > 0: utils.print_out('Step: ' + str(step) + ', org_def_acc=' + str(org_def_acc) + ', org_def_auc=' + str(org_def_auc)) utils.print_out('Step: ' + str(step) + ', cls_def_acc=' + str(cls_def_acc) + ', cls_def_auc=' + str(cls_def_auc)) if is_test: with open(args.output_dir+'/src_changed.txt', 'w') as output_file: output_file.write('\n'.join([' '.join(a) for a in ref_changed])) with open(args.output_dir+'/adv_changed.txt', 'w') as output_file: output_file.write('\n'.join([' '.join(a) for a in trans_changed])) with open(args.output_dir+'/adv.txt', 'w') as output_file: output_file.write('\n'.join([' '.join(a) for a in translation_list])) with open(args.output_dir+'/adv_score.txt', 'w') as output_file: for score in cls_logits: output_file.write(' '.join([str(a) for a in score])+'\n') return cls_acc, cls_acc_pos, cls_acc_neg, changed_bleu
def test_adv_pos_neg(args): data_task = 'adv' if args.ae_vocab_file is not None: data_task = 'adv_counter_fitting' vocab, _ = input_data.load_vocab(args.vocab_file) ae_vocab, _ = ( args.ae_vocab_file, None) if args.ae_vocab_file is None else input_data.load_vocab( args.ae_vocab_file) args.stop_words = setStopWord( vocab) if args.ae_vocab_file is None else setStopWord(ae_vocab) args.vocab_map = None if args.ae_vocab_file is None else ( maping_vocabs_bert(ae_vocab, vocab) if args.classification_model == 'BERT' else maping_vocabs(ae_vocab, vocab)) test_iter = input_data.get_dataset_iter( args, args.test_file, args.test_output, data_task, is_training=False, is_test=True, is_bert=(args.classification_model == 'BERT')) test_next = test_iter.get_next() step = 0 model_test = AdversarialModelCopy(args, mode="Infer", include_cls=False) model_classifier = AdversarialModelCopy( args, mode="Infer", include_ae=False, embedding=model_test.get_embedding()) utils.print_out('Testing model constructed.') saver = tf.train.Saver() sess_pos = tf.Session() sess_pos.run([tf.global_variables_initializer(), tf.tables_initializer()]) sess_pos.run(test_iter.initializer) saver.restore(sess_pos, args.load_model_pos) _, _, dev_logits, dev_labels = eval_steps.run_classification( args, model_classifier, vocab, sess_pos, test_iter, test_next) target_predicts = np.argmax(dev_logits, axis=-1) pos_mask = (dev_labels[:, 1] == 1) neg_mask = (dev_labels[:, 1] == 0) correct_mask = (target_predicts == dev_labels[:, 1]) pos_predict_mask = pos_mask & correct_mask neg_predict_mask = neg_mask & correct_mask keep_orig_mask = 1 - correct_mask sess_neg = tf.Session() sess_neg.run([tf.global_variables_initializer(), tf.tables_initializer()]) saver.restore(sess_neg, args.load_model_neg) if args.use_model is not None: args.use_model.set_sess(sess_pos) start = default_timer() decoder_reference_list, decoder_prediction_list_pos, cls_labels, copy_masks_pos = \ eval_steps.run_adv(args, model_test, sess_pos, test_iter, test_next) decoder_reference_list, decoder_prediction_list_neg, cls_labels, copy_masks_neg = \ eval_steps.run_adv(args, model_test, sess_neg, test_iter, test_next) decoder_prediction_list = [] for i in range((len(decoder_reference_list) // args.batch_size) + (1 if ( len(decoder_reference_list) % args.batch_size > 0) else 0)): start, end = i * args.batch_size, (i + 1) * args.batch_size decoder_prediction_batch = np.array(decoder_reference_list[start:end]) * np.expand_dims(keep_orig_mask[start:end], axis=1) + \ np.array(decoder_prediction_list_pos[start:end]) * np.expand_dims(pos_predict_mask[start:end], axis=1) \ + np.array(decoder_prediction_list_neg[start:end]) * np.expand_dims(neg_predict_mask[start:end], axis=1) decoder_prediction_list.extend(decoder_prediction_batch) end = default_timer() utils.print_out('Adversarial attack elapsed:' + '{0:.4f}'.format(end - start) + 's') cls_logits_def, cls_origs_def, cls_logits, cls_orig_logits, sent_embs, adv_sent_embs, \ orig_alphas, trans_alphas, trans_alphas_def = \ eval_steps.run_classifications(args, sess_pos, test_iter, decoder_prediction_list, model_classifier, test_next) evaluate_attack(args, step, decoder_reference_list, decoder_prediction_list, cls_logits, cls_orig_logits, cls_labels, vocab, sent_embs, adv_sent_embs, is_test=True, orig_alphas=orig_alphas, trans_alphas=trans_alphas, cls_logits_def=cls_logits_def, cls_origs_def=cls_origs_def)
def test(args): data_task = 'ae' if args.classification: data_task = 'clss' if args.adv: data_task = 'adv' if args.ae_vocab_file is not None: data_task = 'adv_counter_fitting' vocab, _ = input_data.load_vocab(args.vocab_file) ae_vocab, _ = ( args.ae_vocab_file, None) if args.ae_vocab_file is None else input_data.load_vocab( args.ae_vocab_file) args.stop_words = setStopWord( vocab) if args.ae_vocab_file is None else setStopWord(ae_vocab) args.vocab_map = None if args.ae_vocab_file is None else ( maping_vocabs_bert(ae_vocab, vocab) if args.classification_model == 'BERT' else maping_vocabs(ae_vocab, vocab)) test_iter = input_data.get_dataset_iter( args, args.test_file, args.test_output, data_task, is_training=False, is_test=True, is_bert=(args.classification_model == 'BERT')) test_next = test_iter.get_next() step = 0 if args.adv: model_test = AdversarialModelCopy(args, mode="Infer", include_cls=False) model_classifier = AdversarialModelCopy( args, mode="Infer", include_ae=False, embedding=model_test.get_embedding()) elif args.classification: if args.classification_model == 'RNN': utils.print_out('Initialise classification model: RNN') model_test = ClassificationModel(args, mode='Train') elif args.classification_model == 'CNN': utils.print_out('Initialise classification model: CNN') model_test = CNNClassificationModel(args, mode='Train') elif args.classification_model == 'BERT': bert_config = modeling.BertConfig.from_json_file( args.bert_config_file) model_test = BertClassificationModel(args, bert_config, mode='Test') else: model_test = Seq2SeqModel(args, mode="Infer") utils.print_out('Testing model constructed.') saver = tf.train.Saver() with tf.Session() as sess: if args.classification and args.use_defending_as_target: vars = [i[0] for i in tf.train.list_variables(args.load_model)] def_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='my_classifier') map_def = { variable.op.name.replace('my_classifier', 'defending_classifier'): variable for variable in def_var_list if variable.op.name.replace( 'my_classifier', 'defending_classifier') in vars } tf.train.init_from_checkpoint(args.load_model, map_def) if args.adv or (args.classification and not args.use_defending_as_target): vars = [i[0] for i in tf.train.list_variables(args.load_model)] var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) map_all = { variable.op.name: variable for variable in var_list if variable.op.name in vars } tf.train.init_from_checkpoint(args.load_model, map_all) sess.run([tf.global_variables_initializer(), tf.tables_initializer()]) sess.run(test_iter.initializer) # if args.adv or (args.classification and not args.defending): # saver.restore(sess, args.load_model) if args.use_model is not None: args.use_model.set_sess(sess) if args.adv: eval_steps.eval_adv(args, sess, test_iter, model_test, model_classifier, test_next, vocab, step, demo_per_step=1, is_train=False) elif args.classification: auc, acc = eval_steps.eval_classification(args, sess, test_iter, test_next, model_test, vocab) utils.print_out('Test: acc=' + str(acc) + ', auc=' + str(auc)) else: acc, word_acc, rouge, bleu = eval_steps.eval_ae(sess, test_iter, model_test, test_next, vocab, step, demo_per_step=1) utils.print_out('Test: acc=' + str(acc) + ', word_acc=' + str(word_acc) + ', rouge=' + str(rouge) + ', bleu=' + str(bleu))
def train(args): vocab, _ = input_data.load_vocab(args.vocab_file) ae_vocab, _ = ( args.ae_vocab_file, None) if args.ae_vocab_file is None else input_data.load_vocab( args.ae_vocab_file) args.stop_words = setStopWord( vocab) if args.ae_vocab_file is None else setStopWord(ae_vocab) args.vocab_map = None if args.ae_vocab_file is None else ( maping_vocabs_bert(ae_vocab, vocab) if args.classification_model == 'BERT' else maping_vocabs(ae_vocab, vocab)) train_iter, train_next, dev_iter, dev_next = load_data_iters(args) model, model_dev, model_classifier = create_models(args) utils.print_out('Training model constructed.') saver = tf.train.Saver(max_to_keep=args.max_to_keep) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: # initialise models with pretrained weights init_model(args) if args.use_model is not None: args.use_model.set_sess(sess) sess.run([tf.global_variables_initializer(), tf.tables_initializer()]) sess.run(train_iter.initializer) if args.load_model_cls is not None and args.classification_model != 'BERT' and ( not args.use_defending_as_target): saver_cls = tf.train.Saver(var_list=tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope='my_classifier')) saver_cls.restore(sess, args.load_model_cls) if args.adv and args.load_model is not None: saver_all = tf.train.Saver( var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) saver_all.restore(sess, args.load_model) utils.print_out('start training...') tf.get_default_graph().finalize() # init best infos step = 0 last_improvement_step = -1 best_loss = 1e6 best_auc_max = -1 best_auc_min = 1e6 best_T = [100.0] * 9 upper_bounds = [95.0, 88.0, 78.0, 64.0, 54.0, 44.0, 34.0, 24.0, 14.0] lower_bounds = [90.0, 84.0, 74.0, 58.0, 48.0, 38.0, 28.0, 18.0, 8.0] while True: try: batch = sess.run(train_next) if args.copy: copy_mask = eval_steps.get_copy_mask( sess, model, batch, np.max(batch[5]), n_top_k=args.top_k_attack) batch = batch + (copy_mask, ) results = sess.run(model.make_train_outputs( full_loss_step=(step % args.at_steps == 0), defence=args.defending), feed_dict=model.make_train_inputs( batch)) # Alternative training except tf.errors.OutOfRangeError: break if step % args.print_every_steps == 0: step_name = 'train' if args.defending and step % args.at_steps > 0: step_name = 'defending' utils.print_out( 'Step: ' + str(step) + ', ' + step_name + ' loss=' + str(results[1]) + (', ae_loss=' + str(results[4]) + ', cls_loss=' + str(results[5]) if len(results) > 5 else '') + (', senti_loss=' + str(results[7]) + ', aux_loss=' + str(results[6]) + ', def_loss=' + str(results[8]) if len(results) > 6 else '') + (' *' if (results[1] < best_loss) else '')) if (results[1] < best_loss): best_loss = results[1] if step % (10 * args.print_every_steps) == 0: if args.adv: cls_acc, cls_acc_pos, cls_acc_neg, changed_bleu = eval_steps.eval_adv( args, sess, dev_iter, model_dev, model_classifier, dev_next, vocab, step, 10 * args.print_every_steps) eval_score = cls_acc eval_bleu = changed_bleu if args.target_label is not None: eval_score = cls_acc_neg if args.target_label == 1 else cls_acc_pos # eval_bleu = neg_bleu if args.target_label == 1 else pos_bleu # use accuracy as best selection measure in each threshold thre_score = cls_acc if args.save_checkpoints: for i in range(9): if eval_score >= lower_bounds[ i] and eval_score < upper_bounds[i]: if thre_score < best_T[i]: best_T[i] = thre_score saver.save( sess, args.output_dir + '/' + 'nmt-T' + str(i) + '.ckpt') utils.print_out('Step: ' + str(step) + ' model saved for T' + str(i) + ' *') if (eval_score <= 0.0) or eval_score >= args.lowest_bound_score: if eval_score < best_auc_min: best_auc_min = eval_score last_improvement_step = step else: break elif args.classification: auc, acc = eval_steps.eval_classification( args, sess, dev_iter, dev_next, model, vocab) utils.print_out('Step: ' + str(step) + ', test acc=' + str(acc) + ', auc=' + str(auc)) eval_score = acc if args.output_classes > 2: eval_score = acc if eval_score > best_auc_max: best_auc_max = eval_score last_improvement_step = step saver.save(sess, args.output_dir + '/' + 'nmt.ckpt') else: acc, word_acc, rouge, bleu = eval_steps.eval_ae( sess, dev_iter, model_dev, dev_next, vocab, step, 50 * args.print_every_steps) utils.print_out('Step: ' + str(step) + ', test acc=' + str(acc) + ', word_acc=' + str(word_acc) + ', rouge=' + str(rouge) + ', bleu=' + str(bleu)) if bleu > best_auc_max: best_auc_max = bleu last_improvement_step = step saver.save(sess, args.output_dir + '/' + 'nmt.ckpt') if args.total_steps is None: if step - last_improvement_step > args.stop_steps: break else: if step >= args.total_steps: break step += 1 utils.print_out('finish training') if args.do_test: args.load_model = args.output_dir + '/' + 'nmt.ckpt-' + str(step)