def test(team_raw_data):
    test_x = []
    y = []
    train_data = test_data_func()
    __boost = boost.Boost()
    __boost.load_model()
    for x in train_data:
        home = x[1]
        away = x[0]
        home_vector = team_raw_data[home]
        away_vector = team_raw_data[away]

        away_state = x[2:4]
        home_state = x[4:6]

        input_vector = home_vector.tolist() + home_state + away_vector.tolist(
        ) + away_state
        y.append(x[-1])
        test_x.append(input_vector)
    pred_y = __boost.predict(test_x)
    auc_ = auc(y, pred_y)
    acc(y, pred_y)
    print("AUC:%s" % auc_)

    with open('log/xgboost_test.log', 'w+') as f:
        for i in range(len(pred_y)):
            f.write('%s,%s' % (y[i], pred_y[i]))
            f.write('\n')
Example #2
0
def test(team_data, opt):
    data_provider = MatchData(1000)
    data_provider.roll_data()

    #testing_data = data_provider.get_test_data()
    testing_data = test_data_func()

    log_file = open('testing.log', 'w+')

    correct = 0
    wrong = 0
    y_label = []
    y_pred = []
    for i in range(len(testing_data)):

        away_id = testing_data[i][0]
        home_id = testing_data[i][1]

        away_current_state = testing_data[i][2:4]
        home_current_state = testing_data[i][4:6]
        score = [testing_data[i][7], testing_data[i][6]]
        away_vector = team_data[away_id]
        home_vector = team_data[home_id]
        result = [testing_data[i][8]]

        prob = predict(opt.model_name,
                       home_state=home_current_state,
                       home_vector=home_vector,
                       away_state=away_current_state,
                       away_vector=away_vector,
                       opt=opt)

        pred_win = np.argmax(prob.data.cpu().numpy())

        y_label.append(result)
        y_pred.append(prob.data.cpu().numpy()[1])

        if pred_win == result:
            correct += 1
            #line = 'Test: %s Correct! Confidence=%s' % (i, prob.data[pred_win])
        else:
            wrong += 1
            #line = 'Test: %s Wrong! Confidence=%s' % (i, prob.data.cpu().numpy().tolist())

        # print(line)
        # log_file.write(line+'\n')
    auc_score = auc(y_label, y_pred)
    print("TEST: auc score: %s" % auc_score)
    # log_file.close()

    print("Wrong: %s Correct: %s" % (wrong, correct))
def test(team_raw_data, opt):
    __test_data = test_data_func()
    bayes_model = bayes.Bayes()
    bayes_model.load_model()
    log_file = open('log/svm.log', 'w+')
    correct = 0
    wrong = 0
    probs = []
    y = []
    for x in __test_data:
        home = x[1]
        away = x[0]
        home_vector = team_raw_data[home]
        away_vector = team_raw_data[away]
        away_state = np.array(x[2:4])
        home_state = np.array(x[4:6])
        input_vector = home_vector.tolist() + away_vector.tolist(
        ) + home_state.tolist() + away_state.tolist()
        input_vector = np.array(input_vector)

        pred = bayes_model.predict([input_vector])[0]
        if pred[0] > pred[1]:
            pred_id = 0
        else:
            pred_id = 1
        line = 'Bayes: Pred:%s Real: %s Confidence=%s' % (pred_id, x[-1],
                                                          pred[pred_id])
        if pred_id == x[-1]:
            correct += 1
        else:
            wrong += 1
        print(line)
        probs.append(pred[pred_id])
        y.append(x[-1])
        log_file.write(line + '\n')

    print("Correct: %s Wrong: %s" % (correct, wrong))
    print("Acc=%s" % (float(correct) / float(correct + wrong)))
    print("AUC=%s" % (auc(y, probs)))
    log_file.close()
Example #4
0
def test(team_data, opt):

    testing_data = test_data_func()

    correct = 0
    wrong = 0
    y_label = []
    y_pred = []
    for i in range(len(testing_data)):

        away_id = testing_data[i][0]
        home_id = testing_data[i][1]

        away_current_state = testing_data[i][2:4]
        home_current_state = testing_data[i][4:6]
        score = [testing_data[i][7], testing_data[i][6]]
        away_vector = team_data[away_id]
        home_vector = team_data[home_id]
        result = [testing_data[i][8]]

        prob = predict(opt.model_name,
                       home_state=home_current_state,
                       home_vector=home_vector,
                       away_state=away_current_state,
                       away_vector=away_vector,
                       opt=opt)

        pred_win = np.argmax(prob.data.cpu().numpy())

        y_label.append(result)
        y_pred.append(prob.data.cpu().numpy()[1])

        if pred_win == result:
            correct += 1
        else:
            wrong += 1
    auc_score = auc(y_label, y_pred)
    print("TEST: auc score: %s" % auc_score)

    print("Wrong: %s Correct: %s" % (wrong, correct))
Example #5
0
    for n in data.columns.values[:-1]:
        bins = bin.chi_merge(n)
        woe.add_woe_col(data, bins)

    # 单变量ar值计算
    # ar = ARUtil.cal_ar(data['SepalWidth_woe'], data['Label'])

    train_data, test_data = split_data(data, 0.7)
    model = modeling.model(
        train_data, ['SepalLength_woe', 'PetalLength_woe', 'PetalWidth_woe'],
        'Label')
    predict_score = modeling.score_trans(
        test_data[['SepalLength_woe', 'PetalLength_woe', 'PetalWidth_woe']],
        model, 300, 25)
    pprint(list(zip(test_data['Label'].values, predict_score)))
    auc = evaluate.auc(
        model, test_data[[
            'SepalLength_woe', 'PetalLength_woe', 'PetalWidth_woe', 'Label'
        ]])
    print("auc值: " + str(auc))
    evaluate.roc(
        model, test_data[[
            'SepalLength_woe', 'PetalLength_woe', 'PetalWidth_woe', 'Label'
        ]])

    # select_func = feature_selection.fea_select(data[['SepalLength', 'SepalWidth']], data['Label'], 1)
    # print(select_func.transform(data[['SepalLength', 'SepalWidth']]))

    # feature_selection.fea_select(data[['SepalLength_woe', 'SepalWidth_woe']], data['Label'])
    # feature_selection.mi(data['SepalWidth_woe'], data['Label'])
Example #6
0
def train_fn(with_facts, with_constraints, iterations, KB, prior_mean,
             prior_lambda, sess, kb_label):
    random.seed(config.RANDOM_SEED)
    train_writer = tf.summary.FileWriter('logging/' + kb_label + '/train',
                                         graph=None)
    test_writer = tf.summary.FileWriter('logging/' + kb_label + '/test',
                                        graph=None)
    modus_ponens_writer = tf.summary.FileWriter('logging/' + kb_label +
                                                '/modus_ponens',
                                                graph=None)
    modus_tollens_writer = tf.summary.FileWriter('logging/' + kb_label +
                                                 '/modus_tollens',
                                                 graph=None)

    train_kb = True
    for i in range(iterations):
        ti = time.time()
        if i % config.FREQ_OF_FEED_DICT_GENERATION == 0:
            train_kb = True
            feed_dict = get_feed_dict(pairs_of_train_data,
                                      with_constraints=with_constraints,
                                      with_facts=with_facts)
            feed_dict[KB.prior_mean] = prior_mean
            feed_dict[KB.prior_lambda] = prior_lambda
            feed_dict[weight_ontology] = config.WEIGHT_ONTOLOGY_CLAUSES_START if \
                i < config.ITERATIONS_UNTIL_WEIGHT_SWAP else config.WEIGHT_ONTOLOGY_CLAUSES_END

        if train_kb:
            sat_level, normal_loss, reg_loss = KB.train(sess, feed_dict)
            if np.isnan(sat_level):
                train_kb = False
            if normal_loss < 0:  # Using log-likelihood aggregation
                sat_level = np.exp(-sat_level)
            if sat_level >= config.SATURATION_LIMIT:
                train_kb = False
        if i % config.FREQ_OF_PRINT == 0:
            iter_time = time.time() - ti
            feed_dict = get_feed_dict(pairs_of_train_data)
            for kb in [KB_facts, KB_rules, KB_full, KB_ontology, KB_logical]:
                feed_dict[kb.prior_mean] = prior_mean
                feed_dict[kb.prior_lambda] = prior_lambda
                feed_dict[weight_ontology] = config.WEIGHT_ONTOLOGY_CLAUSES_START if \
                    i < config.ITERATIONS_UNTIL_WEIGHT_SWAP else config.WEIGHT_ONTOLOGY_CLAUSES_END

            summary = sess.run(summary_merge, feed_dict)
            train_writer.add_summary(summary, i)

            print(i, 'Sat level', str(sat_level), 'loss', normal_loss,
                  'regularization', reg_loss, 'iteration time', iter_time)
        if i + 1 % config.FREQ_OF_SAVE == 0:
            KB.save(sess, kb_label)
        if i % config.FREQ_OF_TEST == 0:
            predicted_types_values_tensor = tf.concat(
                [isOfType[t].tensor() for t in selected_types], 1)
            predicted_partOf_value_tensor = ltn.Literal(
                True, isPartOf, pairs_of_objects).tensor

            feed_dict = {}
            feed_dict[objects.tensor] = test_data[:, 1:]
            feed_dict[pairs_of_objects.tensor] = pairs_of_test_data
            feed_dict[weight_ontology] = config.WEIGHT_ONTOLOGY_CLAUSES_START if \
                i < config.ITERATIONS_UNTIL_WEIGHT_SWAP else config.WEIGHT_ONTOLOGY_CLAUSES_END

            chosen_pairs = np.random.choice(range(pairs_of_test_data.shape[0]),
                                            config.NUMBER_PAIRS_AXIOMS_TESTING)
            # chosen_pairs = np.random.choice(idxs_of_test_pos_examples_of_partOf, config.NUMBER_PAIRS_AXIOMS_TESTING)

            feed_dict_rules(feed_dict, pairs_of_test_data[chosen_pairs])

            tensors_to_retrieve = [predicted_types_values_tensor, predicted_partOf_value_tensor, rules_summary] \
                                  + grad_MP_tensors + grad_MT_tensors + literal_debug1

            outputs = sess.run(tensors_to_retrieve, feed_dict)
            values_of_types = outputs[0]
            values_of_partOf = outputs[1]
            summary_r = outputs[2]
            r_ind = 3
            if config.USE_IMPLICATION_CLAUSES:
                grad_MP = outputs[3:3 + len(rules_ontology)]
                r_ind = 3 + 2 * len(rules_ontology)
                grad_MT = outputs[3 + len(rules_ontology):r_ind]

            literals = outputs[r_ind:]

            is_pair_of = partOF_of_pairs_of_test_data[chosen_pairs]
            bb_ids = pairs_of_bb_idxs_test[chosen_pairs]

            n_correct_mp_reasoning = 0
            n_wrong_mp_reasoning = 0
            n_correct_mp_updates = 0
            n_pos_correct_mp_reasoning = 0

            tot_mp_grad_magn = 0
            correct_mp_reason_magn = 0
            correct_mp_upd_magn = 0

            n_correct_mt_reasoning = 0
            n_wrong_mt_reasoning = 0
            n_correct_mt_updates = 0
            n_pos_correct_mt_reasoning = 0

            tot_mt_grad_magn = 0
            correct_mt_reason_magn = 0
            correct_mt_upd_magn = 0

            tot_evals = len(rules_ontology) * len(chosen_pairs)

            for r_i in range(len(rules_ontology)):
                rule = rules_ontology[r_i]
                wholes_of = rule in clauses_for_wholes_of_parts
                if config.PRINT_GRAD_DEBUG:
                    print(' and '.join([l.name
                                        for l in rule.literals[:2]]) + "->" +
                          ' or '.join([l.name for l in rule.literals[2:]]))
                for j in range(len(chosen_pairs)):
                    label1, label2 = bb_ids[j][0], bb_ids[j][1]
                    t1 = types_of_test_data[label1]
                    t2 = types_of_test_data[label2]
                    x = t1 if wholes_of else t2
                    y = t2 if wholes_of else t1
                    conseq_true = y in [l.name for l in rule.literals[2:]]
                    ant_true = is_pair_of[j] and x == rule.literals[0].name
                    is_correct_mp_reasoning = ant_true and conseq_true
                    is_correct_mt_reasoning = (not ant_true) and (
                        not conseq_true)

                    grad_mp_xy = 0
                    grad_mt_xy = 0

                    if is_correct_mp_reasoning:
                        n_pos_correct_mp_reasoning += 1
                    if is_correct_mt_reasoning:
                        n_pos_correct_mt_reasoning += 1
                    if config.USE_IMPLICATION_CLAUSES:
                        grad_mp_xy = grad_MP[r_i][j]
                        grad_mt_xy = grad_MT[r_i][j]
                    else:
                        p_1 = literals[r_i][j, 0]
                        p_2 = literals[r_i][j, 1]
                        Q = literals[r_i][j, 2:]
                        nP_t = p_1 + p_2 - p_1 * p_2
                        Q_t = 1 - np.prod([1 - y for y in Q])
                        truth_val = nP_t + Q_t - nP_t * Q_t
                        grad_mp_xy = (1 - nP_t) / truth_val
                        grad_mt_xy = (1 - Q_t) / truth_val

                    tot_mp_grad_magn += grad_mp_xy
                    tot_mt_grad_magn += grad_mt_xy
                    if is_correct_mp_reasoning:
                        correct_mp_reason_magn += grad_mp_xy
                    if conseq_true:
                        correct_mp_upd_magn += grad_mp_xy

                    if is_correct_mt_reasoning:
                        correct_mt_reason_magn += grad_mt_xy
                    if not ant_true:
                        correct_mt_upd_magn += grad_mt_xy

                    if grad_mp_xy > 0.1:
                        if is_correct_mp_reasoning:
                            n_correct_mp_reasoning += 1
                        else:
                            n_wrong_mp_reasoning += 1
                        if conseq_true:
                            n_correct_mp_updates += 1
                    if grad_mt_xy > 0.1:
                        if is_correct_mt_reasoning:
                            n_correct_mt_reasoning += 1
                        else:
                            n_wrong_mt_reasoning += 1
                        if not ant_true:
                            n_correct_mt_updates += 1
                    if config.PRINT_GRAD_DEBUG:
                        print("Correct MP reason", is_correct_mp_reasoning,
                              "Correct conseq update", conseq_true,
                              "Correct MT reason", is_correct_mt_reasoning,
                              "Correct ant update", ant_true)
                        print(literals[r_i][j, :])
                        print(grad_mp_xy, grad_mt_xy)

            # Compute AUC of partof and precision of types
            cm = compute_confusion_matrix_pof(config.THRESHOLDS,
                                              values_of_partOf,
                                              pairs_of_test_data,
                                              partOF_of_pairs_of_test_data)
            measures = {}
            compute_measures(cm, 'test', measures)
            precision, recall = stat(measures, 'test')
            precision = adjust_prec(precision)
            auc_pof = auc(precision, recall)

            max_type_labels = np.argmax(values_of_types, 1)
            max_type_labels = selected_types[max_type_labels]
            correct = np.where(max_type_labels == types_of_test_data)[0]
            prec_types = len(correct) / len(max_type_labels)

            # Creating IR measures for the consequent updating
            n_conseq_gradient_updates = n_correct_mp_reasoning + n_wrong_mp_reasoning
            prec_conseq_gradient_update = recall_conseq_gradient_update = \
                f1_conseq_gradient_update = prec_conseq_update = 0
            if n_conseq_gradient_updates > 0:
                prec_conseq_gradient_update = n_correct_mp_reasoning / n_conseq_gradient_updates
                recall_conseq_gradient_update = n_correct_mp_reasoning / n_pos_correct_mp_reasoning
                prec_conseq_update = n_correct_mp_updates / n_conseq_gradient_updates
                f1_conseq_gradient_update = 0 if recall_conseq_gradient_update == 0 or prec_conseq_gradient_update == 0 \
                    else 2/(1/recall_conseq_gradient_update + 1/prec_conseq_gradient_update)

            avg_conseq_grad_magn = tot_mp_grad_magn / tot_evals
            ratio_magn_correct_mp_reason = correct_mp_reason_magn / tot_mp_grad_magn
            # Average gradient of possible correct inferences (should be high)
            recall_magn_correct_mp_reason = correct_mp_reason_magn / n_pos_correct_mp_reasoning
            ratio_magn_correct_mp_upd = correct_mp_upd_magn / tot_mp_grad_magn

            # Creating IR measures for the antecedent updating
            n_ant_gradient_updates = n_correct_mt_reasoning + n_wrong_mt_reasoning
            prec_ant_gradient_update = recall_ant_gradient_update = f1_ant_gradient_update = prec_ant_update = 0
            if n_ant_gradient_updates > 0:
                prec_ant_gradient_update = n_correct_mt_reasoning / n_ant_gradient_updates
                recall_ant_gradient_update = n_correct_mt_reasoning / n_pos_correct_mt_reasoning
                prec_ant_update = n_correct_mt_updates / n_ant_gradient_updates
                f1_ant_gradient_update = 0 if recall_ant_gradient_update == 0 or prec_ant_gradient_update == 0 \
                    else 2 / (1 / recall_ant_gradient_update + 1 / prec_ant_gradient_update)

            avg_ant_grad_magn = tot_mt_grad_magn / tot_evals
            ratio_magn_correct_mt_reason = correct_mt_reason_magn / tot_mt_grad_magn
            # Average gradient of possible correct inferences (should be high)
            recall_magn_correct_mt_reason = correct_mt_reason_magn / n_pos_correct_mt_reasoning
            ratio_magn_correct_mt_upd = correct_mt_upd_magn / tot_mt_grad_magn

            tot_grad_magn = tot_mp_grad_magn + tot_mt_grad_magn
            fract_conseq_gradient = 0 if tot_grad_magn == 0 else \
                tot_mp_grad_magn / tot_grad_magn
            fract_ant_gradient = 0 if tot_grad_magn == 0 else \
                tot_mt_grad_magn / tot_grad_magn

            summary_t = tf.Summary(value=[
                tf.Summary.Value(tag="test/auc_pof", simple_value=auc_pof),
                tf.Summary.Value(tag="test/prec_types",
                                 simple_value=prec_types),
            ])

            summary_mp = tf.Summary(value=[
                tf.Summary.Value(tag="gradients/prec_reasoning",
                                 simple_value=prec_conseq_gradient_update),
                tf.Summary.Value(tag="gradients/recall_reasoning",
                                 simple_value=recall_conseq_gradient_update),
                tf.Summary.Value(tag="gradients/f1_reasoning",
                                 simple_value=f1_conseq_gradient_update),
                tf.Summary.Value(tag="gradients/num_update",
                                 simple_value=n_conseq_gradient_updates),
                tf.Summary.Value(tag="gradients/prec_update",
                                 simple_value=prec_conseq_update),
                tf.Summary.Value(tag="gradients/avg_magnitude_update",
                                 simple_value=avg_conseq_grad_magn),
                tf.Summary.Value(tag="gradients/ratio_of_gradient",
                                 simple_value=fract_conseq_gradient),
                tf.Summary.Value(tag="gradients/ratio_magn_correct_reason",
                                 simple_value=ratio_magn_correct_mp_reason),
                tf.Summary.Value(tag="gradients/avg_magn_correct_reason",
                                 simple_value=recall_magn_correct_mp_reason),
                tf.Summary.Value(tag="gradients/ratio_magn_correct_update",
                                 simple_value=ratio_magn_correct_mp_upd)
            ])

            summary_mt = tf.Summary(value=[
                tf.Summary.Value(tag="gradients/prec_reasoning",
                                 simple_value=prec_ant_gradient_update),
                tf.Summary.Value(tag="gradients/recall_reasoning",
                                 simple_value=recall_ant_gradient_update),
                tf.Summary.Value(tag="gradients/f1_reasoning",
                                 simple_value=f1_ant_gradient_update),
                tf.Summary.Value(tag="gradients/num_update",
                                 simple_value=n_ant_gradient_updates),
                tf.Summary.Value(tag="gradients/prec_update",
                                 simple_value=prec_ant_update),
                tf.Summary.Value(tag="gradients/avg_magnitude_update",
                                 simple_value=avg_ant_grad_magn),
                tf.Summary.Value(tag="gradients/ratio_of_gradient",
                                 simple_value=fract_ant_gradient),
                tf.Summary.Value(tag="gradients/ratio_magn_correct_reason",
                                 simple_value=ratio_magn_correct_mt_reason),
                tf.Summary.Value(tag="gradients/avg_magn_correct_reason",
                                 simple_value=recall_magn_correct_mt_reason),
                tf.Summary.Value(tag="gradients/ratio_magn_correct_update",
                                 simple_value=ratio_magn_correct_mt_upd)
            ])

            test_writer.add_summary(summary_r, i)
            test_writer.add_summary(summary_t, i)

            modus_ponens_writer.add_summary(summary_mp, i)
            modus_tollens_writer.add_summary(summary_mt, i)

    train_writer.flush()
    test_writer.flush()
    modus_ponens_writer.flush()
    modus_tollens_writer.flush()

    return feed_dict, auc_pof, prec_types
Example #7
0
history = []

for valid_inputs, labels in val_DataLoader:
    valid_inputs = valid_inputs.transpose(1,
                                          3).transpose(2,
                                                       3).to(device).float()
    labels = labels.to(device).float()
    outputs = model(valid_inputs)
    loss = criterion(outputs, labels)
    valid_loss += loss.item() * valid_inputs.size(0)

    res = torch.gt(outputs, 0.5)

    valid_auc += np.sum(
        np.array([
            evaluate.auc(labels[:, i].cpu(), res[:, i].cpu())
            for i in range(res.size(0))
        ]))
    valid_macro_f1 += np.sum(
        np.array([
            evaluate.macro_f1(labels[0].cpu(), res[0].cpu())
            for i in range(res.size(0))
        ]))
    valid_micro_f1 += np.sum(
        np.array([
            evaluate.micro_f1(labels[0].cpu(), res[0].cpu())
            for i in range(res.size(0))
        ]))
    count_valid += labels.size(0)
print("valid ---- epoch == > {}, auc ==> {},  macro_f1==> {}, micro_f1 ==> {}".
      format(0, valid_auc / count_valid, valid_macro_f1 / count_valid,
Example #8
0
    train_true = []
    train_pred = []
    for key in di.keys():
        print(di_true[key])
        print(di[key])
        print()
        train_pred.append(di[key].tolist())
        train_true.append(di_true[key].tolist())

    train_pred = np.array(train_pred)
    train_true = np.array(train_true)
    print(train_pred.shape)
    print(train_true.shape)

    valid_auc += evaluate.auc(train_true, train_pred)
    valid_macro_f1 += evaluate.macro_f1(train_true, train_pred)
    valid_micro_f1 += evaluate.micro_f1(train_true, train_pred)

    # try:
    #     valid_auc +=  evaluate.auc(labels.cpu(),res.cpu())
    #     valid_macro_f1 += evaluate.macro_f1(labels[0].cpu(),res[0].cpu())
    #     valid_micro_f1 += evaluate.micro_f1(labels[0].cpu(),res[0].cpu())
    #     count_valid += 1
    # except:
    #     print("error ",count_valid)
    #     # if (count_valid> 32):
    #     #     break
    print(
        "valid ---- epoch == > {},loss == {}, auc ==> {},  macro_f1==> {}, micro_f1 ==> {}"
        .format(epoch, valid_loss, valid_auc, valid_macro_f1, valid_micro_f1))
Example #9
0
def train_and_valid(model, loss_function,optimizer,epochs,train_data,valid_data,rep):
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("device :",device)
    history = []
    for epoch in range(epochs):
        epoch_start = time.time()
        print("Epoch: {}/{}".format(epoch+1,epochs))
        model.train()
        train_loss = 0.0
        train_auc = 0.0
        train_macro_f1 = 0.0
        train_micro_f1 = 0.0

        valid_loss = 0.0
        valid_auc = 0.0
        valid_macro_f1 = 0.0
        valid_micro_f1 = 0.0

        count_item = 0
        count_valid = 0
        print("-----train mode-----")

        for inputs,labels in train_data:
          #  print("count item",count_item)
            inputs = inputs.transpose(1,3).transpose(2,3).to(device).float()
            labels = labels.to(device).float()
            # print(inputs.shape)
           
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            
            loss.backward()
            optimizer.step()
        
            train_loss += loss.item() * inputs.size(0)

            res = torch.gt(outputs,0.5)
            
            train_auc += np.sum(np.array([evaluate.auc(labels[i].cpu(),res[i].cpu()) for i in range(res.size(0))]))
            train_macro_f1 += np.sum(np.array([evaluate.macro_f1(labels[i].cpu(),res[i].cpu())  for i in range(res.size(0))]))
            train_micro_f1 += np.sum(np.array([evaluate.micro_f1(labels[i].cpu(),res[i].cpu())  for i in range(res.size(0))]))
            count_item += labels.size(0)
            # if (count_item > 128):
            #     break
        print("train ---- epoch == > {},loss==>{}, auc ==> {},  macro_f1==> {}, micro_f1 ==> {}".format(epoch,train_loss/count_item,train_auc/count_item,train_macro_f1/count_item,train_micro_f1/count_item))
        epoch_end  =time.time()
        print("time : ",epoch_end - epoch_start)
        print("-----valid mode-----")
        with torch.no_grad():
            model.eval()
        for valid_inputs,labels in valid_data:
            valid_inputs = valid_inputs.transpose(1,3).transpose(2,3).to(device).float()
            labels = labels.to(device).float()
            outputs = model(valid_inputs)
            loss = loss_function(outputs, labels)
            valid_loss += loss.item() * valid_inputs.size(0)

            res = torch.gt(outputs,0.5)
        
            valid_auc += np.sum(np.array([evaluate.auc(labels[i].cpu(),res[i].cpu()) for i in range(res.size(0))]))
            valid_macro_f1 += np.sum(np.array([evaluate.macro_f1(labels[0].cpu(),res[0].cpu())  for i in range(res.size(0))]))
            valid_micro_f1 += np.sum(np.array([evaluate.micro_f1(labels[0].cpu(),res[0].cpu())  for i in range(res.size(0))]))
            count_valid += labels.size(0)
            # if (count_valid> 128):
            #     break
        print("valid ---- epoch == > {},loss == {}, auc ==> {},  macro_f1==> {}, micro_f1 ==> {}".format(epoch, valid_loss/count_valid, valid_auc/count_valid, valid_macro_f1/count_valid, valid_micro_f1/count_valid))
        epochs_end  =time.time()
        print("time : ",epochs_end - epoch_start)
        history.append([train_auc/count_item,train_macro_f1/count_item,train_micro_f1/count_item, valid_auc/count_valid, valid_macro_f1/count_valid, valid_micro_f1/count_valid])
        epoch_begin = 8 + rep
        torch.save(model, 'not_freeze_models/'+'resnet'+str(epoch+ 3 +epoch_begin)+'.pt')
    return model