def main(_):
    with tf.Graph().as_default() as single_gpu_graph:
        model_config = configuration.model_config(
            input_file_pattern=FLAGS.input_file_pattern,
            batch_size=FLAGS.batch_size)
        training_config = configuration.training_config()
        model = skip_thoughts_model.SkipThoughtsModel(model_config,
                                                      mode="train")
        model.build()

        # Setup learning rate
        if training_config.learning_rate_decay_factor > 0:
            learning_rate = tf.train.exponential_decay(
                learning_rate=float(training_config.learning_rate),
                global_step=model.global_step,
                decay_steps=training_config.learning_rate_decay_steps,
                decay_rate=training_config.learning_rate_decay_factor,
                staircase=False)
        else:
            learning_rate = tf.constant(training_config.learning_rate)

        optimizer = tf.train.AdamOptimizer(learning_rate)

        train_tensor = tf.contrib.slim.learning.create_train_op(
            total_loss=model.total_loss,
            optimizer=optimizer,
            global_step=model.global_step,
            clip_gradient_norm=training_config.clip_gradient_norm)

    def run(sess, num_iters, tensor_or_op_name_to_replica_names, num_workers,
            worker_id, num_replicas_per_worker):
        fetches = {
            'global_step':
            tensor_or_op_name_to_replica_names[model.global_step.name][0],
            'cost':
            tensor_or_op_name_to_replica_names[model.total_loss.name][0],
            'train_op':
            tensor_or_op_name_to_replica_names[train_tensor.name][0],
        }

        start = time.time()
        for i in range(num_iters):
            results = sess.run(fetches)
            if i % FLAGS.log_frequency == 0:
                end = time.time()
                throughput = float(FLAGS.log_frequency) / float(end - start)
                parallax.log.info(
                    "global step: %d, loss: %f, throughput: %f steps/sec" %
                    (results['global_step'], results['cost'], throughput))
                start = time.time()

    parallax.parallel_run(single_gpu_graph,
                          run,
                          FLAGS.resource_info_file,
                          FLAGS.max_steps,
                          sync=FLAGS.sync,
                          parallax_config=parallax_config.build_config())
Exemple #2
0
def main(unused_argv):
    if not FLAGS.input_file_pattern:
        raise ValueError("--input_file_pattern is required.")
    if not FLAGS.train_dir:
        raise ValueError("--train_dir is required.")

    model_config = configuration.model_config(
        input_file_pattern=FLAGS.input_file_pattern)
    training_config = configuration.training_config()

    tf.logging.info("Building training graph.")
    g = tf.Graph()
    with g.as_default():
        model = skip_thoughts_model.SkipThoughtsModel(model_config,
                                                      mode="train")
        model.build()

        learning_rate = _setup_learning_rate(training_config,
                                             model.global_step)
        optimizer = tf.train.AdamOptimizer(learning_rate)

        train_tensor = tf.contrib.slim.learning.create_train_op(
            total_loss=model.total_loss,
            optimizer=optimizer,
            global_step=model.global_step,
            clip_gradient_norm=training_config.clip_gradient_norm)

        saver = tf.train.Saver()

    tf.contrib.slim.learning.train(
        train_op=train_tensor,
        logdir=FLAGS.train_dir,
        graph=g,
        global_step=model.global_step,
        number_of_steps=training_config.number_of_steps,
        save_summaries_secs=training_config.save_summaries_secs,
        saver=saver,
        save_interval_secs=training_config.save_model_secs)
Exemple #3
0
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()

    np.random.seed(config_train.seed)

    assert config_train.start_token == 0
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    generator = Generator(config=config_gen)
    generator.build()

    rollout_gen = rollout(config=config_gen)

    target_params = pickle.load(open('save/target_params.pkl', 'rb'),
                                encoding='iso-8859-1')
    target_lstm = TARGET_LSTM(config=config_gen,
                              params=target_params)  # The oracle model 预测模型

    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()

    pretrained_optimizer = tf.train.AdamOptimizer(
        config_train.gen_learning_rate)
    gradients, variables = zip(*pretrained_optimizer.compute_gradients(
        generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_update = pretrained_optimizer.apply_gradients(
        zip(gradients, variables))

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    generate_samples(sess, target_lstm, config_train.batch_size,
                     config_train.generated_num, config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    log = open('save/experiment-log.txt', 'w')
    print('Start pre-training generator....')

    log.write('pre-training...\n')

    for epoch in range(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()

        for it in range(gen_data_loader.num_batch):
            batch = gen_data_loader.next_batch()
            _, g_loss = sess.run(
                [gen_pre_update, generator.pretrained_loss],
                feed_dict={
                    generator.input_seqs_pre: batch,
                    generator.input_seqs_mask: np.ones_like(batch)
                })

        if epoch % config_train.test_per_epoch == 0:
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print('pre-train ', epoch, ' test_loss ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)

    print('Start pre-training discriminator...')
    for t in range(config_train.dis_update_time_pre):
        print("Times: " + str(t))
        generate_samples(sess, generator, config_train.batch_size,
                         config_train.generated_num,
                         config_train.negative_file)
        dis_data_loader.load_train_data(config_train.positive_file,
                                        config_train.negative_file)
        for _ in range(config_train.dis_update_time_pre):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed_dict = {
                    discriminator.input_x:
                    x_batch,
                    discriminator.input_y:
                    y_batch,
                    discriminator.dropout_keep_prob:
                    config_dis.dis_dropout_keep_prob  # dropout
                }
                _ = sess.run(discriminator.train_op, feed_dict)

    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(
        generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    # Initialize global variables of optimizer for adversarial training
    uninitialized_var = [
        e for e in tf.global_variables() if e not in tf.trainable_variables()
    ]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    for total_batch in range(config_train.total_batch):
        for iter_gen in range(config_train.gen_update_time):
            samples = sess.run(generator.sample_word_list_reshpae)
            feed = {'pred_seq_rollout:0': samples}
            reward_rollout = []
            for iter_roll in range(config_train.rollout_num):
                rollout_list = sess.run(rollout_gen.sample_rollout_step,
                                        feed_dict=feed)
                rollout_list_stack = np.vstack(rollout_list)
                reward_rollout_seq = sess.run(
                    discriminator.ypred_for_auc,
                    feed_dict={
                        discriminator.input_x: rollout_list_stack,
                        discriminator.dropout_keep_prob: 1.0
                    })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,
                                           feed_dict={
                                               discriminator.input_x: samples,
                                               discriminator.dropout_keep_prob:
                                               1.0
                                           })
                reward_allseq = np.concatenate(
                    (reward_rollout_seq, reward_last_tok), axis=0)[:, 1]
                reward_tmp = []
                for r in range(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(
                        r,
                        config_gen.gen_batch_size * config_gen.sequence_length,
                        config_gen.gen_batch_size)])

                reward_rollout.append(np.array(reward_tmp))
                rewards = np.sum(reward_rollout,
                                 axis=0) / config_train.rollout_num
                _, gen_loss = sess.run(
                    [train_adv_update, generator.gen_loss_adv],
                    feed_dict={
                        generator.input_seqs_adv: samples,
                        generator.rewards: rewards
                    })

        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                test_loss) + '\n'
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)

        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,
                                            config_train.negative_file)

            for _ in range(config_train.dis_update_time_adv):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:
                        x_batch,
                        discriminator.input_y:
                        y_batch,
                        discriminator.dropout_keep_prob:
                        config_dis.dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    log.close()
def main(unused_argv):

    model_config = configuration.model_config(
        lambda_rho=float(FLAGS.lambda_rho),
        lambda_rho_f=float(FLAGS.lambda_rho_f),
        lambda_ranking=float(FLAGS.lambda_ranking),
        lambda_ranking_text=float(FLAGS.lambda_ranking_text),
        lambda_ranking_text_f=float(FLAGS.lambda_ranking_text_f),
        lambda_cap2cap=float(FLAGS.lambda_cap2cap),
        lambda_infer=float(FLAGS.lambda_infer),
        lambda_quick=float(FLAGS.lambda_quick),
        lambda_Iinfer=float(FLAGS.lambda_Iinfer),
        lambda_nn=float(FLAGS.lambda_nn),
        hidden_dim=FLAGS.hidden_dim,
        train_folder=train_dir,
        distance=FLAGS.distance,
        rank_distance=FLAGS.rank_distance,
        nb_layer=FLAGS.nb_layer,
        word_embedding_dim=FLAGS.word_embedding_dim,
        SICK_trial=os.environ["SICK_TRIAL"],
        bin_dir=os.environ["GROUNDSENT_BIN_DIR"],
        mu_reg=FLAGS.mu_reg,
        val_sim_pattern=val_sim_pattern,
        sim_pattern=sim_pattern,
        nn_pattern=nn_pattern,
        val_nn_pattern=val_nn_pattern,
        validation_pattern=validation_pattern,
        bookcorpus_pattern=input_bookcorpus_pattern,
        visual_pattern=input_visual_pattern,
        encoder_dim=FLAGS.encoder_dim,
        ST=FLAGS.ST,
        video_embedding=FLAGS.video_embedding,
        projected_image_dim=int(FLAGS.projected_im),
        text_negative_number=int(FLAGS.neg_words),
        visual_batch_size=int(FLAGS.v_batch),
        lambda_text=float(FLAGS.lambda_text),
        textual=FLAGS.textual,
        visual_batch=int(FLAGS.v_batch),
        bookcorpus_batch_size=batch)

    training_config = configuration.training_config(
        lr=lr, learning_rate_decay_factor=lr_decay)

    tf.logging.info("Building training graph.")
    g = tf.Graph()

    with g.as_default():

        model = skip_thoughts_model.SkipThoughtsModel(model_config,
                                                      mode="train")
        model.build()
        optimizer = tf.train.AdamOptimizer(
            _setup_learning_rate(training_config, model.global_step, "total"))

        ckpt = tf.train.get_checkpoint_state(train_dir)

        train_total = tf.contrib.slim.learning.create_train_op(
            total_loss=model.total_loss,
            optimizer=optimizer,
            global_step=model.global_step,
            clip_gradient_norm=training_config.clip_gradient_norm,
            variables_to_train=tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES))

        if ckpt:
            #ckpt_file = train_dir + "/model.ckpt-" + ckpt.model_checkpoint_path.split("-")[1]
            ckpt_file = ckpt.model_checkpoint_path
            variables_to_restore = optimistic_restore_vars(ckpt_file)
        else:
            variables_to_restore = None

        saver = tf.train.Saver(var_list=variables_to_restore, max_to_keep=2)

        if FLAGS.encoder_train == 0:
            encoder_variables = [
                v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
                if ("encoder" in v.name) or ("w_embedding" in v.name)
            ]
            saver_encoder = tf.train.Saver(var_list=encoder_variables,
                                           max_to_keep=2)
        else:
            encoder_variables = []

        trainable_variables = [
            v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            if v not in encoder_variables
        ]

        train_total = tf.contrib.slim.learning.create_train_op(
            total_loss=model.total_loss,
            optimizer=optimizer,
            global_step=model.global_step,
            clip_gradient_norm=training_config.clip_gradient_norm,
            variables_to_train=trainable_variables)

        if ckpt:
            var_to_init = []
            for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
                if v not in variables_to_restore:
                    if v not in encoder_variables:
                        var_to_init.append(v)
            local_init_op = tf.variables_initializer(var_to_init)
        else:
            local_init_op = tf.global_variables_initializer()

        #tf.logging.info("Variables to init")
        #for v in var_to_init:
        #  print(v.name)
        #tf.logging.info("Variables to restore")
        #for v in variables_to_restore:
        #  print(v.name)

        tf.logging.info("Variables to train")
        for v in trainable_variables:
            print(v.name)

        if ckpt:
            sess = tf.InteractiveSession()
            saver.restore(sess, ckpt.model_checkpoint_path)
            sess.run(local_init_op)
            saver_2 = tf.train.Saver(max_to_keep=2)
            saver_2.save(sess, ckpt.model_checkpoint_path)
            current_step = sess.run(model.global_step)
            sess.close()
        else:
            saver_2 = tf.train.Saver(max_to_keep=2)

        session_config = tf.ConfigProto()
        session_config.gpu_options.allow_growth = True

        tf.contrib.slim.learning.train(
            session_config=session_config,
            train_op=train_total,
            saver=saver_2,
            local_init_op=local_init_op,
            logdir=train_dir,
            graph=g,
            global_step=model.global_step,
            number_of_steps=FLAGS.nb_steps,
            save_summaries_secs=training_config.save_summaries_secs,
            save_interval_secs=training_config.save_model_secs)

    evaluation.evaluate(train_dir, FLAGS.encoder_dim, "T", FLAGS.textual, 0)
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()

    np.random.seed(config_train.seed)

    assert config_train.start_token == 0
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    generator = Generator(config=config_gen)
    generator.build()

    rollout_gen = rollout(config=config_gen)

    #Build target LSTM
    target_params = pickle.load(open('save/target_params.pkl','rb'),encoding='iso-8859-1')
    target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model


    # Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()


    # Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    var_pretrained = [v for v in tf.trainable_variables() if 'teller' in v.name]
    gradients, variables = zip(
        *pretrained_optimizer.compute_gradients(generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_update = pretrained_optimizer.apply_gradients(zip(gradients, variables))

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    generate_samples(sess,target_lstm,config_train.batch_size,config_train.generated_num,config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    log = open('save/experiment-log.txt','w')
    print('Start pre-training generator....')

    log.write('pre-training...\n')

    for epoch in range(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            batch = gen_data_loader.next_batch()
            _,g_loss = sess.run([gen_pre_update,generator.pretrained_loss],feed_dict={generator.input_seqs_pre:batch,
                                                                                      generator.input_seqs_mask:np.ones_like(batch)})

        if epoch % config_train.test_per_epoch == 0:
            #进行测试,通过Generator产生一批序列,
            generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.eval_file)
            # 创建这批序列的data-loader
            likelihood_data_loader.create_batches(config_train.eval_file)
            # 使用oracle 计算 交叉熵损失nll
            test_loss = target_loss(sess,target_lstm,likelihood_data_loader)
            # 打印并写入日志
            print('pre-train ',epoch, ' test_loss ',test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)


    print('Start pre-training discriminator...')
    for t in range(config_train.dis_update_time_pre):
        print("Times: " + str(t))
        generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file)
        dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file)
        for _ in range(config_train.dis_update_time_pre):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch,y_batch = dis_data_loader.next_batch()
                feed_dict = {
                    discriminator.input_x : x_batch,
                    discriminator.input_y : y_batch,
                    discriminator.dropout_keep_prob : config_dis.dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op,feed_dict)



    # Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    # Initialize global variables of optimizer for adversarial training
    uninitialized_var = [e for e in tf.global_variables() if e not in tf.trainable_variables()]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    # Start adversarial training
    for total_batch in range(config_train.total_batch):
        for iter_gen in range(config_train.gen_update_time):
            samples = sess.run(generator.sample_word_list_reshpae)

            feed = {'pred_seq_rollout:0':samples}
            reward_rollout = []
            for iter_roll in range(config_train.rollout_num):
                rollout_list = sess.run(rollout_gen.sample_rollout_step,feed_dict=feed)
                # np.vstack 它是垂直(按照行顺序)的把数组给堆叠起来。
                rollout_list_stack = np.vstack(rollout_list)
                reward_rollout_seq = sess.run(discriminator.ypred_for_auc,feed_dict={
                    discriminator.input_x:rollout_list_stack,discriminator.dropout_keep_prob:1.0
                })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,feed_dict={
                    discriminator.input_x:samples,discriminator.dropout_keep_prob:1.0
                })
                reward_allseq = np.concatenate((reward_rollout_seq,reward_last_tok),axis=0)[:,1]
                reward_tmp = []
                for r in range(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(r,config_gen.gen_batch_size * config_gen.sequence_length,config_gen.gen_batch_size)])

                reward_rollout.append(np.array(reward_tmp))
                rewards = np.sum(reward_rollout,axis = 0) / config_train.rollout_num
                _,gen_loss = sess.run([train_adv_update,generator.gen_loss_adv],feed_dict={generator.input_seqs_adv:samples,
                                                                                           generator.rewards:rewards})


        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print ('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)


        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file)

            for _ in range(config_train.dis_update_time_adv):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch,y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:x_batch,
                        discriminator.input_y:y_batch,
                        discriminator.dropout_keep_prob:config_dis.dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op,feed)

    log.close()