示例#1
0
import matplotlib.pyplot as plt

from conv_qsar_fast.main.core import build_model, train_model, save_model
from conv_qsar_fast.main.test import test_model, test_embeddings_demo, rocauc_plot, parity_plot
from conv_qsar_fast.main.data import get_data_full
from conv_qsar_fast.utils.parse_cfg import read_config
import conv_qsar_fast.utils.reset_layers as reset_layers

if __name__ == '__main__':
    if len(sys.argv) < 2:
        print('Usage: {} "settings.cfg"'.format(sys.argv[0]))
        quit(1)

    # Load settings
    try:
        config = read_config(sys.argv[1])
    except:
        print('Could not read config file {}'.format(sys.argv[1]))
        quit(1)

    # Get model label
    try:
        fpath = config['IO']['model_fpath']
    except KeyError:
        print('Must specify model_fpath in IO in config')
        quit(1)

    ###################################################################################
    ### BUILD MODEL
    ###################################################################################
示例#2
0
def f():
    # 	if len(sys.argv) < 2:
    # 		print('Usage: {} "settings.cfg"'.format(sys.argv[0]))
    # 		quit(1)

    # Load settings
    try:
        # 		config = read_config(sys.argv[1])
        config = read_config('../inputs/Ab-oct.cfg')
    except:
        print('Could not read config file {}'.format(sys.argv[1]))
        quit(1)

    # Get model label
    try:
        fpath = config['IO']['model_fpath']
    except KeyError:
        print('Must specify model_fpath in IO in config')
        quit(1)

    ###################################################################################
    ### DEFINE DATA
    ###################################################################################

    data_kwargs = config['DATA']
    if '__name__' in data_kwargs:
        del data_kwargs['__name__']  #  from configparser
    if 'batch_size' in config['TRAINING']:
        data_kwargs['batch_size'] = int(config['TRAINING']['batch_size'])
    if 'use_fp' in config['ARCHITECTURE']:
        data_kwargs['use_fp'] = config['ARCHITECTURE']['use_fp']
    if 'shuffle_seed' in data_kwargs:
        data_kwargs['shuffle_seed'] = int(data_kwargs['shuffle_seed'])
    else:
        data_kwargs['shuffle_seed'] = int(time.time())
    if 'truncate_to' in data_kwargs:
        data_kwargs['truncate_to'] = int(data_kwargs['truncate_to'])
    if 'training_ratio' in data_kwargs:
        data_kwargs['training_ratio'] = float(data_kwargs['training_ratio'])
    if 'molecular_attributes' in data_kwargs:
        data_kwargs['molecular_attributes'] = input_to_bool(
            data_kwargs['molecular_attributes'])

    if 'cv_folds' in data_kwargs:
        try:
            os.makedirs(os.path.dirname(fpath))
        except:  # folder exists
            pass
        if '<this_fold>' in data_kwargs['cv_folds']:
            cv_folds = data_kwargs['cv_folds']
            total_folds = int(cv_folds.split('/')[1])
            all_cv_folds = [
                '{}/{}'.format(i + 1, total_folds) for i in range(total_folds)
            ]
        else:
            all_cv_folds = [data_kwargs['cv_folds']]

    # Iterate through all folds
    ref_fpath = fpath
    for cv_fold in all_cv_folds:

        ###################################################################################
        ### BUILD MODEL
        ###################################################################################

        print('...building model')
        try:
            kwargs = config['ARCHITECTURE']
            if '__name__' in kwargs:
                del kwargs['__name__']  #  from configparser
            if 'batch_size' in config['TRAINING']:
                kwargs['padding'] = int(config['TRAINING']['batch_size']) > 1
            if 'embedding_size' in kwargs:
                kwargs['embedding_size'] = int(kwargs['embedding_size'])
            if 'hidden' in kwargs:
                kwargs['hidden'] = int(kwargs['hidden'])
            if 'hidden2' in kwargs:
                kwargs['hidden2'] = int(kwargs['hidden2'])
            if 'depth' in kwargs:
                kwargs['depth'] = int(kwargs['depth'])
            if 'scale_output' in kwargs:
                kwargs['scale_output'] = float(kwargs['scale_output'])
            if 'dr1' in kwargs:
                kwargs['dr1'] = float(kwargs['dr1'])
            if 'dr2' in kwargs:
                kwargs['dr2'] = float(kwargs['dr2'])
            if 'output_size' in kwargs:
                kwargs['output_size'] = int(kwargs['output_size'])
            if 'sum_after' in kwargs:
                kwargs['sum_after'] = input_to_bool(kwargs['sum_after'])
            if 'optimizer' in kwargs:
                kwargs['optimizer'] = kwargs['optimizer']

            if 'molecular_attributes' in config['DATA']:
                kwargs['molecular_attributes'] = config['DATA'][
                    'molecular_attributes']

            model = build_model(**kwargs)
            print('...built untrained model')
        except KeyboardInterrupt:
            print('User cancelled model building')
            quit(1)

        print('Using CV fold {}'.format(cv_fold))
        data_kwargs['cv_folds'] = cv_fold
        fpath = ref_fpath.replace('<this_fold>', cv_fold.split('/')[0])
        data = get_data_full(**data_kwargs)

        ###################################################################################
        ### LOAD WEIGHTS?
        ###################################################################################

        if 'weights_fpath' in config['IO']:
            weights_fpath = config['IO']['weights_fpath']
        else:
            weights_fpath = fpath + '.h5'

        try:
            use_old_weights = input_to_bool(
                config['IO']['use_existing_weights'])
        except KeyError:
            print('Must specify whether or not to use existing model weights')
            quit(1)

        if use_old_weights and os.path.isfile(weights_fpath):
            model.load_weights(weights_fpath)
            print('...loaded weight information')

            # Reset final dense?
            if 'reset_final' in config['IO']:
                if config['IO']['reset_final'] in [
                        'true', 'y', 'Yes', 'True', '1'
                ]:
                    layer = model.layers[-1]
                    layer.W.set_value((layer.init(
                        layer.W.shape.eval()).eval()).astype(np.float32))
                    layer.b.set_value(
                        np.zeros(layer.b.shape.eval(), dtype=np.float32))

        elif use_old_weights and not os.path.isfile(weights_fpath):
            print(
                'Weights not found at specified path {}'.format(weights_fpath))
            quit(1)
        else:
            pass

        ###################################################################################
        ### CHECK FOR TESTING CONDITIONS
        ###################################################################################

        # Testing embeddings?
        try:
            if input_to_bool(config['TESTING']['test_embedding']):
                test_embeddings_demo(model, fpath)
                quit(1)
        except KeyError:
            pass

        ###################################################################################
        ### TRAIN THE MODEL
        ###################################################################################

        # Train model
        try:
            print('...training model')
            kwargs = config['TRAINING']
            if '__name__' in kwargs:
                del kwargs['__name__']  #  from configparser
            if 'nb_epoch' in kwargs:
                kwargs['nb_epoch'] = int(kwargs['nb_epoch'])
            if 'batch_size' in kwargs:
                kwargs['batch_size'] = int(kwargs['batch_size'])
            if 'patience' in kwargs:
                kwargs['patience'] = int(kwargs['patience'])
            (model, loss, val_loss) = train_model(model, data, **kwargs)
            print('...trained model')
        except KeyboardInterrupt:
            pass

        ###################################################################################
        ### SAVE MODEL
        ###################################################################################

        # Get the current time
        tstamp = datetime.datetime.utcnow().strftime('%m-%d-%Y_%H-%M')
        print('...saving model')
        save_model(model,
                   loss,
                   val_loss,
                   fpath=fpath,
                   config=config,
                   tstamp=tstamp)
        print('...saved model')

        ###################################################################################
        ### TEST MODEL
        ###################################################################################

        print('...testing model')
        data_withresiduals = test_model(model,
                                        data,
                                        fpath,
                                        tstamp=tstamp,
                                        batch_size=int(
                                            config['TRAINING']['batch_size']))
        print('...tested model')