Beispiel #1
0
        if good_partition is True:
            poi_list = sorted(poi_set)
            break

    # train
    ssvm = SSVM(inference_train=inference_method,
                inference_pred=inference_method,
                dat_obj=dat_obj,
                share_params=SSVM_SHARE_PARAMS,
                multi_label=SSVM_MULTI_LABEL,
                C=ssvm_C,
                poi_info=poi_info_i.loc[poi_list].copy())
    if ssvm.train(sorted(trajid_set_train), n_jobs=N_JOBS) is True:
        for j in test_ix:  # test
            ps_cv, L_cv = keys_cv[j]
            y_hat_list = ssvm.predict(ps_cv, L_cv)
            if y_hat_list is not None:
                F1, pF1, tau = evaluate(dat_obj, keys_cv[j], y_hat_list)
                F1_ssvm.append(F1)
                pF1_ssvm.append(pF1)
                Tau_ssvm.append(tau)
    else:
        for j in test_ix:
            F1_ssvm.append(0)
            pF1_ssvm.append(0)
            Tau_ssvm.append(0)

mean_F1 = np.mean(F1_ssvm)
mean_pF1 = np.mean(pF1_ssvm)
mean_Tau = np.mean(Tau_ssvm)
Beispiel #2
0
        'start POI of query %s does not exist in training set.\n' %
        str(keys[i]))
    sys.exit(0)

best_C = bestC[(ps, L)]
print('\n--------------- Query: (%d, %d), Best_C: %f ---------------\n' %
      (ps, L, best_C))

# train model using all examples in training set and measure performance on test set
ssvm = SSVM(inference_train=inference_method,
            inference_pred=inference_method,
            dat_obj=dat_obj,
            share_params=SSVM_SHARE_PARAMS,
            multi_label=SSVM_MULTI_LABEL,
            C=best_C,
            poi_info=poi_info_i)
if ssvm.train(sorted(trajid_set_i), n_jobs=N_JOBS) is True:
    y_hat_list = ssvm.predict(ps, L)
    print(y_hat_list)
    if y_hat_list is not None:
        recdict_ssvm[(ps, L)] = {
            'PRED': y_hat_list,
            'W': ssvm.osssvm.w,
            'C': ssvm.C
        }

fssvm = os.path.join(
    data_dir, 'ssvm-' + SSVM_VARIANT + '-' + dat_obj.dat_suffix[dat_ix] +
    '-%d.pkl' % (qix))
pickle.dump(recdict_ssvm, open(fssvm, 'bw'))