Exemplo n.º 1
0
def emit_cephconf():
    networks = get_networks('ceph-public-network')
    public_network = ', '.join(networks)

    networks = get_networks('ceph-cluster-network')
    cluster_network = ', '.join(networks)

    cephcontext = {
        'auth_supported': config('auth-supported'),
        'mon_hosts': ' '.join(get_mon_hosts()),
        'fsid': leader_get('fsid'),
        'old_auth': cmp_pkgrevno('ceph', "0.51") < 0,
        'osd_journal_size': config('osd-journal-size'),
        'use_syslog': str(config('use-syslog')).lower(),
        'ceph_public_network': public_network,
        'ceph_cluster_network': cluster_network,
        'loglevel': config('loglevel'),
    }

    if config('prefer-ipv6'):
        dynamic_ipv6_address = get_ipv6_addr()[0]
        if not public_network:
            cephcontext['public_addr'] = dynamic_ipv6_address
        if not cluster_network:
            cephcontext['cluster_addr'] = dynamic_ipv6_address

    # Install ceph.conf as an alternative to support
    # co-existence with other charms that write this file
    charm_ceph_conf = "/var/lib/charm/{}/ceph.conf".format(service_name())
    mkdir(os.path.dirname(charm_ceph_conf),
          owner=ceph.ceph_user(),
          group=ceph.ceph_user())
    render('ceph.conf', charm_ceph_conf, cephcontext, perms=0o644)
    install_alternative('ceph.conf', '/etc/ceph/ceph.conf', charm_ceph_conf,
                        100)
Exemplo n.º 2
0
def emit_cephconf():
    networks = get_networks('ceph-public-network')
    public_network = ', '.join(networks)

    networks = get_networks('ceph-cluster-network')
    cluster_network = ', '.join(networks)

    cephcontext = {
        'auth_supported': config('auth-supported'),
        'mon_hosts': ' '.join(get_mon_hosts()),
        'fsid': leader_get('fsid'),
        'old_auth': cmp_pkgrevno('ceph', "0.51") < 0,
        'osd_journal_size': config('osd-journal-size'),
        'use_syslog': str(config('use-syslog')).lower(),
        'ceph_public_network': public_network,
        'ceph_cluster_network': cluster_network,
        'loglevel': config('loglevel'),
    }

    if config('prefer-ipv6'):
        dynamic_ipv6_address = get_ipv6_addr()[0]
        if not public_network:
            cephcontext['public_addr'] = dynamic_ipv6_address
        if not cluster_network:
            cephcontext['cluster_addr'] = dynamic_ipv6_address

    # Install ceph.conf as an alternative to support
    # co-existence with other charms that write this file
    charm_ceph_conf = "/var/lib/charm/{}/ceph.conf".format(service_name())
    mkdir(os.path.dirname(charm_ceph_conf), owner=ceph.ceph_user(),
          group=ceph.ceph_user())
    render('ceph.conf', charm_ceph_conf, cephcontext, perms=0o644)
    install_alternative('ceph.conf', '/etc/ceph/ceph.conf',
                        charm_ceph_conf, 100)
def get_ceph_context():
    networks = get_networks('ceph-public-network')
    public_network = ', '.join(networks)

    networks = get_networks('ceph-cluster-network')
    cluster_network = ', '.join(networks)

    cephcontext = {
        'auth_supported': config('auth-supported'),
        'mon_hosts': ' '.join(get_mon_hosts()),
        'fsid': leader_get('fsid'),
        'old_auth': cmp_pkgrevno('ceph', "0.51") < 0,
        'use_syslog': str(config('use-syslog')).lower(),
        'ceph_public_network': public_network,
        'ceph_cluster_network': cluster_network,
        'loglevel': config('loglevel'),
        'dio': str(config('use-direct-io')).lower(),
    }

    if config('prefer-ipv6'):
        dynamic_ipv6_address = get_ipv6_addr()[0]
        if not public_network:
            cephcontext['public_addr'] = dynamic_ipv6_address
        if not cluster_network:
            cephcontext['cluster_addr'] = dynamic_ipv6_address
    else:
        cephcontext['public_addr'] = get_public_addr()
        cephcontext['cluster_addr'] = get_cluster_addr()

    # NOTE(dosaboy): these sections must correspond to what is supported in the
    #                config template.
    sections = ['global', 'mds', 'mon']
    cephcontext.update(CephConfContext(permitted_sections=sections)())
    return cephcontext
Exemplo n.º 4
0
def get_ceph_context(upgrading=False):
    """Returns the current context dictionary for generating ceph.conf

    :param upgrading: bool - determines if the context is invoked as
                      part of an upgrade proedure Setting this to true
                      causes settings useful during an upgrade to be
                      defined in the ceph.conf file
    """
    mon_hosts = get_mon_hosts()
    log('Monitor hosts are ' + repr(mon_hosts))

    networks = get_networks('ceph-public-network')
    public_network = ', '.join(networks)

    networks = get_networks('ceph-cluster-network')
    cluster_network = ', '.join(networks)

    cephcontext = {
        'auth_supported': get_auth(),
        'mon_hosts': ' '.join(mon_hosts),
        'fsid': get_fsid(),
        'old_auth': cmp_pkgrevno('ceph', "0.51") < 0,
        'osd_journal_size': config('osd-journal-size'),
        'use_syslog': str(config('use-syslog')).lower(),
        'ceph_public_network': public_network,
        'ceph_cluster_network': cluster_network,
        'loglevel': config('loglevel'),
        'dio': str(config('use-direct-io')).lower(),
        'short_object_len': use_short_objects(),
        'upgrade_in_progress': upgrading,
        'bluestore': config('bluestore'),
    }

    if config('prefer-ipv6'):
        dynamic_ipv6_address = get_ipv6_addr()[0]
        if not public_network:
            cephcontext['public_addr'] = dynamic_ipv6_address
        if not cluster_network:
            cephcontext['cluster_addr'] = dynamic_ipv6_address
    else:
        cephcontext['public_addr'] = get_public_addr()
        cephcontext['cluster_addr'] = get_cluster_addr()

    if config('customize-failure-domain'):
        az = az_info()
        if az:
            cephcontext['crush_location'] = "root=default {} host={}" \
                .format(az, socket.gethostname())
        else:
            log(
                "Your Juju environment doesn't"
                "have support for Availability Zones"
            )

    # NOTE(dosaboy): these sections must correspond to what is supported in the
    #                config template.
    sections = ['global', 'osd']
    cephcontext.update(CephConfContext(permitted_sections=sections)())
    return cephcontext
Exemplo n.º 5
0
    def train(self, model, checkpoint='', is_cuda=True, is_multi_gpu=True, logdir='', savedir=''):
        if model not in self.IMPLEMENTED_MODELS:
            raise NotImplementedError('%s model is not implemented !' % model)

        mode = 'train'
        logger = utils.get_logger(mode)

        # initialize hyperparameters
        hp.set_hparam_yaml(mode)
        logger.info('Setup mode as %s, model : %s' % (mode, model))

        # get network
        network = utils.get_networks(model, checkpoint, is_cuda, is_multi_gpu)

        # Setup dataset
        train_dataloader = DataLoader(Train1Dataset(mode='train'), batch_size=hp.train.batch_size,
                                      shuffle=(mode == 'train'), num_workers=hp.num_workers, drop_last=False)
        test_dataloader = DataLoader(Train1Dataset(mode='test'), batch_size=hp.test.batch_size,
                                     shuffle=(mode == 'test'), num_workers=hp.num_workers, drop_last=False)

        # setup optimizer:
        parameters = network.parameters()
        logger.info(network)
        # TODO: Scheduled LR
        lr = getattr(hp, mode).lr
        optimizer = optim.Adam([p for p in parameters if p.requires_grad], lr=lr)


        # pass model, loss, optimizer and dataset to the trainer
        # get trainer
        trainer = utils.get_trainer()(network, optimizer, train_dataloader, test_dataloader, is_cuda, logdir, savedir)

        # train!
        trainer.run(hp.train.num_epochs)
    def validate_devices_in_networks(self):
        """
        Get deviceId's of every network using both catalogs and compare results
        Network -> Devices
        """
        networks = get_networks(dmd)
        layer3_catalog = dmd.ZenLinkManager.layer3_catalog
        model_catalog = IModelCatalogTool(dmd)
        failed_networks = {}
        for network in networks:
            # Devices under device class in global catalog
            query = Eq('networkId', network)
            layer3_brains = layer3_catalog.evalAdvancedQuery(query)
            layer3_device_ids = set([ brain.deviceId for brain in layer3_brains if brain.deviceId ])

            model_catalog_brains = model_catalog.search(query=query)
            model_catalog_device_ids = set([ brain.deviceId.split("/")[-1] for brain in model_catalog_brains.results if brain.deviceId ])

            if not len(layer3_device_ids - model_catalog_device_ids) == len(model_catalog_device_ids - layer3_device_ids) == 0:
                #import pdb; pdb.set_trace()
                failed_networks[network] = (layer3_device_ids, model_catalog_device_ids)
        if failed_networks:
            print "TEST FAILED: Catalogs return different devices for the following networks:"
            print "\t\t{0}".format(failed_networks.keys())
        else:
            print "TEST PASSED: Both catalogs returned the same devices for all networks."

        return len(failed_networks) == 0
Exemplo n.º 7
0
def conn(ssid: str = '', password: str = ''):
    """
    Connect to a wifi network.
    """
    wifi_port = get_wifi_port()
    if not ssid:
        headers, networks = get_networks()
        typer.echo(tabulate(networks, headers, tablefmt="grid"))

        input_ssid = int(typer.prompt("SSID"))
        if input_ssid not in range(1, len(networks) + 1):
            typer.echo('Invalid option.')
            sys.exit()

        ssid = networks[input_ssid - 1][1]

    if not password:
        password = typer.prompt("Password", hide_input=True)

    with click_spinner.spinner():
        typer.echo(f"Connecting to {ssid}.")
        output = run_cmd(
            CONNECT_TO_NETWORK.format(port=wifi_port,
                                      ssid=ssid,
                                      password=password))
        if not output:
            typer.echo("Connected to wifi.")
        else:
            typer.echo("Wrong ssid or password.")
Exemplo n.º 8
0
def emit_cephconf():
    mon_hosts = get_mon_hosts()
    log('Monitor hosts are ' + repr(mon_hosts))

    networks = get_networks('ceph-public-network')
    public_network = ', '.join(networks)

    networks = get_networks('ceph-cluster-network')
    cluster_network = ', '.join(networks)

    cephcontext = {
        'auth_supported': get_auth(),
        'mon_hosts': ' '.join(mon_hosts),
        'fsid': get_fsid(),
        'old_auth': cmp_pkgrevno('ceph', "0.51") < 0,
        'osd_journal_size': config('osd-journal-size'),
        'use_syslog': str(config('use-syslog')).lower(),
        'ceph_public_network': public_network,
        'ceph_cluster_network': cluster_network,
        'loglevel': config('loglevel'),
        'dio': str(config('use-direct-io')).lower(),
    }

    if config('prefer-ipv6'):
        dynamic_ipv6_address = get_ipv6_addr()[0]
        if not public_network:
            cephcontext['public_addr'] = dynamic_ipv6_address
        if not cluster_network:
            cephcontext['cluster_addr'] = dynamic_ipv6_address

    # Install ceph.conf as an alternative to support
    # co-existence with other charms that write this file
    charm_ceph_conf = "/var/lib/charm/{}/ceph.conf".format(service_name())
    mkdir(os.path.dirname(charm_ceph_conf), owner=ceph.ceph_user(),
          group=ceph.ceph_user())
    with open(charm_ceph_conf, 'w') as cephconf:
        cephconf.write(render_template('ceph.conf', cephcontext))
    install_alternative('ceph.conf', '/etc/ceph/ceph.conf',
                        charm_ceph_conf, 90)
Exemplo n.º 9
0
def create_members(num_members, tf_env=None):
    """ Can go inside Population class with more parametrization options depending on use-case. """
    if tf_env is None:
        tf_env = get_tf_env()
    members = list()
    for i in range(num_members):
        actor_net, value_net = get_networks(tf_env, FP.ACTOR_FC_LAYERS,
                                            FP.VALUE_FC_LAYERS)
        agent = get_tf_ppo_agent(tf_env,
                                 actor_net,
                                 value_net,
                                 member_id=i,
                                 num_epochs=FP.PPO_NUM_EPOCHS)
        replay_buffer = get_replay_buffer(agent.collect_data_spec)
        step_metrics, train_metrics = get_metrics()
        members.append(
            member.Member(agent,
                          replay_buffer,
                          step_metrics=step_metrics,
                          train_metrics=train_metrics))
    return members
Exemplo n.º 10
0
def network_rm(c, name=""):
    networks = utils.get_networks(env.networks, name)
    for nk, nv in networks.iteritems():
        cmd = "docker network rm {}".format(nk)
        utils.run_with_exit(c, cmd)
Exemplo n.º 11
0
def network_create(c, name=""):
    networks = utils.get_networks(env.networks, name)
    for nk, nv in networks.iteritems():
        opts = " ".join(nv["opts"])
        cmd = "docker network create {} {}".format(opts, nk)
        utils.run_with_exit(c, cmd)
Exemplo n.º 12
0
def get_ceph_context(upgrading=False):
    """Returns the current context dictionary for generating ceph.conf

    :param upgrading: bool - determines if the context is invoked as
                      part of an upgrade proedure Setting this to true
                      causes settings useful during an upgrade to be
                      defined in the ceph.conf file
    """
    mon_hosts = get_mon_hosts()
    log('Monitor hosts are ' + repr(mon_hosts))

    networks = get_networks('ceph-public-network')
    public_network = ', '.join(networks)

    networks = get_networks('ceph-cluster-network')
    cluster_network = ', '.join(networks)

    cephcontext = {
        'auth_supported': get_auth(),
        'mon_hosts': ' '.join(mon_hosts),
        'fsid': get_fsid(),
        'old_auth': cmp_pkgrevno('ceph', "0.51") < 0,
        'crush_initial_weight': config('crush-initial-weight'),
        'osd_journal_size': config('osd-journal-size'),
        'osd_max_backfills': config('osd-max-backfills'),
        'osd_recovery_max_active': config('osd-recovery-max-active'),
        'use_syslog': str(config('use-syslog')).lower(),
        'ceph_public_network': public_network,
        'ceph_cluster_network': cluster_network,
        'loglevel': config('loglevel'),
        'dio': str(config('use-direct-io')).lower(),
        'short_object_len': use_short_objects(),
        'upgrade_in_progress': upgrading,
        'bluestore': use_bluestore(),
        'bluestore_experimental': cmp_pkgrevno('ceph', '12.1.0') < 0,
        'bluestore_block_wal_size': config('bluestore-block-wal-size'),
        'bluestore_block_db_size': config('bluestore-block-db-size'),
    }

    try:
        cephcontext['bdev_discard'] = get_bdev_enable_discard()
    except ValueError as ex:
        # the user set bdev-enable-discard to a non valid value, so logging the
        # issue as a warning and falling back to False/disable
        log(str(ex), level=WARNING)
        cephcontext['bdev_discard'] = False

    if config('prefer-ipv6'):
        dynamic_ipv6_address = get_ipv6_addr()[0]
        if not public_network:
            cephcontext['public_addr'] = dynamic_ipv6_address
        if not cluster_network:
            cephcontext['cluster_addr'] = dynamic_ipv6_address
    else:
        cephcontext['public_addr'] = get_public_addr()
        cephcontext['cluster_addr'] = get_cluster_addr()

    if config('customize-failure-domain'):
        az = az_info()
        if az:
            cephcontext['crush_location'] = "root=default {} host={}" \
                .format(az, socket.gethostname())
        else:
            log("Your Juju environment doesn't"
                "have support for Availability Zones")

    # NOTE(dosaboy): these sections must correspond to what is supported in the
    #                config template.
    sections = ['global', 'osd']
    cephcontext.update(
        ch_ceph.CephOSDConfContext(permitted_sections=sections)())
    cephcontext.update(ch_context.CephBlueStoreCompressionContext()())
    return cephcontext
Exemplo n.º 13
0
def get_ceph_context(upgrading=False):
    """Returns the current context dictionary for generating ceph.conf

    :param upgrading: bool - determines if the context is invoked as
                      part of an upgrade proedure Setting this to true
                      causes settings useful during an upgrade to be
                      defined in the ceph.conf file
    """
    mon_hosts = get_mon_hosts()
    log('Monitor hosts are ' + repr(mon_hosts))

    networks = get_networks('ceph-public-network')
    public_network = ', '.join(networks)

    networks = get_networks('ceph-cluster-network')
    cluster_network = ', '.join(networks)

    cephcontext = {
        'auth_supported': get_auth(),
        'mon_hosts': ' '.join(mon_hosts),
        'fsid': get_fsid(),
        'old_auth': cmp_pkgrevno('ceph', "0.51") < 0,
        'crush_initial_weight': config('crush-initial-weight'),
        'osd_journal_size': config('osd-journal-size'),
        'osd_max_backfills': config('osd-max-backfills'),
        'osd_recovery_max_active': config('osd-recovery-max-active'),
        'use_syslog': str(config('use-syslog')).lower(),
        'ceph_public_network': public_network,
        'ceph_cluster_network': cluster_network,
        'loglevel': config('loglevel'),
        'dio': str(config('use-direct-io')).lower(),
        'short_object_len': use_short_objects(),
        'upgrade_in_progress': upgrading,
        'bluestore': use_bluestore(),
        'bluestore_experimental': cmp_pkgrevno('ceph', '12.1.0') < 0,
        'bluestore_block_wal_size': config('bluestore-block-wal-size'),
        'bluestore_block_db_size': config('bluestore-block-db-size'),
    }

    if config('bdev-enable-discard').lower() == 'enabled':
        cephcontext['bdev_discard'] = True
    elif config('bdev-enable-discard').lower() == 'auto':
        cephcontext['bdev_discard'] = should_enable_discard(get_devices())
    else:
        cephcontext['bdev_discard'] = False

    if config('prefer-ipv6'):
        dynamic_ipv6_address = get_ipv6_addr()[0]
        if not public_network:
            cephcontext['public_addr'] = dynamic_ipv6_address
        if not cluster_network:
            cephcontext['cluster_addr'] = dynamic_ipv6_address
    else:
        cephcontext['public_addr'] = get_public_addr()
        cephcontext['cluster_addr'] = get_cluster_addr()

    if config('customize-failure-domain'):
        az = az_info()
        if az:
            cephcontext['crush_location'] = "root=default {} host={}" \
                .format(az, socket.gethostname())
        else:
            log(
                "Your Juju environment doesn't"
                "have support for Availability Zones"
            )

    # NOTE(dosaboy): these sections must correspond to what is supported in the
    #                config template.
    sections = ['global', 'osd']
    cephcontext.update(CephConfContext(permitted_sections=sections)())
    return cephcontext
Exemplo n.º 14
0
#Subset of channels
channels = ["Fp1", "Fp2", "F7", "F3", "Fz", "F4", "F8", "T7", "C3", "Cz", "C4", "T8", "P7", "P3", "Pz", "P4", "P8", "O1", "O2"]

# Reading the subset of channels from files
small_eo = ut.read_file("../data/S072R01.edf", channels = channels)
small_ec = ut.read_file("../data/S072R02.edf", channels = channels)

# Adjacency Matrices with Bootstrap validation
ut.adjacency_matrix(ut.fit_model(small_eo, fs, resolution, "pdc", freq = freq, boot = True), 0.05, "small_eo_pdc")
ut.adjacency_matrix(ut.fit_model(small_ec, fs, resolution, "pdc", freq = freq, boot = True), 0.05, "small_ec_pdc")


######## 1.5

# Save a png of the network for each network
for network in [elem[:-4] for elem in ut.get_networks()]:
    ut.viz_graph(network)


######## 1.6

# Choosing an alternative frequency
alternative_frequency = 50

###PDC

# Fitting models
alt_eo_pdc = ut.fit_model(eo, fs, resolution, "pdc", alternative_frequency)
alt_ec_pdc = ut.fit_model(ec, fs, resolution, "pdc", alternative_frequency)

# Adjacency Matrices
Exemplo n.º 15
0
def main():
    parser = argparse.ArgumentParser(
        description='Graph CNNs for population graphs: '
        'classification of the ABIDE dataset')
    parser.add_argument(
        '--dropout',
        default=0.3,
        type=float,
        help='Dropout rate (1 - keep probability) (default: 0.3)')
    parser.add_argument(
        '--decay',
        default=5e-4,
        type=float,
        help='Weight for L2 loss on embedding matrix (default: 5e-4)')
    parser.add_argument(
        '--hidden1',
        default=32,
        type=int,
        help='Number of filters in hidden layers (default: 16)')
    # parser.add_argument('--lrate', default=0.005, type=float, help='Initial learning rate (default: 0.005)')
    parser.add_argument('--lrate',
                        default=1e-2,
                        type=float,
                        help='Initial learning rate (default: 0.005)')
    # parser.add_argument('--atlas', default='ho', help='atlas for network construction (node definition) (default: ho, '
    #                                                   'see preprocessed-connectomes-project.org/abide/Pipelines.html '
    #                                                   'for more options )')
    parser.add_argument('--epochs',
                        default=100,
                        type=int,
                        help='Number of epochs to train')
    parser.add_argument('--num_features',
                        default=2000,
                        type=int,
                        help='Number of features to keep for '
                        'the feature selection step (default: 2000)')
    parser.add_argument('--num_training',
                        default=1.0,
                        type=float,
                        help='Percentage of training set used for '
                        'training (default: 1.0)')
    parser.add_argument('--depth',
                        default=0,
                        type=int,
                        help='Number of additional hidden layers in the GCN. '
                        'Total number of hidden layers: 1+depth (default: 0)')
    parser.add_argument('--model',
                        default='gcn_cheby',
                        help='gcn model used (default: gcn_cheby, '
                        'uses chebyshev polynomials, '
                        'options: gcn, gcn_cheby, dense )')
    # parser.add_argument('--seed', default=89, type=int, help='Seed for random initialisation (default: 123)')
    parser.add_argument(
        '--folds',
        default=11,
        type=int,
        help='For cross validation, specifies which fold will be '
        'used. All folds are used if set to 11 (default: 11)')
    parser.add_argument(
        '--save',
        default=200,
        type=int,
        help='Parameter that specifies if results have to be saved. '
        'Results will be saved if set to 1 (default: 1)')
    parser.add_argument('--connectivity',
                        default='correlation',
                        help='Type of connectivity used for network '
                        'construction (default: correlation, '
                        'options: correlation, partial correlation, '
                        'tangent)')
    parser.add_argument('--train', default=1, type=int)

    args = parser.parse_args()
    start_time = time.time()

    # GCN Parameters
    params = dict()
    params['model'] = args.model  # gcn model using chebyshev polynomials
    params['lrate'] = args.lrate  # Initial learning rate
    params['epochs'] = args.epochs  # Number of epochs to train
    params['dropout'] = args.dropout  # Dropout rate (1 - keep probability)
    params['hidden1'] = args.hidden1  # Number of units in hidden layers
    params['decay'] = args.decay  # Weight for L2 loss on embedding matrix
    params['early_stopping'] = params[
        'epochs']  # Tolerance for early stopping (# of epochs). No early stopping if set to param.epochs
    params['max_degree'] = 3  # Maximum Chebyshev polynomial degree.
    params[
        'depth'] = args.depth  # number of additional hidden layers in the GCN. Total number of hidden layers: 1+depth
    # params['seed'] = args.seed                      # seed for random initialisation

    # GCN Parameters
    params[
        'num_features'] = args.num_features  # number of features for feature selection step
    params[
        'num_training'] = args.num_training  # percentage of training set used for training
    params[
        'train'] = args.train  # percentage of training set used for training
    # atlas = args.atlas                              # atlas for network construction (node definition)
    # connectivity = args.connectivity                # type of connectivity used for network construction

    # Get class labels
    # subject_IDs = Reader.get_ids()
    ##################################################################
    subject_IDs, shuffled_indices = Reader.get_ids()
    ##################################################################

    labels = Reader.get_labels(subject_IDs, score='DX_Group')  # labels

    # Get acquisition site
    # ####### sites = Reader.get_subject_score(subject_IDs, score='SITE_ID')
    ########## unique = np.unique(list(sites.values())).tolist()

    num_classes = 2  # MDD or HC
    num_nodes = len(subject_IDs)

    # Initialise variables for class labels and acquisition sites
    y_data = np.zeros([num_nodes, num_classes])
    y = np.zeros([num_nodes, 1])
    ########## site = np.zeros([num_nodes, 1], dtype=np.int)

    # Get class labels and acquisition site for all subjects
    for i in range(num_nodes):
        y_data[i, int(labels[subject_IDs[i]]) - 1] = 1
        y[i] = int(labels[subject_IDs[i]])
        ########## site[i] = unique.index(sites[subject_IDs[i]])

    import pickle
    # with open('./label.pkl', 'wb') as filehandle:
    #     pickle.dump(np.argmax(y_data, axis=1), filehandle)

    # Compute feature vectors (vectorised connectivity networks)
    ####### Granger Causality Analysis
    # data_fld = './granger_casuality'
    # features = Reader.load_ec_GCA(subject_IDs, data_fld)
    #######

    features = Reader.get_networks(subject_IDs,
                                   variable='correlation',
                                   isDynamic=False,
                                   isEffective=True)
    ############################################################
    shuffled_features = features[shuffled_indices]
    features = shuffled_features.copy()
    ############################################################
    # features = Reader.get_networks(subject_IDs, variable='graph_measure', isDynamic=True)

    # np.save('./MDD_dataset/features_GCA.npy', features)
    # np.save('./MDD_dataset/labels.npy', np.argmax(y_data, axis=1))

    # Compute population graph using gender and acquisition site
    graph = Reader.create_affinity_graph_from_scores(['Age', 'Sex'],
                                                     subject_IDs)
    # graph = Reader.create_affinity_graph_from_scores(['Sex'], subject_IDs)

    # Folds for cross validation experiments
    #num_samples = np.shape(features)[0]
    skf = StratifiedKFold(n_splits=10)
    #loo = LeaveOneOut()

    train_ind_set = []
    test_ind_set = []
    for train_ind, test_ind in reversed(
            list(skf.split(np.zeros(num_nodes), np.squeeze(y)))):
        train_ind_set.append(train_ind)
        test_ind_set.append(test_ind)
    cur_time = time.time()

    # import pickle
    # with open('./MDD_dataset/train_ind.pkl', 'wb') as filehandle:
    #     pickle.dump(train_ind_set, filehandle)
    # with open('./MDD_dataset/test_ind.pkl', 'wb') as filehandle:
    #     pickle.dump(test_ind_set, filehandle)

    if args.folds == 11:  # run cross validation on all folds
        scores = Parallel(n_jobs=10)(delayed(train_fold)(
            cv, train_ind, test_ind, test_ind, graph, features, y, y_data,
            params, subject_IDs, cur_time) for train_ind, test_ind, cv in zip(
                train_ind_set, test_ind_set, range(10)))

        test_auc = [x[0] for x in scores]
        test_accuracy = [x[1] for x in scores]
        test_sensitivity = [x[2] for x in scores]
        test_specificity = [x[3] for x in scores]
        test_pred = [x[4] for x in scores]
        test_lab = [x[5] for x in scores]

        print('Accuracy : ' + str(np.mean(test_accuracy)) + ' + ' +
              str(np.std(test_accuracy)))
        print('Sensitivity : ' + str(np.mean(test_sensitivity)) + ' + ' +
              str(np.std(test_sensitivity)))
        print('Specificity : ' + str(np.mean(test_specificity)) + ' + ' +
              str(np.std(test_specificity)))
        print('AUC : ' + str(np.mean(test_auc)) + ' + ' +
              str(np.std(test_auc)))

        # np.savez('./statistical_test/FC_Lasso_MLP_pred.npz', pred=test_pred, allow_pickle=True)
        # np.savez('./statistical_test/FC_Lasso_MLP_lab.npz', lab=test_lab, allow_pickle=True)

    else:  # compute results for only one fold

        cv_splits = list(skf.split(features, np.squeeze(y)))

        train = cv_splits[args.folds][0]
        test = cv_splits[args.folds][1]

        val = test

        scores_acc, scores_auc, scores_lin, scores_auc_lin, fold_size = train_fold(
            train, test, val, graph, features, y, y_data, params, subject_IDs,
            cur_time)

        print('overall linear accuracy %f' +
              str(np.sum(scores_lin) * 1. / fold_size))
        print('overall linear AUC %f' + str(np.mean(scores_auc_lin)))
        print('overall accuracy %f' + str(np.sum(scores_acc) * 1. / fold_size))
        print('overall AUC %f' + str(np.mean(scores_auc)))
Exemplo n.º 16
0
def list():
    """
    List available networks.
    """
    headers, networks = get_networks()
    typer.echo(tabulate(networks, headers, tablefmt="grid"))