Exemplo n.º 1
0
def train_model(model_type, attack_level, num_modified_words, percentage_attacked_samples):
    print("train")
    print("%s white-box adversarial attack modifies %d words of %d%% of the instances: " % (
        attack_level, num_modified_words, percentage_attacked_samples))

    global model_conf
    if model_type == "lstm":
        import movieqa.conf_lstm as model_conf
    else:
        import movieqa.conf_cnn as model_conf

    global_step = tf.contrib.framework.get_or_create_global_step()

    init = False
    if not tf.io.gfile.exists(data_conf.TRAIN_DIR):
        init = True
        print("RESTORING WEIGHTS")
        tf.io.gfile.makedirs(data_conf.TRAIN_DIR)
    util.save_config_values(data_conf, data_conf.TRAIN_DIR + "/data")
    util.save_config_values(model_conf, data_conf.TRAIN_DIR + "/model")

    filenames = glob.glob(data_conf.TRAIN_RECORD_PATH + '/*')
    print("Reading training dataset from %s" % filenames)
    dataset = tf.contrib.data.TFRecordDataset(filenames)
    dataset = dataset.map(get_single_sample)
    dataset = dataset.shuffle(buffer_size=9000)
    dataset = dataset.repeat(data_conf.NUM_EPOCHS)
    batch_size = data_conf.BATCH_SIZE

    dataset = dataset.padded_batch(data_conf.BATCH_SIZE, padded_shapes=(
        [None], [ANSWER_COUNT, None], [None], (), [None, None], ()))

    iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)

    next_q, next_a, next_l, next_plot_ids, next_plots, next_q_types = iterator.get_next()

    _, w_atts, s_atts, _ = predict_batch(model_type, [next_q, next_a, next_plots], training=True)

    if attack_level == "sentence":
        m_p = tf.compat.v1.py_func(remove_plot_sentence, [next_plots, s_atts, next_l], [tf.int64])[0]
    elif attack_level == "word":
        m_p = tf.compat.v1.py_func(modify_plot_sentence, [next_plots, w_atts, s_atts, next_l], [tf.int64])[0]

    logits, _, _, _ = predict_batch(model_type, [next_q, next_a, m_p], training=True)

    probabs = model.compute_probabilities(logits=logits)
    loss_batch = model.compute_batch_mean_loss(logits, next_l, model_conf.LOSS_FUNC)
    accuracy = model.compute_accuracies(logits=logits, labels=next_l, dim=1)
    accuracy_batch = tf.reduce_mean(input_tensor=accuracy)
    tf.compat.v1.summary.scalar("train_accuracy", accuracy_batch)
    tf.compat.v1.summary.scalar("train_loss", loss_batch)

    training_op = update_op(loss_batch, global_step, model_conf.OPTIMIZER, model_conf.INITIAL_LEARNING_RATE)
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    with tf.compat.v1.train.MonitoredTrainingSession(
            checkpoint_dir=data_conf.TRAIN_DIR,
            save_checkpoint_secs=60,
            save_summaries_steps=5,
            hooks=[tf.estimator.StopAtStepHook(last_step=model_conf.MAX_STEPS),
                   ], config=config) as sess:
        step = 0
        total_acc = 0.0
        if init:
            _ = sess.run(set_embeddings_op, feed_dict={place: vectors})
        while not sess.should_stop():
            _, loss_val, acc_val, probs_val, lab_val, gs_val = sess.run(
                [training_op, loss_batch, accuracy_batch, probabs, next_l, global_step])

            print(probs_val)
            print(lab_val)
            print("Batch loss: " + str(loss_val))
            print("Batch acc: " + str(acc_val))
            step += 1
            total_acc += acc_val

            print("Total acc: " + str(total_acc / step))
            print("Local_step: " + str(step * batch_size))
            print("Global_step: " + str(gs_val))
            print("===========================================")
    util.copy_model(data_conf.TRAIN_DIR, gs_val)
Exemplo n.º 2
0
def train_model():
    print("train RNN-LSTM model")
    global_step = tf.compat.v1.train.get_or_create_global_step()

    if not tf.io.gfile.exists(data_conf.TRAIN_DIR):
        print("RESTORING WEIGHTS")
        tf.io.gfile.makedirs(data_conf.TRAIN_DIR)
    util.save_config_values(data_conf, data_conf.TRAIN_DIR + "/data")
    util.save_config_values(model_conf, data_conf.TRAIN_DIR + "/model")

    filenames = glob.glob(data_conf.TRAIN_RECORD_PATH + '/*')
    print("Reading training dataset from %s" % filenames)
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(get_single_sample)
    dataset = dataset.shuffle(buffer_size=9000)
    dataset = dataset.repeat(model_conf.NUM_EPOCHS)
    batch_size = model_conf.BATCH_SIZE

    dataset = dataset.padded_batch(model_conf.BATCH_SIZE,
                                   padded_shapes=([None], [ANSWER_COUNT, None],
                                                  [None], (), [None,
                                                               None], ()))

    iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)

    next_q, next_a, next_l, next_plot_ids, next_plots, next_q_types = iterator.get_next(
    )

    logits, _, _, _ = predict_batch([next_q, next_a, next_plots],
                                    training=True)

    probabs = model.compute_probabilities(logits=logits)
    loss_batch = model.compute_batch_mean_loss(logits, next_l,
                                               model_conf.LOSS_FUNC)
    accuracy = model.compute_accuracies(logits=logits, labels=next_l, dim=1)
    accuracy_batch = tf.reduce_mean(input_tensor=accuracy)
    tf.compat.v1.summary.scalar("train_accuracy", accuracy_batch)
    tf.compat.v1.summary.scalar("train_loss", loss_batch)

    training_op = update_op(loss_batch, global_step, model_conf.OPTIMIZER,
                            model_conf.INITIAL_LEARNING_RATE)
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.OFF

    with tf.compat.v1.train.MonitoredTrainingSession(
            checkpoint_dir=data_conf.TRAIN_DIR,
            save_checkpoint_secs=60,
            save_summaries_steps=5,
            hooks=[
                tf.estimator.StopAtStepHook(last_step=model_conf.MAX_STEPS),
            ],
            config=config) as sess:
        step = 0
        total_acc = 0.0
        # print("Feeding embeddings %s of size %s" % (str(vectors), len(vectors)))
        sess.run(set_embeddings_op, feed_dict={place: vectors})
        while not sess.should_stop():
            _, loss_val, acc_val, probs_val, lab_val, gs_val = sess.run([
                training_op, loss_batch, accuracy_batch, probabs, next_l,
                global_step
            ])
            print(probs_val)
            print(lab_val)
            print("Batch loss: " + str(loss_val))
            print("Batch acc: " + str(acc_val))
            step += 1
            total_acc += acc_val

            print("Total acc: " + str(total_acc / step))
            print("Local_step: " + str(step * batch_size))
            print("Global_step: " + str(gs_val))
            print("===========================================")
    util.copy_model(data_conf.TRAIN_DIR, gs_val)