Esempio n. 1
0
def main(_):

    if FLAGS.snake is True:

        pickle_train = '/home/meunier/Snake/snake_tlXlY_edge_trn.pkl'
        pickle_test = '/home/meunier/Snake/snake_tlXlY_edge_tst.pkl'

        #pickle_train = '/home/meunier/Snake/snake_tlXlY_trn.pkl'
        #pickle_test =  '/home/meunier/Snake/snake_tlXlY_tst.pkl'

        #pickle_train = '/home/meunier/Snake/snake_tlXlY_fixed_trn.pkl'
        #pickle_test =  '/home/meunier/Snake/snake_tlXlY_fixed_tst.pkl'

        #pickle_train='/home/meunier/Snake/snake_tlXlY_2_fixed_trn.pkl'
        #pickle_test='/home/meunier/Snake/snake_tlXlY_2_fixed_tst.pkl'

        train_graph = GCNDataset.load_snake_pickle(pickle_train)
        test_graph = GCNDataset.load_snake_pickle(pickle_test)

        config = get_config(FLAGS.configid)
        acc_test = run_model(train_graph, config, test_graph)
        print('Accuracy Test', acc_test)

    elif FLAGS.das_train is True:
        #Load all the files of table
        # Train the model
        graph_train = []

        debug = True
        if debug:

            pickle_train = '/nfs/project/read/testJL/TABLE/das_abp_models/abp_full_tlXlY_trn.pkl'
            pickle_train_ra = '/nfs/project/read/testJL/TABLE/abp_DAS_CRF_Xr.pkl'
            print(pickle_train_ra, pickle_train)
            #train_graph = GCNDataset.load_transkribus_pickle(pickle_train)
            graph_train = GCNDataset.load_transkribus_reverse_arcs_pickle(
                pickle_train, pickle_train_ra, format_reverse='lx')
        else:
            i = 1
            pickle_train = '/nfs/project/read/testJL/TABLE/abp_quantile_models/abp_CV_fold_' + str(
                i) + '_tlXlY_trn.pkl'
            pickle_test = '/nfs/project/read/testJL/TABLE/abp_quantile_models/abp_CV_fold_' + str(
                i) + '_tlXlY_tst.pkl'

            # reversed edged
            pickle_train_ra = '/nfs/project/read/testJL/TABLE/das_abp_models/abp_CV_fold_' + str(
                i) + '_tlXrlY_trn.pkl'
            pickle_test_ra = '/nfs/project/read/testJL/TABLE/das_abp_models/abp_CV_fold_' + str(
                i) + '_tlXrlY_tst.pkl'

            train_graph = GCNDataset.load_transkribus_reverse_arcs_pickle(
                pickle_train, pickle_train_ra)
            test_graph = GCNDataset.load_transkribus_reverse_arcs_pickle(
                pickle_test, pickle_test_ra)

            graph_train.extend(train_graph)
            graph_train.extend(test_graph)

        print('Graph Train Nb', len(graph_train))
        #Load the other dataset for predictions
        configid = FLAGS.configid
        config = get_config(configid)
        #config['nb_iter'] = 100

        dirp = os.path.join('models_all', 'C' + str(configid))
        mkdir_p(dirp)
        save_model_dir = os.path.join(
            dirp, 'alldas_exp1_C' + str(configid) + '.ckpt')
        #I should  save the pickle
        outpicklefname = os.path.join(
            dirp,
            'alldas_exp1_C' + str(configid) + '.validation_scores.pickle')
        run_model_train_val_test(graph_train,
                                 config,
                                 outpicklefname,
                                 ratio_train_val=0.1,
                                 save_model_path=save_model_dir)
        #for test add gcn_graph_test=train_graph

    elif FLAGS.das_predict is True:

        do_test = False  #some internal flags to do some testing

        node_dim = 29
        edge_dim = 140
        nb_class = 5

        configid = FLAGS.configid
        config = get_config(configid)

        #Get the best file
        #TODO Get the best file
        #node_dim = gcn_graph[0].X.shape[1]
        #edge_dim = gcn_graph[0].E.shape[1] - 2.0
        #nb_class = gcn_graph[0].Y.shape[1]

        #f = open('archive_models/das_exp1_C31.validation_scores.pickle', 'rb')

        val_pickle = os.path.join(
            'models_all', 'C' + str(configid),
            "alldas_exp1_C" + str(configid) + '.validation_scores.pickle')
        print('Reading Training Info from:', val_pickle)
        f = open(val_pickle, 'rb')
        R = pickle.load(f)
        val = R['val_acc']
        print('Validation scores', val)

        epoch_index = np.argmax(val)
        print('Best performance on val set: Epoch', epoch_index)

        gcn_model = gcn_models.EdgeConvNet(
            node_dim,
            edge_dim,
            nb_class,
            num_layers=config['num_layers'],
            learning_rate=config['lr'],
            mu=config['mu'],
            node_indim=config['node_indim'],
            nconv_edge=config['nconv_edge'],
        )

        gcn_model.stack_instead_add = config['stack_instead_add']

        if 'fast_convolve' in config:
            gcn_model.fast_convolve = config['fast_convolve']

        gcn_model.create_model()

        if do_test:
            graph_train = []
            for i in range(1, 5):
                pickle_train = '/nfs/project/read/testJL/TABLE/abp_quantile_models/abp_CV_fold_' + str(
                    i) + '_tlXlY_trn.pkl'
                print('loading ', pickle_train)
                train_graph = GCNDataset.load_transkribus_pickle(pickle_train)
                graph_train.extend(train_graph)

        #TODO load the data for test
        #/nfs/project/read/testJL/TABLE/abp_DAS_col9142_CRF_X.pkl
        if FLAGS.das_predict_workflow:
            pickle_predict = '/nfs/project/read/testJL/TABLE/abp_DAS_col9142_workflow_X.pkl'
            pickle_predict_ra = '/nfs/project/read/testJL/TABLE/abp_DAS_col9142_workflow_Xr.pkl'
        else:
            pickle_predict = '/nfs/project/read/testJL/TABLE/abp_DAS_col9142_CRF_X.pkl'
            pickle_predict_ra = '/nfs/project/read/testJL/TABLE/abp_DAS_col9142_CRF_Xr.pkl'

        print('loading ', pickle_predict, pickle_predict_ra)
        predict_graph = GCNDataset.load_test_pickle(
            pickle_predict, nb_class, pickle_reverse_arc=pickle_predict_ra)

        with tf.Session() as session:
            # Restore variables from disk.
            session.run(gcn_model.init)

            if do_test:
                gcn_model.restore_model(session, "models/das_exp1_C31.ckpt-99")
                print('Loaded models')

                graphAcc, node_acc = gcn_model.test_lG(session, graph_train)
                print(graphAcc, node_acc)

            model_path = os.path.join(
                'models_all', 'C' + str(configid), "alldas_exp1_C" +
                str(configid) + ".ckpt-" + str(10 * epoch_index))
            print('Model_path', model_path)
            gcn_model.restore_model(session, model_path)
            print('Loaded models')

            start_time = time.time()
            lY_pred = gcn_model.predict_lG(session,
                                           predict_graph,
                                           verbose=False)
            end_time = time.time()
            print("--- %s seconds ---" % (end_time - start_time))
            print('Number of graphs:', len(lY_pred))

            #Convert to list as Python pickle does not  seem like the array while the list can be pickled
            lY_list = []
            for x in lY_pred:
                lY_list.append(list(x))

            #print(lY_list)
            if FLAGS.das_predict_workflow:
                outpicklefname = 'allmodel_das_predict_C' + str(
                    configid) + '_workflow.pickle'
            else:
                outpicklefname = 'allmodel_das_predict_C' + str(
                    configid) + '.pickle'
            g = open(outpicklefname, 'wb')
            #print(lY_pred)
            pickle.dump(lY_pred, g, protocol=2, fix_imports=True)
            g.close()

    elif FLAGS.qsub_taskid > -1:

        GRID = _make_grid_qsub(0)

        try:
            fold_id, configid = GRID[FLAGS.qsub_taskid]
        except:
            print('Invalid Grid Parameters', FLAGS.qsub_taskid, GRID)
            return -1
        print('Experiement with FOLD', fold_id, ' CONFIG', configid)
        pickle_train = '/nfs/project/read/testJL/TABLE/abp_quantile_models/abp_CV_fold_' + str(
            fold_id) + '_tlXlY_trn.pkl'
        pickle_test = '/nfs/project/read/testJL/TABLE/abp_quantile_models/abp_CV_fold_' + str(
            fold_id) + '_tlXlY_tst.pkl'

        train_graph = GCNDataset.load_transkribus_pickle(pickle_train)
        test_graph = GCNDataset.load_transkribus_pickle(pickle_test)

        config = get_config(configid)

        if os.path.exists(FLAGS.out_dir) is False:
            print('Creating Dir', FLAGS.out_dir)
            os.mkdir(FLAGS.out_dir)

        outpicklefname = os.path.join(
            FLAGS.out_dir,
            'table_F' + str(fold_id) + '_C' + str(configid) + '.pickle')
        run_model_train_val_test(train_graph,
                                 config,
                                 outpicklefname,
                                 ratio_train_val=0.1,
                                 gcn_graph_test=test_graph)

    else:

        if FLAGS.fold == -1:
            #Do it on all the fold for the specified configs
            FOLD_IDS = [1, 2, 3, 4]
            sel_configs_ = FLAGS.grid_configs.split('_')
            sel_configs = [int(x) for x in sel_configs_]
            print('GRID on FOLDS', FOLD_IDS)
            print('Model Configs', sel_configs)

            for cid in sel_configs:
                for fid in FOLD_IDS:
                    print('Running Fold', fid, 'on Config', cid)
                    main_fold(fid, cid, FLAGS.out_dir)

        else:

            pickle_train = '/nfs/project/read/testJL/TABLE/abp_quantile_models/abp_CV_fold_' + str(
                FLAGS.fold) + '_tlXlY_trn.pkl'
            pickle_test = '/nfs/project/read/testJL/TABLE/abp_quantile_models/abp_CV_fold_' + str(
                FLAGS.fold) + '_tlXlY_tst.pkl'

            #reversed edged
            pickle_train_ra = '/nfs/project/read/testJL/TABLE/das_abp_models/abp_CV_fold_' + str(
                FLAGS.fold) + '_tlXrlY_trn.pkl'
            pickle_test_ra = '/nfs/project/read/testJL/TABLE/das_abp_models/abp_CV_fold_' + str(
                FLAGS.fold) + '_tlXrlY_tst.pkl'

            #train_graph = GCNDataset.load_transkribus_pickle(pickle_train)
            train_graph = GCNDataset.load_transkribus_reverse_arcs_pickle(
                pickle_train, pickle_train_ra)
            print('Loaded Trained Graphs:', len(train_graph))
            test_graph = GCNDataset.load_transkribus_reverse_arcs_pickle(
                pickle_test, pickle_test_ra)
            #test_graph = GCNDataset.load_transkribus_pickle(pickle_test, pickle_test_ra)
            print('Loaded Test Graphs:', len(test_graph))

            config = get_config(FLAGS.configid)

            #acc_test = run_model(train_graph, config, test_graph,eval_iter=1)
            #print('Accuracy Test', acc_test)

            outpicklefname = os.path.join(
                FLAGS.out_dir, 'table_F' + str(FLAGS.fold) + '_C' +
                str(FLAGS.configid) + '.pickle')
            run_model_train_val_test(train_graph,
                                     config,
                                     outpicklefname,
                                     gcn_graph_test=test_graph)
Esempio n. 2
0
def main(_):
    config = get_config(FLAGS.configid)
    print(config)

    mkdir_p(FLAGS.out_dir)

    # Pickle for Logit are sufficient
    pickle_train = os.path.join(
        FLAGS.dpath, 'abp_CV_fold_' + str(FLAGS.fold) + '_tlXlY_trn.pkl')
    pickle_test = os.path.join(
        FLAGS.dpath, 'abp_CV_fold_' + str(FLAGS.fold) + '_tlXlY_tst.pkl')

    # Baseline Models do not need reverse arc features
    if 'model' in config:
        train_graph = GCNDataset.load_transkribus_pickle(pickle_train)
        test_graph = GCNDataset.load_transkribus_pickle(pickle_test)
        print('Loaded Test Graphs:', len(test_graph))

        if FLAGS.outname == 'default':
            outpicklefname = os.path.join(
                FLAGS.out_dir, 'table_F' + str(FLAGS.fold) + '_C' +
                str(FLAGS.configid) + '.pickle')
        else:
            outpicklefname = os.path.join(FLAGS.out_dir, FLAGS.outname)

    else:

        if FLAGS.das_predict_workflow is True:
            print('Doing Experiment on Predict Workflow ....')
            pickle_train = '/nfs/project/read/testJL/TABLE/das_abp_models/abp_full_tlXlY_trn.pkl'
            pickle_train_ra = '/nfs/project/read/testJL/TABLE/abp_DAS_CRF_Xr.pkl'
            print(pickle_train_ra, pickle_train)
            # train_graph = GCNDataset.load_transkribus_pickle(pickle_train)
            train_graph = GCNDataset.load_transkribus_reverse_arcs_pickle(
                pickle_train, pickle_train_ra, format_reverse='lx')

            fX_col9142 = '../../usecases/ABP/resources/DAS_2018/abp_DAS_col9142_CRF_X.pkl'
            fXr_col9142 = '../../usecases/ABP/resources/DAS_2018/abp_DAS_col9142_CRF_Xr.pkl'
            fY_col9142 = '../../usecases/ABP/resources/DAS_2018/DAS_col9142_l_Y_GT.pkl'

            test_graph = GCNDataset.load_transkribus_list_X_Xr_Y(
                fX_col9142, fXr_col9142, fY_col9142)

            if FLAGS.outname == 'default':
                outpicklefname = os.path.join(
                    FLAGS.out_dir,
                    'col9142_C' + str(FLAGS.configid) + '.pickle')
            else:
                outpicklefname = os.path.join(FLAGS.out_dir, FLAGS.outname)

        else:
            pickle_train_ra = os.path.join(
                FLAGS.dpath,
                'abp_CV_fold_' + str(FLAGS.fold) + '_tlXrlY_trn.pkl')
            pickle_test_ra = os.path.join(
                FLAGS.dpath,
                'abp_CV_fold_' + str(FLAGS.fold) + '_tlXrlY_tst.pkl')
            train_graph = GCNDataset.load_transkribus_reverse_arcs_pickle(
                pickle_train, pickle_train_ra, attach_edge_label=True)
            test_graph = GCNDataset.load_transkribus_reverse_arcs_pickle(
                pickle_test, pickle_test_ra)

            if FLAGS.outname == 'default':
                outpicklefname = os.path.join(
                    FLAGS.out_dir, 'table_F' + str(FLAGS.fold) + '_C' +
                    str(FLAGS.configid) + '.pickle')
            else:
                outpicklefname = os.path.join(FLAGS.out_dir, FLAGS.outname)

        print('Loaded Trained Graphs:', len(train_graph))
        print('Loaded Test Graphs:', len(test_graph))

    run_model_train_val_test(train_graph,
                             config,
                             outpicklefname,
                             gcn_graph_test=test_graph)