예제 #1
0
def _main(_):
    # Data
    train_data = tx.data.MultiAlignedData(config.train_data)
    val_data = tx.data.MultiAlignedData(config.val_data)
    test_data = tx.data.MultiAlignedData(config.test_data)
    if config.manual:
        manual_data = tx.data.MultiAlignedData(config.manual_data)
    vocab = train_data.vocab(0)

    # Each training batch is used twice: once for updating the generator and
    # once for updating the discriminator. Feedable data iterator is used for
    # such case.
    if config.manual:
        iterator = tx.data.FeedableDataIterator({
            'train_g': train_data,
            'train_d': train_data,
            'train_z': train_data,
            'val': val_data,
            'test': test_data,
            'manual': manual_data
        })
    else:
        iterator = tx.data.FeedableDataIterator({
            'train_g': train_data,
            'train_d': train_data,
            'train_z': train_data,
            'val': val_data,
            'test': test_data
        })
    batch = iterator.get_next()

    # Model
    gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma')
    lambda_g = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_g')
    lambda_z = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_z')
    lambda_z1 = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_z1')
    lambda_z2 = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_z2')
    lambda_ae = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_ae')
    model = CtrlGenModel(batch, vocab, gamma, lambda_g, lambda_z, lambda_z1,
                         lambda_z2, lambda_ae, config.model)

    def _train_epoch(sess,
                     gamma_,
                     lambda_g_,
                     lambda_z_,
                     lambda_z1_,
                     lambda_z2_,
                     lambda_ae_,
                     epoch,
                     verbose=True):
        avg_meters_d = tx.utils.AverageRecorder(size=10)
        avg_meters_g = tx.utils.AverageRecorder(size=10)
        avg_meters_z = tx.utils.AverageRecorder(size=10)

        step = 0
        while True:
            try:
                step += 1
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train_d'),
                    gamma: gamma_,
                    lambda_g: lambda_g_,
                    lambda_z: lambda_z_,
                    lambda_z1: lambda_z1_,
                    lambda_z2: lambda_z2_,
                    lambda_ae: lambda_ae_
                }

                vals_d = sess.run(model.fetches_train_d, feed_dict=feed_dict)
                avg_meters_d.add(vals_d)

                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train_g'),
                    gamma: gamma_,
                    lambda_g: lambda_g_,
                    lambda_z: lambda_z_,
                    lambda_z1: lambda_z1_,
                    lambda_z2: lambda_z2_,
                    lambda_ae: lambda_ae_
                }
                vals_g = sess.run(model.fetches_train_g, feed_dict=feed_dict)
                avg_meters_g.add(vals_g)

                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train_z'),
                    gamma: gamma_,
                    lambda_g: lambda_g_,
                    lambda_z: lambda_z_,
                    lambda_z1: lambda_z1_,
                    lambda_z2: lambda_z2_,
                    lambda_ae: lambda_ae_
                }
                vals_z = sess.run(model.fetches_train_z, feed_dict=feed_dict)
                avg_meters_z.add(vals_z)

                if verbose and (step == 1 or step % config.display == 0):
                    print('epoch: {}, step: {}, {}'.format(
                        epoch, step, avg_meters_d.to_str(4)))
                    print('epoch: {}, step: {}, {}'.format(
                        epoch, step, avg_meters_z.to_str(4)))
                    print('epoch: {}, step: {}, {}'.format(
                        epoch, step, avg_meters_g.to_str(4)))

                if verbose and step % config.display_eval == 0:
                    iterator.restart_dataset(sess, 'val')
                    _eval_epoch(sess, gamma_, lambda_g_, lambda_z_, lambda_z1_,
                                lambda_z2_, lambda_ae_, epoch)

            except tf.errors.OutOfRangeError:
                print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4)))
                print('epoch: {}, {}'.format(epoch, avg_meters_z.to_str(4)))
                print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4)))
                break

    def _eval_epoch(sess,
                    gamma_,
                    lambda_g_,
                    lambda_z_,
                    lambda_z1_,
                    lambda_z2_,
                    lambda_ae_,
                    epoch,
                    val_or_test='val',
                    plot_z=False,
                    plot_max_count=1000,
                    spam=False,
                    repetitions=False,
                    write_text=True,
                    write_labels=False):
        avg_meters = tx.utils.AverageRecorder()

        if plot_z:
            z_vectors = []
            labels = []
            tsne = TSNE(n_components=2)
        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, val_or_test),
                    gamma: gamma_,
                    lambda_g: lambda_g_,
                    lambda_z: lambda_z_,
                    lambda_z1: lambda_z1_,
                    lambda_z2: lambda_z2_,
                    lambda_ae: lambda_ae_,
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
                }

                vals = sess.run(model.fetches_eval, feed_dict=feed_dict)

                batch_size = vals.pop('batch_size')

                # Computes BLEU
                samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
                hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab)

                refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
                refs = np.expand_dims(refs, axis=1)

                bleu = tx.evals.corpus_bleu_moses(refs, hyps)
                vals['bleu'] = bleu

                if spam or repetitions:
                    target_labels = samples['labels_target']
                    predicted_labels = samples['labels_predicted']

                    results = [(r, h, t, p) for r, h, t, p in zip(
                        refs, hyps, target_labels, predicted_labels)]

                # Computes repetitions
                if repetitions:
                    count_equal_strings = 0
                    remain_samples_e = []
                    for r, h, t, p in results:
                        if r == h:
                            count_equal_strings += 1
                        else:
                            remain_samples_e.append((r, h, t, p))
                    vals['equal'] = count_equal_strings / len(hyps)

                # Computes spam
                if spam:
                    count_spam = 0
                    remain_samples_s = []
                    for r, h, t, p in results:
                        words = h.split()
                        if len(words) > 2 and words[-1] == words[-2]:
                            count_spam += 1
                        elif len(words) > 4 and words[-1] == words[
                                -3] and words[-2] == words[-4]:
                            count_spam += 1
                        else:
                            remain_samples_s.append((r, h, t, p))
                    vals['spam'] = count_spam / len(hyps)

                if repetitions and spam:
                    remain_samples = [
                        semple for semple in remain_samples_e
                        if semple in remain_samples_s
                    ]
                    remain_samples = list(remain_samples)
                elif not repetitions and spam:
                    remain_samples = remain_samples_s
                elif repetitions and not spam:
                    remain_samples = remain_samples_e

                if repetitions and spam:
                    refs_remain = [r for r, h, t, p in remain_samples]
                    hyps_remain = [h for r, h, t, p in remain_samples]
                    bleu_remain = tx.evals.corpus_bleu_moses(
                        refs_remain, hyps_remain)
                    vals['bleu_remain'] = bleu_remain

                    if len(remain_samples) != 0:
                        true_labels = 0
                        for _, _, t, p in remain_samples:
                            if t == p:
                                true_labels += 1
                        vals['acc_remain'] = true_labels / len(remain_samples)
                    else:
                        vals['acc_remain'] = 0.

                avg_meters.add(vals, weight=batch_size)

                if plot_z:
                    z_vectors += samples['z_vector'].tolist()
                    labels += samples['labels_source'].tolist()

                # Writes samples
                if write_text:
                    tx.utils.write_paired_text(
                        refs.squeeze(),
                        hyps,
                        os.path.join(config.sample_path,
                                     'text_{}.{}'.format(val_or_test, epoch)),
                        append=True,
                        mode='v')

                # Writes labels samples
                if write_labels:
                    tx.utils.write_paired_text(
                        [str(l) for l in samples['labels_target'].tolist()],
                        [str(l) for l in samples['labels_predicted'].tolist()],
                        os.path.join(config.sample_path,
                                     'labels_{}.{}'.format(val_or_test,
                                                           epoch)),
                        append=True,
                        mode='v')

            except tf.errors.OutOfRangeError:
                print('epoch: {}, {}: {}'.format(
                    epoch, val_or_test, avg_meters.to_str(precision=4)))
                break

        if plot_z:
            if plot_max_count == 0:
                z_vectors = z_vectors
                labels = labels
            else:
                z_vectors = z_vectors[:plot_max_count]
                labels = labels[:plot_max_count]
            tsne_result = tsne.fit_transform(np.array(z_vectors))
            x_data = tsne_result[:, 0]
            y_data = tsne_result[:, 1]
            plt.scatter(x_data,
                        y_data,
                        c=np.array(labels),
                        s=1,
                        cmap=plt.cm.get_cmap('jet', 2))
            plt.clim(0.0, 1.0)
            if not os.path.exists('./images'):
                os.makedirs('./images')
            plt.savefig('./images/{}_{}.png'.format(val_or_test, epoch))
            plt.clf()

        return avg_meters.avg()

    tf.gfile.MakeDirs(config.sample_path)
    tf.gfile.MakeDirs(config.checkpoint_path)

    # Runs the logics
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        saver = tf.train.Saver(max_to_keep=None)
        if config.restore:
            print('Restore from: {}'.format(config.restore))
            saver.restore(sess, config.restore)

        iterator.initialize_dataset(sess)

        gamma_ = 1.
        lambda_g_ = 0.
        lambda_z_ = 0.
        lambda_ae_ = 1.
        lambda_z1_ = config.lambda_z1
        lambda_z2_ = config.lambda_z2
        for epoch in range(1, config.max_nepochs + 1):
            if epoch > config.pretrain_ae_nepochs:
                # Anneals the gumbel-softmax temperature
                gamma_ = max(0.001, gamma_ * config.gamma_decay)
                lambda_g_ = config.lambda_g
                lambda_z_ = config.lambda_z
            if epoch > config.chage_lambda_ae_epoch:
                lambda_ae_ = lambda_ae_ - config.change_lambda_ae
            print(
                'gamma: {}, lambda_g: {}, lambda_z: {}, lambda_z1: {}, lambda_z2: {}, lambda_ae: {}'
                .format(gamma_, lambda_g_, lambda_z_, lambda_z1_, lambda_z2_,
                        lambda_ae_))

            # Train
            iterator.restart_dataset(sess, ['train_g', 'train_d', 'train_z'])
            _train_epoch(sess, gamma_, lambda_g_, lambda_z_, lambda_z1_,
                         lambda_z2_, lambda_ae_, epoch)

            # Val
            iterator.restart_dataset(sess, 'val')
            _eval_epoch(sess,
                        gamma_,
                        lambda_g_,
                        lambda_z_,
                        lambda_z1_,
                        lambda_z2_,
                        lambda_ae_,
                        epoch,
                        'val',
                        plot_z=config.plot_z,
                        plot_max_count=config.plot_max_count,
                        spam=config.spam,
                        repetitions=config.repetitions,
                        write_text=config.write_text,
                        write_labels=config.write_labels)

            saver.save(sess, os.path.join(config.checkpoint_path, 'ckpt'),
                       epoch)

            # Test
            iterator.restart_dataset(sess, 'test')
            _eval_epoch(sess,
                        gamma_,
                        lambda_g_,
                        lambda_z_,
                        lambda_z1_,
                        lambda_z2_,
                        lambda_ae_,
                        epoch,
                        'test',
                        plot_z=config.plot_z,
                        plot_max_count=config.plot_max_count,
                        spam=config.spam,
                        repetitions=config.repetitions,
                        write_text=config.write_text,
                        write_labels=config.write_labels)

            if config.manual:
                iterator.restart_dataset(sess, 'manual')
                _eval_epoch(sess,
                            gamma_,
                            lambda_g_,
                            lambda_z_,
                            lambda_z1_,
                            lambda_z2_,
                            lambda_ae_,
                            epoch,
                            'manual',
                            plot_z=config.plot_z,
                            plot_max_count=config.plot_max_count,
                            spam=config.spam,
                            repetitions=config.repetitions,
                            write_text=config.write_text,
                            write_labels=config.write_labels)
예제 #2
0
def _main(_):
    # Data
    train_data = tx.data.MultiAlignedData(config.train_data)
    test_data = tx.data.MultiAlignedData(config.test_data)
    vocab = train_data.vocab(0)

    # Each training batch is used twice: once for updating the generator and
    # once for updating the discriminator. Feedable data iterator is used for
    # such case.
    iterator = tx.data.FeedableDataIterator(
        {'test': test_data})
    batch = iterator.get_next()

    # Model
    gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma')
    lambda_g = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_g')
    model = CtrlGenModel(batch, vocab, gamma, lambda_g, config.model)

    def _eval_epoch(sess, gamma_, lambda_g_, val_or_test='test'):
        avg_meters = tx.utils.AverageRecorder()

        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, val_or_test),
                    gamma: gamma_,
                    lambda_g: lambda_g_,
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
                }

                vals = sess.run(model.fetches_eval, feed_dict=feed_dict)

                batch_size = vals.pop('batch_size')

                # Computes BLEU
                samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
                hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab)

                refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
                refs = np.expand_dims(refs, axis=1)

                bleu = tx.evals.corpus_bleu_moses(refs, hyps)
                vals['bleu'] = bleu

                avg_meters.add(vals, weight=batch_size)

                # Writes samples
                tx.utils.write_paired_text(
                    refs.squeeze(),hyps,
                    os.path.join(config.sample_path, 'result'),
                    append=True, mode='v')

            except tf.errors.OutOfRangeError:
                print('{}: {}'.format(
                    val_or_test, avg_meters.to_str(precision=4)))
                break

        return avg_meters.avg()

    tf.gfile.MakeDirs(config.sample_path)
    tf.gfile.MakeDirs(config.checkpoint_path)

    # Runs the logics
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        saver = tf.train.Saver(max_to_keep=None)
        if config.restore:
            print('Restore from: {}'.format(config.restore))
            saver.restore(sess, config.restore)

        iterator.initialize_dataset(sess)

        gamma_ = 1.
        lambda_g_ = 0.
                # Anneals the gumbel-softmax temperature
        gamma_ = max(0.001, 1.* config.gamma_decay)
        lambda_g_ = config.lambda_g
        #print('gamma: {}, lambda_g: {}'.format(gamma_, lambda_g_))

        # Test
        iterator.restart_dataset(sess, 'test')
        _eval_epoch(sess, gamma_, lambda_g_, 'test')
예제 #3
0
def _main(_):
    # Data
    train_autoencoder = tx.data.MultiAlignedData(config.train_autoencoder)
    dev_autoencoder = tx.data.MultiAlignedData(config.dev_autoencoder)
    test_autoencoder = tx.data.MultiAlignedData(config.test_autoencoder)
    train_discriminator = tx.data.MultiAlignedData(config.train_discriminator)
    dev_discriminator = tx.data.MultiAlignedData(config.dev_discriminator)
    test_discriminator = tx.data.MultiAlignedData(config.test_discriminator)
    train_defender = tx.data.MultiAlignedData(config.train_defender)
    test_defender = tx.data.MultiAlignedData(config.test_defender)
    vocab = train_discriminator.vocab(0)

    iterator = tx.data.FeedableDataIterator({
        'train_autoencoder': train_autoencoder,
        'dev_autoencder': dev_autoencoder,
        'test_autoencoder': test_autoencoder,
        'train_discriminator': train_discriminator,
        'dev_discriminator': dev_discriminator,
        'test_discriminator': test_discriminator,
        'train_defender': train_defender,
        'test_defender': test_defender,
    })
    batch = iterator.get_next()

    # Model
    gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma')
    lambda_D = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_g')
    lambda_ae_ = 1.0
    model = CtrlGenModel(batch, vocab, lambda_ae_, gamma, lambda_D,
                         config.model)

    def autoencoder(sess,
                    lambda_ae_,
                    gamma_,
                    lambda_D_,
                    epoch,
                    mode,
                    verbose=True):
        avg_meters_g = tx.utils.AverageRecorder(size=10)
        step = 0
        if mode == "train":
            dataset = "train_autoencoder"
            while True:
                try:
                    step += 1
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, dataset),
                        gamma: gamma_,
                        lambda_D: lambda_D_,
                    }
                    vals_g = sess.run(model.fetches_train_g,
                                      feed_dict=feed_dict)
                    loss_g_ae_summary = vals_g.pop("loss_g_ae_summary")
                    loss_g_clas_summary = vals_g.pop("loss_g_clas_summary")
                    avg_meters_g.add(vals_g)

                    if verbose and (step == 1 or step % config.display == 0):
                        print('step: {}, {}'.format(step,
                                                    avg_meters_g.to_str(4)))

                    if verbose and step % config.display_eval == 0:
                        iterator.restart_dataset(sess, 'dev_autoencoder')
                        _eval_epoch(sess, lambda_ae_, gamma_, lambda_ae_,
                                    epoch)

                except tf.errors.OutOfRangeError:
                    print('epoch: {}, {}'.format(epoch,
                                                 avg_meters_g.to_str(4)))
                    break
        else:
            dataset = "test_autoencoder"
            while True:
                try:
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, dataset),
                        gamma: gamma_,
                        lambda_D: lambda_D_,
                        tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
                    }

                    vals = sess.run(model.fetches_eval, feed_dict=feed_dict)

                    samples = tx.utils.dict_pop(vals,
                                                list(model.samples.keys()))
                    hyps = tx.utils.map_ids_to_strs(samples['transferred'],
                                                    vocab)
                    refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
                    refs = np.expand_dims(refs, axis=1)
                    avg_meters_g.add(vals)
                    # Writes samples
                    tx.utils.write_paired_text(refs.squeeze(),
                                               hyps,
                                               os.path.join(
                                                   config.sample_path,
                                                   'val.%d' % epoch),
                                               append=True,
                                               mode='v')

                except tf.errors.OutOfRangeError:
                    print('{}: {}'.format("test_autoencoder_only",
                                          avg_meters_g.to_str(precision=4)))
                    break

    def discriminator(sess,
                      lambda_ae_,
                      gamma_,
                      lambda_D_,
                      epoch,
                      mode,
                      verbose=True):
        avg_meters_d = tx.utils.AverageRecorder(size=10)
        y_true = []
        y_pred = []
        y_prob = []
        sentences = []
        step = 0
        if mode == "train":
            dataset = "train_discriminator"

            while True:
                try:
                    step += 1
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, dataset),
                        gamma: gamma_,
                        lambda_D: lambda_D_,
                    }

                    vals_d = sess.run(model.fetches_train_d,
                                      feed_dict=feed_dict)
                    y_pred.extend(vals_d.pop("y_pred").tolist())
                    y_true.extend(vals_d.pop("y_true").tolist())
                    y_prob.extend(vals_d.pop("y_prob").tolist())
                    sentences.extend(vals_d.pop("sentences").tolist())
                    avg_meters_d.add(vals_d)

                    # if verbose and (step == 1 or step % config.display == 0):
                    if verbose and step % 40 == 0:
                        print('step: {}, {}'.format(step,
                                                    avg_meters_d.to_str(4)))

                except tf.errors.OutOfRangeError:
                    iterator.restart_dataset(sess, 'dev_discriminator')
                    _, _, _, _, val_acc = _eval_discriminator(
                        sess, lambda_ae_, gamma_, lambda_D_, epoch,
                        'dev_discriminator')
                    return val_acc

        if mode == 'test':
            dataset = "test_discriminator"
            iterator.restart_dataset(sess, dataset)
            y_pred, y_true, y_prob, sentences, _ = _eval_discriminator(
                sess, lambda_ae_, gamma_, lambda_D_, epoch, dataset)

            assert (len(y_pred) == len(y_true) == len(y_prob) ==
                    len(sentences))

            # tp=0
            # tn=0
            # fp=0
            # acc=0
            # for sent,label,pred,prob in zip(sentences,y_true,y_pred,y_prob):
            #     if pred==1 and label==1:
            #         tp+=1.0/len(y_true)
            #     if pred==0 and label==0:
            #         tn+=1.0/len(y_true)
            #     if pred==1 and label==0:
            #         fp+=1.0/len(y_true)
            #     if pred==label:
            #         acc+=1.0/len(y_true)

            # print('true_positives:{}'.format(tp))
            # print('true_negatives:{}'.format(tn))
            # print('false_positives:{}'.format(fp))
            # print('accuracy:{}'.format(acc))

            # with open('prob_vocab.txt', 'w') as file:
            #     for word, prob_values in zip(sentences,y_prob):
            #         file.write(word)
            #         file.write('\t')
            #         file.write(str(prob_values))
            #         file.write('\n')

            # txt=open('rand_sent_from_vocab_Discriminator_label.txt','w')
            # with open('rand_sent_from_vocab_Discriminator.txt', 'w') as file:
            #     for sentence, pred in zip(sentences, y_pred):
            #             file.write(sentence+'\n')
            #             txt.write(str(pred)+'\n')
            txt = open(
                DATA_DIR +
                'rand_x_sent_from_vocab_Discriminator_neg_confirmed.txt', 'w')
            with open(
                    DATA_DIR +
                    'rand_x_sent_from_vocab_Discriminator_neg_confirmed.txt',
                    'w') as file:
                for sentence, pred, label in zip(sentences, y_pred, y_true):
                    if pred == 0 and label == 0:
                        file.write(sentence + '\n')
                        txt.write(str(pred) + '\n')

    def defender(sess,
                 lambda_ae_,
                 gamma_,
                 lambda_D_,
                 epoch,
                 mode,
                 verbose=True):
        avg_meters_g = tx.utils.AverageRecorder(size=10)
        step = 0
        if mode == "train":
            dataset = "train_defender"
            while True:
                try:
                    step += 1
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, dataset),
                        gamma: gamma_,
                        lambda_D: lambda_D_,
                    }
                    vals_g = sess.run(model.fetches_train_g,
                                      feed_dict=feed_dict)
                    loss_g_ae_summary = vals_g.pop("loss_g_ae_summary")
                    loss_g_clas_summary = vals_g.pop("loss_g_clas_summary")
                    avg_meters_g.add(vals_g)

                    if verbose and (step == 1 or step % config.display == 0):
                        print('step: {}, {}'.format(step,
                                                    avg_meters_g.to_str(4)))

                except tf.errors.OutOfRangeError:
                    print('epoch: {}, {}'.format(epoch,
                                                 avg_meters_g.to_str(4)))
                    break
        else:
            dataset = "test_defender"
            while True:
                try:
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, dataset),
                        gamma: gamma_,
                        lambda_D: lambda_D_,
                        tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
                    }

                    vals = sess.run(model.fetches_eval, feed_dict=feed_dict)

                    samples = tx.utils.dict_pop(vals,
                                                list(model.samples.keys()))
                    hyps = tx.utils.map_ids_to_strs(samples['transferred'],
                                                    vocab)
                    refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
                    refs = np.expand_dims(refs, axis=1)
                    avg_meters_g.add(vals)
                    # Writes samples
                    tx.utils.write_paired_text(refs.squeeze(),
                                               hyps,
                                               os.path.join(
                                                   config.sample_path,
                                                   'defender_val.%d' % epoch),
                                               append=True,
                                               mode='v')

                except tf.errors.OutOfRangeError:
                    print('{}: {}'.format("test_defender",
                                          avg_meters_g.to_str(precision=4)))
                    break

    def _eval_discriminator(sess, lambda_ae_, gamma_, lambda_D_, epoch,
                            dataset):
        avg_meters_d = tx.utils.AverageRecorder()
        y_true = []
        y_pred = []
        y_prob = []
        sentences = []
        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, dataset),
                    gamma: gamma_,
                    lambda_D: lambda_D_,
                }

                vals_d = sess.run(model.fetches_dev_test_d,
                                  feed_dict=feed_dict)
                y_pred.extend(vals_d.pop("y_pred").tolist())
                y_true.extend(vals_d.pop("y_true").tolist())
                y_prob.extend(vals_d.pop("y_prob").tolist())
                sentence = vals_d.pop("sentences").tolist()
                sentences.extend(tx.utils.map_ids_to_strs(sentence, vocab))
                batch_size = vals_d.pop('batch_size')
                avg_meters_d.add(vals_d, weight=batch_size)

            except tf.errors.OutOfRangeError:
                acc = avg_meters_d.avg()['accu_d']
                print('{}: {}'.format(dataset,
                                      avg_meters_d.to_str(precision=4)))
                break

        return y_pred, y_true, y_prob, sentences, acc

    tf.gfile.MakeDirs(config.sample_path)
    tf.gfile.MakeDirs(config.checkpoint_path)

    # Runs the logics
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        saver = tf.train.Saver(max_to_keep=None)
        print(config.restore)
        if config.restore:
            print('Restore from: {}'.format(config.restore))
            saver.restore(sess, config.restore)

        iterator.initialize_dataset(sess)

        gamma_ = 1.0
        lambda_D_ = 0.0

        prev_acc = 0
        # #Train discriminator.
        for epoch in range(1, config.discriminator_nepochs + 1):
            print("Epoch number:", epoch)
            iterator.restart_dataset(sess, ['train_discriminator'])
            val_acc = discriminator(sess,
                                    lambda_ae_,
                                    gamma_,
                                    lambda_D_,
                                    epoch,
                                    mode='train')
            if (val_acc > prev_acc):
                print("Accuracy is better, saving model")
                prev_acc = val_acc
                saver.save(
                    sess,
                    os.path.join(config.checkpoint_path,
                                 'discriminator_only_ckpt'), epoch)
            else:
                print("Accuracy is worse")
        # Test discriminator.
        iterator.restart_dataset(sess, ['test_discriminator'])
        print('gamma:{}'.format(gamma_))
        discriminator(sess, lambda_ae_, gamma_, lambda_D_, 1, mode='test')

        exit()

        # Train autoencoder
        for epoch in range(1, config.autoencoder_nepochs + 1):
            iterator.restart_dataset(sess, ['train_autoencoder'])
            autoencoder(sess,
                        lambda_ae_,
                        gamma_,
                        lambda_D_,
                        epoch,
                        mode='train')
            saver.save(
                sess,
                os.path.join(config.checkpoint_path,
                             'discriminator_only_and_autoencoder_only_ckpt'),
                epoch)

        # Test autoencoder
        iterator.restart_dataset(sess, ['test_autoencoder'])
        autoencoder(sess, lambda_ae_, gamma_, lambda_D_, 1, mode='test')

        gamma_ = 1.0
        lambda_D_ = 1.0
        # # gamma_decay = 0.5  # Gumbel-softmax temperature anneal rate

        # Train Defender
        for epoch in range(0, config.full_nepochs):
            # gamma_ = max(0.001, gamma_ * 0.5)
            print('gamma: {}, lambda_ae: {}, lambda_D: {}'.format(
                gamma_, lambda_ae_, lambda_D_))

            iterator.restart_dataset(sess, ['train_defender'])
            defender(sess, lambda_ae_, gamma_, lambda_D_, epoch, mode='train')
            saver.save(sess, os.path.join(config.checkpoint_path, 'full_ckpt'),
                       epoch)

    # Test Defender
        iterator.restart_dataset(sess, 'test_defender')
        defender(sess, lambda_ae_, gamma_, lambda_D_, 1, mode='test')
예제 #4
0
def main():
    # Data
    train_data = tx.data.MultiAlignedData(hparams=config.train_data,
                                          device=device)
    val_data = tx.data.MultiAlignedData(hparams=config.val_data, device=device)
    test_data = tx.data.MultiAlignedData(hparams=config.test_data,
                                         device=device)
    vocab = train_data.vocab(0)

    # Each training batch is used twice: once for updating the generator and
    # once for updating the discriminator. Feedable data iterator is used for
    # such case.
    iterator = tx.data.DataIterator({
        'train': train_data,
        'val': val_data,
        'test': test_data
    })

    # Model
    gamma_ = 1.
    lambda_g_ = 0.

    # Model
    model = CtrlGenModel(vocab, hparams=config.model)
    model.to(device)

    # create optimizers
    train_op_d = tx.core.get_optimizer(params=model.d_vars,
                                       hparams=config.model['opt'])

    train_op_g = tx.core.get_optimizer(params=model.g_vars,
                                       hparams=config.model['opt'])

    train_op_g_ae = tx.core.get_optimizer(params=model.g_vars,
                                          hparams=config.model['opt'])

    def _train_epoch(gamma_, lambda_g_, epoch, verbose=True):
        model.train()
        avg_meters_d = tx.utils.AverageRecorder(size=10)
        avg_meters_g = tx.utils.AverageRecorder(size=10)
        iterator.switch_to_dataset("train")
        step = 0
        for batch in iterator:
            train_op_d.zero_grad()
            train_op_g_ae.zero_grad()
            train_op_g.zero_grad()
            step += 1

            vals_d = model(batch,
                           gamma_,
                           lambda_g_,
                           mode="train",
                           component="D")
            loss_d = vals_d['loss_d']
            loss_d.backward()
            train_op_d.step()
            recorder_d = {
                key: value.detach().cpu().data
                for (key, value) in vals_d.items()
            }
            avg_meters_d.add(recorder_d)

            vals_g = model(batch,
                           gamma_,
                           lambda_g_,
                           mode="train",
                           component="G")

            if epoch <= config.pretrain_nepochs:
                loss_g_ae = vals_g['loss_g_ae']
                loss_g_ae.backward()
                train_op_g_ae.step()
            else:
                loss_g = vals_g['loss_g']
                loss_g.backward()
                train_op_g.step()

            recorder_g = {
                key: value.detach().cpu().data
                for (key, value) in vals_g.items()
            }
            avg_meters_g.add(recorder_g)

            if verbose and (step == 1 or step % config.display == 0):
                print('step: {}, {}'.format(step, avg_meters_d.to_str(4)))
                print('step: {}, {}'.format(step, avg_meters_g.to_str(4)))

            if verbose and step % config.display_eval == 0:
                _eval_epoch(gamma_, lambda_g_, epoch)

        print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4)))
        print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4)))

    @torch.no_grad()
    def _eval_epoch(gamma_, lambda_g_, epoch, val_or_test='val'):
        model.eval()
        avg_meters = tx.utils.AverageRecorder()
        iterator.switch_to_dataset(val_or_test)
        for batch in iterator:
            vals, samples = model(batch, gamma_, lambda_g_, mode='eval')

            batch_size = vals.pop('batch_size')

            # Computes BLEU
            hyps = tx.data.map_ids_to_strs(samples['transferred'].cpu(), vocab)

            refs = tx.data.map_ids_to_strs(samples['original'].cpu(), vocab)
            refs = np.expand_dims(refs, axis=1)

            bleu = tx.evals.corpus_bleu_moses(refs, hyps)
            vals['bleu'] = bleu

            avg_meters.add(vals, weight=batch_size)

            # Writes samples
            tx.utils.write_paired_text(refs.squeeze(),
                                       hyps,
                                       os.path.join(config.sample_path,
                                                    'val.%d' % epoch),
                                       append=True,
                                       mode='v')

        print('{}: {}'.format(val_or_test, avg_meters.to_str(precision=4)))

        return avg_meters.avg()

    os.makedirs(config.sample_path, exist_ok=True)
    os.makedirs(config.checkpoint_path, exist_ok=True)

    # Runs the logics
    if config.restore:
        print('Restore from: {}'.format(config.restore))
        ckpt = torch.load(args.restore)
        model.load_state_dict(ckpt['model'])
        train_op_d.load_state_dict(ckpt['optimizer_d'])
        train_op_g.load_state_dict(ckpt['optimizer_g'])

    for epoch in range(1, config.max_nepochs + 1):
        if epoch > config.pretrain_nepochs:
            # Anneals the gumbel-softmax temperature
            gamma_ = max(0.001, gamma_ * config.gamma_decay)
            lambda_g_ = config.lambda_g
        print('gamma: {}, lambda_g: {}'.format(gamma_, lambda_g_))

        # Train
        _train_epoch(gamma_, lambda_g_, epoch)

        # Val
        _eval_epoch(gamma_, lambda_g_, epoch, 'val')

        states = {
            'model': model.state_dict(),
            'optimizer_d': train_op_d.state_dict(),
            'optimizer_g': train_op_g.state_dict()
        }
        torch.save(states, os.path.join(config.checkpoint_path, 'ckpt'))

        # Test
        _eval_epoch(gamma_, lambda_g_, epoch, 'test')
예제 #5
0
def _main(_):
    # Data
    train_data = tx.data.MultiAlignedData(config.train_data)
    val_data = tx.data.MultiAlignedData(config.val_data)
    test_data = tx.data.MultiAlignedData(config.test_data)
    vocab = train_data.vocab(0)

    # Each training batch is used twice: once for updating the generator and
    # once for updating the discriminator. Feedable data iterator is used for
    # such case.
    iterator = tx.data.FeedableDataIterator({
        'train_g': train_data,
        'train_d': train_data,
        'val': val_data,
        'test': test_data
    })
    batch = iterator.get_next()

    # Model
    gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma')
    lambda_g = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_g')
    model = CtrlGenModel(batch, vocab, gamma, lambda_g, config.model)

    def _train_epoch(sess, gamma_, lambda_g_, epoch, verbose=True):
        avg_meters_d = tx.utils.AverageRecorder(size=10)
        avg_meters_g = tx.utils.AverageRecorder(size=10)

        step = 0
        while True:
            try:
                step += 1
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train_d'),
                    gamma: gamma_,
                    lambda_g: lambda_g_
                }

                vals_d = sess.run(model.fetches_train_d, feed_dict=feed_dict)
                avg_meters_d.add(vals_d)

                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train_g'),
                    gamma: gamma_,
                    lambda_g: lambda_g_
                }
                vals_g = sess.run(model.fetches_train_g, feed_dict=feed_dict)
                avg_meters_g.add(vals_g)

                if verbose and (step == 1 or step % config.display == 0):
                    print('step: {}, {}'.format(step, avg_meters_d.to_str(4)))
                    print('step: {}, {}'.format(step, avg_meters_g.to_str(4)))

                if verbose and step % config.display_eval == 0:
                    iterator.restart_dataset(sess, 'val')
                    _eval_epoch(sess, gamma_, lambda_g_, epoch)

            except tf.errors.OutOfRangeError:
                print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4)))
                print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4)))
                break

    def _eval_epoch(sess, gamma_, lambda_g_, epoch, val_or_test='val'):
        avg_meters = tx.utils.AverageRecorder()

        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, val_or_test),
                    gamma: gamma_,
                    lambda_g: lambda_g_,
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
                }

                vals = sess.run(model.fetches_eval, feed_dict=feed_dict)

                batch_size = vals.pop('batch_size')

                # Computes BLEU
                samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
                hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab)

                refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
                refs = np.expand_dims(refs, axis=1)

                bleu = tx.evals.corpus_bleu_moses(refs, hyps)
                tf.summary.scalar('Bleu', bleu)
                vals['bleu'] = bleu

                avg_meters.add(vals, weight=batch_size)

                # Writes samples
                tx.utils.write_paired_text(refs.squeeze(),
                                           hyps,
                                           os.path.join(
                                               config.sample_path,
                                               'val.%d' % epoch),
                                           append=True,
                                           mode='v')
                merged = tf.summary.merge_all()
                writer = tf.summary.FileWriter('./summary', sess.graph)
                result = sess.run(merged)
                writer.add_summary(result)

            except tf.errors.OutOfRangeError:
                print('{}: {}'.format(val_or_test,
                                      avg_meters.to_str(precision=4)))
                break

        return avg_meters.avg()

    tf.gfile.MakeDirs(config.sample_path)
    tf.gfile.MakeDirs(config.checkpoint_path)

    # Runs the logics
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        saver = tf.train.Saver(max_to_keep=None)
        if config.restore:
            print('Restore from: {}'.format(config.restore))
            saver.restore(sess, config.restore)

        iterator.initialize_dataset(sess)

        gamma_ = 1.
        lambda_g_ = 0.
        for epoch in range(1, config.max_nepochs + 1):
            if epoch > config.pretrain_nepochs:
                # Anneals the gumbel-softmax temperature
                gamma_ = max(0.001, gamma_ * config.gamma_decay)
                lambda_g_ = config.lambda_g
            print('gamma: {}, lambda_g: {}'.format(gamma_, lambda_g_))

            # Train
            iterator.restart_dataset(sess, ['train_g', 'train_d'])
            _train_epoch(sess, gamma_, lambda_g_, epoch)

            # Val
            iterator.restart_dataset(sess, 'val')
            _eval_epoch(sess, gamma_, lambda_g_, epoch, 'val')

            saver.save(sess, os.path.join(config.checkpoint_path, 'ckpt'),
                       epoch)

            # Test
            iterator.restart_dataset(sess, 'test')
            _eval_epoch(sess, gamma_, lambda_g_, epoch, 'test')