예제 #1
0
파일: train.py 프로젝트: T4vi/LAIF_task
def run_single_epoch(para, sess, model, train_data_generator):
    logging.info("\n\nEpoch: %d" % epoch)
    sess.run(train_data_generator.iterator.initializer)

    start_time = time.time()
    train_loss = 0.0
    count = 0
    while True:
        try:
            [
                loss, global_step, _
            ] = sess.run(fetches=[model.loss, model.global_step, model.update])
            train_loss += loss
            count += 1
        except tf.compat.v1.errors.OutOfRangeError:
            logging.info("global step: %d, loss: %.5f, epoch time: %.3fs",
                         global_step, train_loss / count,
                         time.time() - start_time)
            save_model(para, sess, model)
            break
예제 #2
0
def pretrain(para, sess, model):
    """The procedure to pretrain the model"""
    embed_dct = read_all_embedding()

    step_time = 0.0
    for step in range(1, para.steps + 1):
        start_time = time.time()

        [raw_rnn_inputs, raw_rnn_inputs_len, raw_target_outputs] = \
            sess.run(
                fetches=[
                    model.raw_rnn_inputs,
                    model.raw_rnn_inputs_len,
                    model.raw_target_outputs,
                ]
            )

        rnn_inputs_embedded = read_pretrained_embedding(
            para, embed_dct, raw_rnn_inputs)

        [loss, _] = sess.run(fetches=[
            model.loss,
            model.update,
        ],
                             feed_dict={
                                 model.rnn_inputs_embedded:
                                 rnn_inputs_embedded,
                                 model.rnn_inputs_len: raw_rnn_inputs_len,
                                 model.target_outputs: raw_target_outputs,
                             })

        step_time += (time.time() - start_time)
        if step % para.steps_per_stats == 0 or step == 1:
            logging.info('step: %d, perplexity: %.2f, step_time: %.2f => save model to %s', \
                         step, np.exp(loss), step_time / para.steps_per_stats,
                         save_model_dir(para))
            save_model(para, sess, model)
            step_time = 0.0
            if para.debug == 1:
                exit()
예제 #3
0
파일: train.py 프로젝트: zyf0220/TPA-LSTM
def train(para, sess, model, train_data_generator):
    valid_para, valid_graph, valid_model, valid_data_generator = \
        create_valid_graph(para)

    with tf.Session(config=config_setup(), graph=valid_graph) as valid_sess:
        valid_sess.run(tf.global_variables_initializer())

        for epoch in range(1, para.num_epochs + 1):
            logging.info("Epoch: %d" % epoch)
            sess.run(train_data_generator.iterator.initializer)

            start_time = time.time()
            train_loss = 0.0
            count = 0
            while True:
                try:
                    [loss, global_step, _] = sess.run(
                        fetches=[model.loss, model.global_step, model.update])
                    train_loss += loss
                    count += 1
                except tf.errors.OutOfRangeError:
                    logging.info(
                        "global step: %d, loss: %.5f, epoch time: %.3f",
                        global_step, train_loss / count,
                        time.time() - start_time)
                    save_model(para, sess, model)
                    break

            # validation
            load_weights(valid_para, valid_sess, valid_model)
            valid_sess.run(valid_data_generator.iterator.initializer)
            valid_loss = 0.0
            valid_rse = 0.0
            count = 0
            n_samples = 0
            all_outputs, all_labels = [], []
            while True:
                try:
                    [loss, outputs, labels] = valid_sess.run(fetches=[
                        valid_model.loss,
                        valid_model.all_rnn_outputs,
                        valid_model.labels,
                    ])
                    if para.mts:
                        valid_rse += np.sum(
                            ((outputs - labels) * valid_data_generator.scale)
                            **2)
                        all_outputs.append(outputs)
                        all_labels.append(labels)
                        n_samples += np.prod(outputs.shape)
                    valid_loss += loss
                    count += 1
                except tf.errors.OutOfRangeError:
                    break
            if para.mts:
                all_outputs = np.concatenate(all_outputs)
                all_labels = np.concatenate(all_labels)
                sigma_outputs = all_outputs.std(axis=0)
                sigma_labels = all_labels.std(axis=0)
                mean_outputs = all_outputs.mean(axis=0)
                mean_labels = all_labels.mean(axis=0)
                idx = sigma_labels != 0
                valid_corr = ((all_outputs - mean_outputs) *
                              (all_labels - mean_labels)).mean(
                                  axis=0) / (sigma_outputs * sigma_labels)
                valid_corr = valid_corr[idx].mean()
                valid_rse = (
                    np.sqrt(valid_rse / n_samples) / train_data_generator.rse)
                valid_loss /= count
                logging.info(
                    "validation loss: %.5f, validation rse: %.5f, validation corr: %.5f",
                    valid_loss, valid_rse, valid_corr)
            else:
                logging.info("validation loss: %.5f", valid_loss / count)
예제 #4
0
파일: train.py 프로젝트: T4vi/LAIF_task
def train(para, sess, model, train_data_generator):
    valid_para, valid_graph, valid_model, valid_data_generator = \
        create_valid_graph(para)

    with tf.Session(config=config_setup(), graph=valid_graph) as valid_sess:
        valid_sess.run(tf.global_variables_initializer())

        for epoch in range(1, para.num_epochs + 1):
            logging.info("\n\nEpoch: %d" % epoch)
            sess.run(train_data_generator.iterator.initializer)

            start_time = time.time()
            train_loss = 0.0
            count = 0
            while True:
                try:
                    [loss, global_step, _] = sess.run(
                        fetches=[model.loss, model.global_step, model.update])
                    train_loss += loss
                    count += 1

                    if count % 25 == 0:
                        # print(count, end=' ')
                        logging.debug(count)
                except tf.compat.v1.errors.OutOfRangeError:
                    logging.info(
                        "global step: %d, loss: %.5f, epoch time: %.3fs",
                        global_step, train_loss / count,
                        time.time() - start_time)
                    #if para.save_models:
                    save_model(para, sess, model)
                    break

            # validation
            # if para.save_models:
            load_weights(valid_para, valid_sess, valid_model)
            # else:
            #     valid_model = model
            valid_sess.run(valid_data_generator.iterator.initializer)
            valid_loss = 0.0
            valid_rse = 0.0
            tp, fp, tn, fn = 0, 0, 0, 0
            count = 0
            n_samples = 0
            all_outputs, all_labels = [], []
            while True:
                try:
                    [loss, outputs, labels] = valid_sess.run(fetches=[
                        valid_model.loss,
                        valid_model.all_rnn_outputs,
                        valid_model.labels,
                    ])
                    if para.mts:
                        valid_rse += np.sum(((outputs - labels) *
                                             valid_data_generator.scale)**2)
                        all_outputs.append(outputs)
                        all_labels.append(labels)
                        n_samples += np.prod(outputs.shape)
                    elif para.data_set == 'muse' or para.data_set == 'lpd5':
                        # print(np.shape(outputs))
                        # print(np.shape(labels))
                        # print(para.batch_size)
                        for b in range(
                                np.shape(outputs)[0]
                        ):  ##era para.batchsize, da' ultimul batch avea mai putine el
                            for p in range(128):
                                if outputs[b][p] >= 0.5 and labels[b][p] >= 0.5:
                                    tp += 1
                                elif outputs[b][p] >= 0.5 and labels[b][
                                        p] < 0.5:
                                    fp += 1
                                elif outputs[b][p] < 0.5 and labels[b][p] < 0.5:
                                    tn += 1
                                elif outputs[b][p] < 0.5 and labels[b][
                                        p] >= 0.5:
                                    fn += 1
                        # print([tp, fp, tn, fn])
                    valid_loss += loss
                    count += 1
                except tf.errors.OutOfRangeError:
                    break
            if para.mts:
                all_outputs = np.concatenate(all_outputs)
                all_labels = np.concatenate(all_labels)
                sigma_outputs = all_outputs.std(axis=0)
                sigma_labels = all_labels.std(axis=0)
                mean_outputs = all_outputs.mean(axis=0)
                mean_labels = all_labels.mean(axis=0)
                idx = sigma_labels != 0
                valid_corr = ((all_outputs - mean_outputs) *
                              (all_labels - mean_labels)).mean(
                                  axis=0) / (sigma_outputs * sigma_labels)
                valid_corr = valid_corr[idx].mean()
                valid_rse = (np.sqrt(valid_rse / n_samples) /
                             train_data_generator.rse)
                valid_loss /= count
                logging.info(
                    "validation loss: %.5f, validation rse: %.5f, validation corr: %.5f",
                    valid_loss, valid_rse, valid_corr)

            elif para.data_set == 'muse' or para.data_set == 'lpd5':
                if (tp != 0 or fp != 0):
                    precision = tp / (tp + fp)
                else:
                    precision = 0
                recall = tp / (tp + fn)
                if precision + recall >= 1e-6:
                    F1 = 2 * precision * recall / (precision + recall)
                else:
                    F1 = 0.0
                logging.info('validation loss: %.5f', valid_loss / count)
                logging.info('precision: %.5f', precision)
                logging.info('recall: %.5f', recall)
                logging.info('F1 score: %.5f', F1)
예제 #5
0
def policy_gradient(para, sess, model): # pylint: disable=too-many-locals
    """The procedure of policy gradient reinforcement learning"""
    embed_dct = read_all_embedding()

    seed_id_list = read_all_seed_ids()

    rev_vocab = read_rev_vocab()

    original_para, original_graph, original_model = create_original_graph(para)

    with tf.Session(config=config_setup(), graph=original_graph) as original_sess:
        load_original_weights(original_para, original_sess, original_model)

        step_time = 0.0
        for step in range(1, para.steps + 1):
            start_time = time.time()

            chosen_ids = random.sample(range(0, len(seed_id_list)), para.batch_size)
            seed_ids = [seed_id_list[idx] for idx in chosen_ids]

            output_lengths = random.randint(1, para.max_len - 1)

            # raw_rnn_inputs: [batch_size, output_lengths]
            raw_rnn_inputs, _ = predict(rev_vocab, embed_dct, para, sess, model,
                                        seed_ids, output_lengths, True)
            # raw_rnn_inputs_len: [batch_size]
            raw_rnn_inputs_len = np.array([output_lengths] * para.batch_size)

            # raw_inputs_embedded: [batch_size, output_lengths, embedding_size]
            rnn_inputs_embedded = read_pretrained_embedding(para, embed_dct, raw_rnn_inputs)

            # get original probs
            [probs] = original_sess.run(
                fetches=[original_model.probs],
                feed_dict={
                    original_model.rnn_inputs_embedded: rnn_inputs_embedded,
                    original_model.rnn_inputs_len: raw_rnn_inputs_len,
                })

            # get sampled ids
            [sampled_ids] = sess.run(
                fetches=[model.sampled_ids],
                feed_dict={
                    model.rnn_inputs_embedded: rnn_inputs_embedded,
                    model.rnn_inputs_len: raw_rnn_inputs_len,
                })
            sampled_ids = np.reshape(sampled_ids, (para.batch_size))

            # get reward
            rewards, msg = reward_functions(para, raw_rnn_inputs, raw_rnn_inputs_len,
                                            sampled_ids, probs)
            [_] = sess.run(
                fetches=[model.rl_update],
                feed_dict={
                    model.rnn_inputs_embedded: rnn_inputs_embedded,
                    model.rnn_inputs_len: raw_rnn_inputs_len,
                    model.sampled_ids_inputs: sampled_ids,
                    model.rewards: rewards
                })

            step_time += (time.time() - start_time)
            if step % para.steps_per_stats == 0 or step == 1:
                logging.info('step: %d, reward: %.5f, step_time: %.2f => save model to %s',
                             step, msg['mean_reward'], step_time / para.steps_per_stats,
                             save_model_dir(para))
                for key, value in msg.items():
                    if key == 'mean_reward':
                        continue
                    logging.info('%s: %.2f', key, value)
                save_model(para, sess, model)
                step_time = 0

                if para.debug == 1:
                    exit()