Exemplo n.º 1
0
def train_cvrnn_model(model, sess, config, input_data, target_data,
                      trace_lengths, selection_matrix):
    [
        output_x,
        # total_loss,
        kl_loss,
        likelihood_loss,
        _
    ] = sess.run(
        [
            model.output,
            # model.cost,
            model.kl_loss,
            model.likelihood_loss,
            model.train_general_op
        ],
        feed_dict={
            model.input_data_ph: input_data,
            model.target_data_ph: target_data,
            model.trace_length_ph: trace_lengths,
            model.selection_matrix_ph: selection_matrix
        })
    output_decoder = []
    kld_all = []
    ll_all = []
    for batch_index in range(0, len(output_x)):
        output_decoder_batch = []
        for trace_length_index in range(0, config.Learn.max_seq_length):
            if selection_matrix[batch_index][trace_length_index]:
                output_decoder_batch.append(
                    output_x[batch_index][trace_length_index])
                kld_all.append(kl_loss[batch_index][trace_length_index])
                ll_all.append(likelihood_loss[batch_index][trace_length_index])
            else:
                output_decoder_batch.append(
                    np.asarray([0] * config.Arch.CVRNN.x_dim))
        output_decoder.append(output_decoder_batch)
    output_decoder = np.asarray(output_decoder)

    acc = compute_rnn_acc(output_prob=output_decoder,
                          target_label=target_data,
                          selection_matrix=selection_matrix,
                          config=config)

    print_number = random.randint(1, 21)
    if print_number == 1:
        print("acc is {0} with mean KLD:{1} and mean ll:{2} ".format(
            acc,
            np.mean(kl_loss),
            np.mean(likelihood_loss),
        ))
    # print acc
    # if cost_out > 0.0001: # TODO: we still need to consider how to define convergence
    #     converge_flag = False
    cost_out = likelihood_loss + kl_loss

    return kld_all, ll_all
def train_clvrnn_model(model, sess, config, input_data, target_data,
                       trace_lengths, selection_matrix):
    global TRAINING_ITERATIONS
    beta_kld = TRAINING_ITERATIONS * 0.00001 if TRAINING_ITERATIONS * 0.00001 < 1 else 1

    [output_x, kl_loss, likelihood_loss, _] = sess.run(
        [
            model.output, model.kl_loss, model.likelihood_loss,
            model.train_general_op
        ],
        feed_dict={
            model.input_data_ph: input_data,
            model.target_data_ph: target_data,
            model.trace_length_ph: trace_lengths,
            model.selection_matrix_ph: selection_matrix,
            model.kld_beta: beta_kld,
            # model.kld_beta: 0,
        })
    output_decoder = []
    prior_output_decoder = []

    kld_all = []
    ll_all = []
    for batch_index in range(0, len(output_x)):
        output_decoder_batch = []
        prior_output_decoder_batch = []
        for trace_length_index in range(0, config.Learn.max_seq_length):
            if selection_matrix[batch_index][trace_length_index]:
                output_decoder_batch.append(
                    output_x[batch_index][trace_length_index])
                kld_all.append(kl_loss[batch_index][trace_length_index])
                ll_all.append(likelihood_loss[batch_index][trace_length_index])
            else:
                output_decoder_batch.append(
                    np.asarray([0] * config.Arch.CLVRNN.x_dim))
        output_decoder.append(output_decoder_batch)
        prior_output_decoder.append(prior_output_decoder_batch)
    output_decoder = np.asarray(output_decoder)

    acc = compute_rnn_acc(output_prob=output_decoder,
                          target_label=target_data,
                          selection_matrix=selection_matrix,
                          config=config)

    if TRAINING_ITERATIONS % 20 == 0:
        print(
            "Embed acc is {0} with mean KLD:{1} and mean ll:{2} with beta:{3} "
            .format(acc, np.mean(kld_all), np.mean(ll_all), beta_kld))
    TRAINING_ITERATIONS += 1
    # print acc
    # if cost_out > 0.0001: # TODO: we still need to consider how to define convergence
    #     converge_flag = False
    cost_out = likelihood_loss + kl_loss

    return kld_all, ll_all
def validate_model(testing_dir_games_all,
                   data_store,
                   config,
                   sess,
                   model,
                   player_id_cluster_dir,
                   train_game_number,
                   validate_cvrnn_flag=True,
                   validate_td_flag=True):
    output_decoder_all = None
    target_data_all = None
    selection_matrix_all = None
    q_values_all = None
    validate_variance_flag = False
    print('validating model')
    for dir_game in testing_dir_games_all:
        print('validating game {0}'.format(str(dir_game)))
        if dir_game == '.DS_Store':
            continue

        [
            output_decoder_all, target_data_all, selection_matrix_all,
            q_values_all, match_q_values_players_dict
        ] = gathering_running_and_run(
            dir_game,
            config,
            player_id_cluster_dir,
            data_store,
            model,
            sess,
            training_flag=False,
            game_number=None,
            validate_cvrnn_flag=validate_cvrnn_flag,
            validate_td_flag=validate_td_flag,
            validate_variance_flag=validate_variance_flag,
            output_decoder_all=output_decoder_all,
            target_data_all=target_data_all,
            selection_matrix_all=selection_matrix_all,
            q_values_all=q_values_all)
        # validate_variance_flag = False
        if match_q_values_players_dict is not None:
            plot_players_games(match_q_values_players_dict, train_game_number)

    if validate_cvrnn_flag:
        acc = compute_rnn_acc(output_prob=output_decoder_all,
                              target_label=target_data_all,
                              selection_matrix=selection_matrix_all,
                              config=config,
                              if_print=True)
        print("testing acc is {0}".format(str(acc)))
    if validate_td_flag:
        print("testing avg qs is {0}".format(str(np.mean(q_values_all,
                                                         axis=0))))
Exemplo n.º 4
0
def validate_model(testing_dir_games_all,
                   data_store,
                   source_data_dir,
                   config,
                   sess,
                   model,
                   player_id_cluster_dir,
                   train_game_number,
                   validate_cvrnn_flag,
                   validate_td_flag,
                   validate_diff_flag,
                   validate_pred_flag,
                   file_writer=None):
    output_decoder_all = None
    target_data_all = None
    selection_matrix_all = None
    q_values_all = None
    validate_variance_flag = False
    pred_target_data_all = None
    pred_output_prob_all = None

    if validate_diff_flag:
        real_label_record = np.ones([len(testing_dir_games_all), 5000]) * -1
        output_label_record = np.ones([len(testing_dir_games_all), 5000]) * -1

    print('validating model')
    for dir_index in range(0, len(testing_dir_games_all)):

        real_label_all = None
        output_label_all = None

        dir_game = testing_dir_games_all[dir_index]
        print('validating game {0}'.format(str(dir_game)))
        if dir_game == '.DS_Store':
            continue

        [
            output_decoder_all, target_data_all, selection_matrix_all,
            q_values_all, real_label_all, output_label_all,
            pred_target_data_all, pred_output_prob_all,
            match_q_values_players_dict
        ] = gathering_data_and_run(
            dir_game,
            config,
            player_id_cluster_dir,
            data_store,
            source_data_dir,
            model,
            sess,
            training_flag=False,
            game_number=None,
            validate_cvrnn_flag=validate_cvrnn_flag,
            validate_td_flag=validate_td_flag,
            validate_diff_flag=validate_diff_flag,
            validate_variance_flag=validate_variance_flag,
            validate_predict_flag=validate_pred_flag,
            output_decoder_all=output_decoder_all,
            target_data_all=target_data_all,
            selection_matrix_all=selection_matrix_all,
            q_values_all=q_values_all,
            output_label_all=output_label_all,
            real_label_all=real_label_all,
            pred_target_data_all=pred_target_data_all,
            pred_output_prob_all=pred_output_prob_all)
        # validate_variance_flag = False
        # if match_q_values_players_dict is not None:
        #     plot_players_games(match_q_values_players_dict, train_game_number)

        if validate_diff_flag:
            real_label_record[dir_index][:len(
                real_label_all)] = real_label_all[:len(real_label_all)]
            output_label_record[dir_index][:len(
                output_label_all)] = output_label_all[:len(output_label_all)]

    if validate_cvrnn_flag:
        acc = compute_rnn_acc(output_prob=output_decoder_all,
                              target_label=target_data_all,
                              selection_matrix=selection_matrix_all,
                              config=config,
                              if_print=True)
        print("testing acc is {0}".format(str(acc)))
        if file_writer is not None:
            file_writer.write("testing acc is {0}\n".format(str(acc)))
def train_cvrnn_model(model,
                      sess,
                      config,
                      input_data,
                      target_data,
                      trace_lengths,
                      selection_matrix,
                      pretrain_flag=False):
    if pretrain_flag:
        [
            output_x,
            # total_loss,
            kl_loss,
            likelihood_loss,
            _
        ] = sess.run(
            [
                model.output,
                # model.cost,
                model.kl_loss,
                model.likelihood_loss,
                model.train_ll_op
            ],
            feed_dict={
                model.input_data_ph: input_data,
                model.target_data_ph: target_data,
                model.trace_length_ph: trace_lengths,
                model.selection_matrix_ph: selection_matrix
            })
    else:
        [
            output_x,
            # total_loss,
            kl_loss,
            likelihood_loss,
            _
        ] = sess.run(
            [
                model.output,
                # model.cost,
                model.kl_loss,
                model.likelihood_loss,
                model.train_general_op
            ],
            feed_dict={
                model.input_data_ph: input_data,
                model.target_data_ph: target_data,
                model.trace_length_ph: trace_lengths,
                model.selection_matrix_ph: selection_matrix
            })
    output_decoder = []
    for batch_index in range(0, len(output_x)):
        output_decoder_batch = []
        for trace_length_index in range(0, config.Learn.max_seq_length):
            if selection_matrix[batch_index][trace_length_index]:
                output_decoder_batch.append(
                    output_x[batch_index][trace_length_index])
            else:
                output_decoder_batch.append(
                    np.asarray([0] * config.Arch.CVRNN.x_dim))
        output_decoder.append(output_decoder_batch)
    output_decoder = np.asarray(output_decoder)

    acc = compute_rnn_acc(output_prob=output_decoder,
                          target_label=target_data,
                          selection_matrix=selection_matrix,
                          config=config)
    # print acc
    # if cost_out > 0.0001: # TODO: we still need to consider how to define convergence
    #     converge_flag = False
    cost_out = likelihood_loss + kl_loss
def validate_model(testing_dir_games_all,
                   data_store,
                   source_data_dir,
                   config,
                   sess,
                   model,
                   player_id_cluster_dir,
                   train_game_number,
                   validate_cvrnn_flag,
                   validate_td_flag,
                   validate_diff_flag,
                   validate_pred_flag,
                   file_writer=None):
    output_decoder_all = None
    target_data_all = None
    selection_matrix_all = None
    q_values_all = None
    validate_variance_flag = False
    pred_target_data_all = None
    pred_output_prob_all = None

    if validate_diff_flag:
        real_label_record = np.ones([len(testing_dir_games_all), 5000]) * -1
        output_label_record = np.ones([len(testing_dir_games_all), 5000]) * -1

    print('validating model')
    for dir_index in range(0, len(testing_dir_games_all)):

        real_label_all = None
        output_label_all = None

        dir_game = testing_dir_games_all[dir_index]
        print('validating game {0}'.format(str(dir_game)))
        if dir_game == '.DS_Store':
            continue

        [
            output_decoder_all, target_data_all, selection_matrix_all,
            q_values_all, real_label_all, output_label_all,
            pred_target_data_all, pred_output_prob_all,
            match_q_values_players_dict
        ] = gathering_data_and_run(
            dir_game,
            config,
            player_id_cluster_dir,
            data_store,
            source_data_dir,
            model,
            sess,
            training_flag=False,
            game_number=None,
            validate_cvrnn_flag=validate_cvrnn_flag,
            validate_td_flag=validate_td_flag,
            validate_diff_flag=validate_diff_flag,
            validate_variance_flag=validate_variance_flag,
            validate_predict_flag=validate_pred_flag,
            output_decoder_all=output_decoder_all,
            target_data_all=target_data_all,
            selection_matrix_all=selection_matrix_all,
            q_values_all=q_values_all,
            output_label_all=output_label_all,
            real_label_all=real_label_all,
            pred_target_data_all=pred_target_data_all,
            pred_output_prob_all=pred_output_prob_all)
        # validate_variance_flag = False
        # if match_q_values_players_dict is not None:
        #     plot_players_games(match_q_values_players_dict, train_game_number)

        if validate_diff_flag:
            real_label_record[dir_index][:len(
                real_label_all)] = real_label_all[:len(real_label_all)]
            output_label_record[dir_index][:len(
                output_label_all)] = output_label_all[:len(output_label_all)]

    if validate_cvrnn_flag:
        acc = compute_rnn_acc(output_prob=output_decoder_all,
                              target_label=target_data_all,
                              selection_matrix=selection_matrix_all,
                              config=config,
                              if_print=True)
        print("testing acc is {0}".format(str(acc)))
        if file_writer is not None:
            file_writer.write("testing acc is {0}\n".format(str(acc)))
    # if validate_td_flag:
    #     print ("testing avg qs is {0}".format(str(np.mean(q_values_all, axis=0))))
    #     if file_writer is not None:
    #         file_writer.write("testing avg qs is {0}\n".format(str(np.mean(q_values_all, axis=0))))

    if validate_diff_flag:
        # print ('general real label is {0}'.format(str(np.sum(real_label_record, axis=1))))
        # print ('general output label is {0}'.format(str(np.sum(output_label_record, axis=1))))
        for i in range(0, output_label_record.shape[1]):
            real_outcome_record_step = real_label_record[:, i]
            model_output_record_step = output_label_record[:, i]
            diff_sum = 0
            total_number = 0
            print_flag = True
            for win_index in range(0, len(real_outcome_record_step)):
                if model_output_record_step[
                        win_index] == -100 or real_outcome_record_step[
                            win_index] == -100:
                    print_flag = True
                    continue
                diff = abs(model_output_record_step[win_index] -
                           real_outcome_record_step[win_index])
                diff_sum += diff
                total_number += 1
            if print_flag:
                if i % 100 == 0 and total_number > 0:
                    print('diff of time {0} is {1}'.format(
                        str(i), str(float(diff_sum) / total_number)))
                    if file_writer is not None:
                        file_writer.write('diff of time {0} is {1}\n'.format(
                            str(i), str(float(diff_sum) / total_number)))
    if validate_pred_flag:
        TP, TN, FP, FN, acc, ll, auc = compute_acc(pred_target_data_all,
                                                   pred_output_prob_all,
                                                   if_binary_result=True,
                                                   if_print=False,
                                                   if_add_ll=True)
        precision = float(TP) / (TP + FP)
        recall = float(TP) / (TP + FN)
        print("Prediction acc is {0} with precision {1} and recall {2}".format(
            acc, precision, recall))