Exemple #1
0
    def help_gpus(self):
        msg = QMessageBox()
        msg.setIcon(QMessageBox.Question)
        msg.setText("Setting the number of GPUs")
        msg.setWindowTitle("Number of GPUs")
        if not self.HAVE_CUDA:
            info = "No GPUs are detected on your system"
        else:
            gpu_id = 0
            is_available = True
            while is_available:
                try:
                    cmt.cuda_set_device(gpu_id)
                    is_available = True
                except Exception:
                    is_available = False
            info = "%d GPU is detected on your system" % (gpu_id + 1)

        msg.setInformativeText("SpyKING CIRCUS can use several GPUs\n"
                               "either locally or on multiple machine\n"
                               "using MPI (see documentation)"
                               "\n"
                               "\n"
                               "%s" % info)
        msg.setStandardButtons(QMessageBox.Close)
        msg.setDefaultButton(QMessageBox.Close)
        answer = msg.exec_()
Exemple #2
0
    def __init__(
            self,
            n_stages=10,  # number of training iterations
            latent_size=100,  # hidden layer size
            learning_rate=1e-2,  # learning rate
            momentum=0,  # momentum
            n_gibbs_steps=1,  # number of Gibbs sampling steps
            use_persistent_chain=False,  # use persistent CD?
            minibatch_size=128,  # size of minibatch
            load_data_every=-1,  # how frequently to load data to GPU
            seed=1234):
        self.stage = 0
        self.reload_data = True
        self.n_stages = n_stages
        self.latent_size = latent_size
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.n_gibbs_steps = n_gibbs_steps
        self.use_persistent_chain = use_persistent_chain
        self.minibatch_size = minibatch_size
        self.load_data_every = load_data_every
        self.seed = seed
        self.gpu_dataset = None

        cm.cuda_set_device(0)
        cm.init()

        self.rng = np.random.mtrand.RandomState(seed)
        cm.CUDAMatrix.init_random(seed=seed)
Exemple #3
0
    def __init__(   self,
                    n_stages=10,                # number of training iterations
                    latent_size=100,            # hidden layer size
                    learning_rate=1e-2,         # learning rate
                    momentum=0,                 # momentum
                    n_gibbs_steps=1,            # number of Gibbs sampling steps
                    use_persistent_chain=False, # use persistent CD?
                    minibatch_size=128,         # size of minibatch
                    load_data_every=-1,         # how frequently to load data to GPU
                    seed=1234
                    ):
        self.stage = 0
        self.reload_data = True
        self.n_stages = n_stages
        self.latent_size = latent_size
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.n_gibbs_steps = n_gibbs_steps
        self.use_persistent_chain = use_persistent_chain
        self.minibatch_size = minibatch_size
        self.load_data_every = load_data_every
        self.seed = seed
        self.gpu_dataset = None

        cm.cuda_set_device(0)
        cm.init()

        self.rng = np.random.mtrand.RandomState(seed)
        cm.CUDAMatrix.init_random(seed = seed)
def LockGPU():
  board = gpu_lock.obtain_lock_id()
  if board == -1:
    print 'No GPU board available.'
    sys.exit(1)
  else:
    cm.cuda_set_device(board)
    cm.cublas_init()
  return board
Exemple #5
0
def LockGPU(max_retries=10):
    for retry_count in range(max_retries):
        board = gpu_lock.obtain_lock_id()
        if board != -1:
            break
    if board == -1:
        print 'No GPU board available.'
        sys.exit(1)
    else:
        cm.cuda_set_device(board)
        cm.cublas_init()
Exemple #6
0
def LockGPU(max_retries=10):
    for retry_count in range(max_retries):
        board = gpu_lock.obtain_lock_id()
        if board != -1:
            break
    if board == -1:
        print "No GPU board available."
        sys.exit(1)
    else:
        cm.cuda_set_device(board)
        cm.cublas_init()
Exemple #7
0
def LockGPU(max_retries=10, board=-1):
  retry_count = 0
  while board == -1 and retry_count < max_retries:
    board = gpu_lock.obtain_lock_id()
    if board == -1:
      sleep(1)
      retry_count += 1
  if board == -1:
    print 'No GPU board available.'
    sys.exit(1)
  else:
    cm.cuda_set_device(board)
    cm.cublas_init()
  return board
Exemple #8
0
def LockGPU(max_retries=10):
  """ Locks a free GPU board and returns its id. """
  for retry_count in range(max_retries):
    board = gpu_lock.obtain_lock_id()
    if board != -1:
      break
    sleep(1)
  if board == -1:
    print 'No GPU board available.'
    sys.exit(1)
  else:
    cm.cuda_set_device(board)
    cm.cublas_init()
  return board
Exemple #9
0
def LockGPU(max_retries=10, board=-1):
    retry_count = 0
    while board == -1 and retry_count < max_retries:
        board = gpu_lock.obtain_lock_id()
        if board == -1:
            sleep(1)
            retry_count += 1
    if board == -1:
        print 'No GPU board available.'
        sys.exit(1)
    else:
        cm.cuda_set_device(board)
        cm.cublas_init()
    return board
Exemple #10
0
def LockGPU(max_retries=10):
  """ Locks a free GPU board and returns its id. """
  for retry_count in range(max_retries):
    board = gpu_lock.obtain_lock_id()
    if board != -1:
      break
    sleep(1)
  if board == -1:
    print 'No GPU board available.'
    sys.exit(1)
  else:
    cm.cuda_set_device(board)
    cm.cublas_init()
  return board
Exemple #11
0
def LockGPU(max_retries=10, board=-1):
    # retry_count = 0
    # while board == -1 and retry_count < max_retries:
    #   board = gpu_lock.obtain_lock_id()
    #   if board == -1:
    #     sleep(1)
    #     retry_count += 1
    # if board == -1:
    #   print 'No GPU board available.'
    #   sys.exit(1)
    # else:
    #   cm.cuda_set_device(board)
    #   cm.cublas_init()
    board = 3
    cm.cuda_set_device(board)
    cm.cublas_init()
    return board
Exemple #12
0
def run(args=None):
    usage = "usage : %prog [options]"
    parser = optparse.OptionParser(usage=usage)

    parser.add_option('--cfg_file', dest='cfg_file', default=None,
            help='File with settings from previously trained net')

    parser.add_option(
        "--test", action="store_true", dest="test", default=False)

    # Architecture
    parser.add_option(
        "--layerSize", dest="layerSize", type="int", default=1824)
    parser.add_option("--numLayers", dest="numLayers", type="int", default=5)
    parser.add_option(
        "--temporalLayer", dest="temporalLayer", type="int", default=3)

    # Optimization
    parser.add_option("--momentum", dest="momentum", type="float",
                      default=0.95)
    parser.add_option("--epochs", dest="epochs", type="int", default=20)
    parser.add_option("--step", dest="step", type="float", default=1e-5)
    parser.add_option("--anneal", dest="anneal", type="float", default=1.3,
                      help="Sets (learning rate := learning rate / anneal) after each epoch.")
    parser.add_option('--reg', dest='reg', type='float', default=0.0,
                      help='lambda for L2 regularization of the weight matrices')

    # Data
    parser.add_option("--dataDir", dest="dataDir", type="string",
                      default=TRAIN_DATA_DIR['fbank'])
    parser.add_option('--alisDir', dest='alisDir', type='string', default=TRAIN_ALIS_DIR)
    parser.add_option('--startFile', dest='startFile', type='int', default=1, help='Start file for running testing')
    parser.add_option("--numFiles", dest="numFiles", type="int", default=384)
    parser.add_option(
        "--inputDim", dest="inputDim", type="int", default=41 * 15)
    parser.add_option("--rawDim", dest="rawDim", type="int", default=41 * 15)
    parser.add_option("--outputDim", dest="outputDim", type="int", default=35)
    parser.add_option(
        "--maxUttLen", dest="maxUttLen", type="int", default=MAX_UTT_LEN)

    # Save/Load
    parser.add_option('--save_every', dest='save_every', type='int',
            default=10, help='During training, save parameters every x number of files')

    parser.add_option('--run_desc', dest='run_desc', type='string', default='', help='Description of experiment run')

    (opts, args) = parser.parse_args(args)

    if opts.cfg_file:
        cfg = load_config(opts.cfg_file)
    else:
        cfg = vars(opts)

    # These config values should be updated every time
    cfg['host'] = get_hostname()
    cfg['git_rev'] = get_git_revision()
    cfg['pid'] = os.getpid()

    # Create experiment output directory

    if not opts.cfg_file:
        time_string = str(TimeString())
        output_dir = pjoin(RUN_DIR, time_string)
        cfg['output_dir'] = output_dir
        if not os.path.exists(output_dir):
            print 'Creating %s' % output_dir
            os.makedirs(output_dir)
        opts.cfg_file = pjoin(output_dir, 'cfg.json')
    else:
        output_dir = cfg['output_dir']

    cfg['output_dir'] = output_dir
    cfg['in_file'] = pjoin(output_dir, 'params.pk')
    cfg['out_file'] = pjoin(output_dir, 'params.pk')
    cfg['test'] = opts.test
    if opts.test:
        cfg['dataDir'] = opts.dataDir
        cfg['numFiles'] = opts.numFiles
        cfg['startFile'] = opts.startFile
    if 'reg' not in cfg:
        cfg['reg'] = 0.0

    # Logging

    logging.basicConfig(filename=pjoin(output_dir, 'train.log'), level=logging.DEBUG)
    logger = logging.getLogger()
    logger.addHandler(logging.StreamHandler())
    logger.info('Running on %s' % cfg['host'])

    # seed for debugging, turn off when stable
    np.random.seed(33)
    import random
    random.seed(33)

    if 'CUDA_DEVICE' in os.environ:
        cm.cuda_set_device(int(os.environ['CUDA_DEVICE']))
    else:
        cm.cuda_set_device(0)  # Default

    opts = CfgStruct(**cfg)

    # Testing
    if opts.test:
        test(opts)
        return

    alisDir = opts.alisDir if opts.alisDir else opts.dataDir
    loader = dl.DataLoader(opts.dataDir, opts.rawDim, opts.inputDim, alisDir)

    nn = rnnet.NNet(opts.inputDim, opts.outputDim, opts.layerSize, opts.numLayers,
                    opts.maxUttLen, temporalLayer=opts.temporalLayer, reg=opts.reg)
    nn.initParams()

    SGD = sgd.SGD(nn, opts.maxUttLen, alpha=opts.step, momentum=opts.momentum)

    # Dump config
    cfg['param_count'] = nn.paramCount()
    dump_config(cfg, opts.cfg_file)

    # Training
    epoch_file = pjoin(output_dir, 'epoch')
    if os.path.exists(epoch_file):
        start_epoch = int(open(epoch_file, 'r').read()) + 1
    else:
        start_epoch = 0

    # Load model if specified
    if os.path.exists(opts.in_file):
        with open(opts.in_file, 'r') as fid:
            SGD.fromFile(fid)
            SGD.alpha = SGD.alpha / (opts.anneal ** start_epoch)
            nn.fromFile(fid)

    num_files_file = pjoin(output_dir, 'num_files')

    for k in range(start_epoch, opts.epochs):
        perm = np.random.permutation(opts.numFiles) + 1
        loader.loadDataFileAsynch(perm[0])

        file_start = 0
        if k == start_epoch:
            if os.path.exists(num_files_file):
                file_start = int(open(num_files_file, 'r').read().strip())
                logger.info('Starting from file %d, epoch %d' % (file_start, start_epoch))
        else:
            open(num_files_file, 'w').write(str(file_start))

        for i in xrange(file_start, perm.shape[0]):
            start = time.time()
            data_dict, alis, keys, sizes = loader.getDataAsynch()
            # Prefetch
            if i + 1 < perm.shape[0]:
                loader.loadDataFileAsynch(perm[i + 1])
            SGD.run(data_dict, alis, keys, sizes)
            end = time.time()
            logger.info('File time %f' % (end - start))

            # Save parameters and cost
            if (i+1) % opts.save_every == 0:
                logger.info('Saving parameters')
                with open(opts.out_file, 'wb') as fid:
                    SGD.toFile(fid)
                    nn.toFile(fid)
                    open(num_files_file, 'w').write('%d' % (i+1))
                logger.info('Done saving parameters')
                with open(pjoin(output_dir, 'last_cost'), 'w') as fid:
                    if opts.reg > 0.0:
                        fid.write(str(SGD.expcost[-1] - SGD.regcost[-1]))
                    else:
                        fid.write(str(SGD.expcost[-1]))

        # Save epoch completed
        open(pjoin(output_dir, 'epoch'), 'w').write(str(k))

        # Save parameters for the epoch
        with open(opts.out_file + '.epoch{0:02}'.format(k), 'wb') as fid:
            SGD.toFile(fid)
            nn.toFile(fid)

        SGD.alpha = SGD.alpha / opts.anneal

    # Run now complete, touch sentinel file
    touch_file(pjoin(output_dir, 'sentinel'))
Exemple #13
0
    def train(self):
        '''
        Main train function : modified version of the original train function.
        Additions : GPU selection (useful for multi-GPU machines)
					Saving the sum of the square of the data for post-processing
					Visible data are saved
					Data samples are permuted for training
					Weights are saved every 100 training epochs
					Training energy is visualized every 100 training epochs
		NOTE : anneal learning rate used in the initial code, is NOT used here!
        '''
        #plt.ion()
        f1 = plt.figure()
        ax1 = f1.add_subplot(111)
        #ax2 = f1.add_subplot(122)
        #plt.show()

        cmt.cuda_set_device(self.gpuId)
        cmt.cublas_init()
        cmt.CUDAMatrix.init_random(1)

        np.random.seed(self.npRandSeed)
        prng = RandomState(self.npRandState)

        ################################################################
        ##################### CHANGE PATH ##############################
        # Move to current experiment path:
        os.chdir(self.saveDir)
        # Get current path:
        os.getcwd()

        self.plotsDir = 'plots'
        #self.probabilitiesDir = 'p_all'
        if not os.path.isdir(self.plotsDir):
            os.makedirs(self.plotsDir)
        if not os.path.isdir(self.plotsDir + '/energy'):
            os.makedirs(self.plotsDir + '/energy')
        #if not os.path.isdir(self.probabilitiesDir):
        #	os.makedirs(self.probabilitiesDir)
        if not os.path.isdir('weights'):
            os.makedirs('weights')

        d = self.d.astype(np.float32)
        print("visible size: ", d.shape)

        dsq = np.square(d)
        lsq = np.sum(dsq, axis=0)
        with open('lsqComplete.pkl', 'wb') as pklFile:
            cPickle.dump(lsq, pklFile)

        del dsq, lsq

        # Save visible data :
        visData = d
        np.savez('visData.npz',
                 data=d,
                 obsKeys=self.obsKeys,
                 epochTime=self.epochTime)

        with open('visData.txt', 'w') as f:
            f.write("\n Dataset : %s" % (self.dataFilename))
            f.write("\n visData size: %s " % str(visData.shape))
            f.write("\n visData type: %s " % str(visData.dtype))
            f.write("\n \n visData Range: %s " %
                    str(np.max(visData, axis=0) - np.min(visData, axis=0)))
            f.write("\n \n visData min: %s " % str(np.min(visData, axis=0)))
            f.write("\n \n visData max: %s " % str(np.max(visData, axis=0)))
            f.write("\n \n visData mean: %s " % str(np.mean(visData, axis=0)))
            f.write("\n \n visData std: %s " % str(np.std(visData, axis=0)))
            f.close()

        del visData  #if not needed for computing the latent states

        permIdx = prng.permutation(d.shape[0])

        d = d[permIdx, :]

        #subsetting train and test datasets
        #trainPerc = 0.7
        #trainSampNum = int(np.ceil(trainPerc*d.shape[0]))
        #trainSampNum = int(np.floor(trainSampNum/self.batch_size)*self.batch_size)
        #testSampNum = int(d.shape[0]-trainSampNum-1)

        # The test dataset is not used at the moment, it can be used as
        # a validation set to check for overfitting. To use it, uncomment
        # all the variables with 'test' in their name

        #~ d_test = d[trainSampNum+1:,:]
        #d = d[:trainSampNum,:]
        #obsKeys = self.obsKeys[:trainSampNum]

        totnumcases = d.shape[0]
        num_vis = d.shape[1]

        num_batches = int(totnumcases / self.batch_size)
        print("num_batches: ", num_batches)
        dev_dat = cmt.CUDAMatrix(d.T)  # VxP
        #~ test_dat = cmt.CUDAMatrix(d_test.T)

        del d, self.d, self.epochTime, self.obsKeys

        # training parameters (as in the original code by Ranzato)
        epsilon = self.epsilon
        epsilonVF = 2 * epsilon
        epsilonFH = 0.02 * epsilon
        epsilonb = 0.02 * epsilon
        epsilonw_mean = 0.2 * epsilon
        epsilonb_mean = 0.1 * epsilon
        weightcost_final = self.weightcost_final

        # HMC setting
        hmc_step_nr = self.hmc_step_nr
        hmc_step = 0.01
        hmc_target_ave_rej = self.hmc_target_ave_rej
        hmc_ave_rej = hmc_target_ave_rej

        # initialize weights
        VF = cmt.CUDAMatrix(
            np.array(0.02 * prng.randn(num_vis, self.num_fac),
                     dtype=np.float32,
                     order='F'))  # VxH
        if self.apply_mask == 0:
            FH = cmt.CUDAMatrix(
                np.array(np.eye(self.num_fac, self.num_hid_cov),
                         dtype=np.float32,
                         order='F'))  # HxO
        else:
            dd = loadmat(
                'your_FHinit_mask_file.mat'
            )  # see CVPR2010paper_material/topo2D_3x3_stride2_576filt.mat for an example
            FH = cmt.CUDAMatrix(np.array(dd["FH"], dtype=np.float32,
                                         order='F'))
        bias_cov = cmt.CUDAMatrix(
            np.array(2.0 * np.ones((self.num_hid_cov, 1)),
                     dtype=np.float32,
                     order='F'))
        bias_vis = cmt.CUDAMatrix(
            np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F'))
        w_mean = cmt.CUDAMatrix(
            np.array(0.05 * prng.randn(num_vis, self.num_hid_mean),
                     dtype=np.float32,
                     order='F'))  # VxH
        bias_mean = cmt.CUDAMatrix(
            np.array(-2.0 * np.ones((self.num_hid_mean, 1)),
                     dtype=np.float32,
                     order='F'))

        # initialize variables to store derivatives
        VFinc = cmt.CUDAMatrix(
            np.array(np.zeros((num_vis, self.num_fac)),
                     dtype=np.float32,
                     order='F'))
        FHinc = cmt.CUDAMatrix(
            np.array(np.zeros((self.num_fac, self.num_hid_cov)),
                     dtype=np.float32,
                     order='F'))
        bias_covinc = cmt.CUDAMatrix(
            np.array(np.zeros((self.num_hid_cov, 1)),
                     dtype=np.float32,
                     order='F'))
        bias_visinc = cmt.CUDAMatrix(
            np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F'))
        w_meaninc = cmt.CUDAMatrix(
            np.array(np.zeros((num_vis, self.num_hid_mean)),
                     dtype=np.float32,
                     order='F'))
        bias_meaninc = cmt.CUDAMatrix(
            np.array(np.zeros((self.num_hid_mean, 1)),
                     dtype=np.float32,
                     order='F'))

        # initialize temporary storage
        data = cmt.CUDAMatrix(
            np.array(np.empty((num_vis, self.batch_size)),
                     dtype=np.float32,
                     order='F'))  # VxP
        normdata = cmt.CUDAMatrix(
            np.array(np.empty((num_vis, self.batch_size)),
                     dtype=np.float32,
                     order='F'))  # VxP
        negdataini = cmt.CUDAMatrix(
            np.array(np.empty((num_vis, self.batch_size)),
                     dtype=np.float32,
                     order='F'))  # VxP
        feat = cmt.CUDAMatrix(
            np.array(np.empty((self.num_fac, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        featsq = cmt.CUDAMatrix(
            np.array(np.empty((self.num_fac, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        negdata = cmt.CUDAMatrix(
            np.array(prng.randn(num_vis, self.batch_size),
                     dtype=np.float32,
                     order='F'))
        old_energy = cmt.CUDAMatrix(
            np.array(np.zeros((1, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        new_energy = cmt.CUDAMatrix(
            np.array(np.zeros((1, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        energy = cmt.CUDAMatrix(
            np.array(np.zeros((1, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        gradient = cmt.CUDAMatrix(
            np.array(np.empty((num_vis, self.batch_size)),
                     dtype=np.float32,
                     order='F'))  # VxP
        normgradient = cmt.CUDAMatrix(
            np.array(np.empty((num_vis, self.batch_size)),
                     dtype=np.float32,
                     order='F'))  # VxP
        thresh = cmt.CUDAMatrix(
            np.array(np.zeros((1, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        feat_mean = cmt.CUDAMatrix(
            np.array(np.empty((self.num_hid_mean, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        vel = cmt.CUDAMatrix(
            np.array(prng.randn(num_vis, self.batch_size),
                     dtype=np.float32,
                     order='F'))
        length = cmt.CUDAMatrix(
            np.array(np.zeros((1, self.batch_size)),
                     dtype=np.float32,
                     order='F'))  # 1xP
        lengthsq = cmt.CUDAMatrix(
            np.array(np.zeros((1, self.batch_size)),
                     dtype=np.float32,
                     order='F'))  # 1xP
        normcoeff = cmt.CUDAMatrix(
            np.array(np.zeros((1, self.batch_size)),
                     dtype=np.float32,
                     order='F'))  # 1xP

        # commented to avoid computing the energy on test data
        #~ data_test = cmt.CUDAMatrix( np.array(np.empty((num_vis, testSampNum)), dtype=np.float32, order='F')) # Vxtest_batch
        #~ normdata_test = cmt.CUDAMatrix( np.array(np.empty((num_vis, testSampNum)), dtype=np.float32, order='F')) # Vxtest_batch
        #~ length_test = cmt.CUDAMatrix( np.array(np.zeros((1, testSampNum)), dtype=np.float32, order='F')) # 1xtest_batch
        #~ lengthsq_test = cmt.CUDAMatrix( np.array(np.zeros((1, testSampNum)), dtype=np.float32, order='F')) # 1xtest_batch
        #~ normcoeff_test = cmt.CUDAMatrix( np.array(np.zeros((1, testSampNum)), dtype=np.float32, order='F')) # 1xtest_batch
        #~ vel_test = cmt.CUDAMatrix( np.array(prng.randn(num_vis, testSampNum), dtype=np.float32, order='F'))
        #~ feat_test = cmt.CUDAMatrix( np.array(np.empty((self.num_fac, testSampNum)), dtype=np.float32, order='F'))
        #~ featsq_test = cmt.CUDAMatrix( np.array(np.empty((self.num_fac, testSampNum)), dtype=np.float32, order='F'))
        #~ feat_mean_test = cmt.CUDAMatrix( np.array(np.empty((self.num_hid_mean, testSampNum)), dtype=np.float32, order='F'))
        #~ energy_test = cmt.CUDAMatrix( np.array(np.zeros((1, testSampNum)), dtype=np.float32, order='F'))

        if self.apply_mask == 1:  # this used to constrain very large FH matrices only allowing to change values in a neighborhood
            dd = loadmat('your_FHinit_mask_file.mat')
            mask = cmt.CUDAMatrix(
                np.array(dd["mask"], dtype=np.float32, order='F'))
        normVF = 1
        small = 0.5

        # other temporary vars
        t1 = cmt.CUDAMatrix(
            np.array(np.empty((self.num_hid_cov, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        t2 = cmt.CUDAMatrix(
            np.array(np.empty((self.num_hid_cov, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        t3 = cmt.CUDAMatrix(
            np.array(np.empty((self.num_fac, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        t4 = cmt.CUDAMatrix(
            np.array(np.empty((1, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        t5 = cmt.CUDAMatrix(
            np.array(np.empty((1, 1)), dtype=np.float32, order='F'))
        t6 = cmt.CUDAMatrix(
            np.array(np.empty((num_vis, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        t7 = cmt.CUDAMatrix(
            np.array(np.empty((num_vis, self.batch_size)),
                     dtype=np.float32,
                     order='F'))
        t8 = cmt.CUDAMatrix(
            np.array(np.empty((num_vis, self.num_fac)),
                     dtype=np.float32,
                     order='F'))
        t9 = cmt.CUDAMatrix(
            np.array(np.zeros((self.num_fac, self.num_hid_cov)),
                     dtype=np.float32,
                     order='F'))
        t10 = cmt.CUDAMatrix(
            np.array(np.empty((1, self.num_fac)), dtype=np.float32, order='F'))
        t11 = cmt.CUDAMatrix(
            np.array(np.empty((1, self.num_hid_cov)),
                     dtype=np.float32,
                     order='F'))

        # commented to avoid computing the energy on test data
        #~ t1_test = cmt.CUDAMatrix( np.array(np.empty((self.num_hid_cov, testSampNum)), dtype=np.float32, order='F'))
        #~ t2_test = cmt.CUDAMatrix( np.array(np.empty((self.num_hid_cov, testSampNum)), dtype=np.float32, order='F'))
        #~ t3_test = cmt.CUDAMatrix( np.array(np.empty((self.num_fac, testSampNum)), dtype=np.float32, order='F'))
        #~ t4_test = cmt.CUDAMatrix( np.array(np.empty((1,testSampNum)), dtype=np.float32, order='F'))
        #~ t5_test = cmt.CUDAMatrix( np.array(np.empty((1,1)), dtype=np.float32, order='F'))
        #~ t6_test = cmt.CUDAMatrix( np.array(np.empty((num_vis, testSampNum)), dtype=np.float32, order='F'))

        meanEnergy = np.zeros(self.num_epochs)
        minEnergy = np.zeros(self.num_epochs)
        maxEnergy = np.zeros(self.num_epochs)
        #~ meanEnergy_test = np.zeros(self.num_epochs)
        #~ minEnergy_test = np.zeros(self.num_epochs)
        #~ maxEnergy_test = np.zeros(self.num_epochs)

        # start training
        for epoch in range(self.num_epochs):

            print "Epoch " + str(epoch)

            # anneal learning rates as found in the original code -
            # uncomment if you wish to use annealing!
            #~ epsilonVFc    = epsilonVF/max(1,epoch/20)
            #~ epsilonFHc    = epsilonFH/max(1,epoch/20)
            #~ epsilonbc    = epsilonb/max(1,epoch/20)
            #~ epsilonw_meanc = epsilonw_mean/max(1,epoch/20)
            #~ epsilonb_meanc = epsilonb_mean/max(1,epoch/20)

            # no annealing is used in our experiments because learning
            # was stopping too early
            epsilonVFc = epsilonVF
            epsilonFHc = epsilonFH
            epsilonbc = epsilonb
            epsilonw_meanc = epsilonw_mean
            epsilonb_meanc = epsilonb_mean

            weightcost = weightcost_final

            if epoch <= self.startFH:
                epsilonFHc = 0
            if epoch <= self.startwd:
                weightcost = 0

            # commented to avoid computing the energy on test data
            #~ data_test = test_dat

            #~ data_test.mult(data_test, target = t6_test) # DxP
            #~ t6_test.sum(axis = 0, target = lengthsq_test) # 1xP
            #~ lengthsq_test.mult(1./num_vis) # normalize by number of components (like std)
            #~ lengthsq_test.add(small) # small avoids division by 0
            #~ cmt.sqrt(lengthsq_test, target = length_test)
            #~ length_test.reciprocal(target = normcoeff_test) # 1xP
            #~ data_test.mult_by_row(normcoeff_test, target = normdata_test) # normalized data

            for batch in range(num_batches):

                # get current minibatch
                data = dev_dat.slice(
                    batch * self.batch_size, (batch + 1) *
                    self.batch_size)  # DxP (nr dims x nr samples)

                # normalize input data
                data.mult(data, target=t6)  # DxP
                t6.sum(axis=0, target=lengthsq)  # 1xP
                lengthsq.mult(
                    1. /
                    num_vis)  # normalize by number of components (like std)
                lengthsq.add(small)  # small avoids division by 0
                cmt.sqrt(lengthsq, target=length)
                length.reciprocal(target=normcoeff)  # 1xP
                data.mult_by_row(normcoeff, target=normdata)  # normalized data
                ## compute positive sample derivatives
                # covariance part
                cmt.dot(VF.T, normdata,
                        target=feat)  # HxP (nr facs x nr samples)
                feat.mult(feat, target=featsq)  # HxP
                cmt.dot(FH.T, featsq,
                        target=t1)  # OxP (nr cov hiddens x nr samples)
                t1.mult(-0.5)
                t1.add_col_vec(bias_cov)  # OxP
                t1.apply_sigmoid(target=t2)  # OxP
                cmt.dot(featsq, t2.T, target=FHinc)  # HxO
                cmt.dot(FH, t2, target=t3)  # HxP
                t3.mult(feat)
                cmt.dot(normdata, t3.T, target=VFinc)  # VxH
                t2.sum(axis=1, target=bias_covinc)
                bias_covinc.mult(-1)
                # visible bias
                data.sum(axis=1, target=bias_visinc)
                bias_visinc.mult(-1)
                # mean part
                cmt.dot(w_mean.T, data,
                        target=feat_mean)  # HxP (nr mean hiddens x nr samples)
                feat_mean.add_col_vec(bias_mean)  # HxP
                feat_mean.apply_sigmoid()  # HxP
                feat_mean.mult(-1)
                cmt.dot(data, feat_mean.T, target=w_meaninc)
                feat_mean.sum(axis=1, target=bias_meaninc)

                # HMC sampling: draw an approximate sample from the model
                if self.doPCD == 0:  # CD-1 (set negative data to current training samples)
                    hmc_step, hmc_ave_rej = self.draw_HMC_samples(
                        data, negdata, normdata, vel, gradient, normgradient,
                        new_energy, old_energy, VF, FH, bias_cov, bias_vis,
                        w_mean, bias_mean, hmc_step, hmc_step_nr, hmc_ave_rej,
                        hmc_target_ave_rej, t1, t2, t3, t4, t5, t6, t7, thresh,
                        feat, featsq, self.batch_size, feat_mean, length,
                        lengthsq, normcoeff, small, num_vis)
                else:  # PCD-1 (use previous negative data as starting point for chain)
                    negdataini.assign(negdata)
                    hmc_step, hmc_ave_rej = self.draw_HMC_samples(
                        negdataini, negdata, normdata, vel, gradient,
                        normgradient, new_energy, old_energy, VF, FH, bias_cov,
                        bias_vis, w_mean, bias_mean, hmc_step, hmc_step_nr,
                        hmc_ave_rej, hmc_target_ave_rej, t1, t2, t3, t4, t5,
                        t6, t7, thresh, feat, featsq, self.batch_size,
                        feat_mean, length, lengthsq, normcoeff, small, num_vis)

                # compute derivatives at the negative samples
                # normalize input data
                negdata.mult(negdata, target=t6)  # DxP
                t6.sum(axis=0, target=lengthsq)  # 1xP
                lengthsq.mult(
                    1. /
                    num_vis)  # normalize by number of components (like std)
                lengthsq.add(small)
                cmt.sqrt(lengthsq, target=length)
                length.reciprocal(target=normcoeff)  # 1xP
                negdata.mult_by_row(normcoeff,
                                    target=normdata)  # normalized data
                # covariance part
                cmt.dot(VF.T, normdata, target=feat)  # HxP
                feat.mult(feat, target=featsq)  # HxP
                cmt.dot(FH.T, featsq, target=t1)  # OxP
                t1.mult(-0.5)
                t1.add_col_vec(bias_cov)  # OxP
                t1.apply_sigmoid(target=t2)  # OxP
                FHinc.subtract_dot(featsq, t2.T)  # HxO
                FHinc.mult(0.5)
                cmt.dot(FH, t2, target=t3)  # HxP
                t3.mult(feat)
                VFinc.subtract_dot(normdata, t3.T)  # VxH
                bias_covinc.add_sums(t2, axis=1)
                # visible bias
                bias_visinc.add_sums(negdata, axis=1)
                # mean part
                cmt.dot(w_mean.T, negdata, target=feat_mean)  # HxP
                feat_mean.add_col_vec(bias_mean)  # HxP
                feat_mean.apply_sigmoid()  # HxP
                w_meaninc.add_dot(negdata, feat_mean.T)
                bias_meaninc.add_sums(feat_mean, axis=1)

                # update parameters
                VFinc.add_mult(VF.sign(), weightcost)  # L1 regularization
                VF.add_mult(VFinc, -epsilonVFc / self.batch_size)
                # normalize columns of VF: normalize by running average of their norm
                VF.mult(VF, target=t8)
                t8.sum(axis=0, target=t10)
                cmt.sqrt(t10)
                t10.sum(axis=1, target=t5)
                t5.copy_to_host()
                normVF = .95 * normVF + (
                    .05 / self.num_fac) * t5.numpy_array[0, 0]  # estimate norm
                t10.reciprocal()
                VF.mult_by_row(t10)
                VF.mult(normVF)
                bias_cov.add_mult(bias_covinc, -epsilonbc / self.batch_size)
                bias_vis.add_mult(bias_visinc, -epsilonbc / self.batch_size)

                if epoch > self.startFH:
                    FHinc.add_mult(FH.sign(), weightcost)  # L1 regularization
                    FH.add_mult(FHinc, -epsilonFHc / self.batch_size)  # update
                    # set to 0 negative entries in FH
                    FH.greater_than(0, target=t9)
                    FH.mult(t9)
                    if self.apply_mask == 1:
                        FH.mult(mask)
                    # normalize columns of FH: L1 norm set to 1 in each column
                    FH.sum(axis=0, target=t11)
                    t11.reciprocal()
                    FH.mult_by_row(t11)
                w_meaninc.add_mult(w_mean.sign(), weightcost)
                w_mean.add_mult(w_meaninc, -epsilonw_meanc / self.batch_size)
                bias_mean.add_mult(bias_meaninc,
                                   -epsilonb_meanc / self.batch_size)

            if self.verbose == 1:
                print "VF: " + '%3.2e' % VF.euclid_norm(
                ) + ", DVF: " + '%3.2e' % (
                    VFinc.euclid_norm() * (epsilonVFc / self.batch_size)
                ) + ", FH: " + '%3.2e' % FH.euclid_norm(
                ) + ", DFH: " + '%3.2e' % (
                    FHinc.euclid_norm() * (epsilonFHc / self.batch_size)
                ) + ", bias_cov: " + '%3.2e' % bias_cov.euclid_norm(
                ) + ", Dbias_cov: " + '%3.2e' % (
                    bias_covinc.euclid_norm() * (epsilonbc / self.batch_size)
                ) + ", bias_vis: " + '%3.2e' % bias_vis.euclid_norm(
                ) + ", Dbias_vis: " + '%3.2e' % (
                    bias_visinc.euclid_norm() * (epsilonbc / self.batch_size)
                ) + ", wm: " + '%3.2e' % w_mean.euclid_norm(
                ) + ", Dwm: " + '%3.2e' % (
                    w_meaninc.euclid_norm() *
                    (epsilonw_meanc / self.batch_size)
                ) + ", bm: " + '%3.2e' % bias_mean.euclid_norm(
                ) + ", Dbm: " + '%3.2e' % (
                    bias_meaninc.euclid_norm() *
                    (epsilonb_meanc / self.batch_size)
                ) + ", step: " + '%3.2e' % hmc_step + ", rej: " + '%3.2e' % hmc_ave_rej
                with open('terminal.txt', 'a') as f:
                    f.write('\n' + "epoch: %s" % str(epoch) + ", VF: " +
                            '%3.2e' % VF.euclid_norm() + ", DVF: " + '%3.2e' %
                            (VFinc.euclid_norm() *
                             (epsilonVFc / self.batch_size)) + ", FH: " +
                            '%3.2e' % FH.euclid_norm() + ", DFH: " + '%3.2e' %
                            (FHinc.euclid_norm() *
                             (epsilonFHc / self.batch_size)) + ", bias_cov: " +
                            '%3.2e' % bias_cov.euclid_norm() +
                            ", Dbias_cov: " + '%3.2e' %
                            (bias_covinc.euclid_norm() *
                             (epsilonbc / self.batch_size)) + ", bias_vis: " +
                            '%3.2e' % bias_vis.euclid_norm() +
                            ", Dbias_vis: " + '%3.2e' %
                            (bias_visinc.euclid_norm() *
                             (epsilonbc / self.batch_size)) + ", wm: " +
                            '%3.2e' % w_mean.euclid_norm() + ", Dwm: " +
                            '%3.2e' % (w_meaninc.euclid_norm() *
                                       (epsilonw_meanc / self.batch_size)) +
                            ", bm: " + '%3.2e' % bias_mean.euclid_norm() +
                            ", Dbm: " + '%3.2e' %
                            (bias_meaninc.euclid_norm() *
                             (epsilonb_meanc / self.batch_size)) + ", step: " +
                            '%3.2e' % hmc_step + ", rej: " +
                            '%3.2e' % hmc_ave_rej)
                sys.stdout.flush()

            # commented to avoid computing the energy on trainig data
            self.compute_energy_mcRBM_visual(data, normdata, energy, VF, FH,
                                             bias_cov, bias_vis, w_mean,
                                             bias_mean, t1, t2, t6, feat,
                                             featsq, feat_mean, length,
                                             lengthsq, normcoeff, small,
                                             num_vis)
            energy.copy_to_host()
            meanEnergy[epoch] = np.mean(energy.numpy_array)
            minEnergy[epoch] = np.min(energy.numpy_array)
            maxEnergy[epoch] = np.max(energy.numpy_array)

            # commented to avoid computing the energy on test data
            #~ self.compute_energy_mcRBM_visual(data_test,normdata_test,energy_test,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1_test,t2_test,t6_test,feat_test,featsq_test,feat_mean_test,length_test,lengthsq_test,normcoeff_test,small,num_vis)
            #~ energy_test.copy_to_host()
            #~ meanEnergy_test[epoch] = np.mean(energy_test.numpy_array)
            #~ minEnergy_test[epoch] = np.min(energy_test.numpy_array)
            #~ maxEnergy_test[epoch] = np.max(energy_test.numpy_array)

            ax1.cla()
            ax1.plot(range(epoch), meanEnergy[0:epoch])
            ax1.plot(range(epoch), maxEnergy[0:epoch])
            ax1.plot(range(epoch), minEnergy[0:epoch])

            if np.mod(epoch, 100) == 0:
                #f1.savefig(output_folder + str(epoch)+'_'+'fig.png')
                f1.savefig(self.plotsDir +
                           '/energy/energyAt_%s.png' % str(epoch))

            # back-up every once in a while
            if np.mod(epoch, 100) == 0:
                VF.copy_to_host()
                FH.copy_to_host()
                bias_cov.copy_to_host()
                w_mean.copy_to_host()
                bias_mean.copy_to_host()
                bias_vis.copy_to_host()
                savemat(
                    "./weights/ws_temp%s" % str(epoch), {
                        'VF': VF.numpy_array,
                        'FH': FH.numpy_array,
                        'bias_cov': bias_cov.numpy_array,
                        'bias_vis': bias_vis.numpy_array,
                        'w_mean': w_mean.numpy_array,
                        'bias_mean': bias_mean.numpy_array,
                        'epoch': epoch
                    })

                # uncomment if computing the energy in order to store its evolution throghout training
                #~ savemat(self.refDir + '/' + "training_energy_" + str(self.num_fac) + "_cov" + str(self.num_hid_cov) + "_mean" + str(self.num_hid_mean), {'meanEnergy':meanEnergy,'meanEnergy_test':meanEnergy_test,'maxEnergy': maxEnergy, 'maxEnergy_test': maxEnergy_test, 'minEnergy': minEnergy, 'minEnergy_test': minEnergy_test, 'epoch':epoch})
                #savemat("training_energy_" + str(self.num_fac) + "_cov" + str(self.num_hid_cov) + "_mean" + str(self.num_hid_mean), {'meanEnergy':meanEnergy, 'maxEnergy': maxEnergy, 'minEnergy': minEnergy, 'epoch':epoch})

            # in order to stop the training gracefully, create an empty file
            # named 'stop_now' in the folder containing the experiment
            # configuration file
            if os.path.isfile('stop_now'):
                break

        # final back-up
        VF.copy_to_host()
        FH.copy_to_host()
        bias_cov.copy_to_host()
        bias_vis.copy_to_host()
        w_mean.copy_to_host()
        bias_mean.copy_to_host()
        savemat(
            "ws_fac%s" % str(self.num_fac) + "_cov%s" % str(self.num_hid_cov) +
            "_mean%s" % str(self.num_hid_mean), {
                'VF': VF.numpy_array,
                'FH': FH.numpy_array,
                'bias_cov': bias_cov.numpy_array,
                'bias_vis': bias_vis.numpy_array,
                'w_mean': w_mean.numpy_array,
                'bias_mean': bias_mean.numpy_array,
                'epoch': epoch
            })

        # uncomment if computing the energy in order to store its evolution throghout training
        #~ savemat(self.refDir + '/' + "training_energy_" + str(self.num_fac) + "_cov" + str(self.num_hid_cov) + "_mean" + str(self.num_hid_mean), {'meanEnergy':meanEnergy,'meanEnergy_test':meanEnergy_test,'maxEnergy': maxEnergy, 'maxEnergy_test': maxEnergy_test, 'minEnergy': minEnergy, 'minEnergy_test': minEnergy_test, 'epoch':epoch})
        savemat(
            "training_energy_" + str(self.num_fac) + "_cov" +
            str(self.num_hid_cov) + "_mean" + str(self.num_hid_mean), {
                'meanEnergy': meanEnergy,
                'maxEnergy': maxEnergy,
                'minEnergy': minEnergy,
                'epoch': epoch
            })

        # Compute states if desired:
        # normalise data for covariance hidden:
        #dsq = np.square(visData)
        #lsq = np.sum(dsq, axis=0)
        #lsq /= visData.shape[1]
        #lsq += np.spacing(1)
        #l = np.sqrt(lsq)
        #normD = visData/l

        #logisticArg_c = (-0.5*np.dot(FH.numpy_array.T, np.square(np.dot(VF.numpy_array.T, normD.T))) + bias_cov.numpy_array).T
        #p_hc = logisticFunc(logisticArg_c)

        #logisticArg_m = np.dot(visData, w_mean.numpy_array) + bias_mean.numpy_array.T
        #p_hm = logisticFunc(logisticArg_m)

        #p_all = np.concatenate((p_hc, p_hm), axis=1)
        #savemat(self.probabilitiesDir + '/pAll_%i.mat' % epoch, mdict={'p_all':p_all})

        with open('done', 'w') as doneFile:
            doneFile.write(
                datetime.strftime(datetime.now(), '%d/%m/%Y %H:%M:%S'))
Exemple #14
0
def run(args=None):
    usage = "usage : %prog [options]"
    parser = optparse.OptionParser(usage=usage)

    parser.add_option('--cfg_file',
                      dest='cfg_file',
                      default=None,
                      help='File with settings from previously trained net')

    parser.add_option("--test",
                      action="store_true",
                      dest="test",
                      default=False)

    # Architecture
    parser.add_option("--layerSize",
                      dest="layerSize",
                      type="int",
                      default=1824)
    parser.add_option("--numLayers", dest="numLayers", type="int", default=5)
    parser.add_option("--temporalLayer",
                      dest="temporalLayer",
                      type="int",
                      default=3)

    # Optimization
    parser.add_option("--momentum",
                      dest="momentum",
                      type="float",
                      default=0.95)
    parser.add_option("--epochs", dest="epochs", type="int", default=20)
    parser.add_option("--step", dest="step", type="float", default=1e-5)
    parser.add_option(
        "--anneal",
        dest="anneal",
        type="float",
        default=1.3,
        help="Sets (learning rate := learning rate / anneal) after each epoch."
    )
    parser.add_option(
        '--reg',
        dest='reg',
        type='float',
        default=0.0,
        help='lambda for L2 regularization of the weight matrices')

    # Data
    parser.add_option("--dataDir",
                      dest="dataDir",
                      type="string",
                      default=TRAIN_DATA_DIR['fbank'])
    parser.add_option('--alisDir',
                      dest='alisDir',
                      type='string',
                      default=TRAIN_ALIS_DIR)
    parser.add_option('--startFile',
                      dest='startFile',
                      type='int',
                      default=1,
                      help='Start file for running testing')
    parser.add_option("--numFiles", dest="numFiles", type="int", default=384)
    parser.add_option("--inputDim",
                      dest="inputDim",
                      type="int",
                      default=41 * 15)
    parser.add_option("--rawDim", dest="rawDim", type="int", default=41 * 15)
    parser.add_option("--outputDim", dest="outputDim", type="int", default=35)
    parser.add_option("--maxUttLen",
                      dest="maxUttLen",
                      type="int",
                      default=MAX_UTT_LEN)

    # Save/Load
    parser.add_option(
        '--save_every',
        dest='save_every',
        type='int',
        default=10,
        help='During training, save parameters every x number of files')

    parser.add_option('--run_desc',
                      dest='run_desc',
                      type='string',
                      default='',
                      help='Description of experiment run')

    (opts, args) = parser.parse_args(args)

    if opts.cfg_file:
        cfg = load_config(opts.cfg_file)
    else:
        cfg = vars(opts)

    # These config values should be updated every time
    cfg['host'] = get_hostname()
    cfg['git_rev'] = get_git_revision()
    cfg['pid'] = os.getpid()

    # Create experiment output directory

    if not opts.cfg_file:
        time_string = str(TimeString())
        output_dir = pjoin(RUN_DIR, time_string)
        cfg['output_dir'] = output_dir
        if not os.path.exists(output_dir):
            print 'Creating %s' % output_dir
            os.makedirs(output_dir)
        opts.cfg_file = pjoin(output_dir, 'cfg.json')
    else:
        output_dir = cfg['output_dir']

    cfg['output_dir'] = output_dir
    cfg['in_file'] = pjoin(output_dir, 'params.pk')
    cfg['out_file'] = pjoin(output_dir, 'params.pk')
    cfg['test'] = opts.test
    if opts.test:
        cfg['dataDir'] = opts.dataDir
        cfg['numFiles'] = opts.numFiles
        cfg['startFile'] = opts.startFile
    if 'reg' not in cfg:
        cfg['reg'] = 0.0

    # Logging

    logging.basicConfig(filename=pjoin(output_dir, 'train.log'),
                        level=logging.DEBUG)
    logger = logging.getLogger()
    logger.addHandler(logging.StreamHandler())
    logger.info('Running on %s' % cfg['host'])

    # seed for debugging, turn off when stable
    np.random.seed(33)
    import random
    random.seed(33)

    if 'CUDA_DEVICE' in os.environ:
        cm.cuda_set_device(int(os.environ['CUDA_DEVICE']))
    else:
        cm.cuda_set_device(0)  # Default

    opts = CfgStruct(**cfg)

    # Testing
    if opts.test:
        test(opts)
        return

    alisDir = opts.alisDir if opts.alisDir else opts.dataDir
    loader = dl.DataLoader(opts.dataDir, opts.rawDim, opts.inputDim, alisDir)

    nn = rnnet.NNet(opts.inputDim,
                    opts.outputDim,
                    opts.layerSize,
                    opts.numLayers,
                    opts.maxUttLen,
                    temporalLayer=opts.temporalLayer,
                    reg=opts.reg)
    nn.initParams()

    SGD = sgd.SGD(nn, opts.maxUttLen, alpha=opts.step, momentum=opts.momentum)

    # Dump config
    cfg['param_count'] = nn.paramCount()
    dump_config(cfg, opts.cfg_file)

    # Training
    epoch_file = pjoin(output_dir, 'epoch')
    if os.path.exists(epoch_file):
        start_epoch = int(open(epoch_file, 'r').read()) + 1
    else:
        start_epoch = 0

    # Load model if specified
    if os.path.exists(opts.in_file):
        with open(opts.in_file, 'r') as fid:
            SGD.fromFile(fid)
            SGD.alpha = SGD.alpha / (opts.anneal**start_epoch)
            nn.fromFile(fid)

    num_files_file = pjoin(output_dir, 'num_files')

    for k in range(start_epoch, opts.epochs):
        perm = np.random.permutation(opts.numFiles) + 1
        loader.loadDataFileAsynch(perm[0])

        file_start = 0
        if k == start_epoch:
            if os.path.exists(num_files_file):
                file_start = int(open(num_files_file, 'r').read().strip())
                logger.info('Starting from file %d, epoch %d' %
                            (file_start, start_epoch))
        else:
            open(num_files_file, 'w').write(str(file_start))

        for i in xrange(file_start, perm.shape[0]):
            start = time.time()
            data_dict, alis, keys, sizes = loader.getDataAsynch()
            # Prefetch
            if i + 1 < perm.shape[0]:
                loader.loadDataFileAsynch(perm[i + 1])
            SGD.run(data_dict, alis, keys, sizes)
            end = time.time()
            logger.info('File time %f' % (end - start))

            # Save parameters and cost
            if (i + 1) % opts.save_every == 0:
                logger.info('Saving parameters')
                with open(opts.out_file, 'wb') as fid:
                    SGD.toFile(fid)
                    nn.toFile(fid)
                    open(num_files_file, 'w').write('%d' % (i + 1))
                logger.info('Done saving parameters')
                with open(pjoin(output_dir, 'last_cost'), 'w') as fid:
                    if opts.reg > 0.0:
                        fid.write(str(SGD.expcost[-1] - SGD.regcost[-1]))
                    else:
                        fid.write(str(SGD.expcost[-1]))

        # Save epoch completed
        open(pjoin(output_dir, 'epoch'), 'w').write(str(k))

        # Save parameters for the epoch
        with open(opts.out_file + '.epoch{0:02}'.format(k), 'wb') as fid:
            SGD.toFile(fid)
            nn.toFile(fid)

        SGD.alpha = SGD.alpha / opts.anneal

    # Run now complete, touch sentinel file
    touch_file(pjoin(output_dir, 'sentinel'))
def LockGPU(max_retries=10, board=-1):

  # Assuming you already got GPU lock
  cm.cuda_set_device(board)
  cm.cublas_init()
  return board
Exemple #16
0
    def calc_output_legacy(self, data, batch_size):
        """ Calculate the output (probababilies) for a set of data

        The purpose of this function is to calculate the output of a DN on 
        some set of data.  The values will calculated using rbm_cudamat
        on slices of data specified by the batch size

        """

        import cudamat as cm
        import rbm_numpy, rbm_cudamat

        # Initialize CUDA
        cm.cublas_init()
        cm.CUDAMatrix.init_random(1)

        if self.legacy_card_number != 0:
            cm.cuda_set_device(self.legacy_card_number)

        # Create output, use the size of the last layer to do this
        output = np.empty(
            (data.shape[0], self.arch[(self.layer_count - 1)]['node_count']))

        # Slice up data, handling batches of batch_size. USE INT DIVISION
        processed = 0
        for j in range(data.shape[0] // batch_size):

            curr_data = data[j * batch_size:(j + 1) * batch_size, :]

            for i in range(1, self.layer_count):

                # Handle a sigmoid node
                if self.arch[i]['node_type'] == 'S':
                    curr_data = \
                      rbm_cudamat.calc_hidden_probs(curr_data,
                                                    self.weights[i]['w'],
                                                    self.weights[i]['hb'],
                                                    batch_size)

            output[j * batch_size:(j + 1) * batch_size, :] = curr_data[:, :]
            processed = processed + batch_size

        # Now handle anything that was left over i.e., what didn't fit in
        if processed != data.shape[0]:

            curr_data = data[processed:, :]

            for i in range(1, self.layer_count):

                # Handle a sigmoid node
                if self.arch[i]['node_type'] == 'S':
                    curr_data = \
                      rbm_numpy.calc_hidden_probs(curr_data,
                                                  self.weights[i]['w'],
                                                  self.weights[i]['hb'])

            output[processed:, :] = curr_data[:, :]

        cm.cublas_shutdown()

        return output
Exemple #17
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    numpy.random.seed(426236)
    #params         = detect_memory(params)
    parallel_hdf5 = get_parallel_hdf5_flag(params)
    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.extracting')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_t = params.getint('detecton', 'N_t')
    N_total = params.nb_channels
    template_shift = params.getint('detection', 'template_shift')
    chunk_size = params.getint('data', 'chunk_size')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('extracting', 'safety_time')
    max_elts_temp = params.getint('extracting', 'max_elts')
    output_dim = params.getfloat('extracting', 'output_dim')
    noise_thr = params.getfloat('extracting', 'noise_thr')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    blosc_compress = params.getboolean('data', 'blosc_compress')
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    amp_limits = map(float, tmp_limits)
    elt_count = 0
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    #################################################################

    if comm.rank == 0:
        print_and_log(["Extracting templates from already found clusters..."],
                      'default', logger)

    thresholds = io.load_data(params, 'thresholds')
    basis_proj, basis_rec = io.load_data(params, 'basis')
    clusters, spiketimes, N_clusters = io.load_data(params, 'spike-cluster')
    inv_clusters = numpy.zeros(clusters.max() + 1, dtype=numpy.int32)
    inv_clusters[numpy.unique(clusters)] = numpy.argsort(
        numpy.unique(clusters))

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    result = {}
    for i in xrange(N_clusters):
        result['data_tmp_' + str(i)] = numpy.zeros(
            (0, N_e * basis_proj.shape[1]), dtype=numpy.float32)
        result['times_' + str(i)] = numpy.zeros(0, dtype=numpy.int32)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(numpy.arange(nb_chunks))

    nb_templates = numpy.sum(
        comm.rank == numpy.mod(numpy.arange(N_clusters), comm.size))
    nb_elts = max_elts_temp * nb_templates

    to_explore = all_chunks

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gidx in all_chunks:

        if (elt_count < nb_elts):
            #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            #print "Extracting the peaks..."
            idx = numpy.where((spiketimes >= gidx * chunk_size)
                              & (spiketimes < (gidx + 1) * chunk_size))[0]
            local_offset = t_offset
            local_peaktimes = spiketimes[idx] - local_offset

            #print "Removing the useless borders..."
            local_borders = (template_shift, chunk_size - template_shift)
            idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                           local_borders[1])
            local_peaktimes = local_peaktimes[idx]
            local_clusters = inv_clusters[clusters[idx]]

            if len(local_peaktimes) > 0:
                all_times = numpy.zeros(
                    (N_e, local_peaktimes[-1] - local_peaktimes[0] + 1),
                    dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    local_peaktimes[-1] - local_peaktimes[0])

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                clusters_id = local_clusters[argmax_peak]
                local_peaktimes = local_peaktimes[argmax_peak]

                #print "Selection of the peaks with spatio-temporal masks..."
                for idx in xrange(len(local_peaktimes)):

                    if elt_count == nb_elts:
                        break

                    temp = clusters_id[idx]

                    if numpy.mod(temp, comm.size) == comm.rank:

                        elec = numpy.argmin(local_chunk[local_peaktimes[idx]])
                        indices = inv_nodes[edges[nodes[elec]]]
                        myslice = all_times[indices,
                                            min_times[idx]:max_times[idx]]
                        peak = local_peaktimes[idx]
                        if not myslice.any():
                            if (len(result['data_tmp_' + str(temp)]) <
                                    max_elts_temp):
                                elt_count += 1
                                sub_mat = local_chunk[peak -
                                                      template_shift:peak +
                                                      template_shift + 1, :]
                                sub_mat = numpy.dot(basis_rec, sub_mat)
                                nx, ny = sub_mat.shape
                                sub_mat = sub_mat.reshape((1, nx * ny))
                                result['data_tmp_' + str(temp)] = numpy.vstack(
                                    (result['data_tmp_' + str(temp)], sub_mat))
                                to_add = numpy.array([peak + local_offset],
                                                     dtype=numpy.int32)
                                result['times_' +
                                       str(temp)] = numpy.concatenate(
                                           (result['times_' + str(temp)],
                                            to_add))
                            all_times[indices,
                                      min_times[idx]:max_times[idx]] = True

    total_nb_elts = 0
    for temp in xrange(N_clusters):
        total_nb_elts += len(result['data_tmp_' + str(temp)])

    gdata = gather_array(numpy.array([total_nb_elts], dtype=numpy.float32),
                         comm, 0)
    if comm.rank == 0:
        print_and_log([
            "Found %d spikes over %d requested" %
            (int(numpy.sum(gdata)), int(nb_elts))
        ], 'default', logger)

    #print "Spikes extracted in", time.time() - t_start, "s"

    comm.Barrier()

    local_nb_clusters = 0
    for temp in xrange(comm.rank, N_clusters, comm.size):
        if len(result['data_tmp_' + str(temp)]) > 0:
            local_nb_clusters += 1

    #print total_nb_clusters, "found in", time.time() - t_start, "s"
    gdata3 = gather_array(
        numpy.array([local_nb_clusters], dtype=numpy.float32), comm, 0)

    comm.Barrier()
    if comm.rank == 0:
        print_and_log(["Extracting the templates..."], 'default', logger)

    total_nb_clusters = int(
        comm.bcast(numpy.array([int(numpy.sum(gdata3))], dtype=numpy.int32),
                   root=0)[0])
    offsets = numpy.zeros(comm.size, dtype=numpy.int32)
    for i in xrange(comm.size - 1):
        offsets[i + 1] = comm.bcast(numpy.array([local_nb_clusters],
                                                dtype=numpy.int32),
                                    root=i)

    if parallel_hdf5:
        node_pad = numpy.sum(offsets[:comm.rank + 1])
        hfile = h5py.File(file_out_suff + '.templates.hdf5',
                          'w',
                          driver='mpio',
                          comm=comm,
                          libver='earliest')
        norms = hfile.create_dataset('norms',
                                     shape=(2 * total_nb_clusters, ),
                                     dtype=numpy.float32,
                                     chunks=True)
        electrodes = hfile.create_dataset('electrodes',
                                          shape=(total_nb_clusters, ),
                                          dtype=numpy.int32,
                                          chunks=True)
        amps_lims = hfile.create_dataset('limits',
                                         shape=(total_nb_clusters, 2),
                                         dtype=numpy.float32,
                                         chunks=True)
        g_count = node_pad
        g_offset = total_nb_clusters
    else:
        node_pad = 0
        hfile = h5py.File(file_out_suff + '.templates-%d.hdf5' % comm.rank,
                          'w',
                          libver='earliest')
        electrodes = hfile.create_dataset('electrodes',
                                          shape=(local_nb_clusters, ),
                                          dtype=numpy.int32,
                                          chunks=True)
        norms = hfile.create_dataset('norms',
                                     shape=(2 * local_nb_clusters, ),
                                     dtype=numpy.float32,
                                     chunks=True)
        amps_lims = hfile.create_dataset('limits',
                                         shape=(local_nb_clusters, 2),
                                         dtype=numpy.float32,
                                         chunks=True)
        g_count = 0
        g_offset = local_nb_clusters

    cfile = h5py.File(file_out_suff + '.clusters-%d.hdf5' % comm.rank,
                      'w',
                      libver='earliest')
    count_templates = node_pad

    temp_x = numpy.zeros(0, dtype=numpy.int32)
    temp_y = numpy.zeros(0, dtype=numpy.int32)
    temp_data = numpy.zeros(0, dtype=numpy.float32)

    to_explore = xrange(comm.rank, N_clusters, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for temp in to_explore:
        n_data = len(result['data_tmp_' + str(temp)])
        if n_data > 0:
            data = result['data_tmp_' + str(temp)].reshape(
                n_data, basis_proj.shape[1], N_e)
            first_component = numpy.median(data, axis=0)
            tmp_templates = numpy.dot(first_component.T, basis_rec)
            electrodes[g_count] = indices[tmpidx[0][0]]
            indices = inv_nodes[edges[nodes[electrodes[-1]]]]
            templates = numpy.zeros((N_e, N_t), dtype=numpy.float32)
            if shift > 0:
                templates[indices, shift:] = tmp_templates[:, :-shift]
            elif shift < 0:
                templates[indices, :shift] = tmp_templates[:, -shift:]
            else:
                templates[indices, :] = tmp_templates

            templates = templates.flatten()
            dx = templates.nonzero()[0].astype(numpy.int32)

            temp_x = numpy.concatenate((temp_x, dx))
            temp_y = numpy.concatenate(
                (temp_y,
                 count_templates * numpy.ones(len(dx), dtype=numpy.int32)))
            temp_data = numpy.concatenate((temp_data, templates[dx]))

            norms[g_count] = numpy.sqrt(
                numpy.sum(templates.flatten()**2) / (N_e * N_t))

            x, y, z = data.shape
            data_flat = data.reshape(x, y * z)
            first_flat = first_component.reshape(y * z, 1)
            amplitudes = numpy.dot(data_flat, first_flat)
            amplitudes /= numpy.sum(first_flat**2)
            for i in xrange(x):
                data_flat[i, :] -= amplitudes[i] * first_flat[:, 0]

            variations = 10 * numpy.median(
                numpy.abs(amplitudes - numpy.median(amplitudes)))
            physical_limit = noise_thr * (
                -thresholds[indices[tmpidx[0][0]]]) / tmp_templates.min()
            amp_min = max(physical_limit,
                          numpy.median(amplitudes) - variations)
            amp_max = min(amp_limits[1], numpy.median(amplitudes) + variations)
            amps_lims[g_count] = [amp_min, amp_max]

            if len(data_flat) > 1:
                pca = PCA(1)
                res_pca = pca.fit_transform(data_flat.astype(numpy.double))
                second_component = pca.components_.T.astype(
                    numpy.float32).reshape(y, z)
            else:
                second_component = data_flat.reshape(y, z) / numpy.sum(
                    data_flat**2)

            tmp_templates = numpy.dot(second_component.T, basis_rec)
            offset = total_nb_clusters + count_templates
            sub_templates = numpy.zeros((N_e, N_t), dtype=numpy.float32)
            if shift > 0:
                sub_templates[indices, shift:] = tmp_templates[:, :-shift]
            elif shift < 0:
                sub_templates[indices, :shift] = tmp_templates[:, -shift:]
            else:
                sub_templates[indices, :] = tmp_templates

            sub_templates = sub_templates.flatten()
            dx = sub_templates.nonzero()[0].astype(numpy.int32)

            temp_x = numpy.concatenate((temp_x, dx))
            temp_y = numpy.concatenate(
                (temp_y, offset * numpy.ones(len(dx), dtype=numpy.int32)))
            temp_data = numpy.concatenate((temp_data, sub_templates[dx]))

            norms[g_count + g_offset] = numpy.sqrt(
                numpy.sum(sub_templates.flatten()**2) / (N_e * N_t))

            count_templates += 1
            g_count += 1

        io.write_datasets(cfile,
                          to_write,
                          result,
                          ielec,
                          compress=hdf5_compress)

    #At the end we should have a templates variable to store.
    cfile.close()
    del result, templates, amps_lims
    comm.Barrier()

    #We need to gather the sparse arrays
    temp_x = gather_array(temp_x, comm, dtype='int32', compress=blosc_compress)
    temp_y = gather_array(temp_y, comm, dtype='int32', compress=blosc_compress)
    temp_data = gather_array(temp_data, comm, compress=blosc_compress)

    if parallel_hdf5:
        if comm.rank == 0:
            rs = [
                h5py.File(file_out_suff + '.clusters-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            cfile = h5py.File(file_out_suff + '.clusters.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(cfile, ['electrodes'],
                              {'electrodes': electrodes[:]},
                              compress=hdf5_compress)
            for i in xrange(comm.size):
                for j in range(i, N_e, comm.size):
                    io.write_datasets(cfile,
                                      to_write,
                                      rs[i],
                                      j,
                                      compress=hdf5_compress)
                rs[i].close()
                os.remove(file_out_suff + '.clusters-%d.hdf5' % i)
            cfile.close()
        hfile.close()
    else:
        hfile.close()
        if comm.rank == 0:
            ts = [
                h5py.File(file_out_suff + '.templates-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            rs = [
                h5py.File(file_out_suff + '.clusters-%d.hdf5' % i,
                          'r',
                          libver='earliest') for i in xrange(comm.size)
            ]
            result = {}
            hfile = h5py.File(file_out_suff + '.templates.hdf5',
                              'w',
                              libver='earliest')
            cfile = h5py.File(file_out_suff + '.clusters.hdf5',
                              'w',
                              libver='earliest')
            electrodes = hfile.create_dataset('electrodes',
                                              shape=(total_nb_clusters, ),
                                              dtype=numpy.int32,
                                              chunks=True)
            norms = hfile.create_dataset('norms',
                                         shape=(2 * total_nb_clusters, ),
                                         dtype=numpy.float32,
                                         chunks=True)
            amplitudes = hfile.create_dataset('limits',
                                              shape=(total_nb_clusters, 2),
                                              dtype=numpy.float32,
                                              chunks=True)
            count = 0
            for i in xrange(comm.size):
                loc_temp = ts[i].get('templates')
                middle = loc_temp.shape[2] // 2
                norms[count:count + middle] = loc_norms[:middle]
                norms[n_clusters + count:n_clusters + count +
                      middle] = loc_norms[middle:]
                electrodes[count:count + middle] = ts[i].get('electrodes')
                amplitudes[count:count + middle] = ts[i].get('limits')
                count += middle
                for j in range(i, N_e, comm.size):
                    io.write_datasets(cfile,
                                      to_write,
                                      rs[i],
                                      j,
                                      compress=hdf5_compress)
                ts[i].close()
                rs[i].close()
                os.remove(file_out_suff + '.templates-%d.hdf5' % i)
                os.remove(file_out_suff + '.clusters-%d.hdf5' % i)
            io.write_datasets(cfile, ['electrodes'],
                              {'electrodes': electrodes[:]},
                              compress=hdf5_compress)
            hfile.close()
            cfile.close()

    if comm.rank == 0:
        hfile = h5py.File(file_out_suff + '.templates.hdf5',
                          'r+',
                          libver='earliest')
        hfile.create_dataset('temp_x', data=temp_x)
        hfile.create_dataset('temp_y', data=temp_y)
        hfile.create_dataset('temp_data', data=temp_data)
        hfile.create_dataset('temp_shape',
                             data=numpy.array(
                                 [N_e, N_t, 2 * total_nb_clusters],
                                 dtype=numpy.int32))
        hfile.close()

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Merging similar templates..."], 'default', logger)

    merged1 = algo.merging_cc(params, parallel_hdf5)

    comm.Barrier()
    if remove_mixture:
        if comm.rank == 0:
            print_and_log(["Removing mixtures..."], 'default', logger)
        merged2 = algo.delete_mixtures(params, parallel_hdf5)
    else:
        merged2 = [0, 0]

    if comm.rank == 0:
        print_and_log([
            "Number of global merges    : %d" % merged1[1],
            "Number of mixtures removed : %d" % merged2[1]
        ], 'info', logger)

    comm.Barrier()
    io.get_overlaps(params, erase=True, parallel_hdf5=parallel_hdf5)

    data_file.close()
Exemple #18
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    logger         = init_logging(params.logfile)
    logger         = logging.getLogger('circus.fitting')
    data_file      = params.data_file
    data_file.open()
    N_e            = params.getint('data', 'N_e')
    N_total        = params.nb_channels
    N_t            = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    file_out       = params.get('data', 'file_out')
    file_out_suff  = params.get('data', 'file_out_suff')
    sign_peaks     = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh   = params.getfloat('detection', 'spike_thresh')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening  = params.getboolean('whitening', 'spatial')
    chunk_size     = params.getint('fitting', 'chunk_size')
    gpu_only       = params.getboolean('fitting', 'gpu_only')
    nodes, edges   = get_nodes_and_edges(params)
    tmp_limits     = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',')
    tmp_limits     = map(float, tmp_limits)
    amp_auto       = params.getboolean('fitting', 'amp_auto')
    space_explo    = params.getfloat('fitting', 'space_explo')
    nb_chances     = params.getint('fitting', 'nb_chances')
    max_chunk      = params.getfloat('fitting', 'max_chunk')
    noise_thr      = params.getfloat('clustering', 'noise_thr')
    collect_all    = params.getboolean('fitting', 'collect_all')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes         = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes]  = numpy.argsort(nodes)
    #################################################################

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank//nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates  = io.load_data_memshared(params, 'templates', normalize=True, transpose=True)
        N_tm, x    = templates.shape
    else:
        templates  = io.load_data(params, 'templates')
        x, N_tm    = templates.shape

    temp_2_shift   = 2*template_shift
    full_gpu       = use_gpu and gpu_only
    n_tm           = N_tm//2
    n_scalar       = N_e*N_t
    last_spikes    = numpy.zeros((n_tm, 1), dtype=numpy.int32)
    temp_window    = numpy.arange(-template_shift, template_shift+1)

    if not amp_auto:
        amp_limits       = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits       = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')
    
    if not SHARED_MEMORY:
        for idx in xrange(templates.shape[1]):
            myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1])
            templates.data[myslice] /= norm_templates[idx]
        templates = templates.T

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg  = io.load_data(params, 'waveform')
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg))* len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos  = io.load_data(params, 'waveform-pos')
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos))* len(waveform_pos))
            matched_tresholds_pos = io.load_data(params, 'matched-thresholds-pos')

    if ignore_dead_times:
        dead_times = numpy.loadtxt(params.get('triggers', 'dead_file'))
        if len(dead_times.shape) == 1:
            dead_times = dead_times.reshape(1, 2)
        dead_in_ms = params.getboolean('triggers', 'dead_in_ms')
        if dead_in_ms:
            dead_times *= numpy.int64(data_file.sampling_rate*1e-3)
        dead_times = dead_times.astype(numpy.int64)
        all_dead_times = []
        for i in xrange(len(dead_times)):
            all_dead_times += range(dead_times[i, 0], dead_times[i, 1])

    thresholds = io.load_data(params, 'thresholds')


    if collect_all:
        neighbors = {}
        for i in xrange(n_tm):
            tmp  = templates[i, :].toarray().reshape(N_e, N_t) * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, 1) != 0)[0]

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False)

    info_string   = ''

    
    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" %(comm.size)
        else:
            info_string = "using %d CPUs" %(comm.size)

    comm.Barrier()

    c_overlap  = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    N_over     = int(numpy.sqrt(over_shape[0]))
    S_over     = over_shape[1]
    ## If the number of overlaps is different from templates, we need to recompute them
    if N_over != N_tm:
        if comm.rank == 0:
            print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger)
        c_overlap  = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        N_over     = int(numpy.sqrt(over_shape[0]))
        S_over     = over_shape[1]

    if SHARED_MEMORY:
        c_overs    = io.load_data_memshared(params, 'overlaps', nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    else:
        c_overlap  = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_x     = c_overlap.get('over_x')[:]
        over_y     = c_overlap.get('over_y')[:]
        over_data  = c_overlap.get('over_data')[:]
        over_shape = c_overlap.get('over_shape')[:]
        c_overlap.close()

        # To be faster, we rearrange the overlaps into a dictionnary. This has a cost: twice the memory usage for 
        # a short period of time
        c_overs   = {}
        overlaps  = scipy.sparse.csr_matrix((over_data, (over_x, over_y)), shape=(over_shape[0], over_shape[1]))
        del over_x, over_y, over_data
        
        for i in xrange(N_over):
            c_overs[i] = overlaps[i*N_over:(i+1)*N_over]
        del overlaps

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." %(info_string, n_tm)], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening  = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in xrange(N_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False)
        except Exception:
            if comm.rank == 0:
                print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger)
            for i in xrange(N_over):
                if c_overs.has_key(i):
                    del c_overs[i]
            full_gpu = False

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks          = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' %comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' %comm.rank, 'wb')
    comm.Barrier()
    templates_file  = open(file_out_suff + '.templates-%d.data' %comm.rank, 'wb')
    comm.Barrier()

    if collect_all:
        garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' %comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file  = open(file_out_suff + '.gtemplates-%d.data' %comm.rank, 'wb')
        comm.Barrier()


    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False)

    last_chunk_size = 0

    to_explore = xrange(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):
        #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift]

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last  = data_file.is_last_chunk(gidx, nb_chunks)

        if is_last:
            padding = (-2*template_shift, 0)
        elif is_first:
            padding = (0, 2*template_shift)
        else:
            padding = (-2*template_shift, 2*template_shift)

        result       = {'spiketimes' : [], 'amplitudes' : [], 'templates' : []}

        local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes)           
        len_chunk             = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant')

        #print "Extracting the peaks..."

        if collect_all:
            all_found_spikes = {}
            for i in xrange(N_e):
                all_found_spikes[i] = []

        local_peaktimes = numpy.zeros(0, dtype=numpy.int32)

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_pos[i])
                    local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_neg[i])
                    local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
        else:
            for i in xrange(N_e):
                if sign_peaks == 'negative':
                    peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=True)
                elif sign_peaks == 'positive':
                    peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=False)
                elif sign_peaks == 'both':
                    peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]), thresholds[i], valley=False)                    
                local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes)) 
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()


            
        local_peaktimes = numpy.unique(local_peaktimes)

        if ignore_dead_times:
            local_peaktimes = numpy.array(list(set(local_peaktimes + t_offset).difference(all_dead_times)), dtype=numpy.int32) - t_offset
            local_peaktimes = numpy.sort(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders   = (template_shift, len_chunk - template_shift)
        idx             = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if collect_all:
            for i in xrange(N_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.int32)

                if ignore_dead_times:
                    all_found_spikes[i] = numpy.array(list(set(all_found_spikes[i] + t_offset).difference(all_dead_times)), dtype=numpy.int32) - t_offset
                    all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx                 = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        n_t             = len(local_peaktimes)
        all_indices     = numpy.arange(n_t)
                            

        if full_gpu:
        #   all_indices = cmt.CUDAMatrix(all_indices)
            tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)), copy_on_host=False)


        if n_t > 0:
            #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."     

            if collect_all:
                c_local_chunk = local_chunk.copy()

            local_chunk = local_chunk.T.ravel()
            sub_mat     = numpy.zeros((N_e*(2*template_shift+1), n_t), dtype=numpy.float32)

            if len_chunk != last_chunk_size:
                slice_indices = numpy.zeros(0, dtype=numpy.int32)
                for idx in xrange(N_e):
                    slice_indices = numpy.concatenate((slice_indices, len_chunk*idx + temp_window))
                last_chunk_size = len_chunk

            for count, idx in enumerate(local_peaktimes):
                sub_mat[:, count] = numpy.take(local_chunk, slice_indices + idx)

            del local_chunk

            if use_gpu: 
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b       = cmt.sparse_dot(templates, sub_mat)
            else:
                b       = templates.dot(sub_mat)                

            del sub_mat

            local_offset = padding[0] + t_offset
            local_bounds = (temp_2_shift, len_chunk - temp_2_shift)
            all_spikes   = local_peaktimes + local_offset

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            #b           = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()

            failure     = numpy.zeros(n_t, dtype=numpy.int32)

            if full_gpu:
                mask     = numpy.zeros((2*n_tm, n_t), dtype=numpy.float32)
                mask[:n_tm, :] = 1
                data     = cmt.empty(mask.shape)
                patch_gpu= b.shape[1] == 1
            else:
                mask     = numpy.ones((n_tm, n_t), dtype=numpy.float32)
                sub_b    = b[:n_tm, :]

            min_time     = local_peaktimes.min()
            max_time     = local_peaktimes.max()
            local_len    = max_time - min_time + 1
            min_times    = numpy.maximum(local_peaktimes - min_time - temp_2_shift, 0)
            max_times    = numpy.minimum(local_peaktimes - min_time + temp_2_shift + 1, max_time - min_time)
            max_n_t      = int(space_explo*(max_time-min_time+1)//(2*temp_2_shift + 1))

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, N_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in xrange(N_e):
                    c_all_times[all_found_spikes[i], i] = True
                    
            while (numpy.mean(failure) < nb_chances):

                if full_gpu:
                    gpu_mask    = cmt.CUDAMatrix(mask, copy_on_host=False)
                    b.mult(gpu_mask, data)
                    tmp_mat     = data.max(0)
                    argmax_bi   = numpy.argsort(tmp_mat.asarray()[0, :])[::-1]
                    del tmp_mat
                else:
                    data        = sub_b * mask
                    argmax_bi   = numpy.argsort(numpy.max(data, 0))[::-1]

                while (len(argmax_bi) > 0):

                    subset          = []
                    indices         = []
                    all_times       = numpy.zeros(local_len, dtype=numpy.bool)

                    for count, idx in enumerate(argmax_bi):
                        myslice = all_times[min_times[idx]:max_times[idx]]
                        if not myslice.any():
                            subset  += [idx]
                            indices += [count]
                            all_times[min_times[idx]:max_times[idx]] = True
                        if len(subset) > max_n_t:
                            break

                    subset    = numpy.array(subset, dtype=numpy.int32)
                    argmax_bi = numpy.delete(argmax_bi, indices)

                    if full_gpu:
                        b_array = b.asarray()
                        sub_b   = b_array[:n_tm, :]

                    inds_t, inds_temp = subset, numpy.argmax(numpy.take(sub_b, subset, axis=1), 0)

                    if full_gpu:
                        best_amp  = sub_b[inds_temp, inds_t]/n_scalar
                        best_amp2 = b_array[inds_temp + n_tm, inds_t]/n_scalar
                    else:
                        
                        best_amp  = sub_b[inds_temp, inds_t]/n_scalar
                        best_amp2 = b[inds_temp + n_tm, inds_t]/n_scalar

                    mask[inds_temp, inds_t] = 0

                    best_amp_n   = best_amp/numpy.take(norm_templates, inds_temp)
                    best_amp2_n  = best_amp2/numpy.take(norm_templates, inds_temp + n_tm)

                    all_idx      = ((best_amp_n >= amp_limits[inds_temp, 0]) & (best_amp_n <= amp_limits[inds_temp, 1]))
                    to_keep      = numpy.where(all_idx == True)[0]
                    to_reject    = numpy.where(all_idx == False)[0]
                    ts           = numpy.take(local_peaktimes, inds_t[to_keep])
                    good         = (ts >= local_bounds[0]) & (ts < local_bounds[1])

                    # We reduce to only the good times that will be kept
                    #to_keep      = to_keep[good]
                    #ts           = ts[good]
                    
                    if len(ts) > 0:
                        if full_gpu:
                            tmp  = cmt.CUDAMatrix(numpy.ones((len(ts), 1)), copy_on_host=False)
                            tmp3 = cmt.CUDAMatrix(-ts.reshape((len(ts), 1)), copy_on_host=False)
                            tmp  = tmp.dot(tmp_gpu)
                            tmp.add_col_vec(tmp3)
                            condition = cmt.empty(tmp.shape)
                            cmt.abs(tmp, condition).less_than(temp_2_shift + 1)
                            condition = condition.asarray().astype(numpy.bool)
                            tmp       = tmp.asarray().astype(numpy.int32)
                        else:
                            tmp      = numpy.dot(numpy.ones((len(ts), 1), dtype=numpy.int32), local_peaktimes.reshape((1, n_t)))
                            tmp     -= ts.reshape((len(ts), 1))
                            condition = numpy.abs(tmp) <= temp_2_shift

                        for count, keep in enumerate(to_keep):
                            
                            idx_b    = numpy.compress(condition[count, :], all_indices)
                            ytmp     = tmp[count, condition[count, :]] + temp_2_shift
                            
                            indices  = numpy.zeros((S_over, len(ytmp)), dtype=numpy.float32)
                            indices[ytmp, numpy.arange(len(ytmp))] = 1

                            if full_gpu: 
                                indices  = cmt.CUDAMatrix(indices, copy_on_host=False)
                                if patch_gpu:
                                    b_lines  = b.get_col_slice(0, b.shape[0])
                                else:
                                    b_lines  = b.get_col_slice(idx_b[0], idx_b[-1]+1)

                                tmp1 = cmt.sparse_dot(c_overs[inds_temp[keep]], indices, mult=-best_amp[keep])
                                tmp2 = cmt.sparse_dot(c_overs[inds_temp[keep] + n_tm], indices, mult=-best_amp2[keep])
                                b_lines.add(tmp1.add(tmp2))
                                del tmp1, tmp2
                            else:
                                tmp1   = c_overs[inds_temp[keep]].multiply(-best_amp[keep]).dot(indices)
                                tmp2   = c_overs[inds_temp[keep] + n_tm].multiply(-best_amp2[keep]).dot(indices)
                                b[:, idx_b] += tmp1 + tmp2

                            if good[count]:

                                t_spike               = ts[count] + local_offset
                                result['spiketimes'] += [t_spike]
                                result['amplitudes'] += [(best_amp_n[keep], best_amp2_n[keep])]
                                result['templates']  += [inds_temp[keep]]

                    myslice           = numpy.take(inds_t, to_reject)
                    failure[myslice] += 1
                    sub_idx           = (numpy.take(failure, myslice) >= nb_chances)
                    
                    mask[:, numpy.compress(sub_idx, myslice)] = 0


            spikes_to_write     = numpy.array(result['spiketimes'], dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32)
            templates_to_write  = numpy.array(result['templates'], dtype=numpy.int32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write, spikes_to_write - local_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False

                gspikes       = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times   = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times                

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_tresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx      = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_tresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx      = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(matched_tresholds_neg[bestlecs], matched_tresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx      = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                
                gspikes  = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write     = numpy.array(gspikes + local_offset, dtype=numpy.uint32)
                gtemplates_to_write  = numpy.array(bestlecs, dtype=numpy.int32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())
            

            if full_gpu:
                del gpu_mask, b, data

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()
        
        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()


    comm.Barrier()
    
    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

    data_file.close()
Exemple #19
0
def main():

    parser = ArgumentParser()
    parser.add_argument("query_file", help = "word2vec file in json format")
    parser.add_argument("bidword_file", help = "word2vec file in json format")
    args = parser.parse_args()

    query_file = args.query_file
    bidword_file = args.bidword_file

    if DEBUG_FLAG:
        print "loading bidword dict ..."
    start = time()
    bidword_list, bidword_matrix = load_normalized_matrix(bidword_file)
    end = time()
    if DEBUG_FLAG:
        print "loading bidword dict done", duration(start, end)
    
    if DEBUG_FLAG:
        print "loading query dict ..."
    start = time()
    query_list, query_matrix = load_normalized_matrix(query_file)
    end = time()
    if DEBUG_FLAG:
        print "loading query dict done", duration(start, end)

    hash_length = 12
    hash_number = 1

    seed_matrix = random((200, hash_length * hash_number)) - 0.5

    if DEBUG_FLAG:
        print "initing cublas ..."
    start = time()
    cuda_set_device(1)
    cublas_init(1000000)
    end = time()
    if DEBUG_FLAG:
        print "initing cublas done", duration(start, end)

    if DEBUG_FLAG:
        print "computing hash_matrix ..."
    start = time()
    cuda_seed_matrix = CUDAMatrix(seed_matrix)
    cuda_bidword_matrix = CUDAMatrix(bidword_matrix)
    bidword_hash_matrix = dot(cuda_bidword_matrix, cuda_seed_matrix).asarray()
    del cuda_bidword_matrix
    cuda_query_matrix = CUDAMatrix(query_matrix)
    query_hash_matrix = dot(cuda_query_matrix, cuda_seed_matrix).asarray()
    del cuda_query_matrix
    end = time()
    if DEBUG_FLAG:
        print "computing hash_matrix done", duration(start, end)

    
    if DEBUG_FLAG:
        print "initing bidword_hash_dict_list ..."
    start = time()
    bidword_hash_dict_list = [dict([]) for i in xrange(hash_number)]
    end = time()
    if DEBUG_FLAG:
        print "initing bidword_hash_dict_list done", duration(start, end)
    
    if DEBUG_FLAG:
        print "aggregating bidword_hash_dict_list ..."
    start = time()
    for i in xrange(bidword_hash_matrix.shape[0]):
        hash_string = "".join(['1' if j > 0 else '0' for j in bidword_hash_matrix[i, :]])
        for j in xrange(hash_number):
            hash_index_start = j * hash_length
            hash_index_end = hash_index_start + hash_length
            hash_key = hash_string[hash_index_start:hash_index_end]
            if hash_key in bidword_hash_dict_list[j]:
                bidword_hash_dict_list[j][hash_key].add(i)
            else:
                bidword_hash_dict_list[j][hash_key] = set([i])
    end = time()
    if DEBUG_FLAG:
        print "aggregating bidword_hash_dict_list done", duration(start, end)

    if DEBUG_FLAG:
        print "aggregating query_hash_dict ..."
    start = time()
    query_hash_dict = {}
    for i in xrange(query_hash_matrix.shape[0]):
        hash_string = "".join(['1' if j > 0 else '0' for j in query_hash_matrix[i, :]])
        if hash_string in query_hash_dict:
            query_hash_dict[hash_string].add(i)
        else:
            query_hash_dict[hash_string] = set([i])
    end = time()
    if DEBUG_FLAG:
        print "aggregating querh_hash_dict done", duration(start, end)

    profiler_total = 0
    profiler_first = 0
    profiler_first_zero = 0
    profiler_first_one = 0
    profiler_first_two = 0
    profiler_first_three = 0
    profiler_first_four = 0
    profiler_second = 0
    profiler_third = 0
    timer = time()

    for hash_string in query_hash_dict:
        time_flag_total = time()
        time_flag_first = time()
        # random release memory
        
        if random_sample() > 0.95:
            collect()
        
        # aggregating query_index_set and bidword_index_set
        query_index_set = query_hash_dict[hash_string]
        bidword_index_set = set()
        for i in xrange(hash_number):
            time_flag_first_zero = time()
            hash_index_start = i * hash_length
            hash_index_end = hash_index_start + hash_length
            hash_key = hash_string[hash_index_start:hash_index_end]
            profiler_first_zero += time() - time_flag_first_zero
            # circum hash with hamming distance 0
            time_flag_first_one = time()
            bidword_index_set |= bidword_hash_dict_list[i][hash_key]
            profiler_first_one += time() - time_flag_first_one
            # circum hash with hamming distance 1
            time_flag_first_two = time()
            for first_index in xrange(hash_length):
                circum_hash_key = list(hash_key)
                circum_hash_key[first_index] = '1' if hash_key[first_index] == '0' else '0'
                circum_hash_key = "".join(circum_hash_key)
                if circum_hash_key in bidword_hash_dict_list[i]:
                    bidword_index_set |= bidword_hash_dict_list[i][circum_hash_key]
            profiler_first_two += time() - time_flag_first_two
            # circum hash with hamming distance 2
            time_flag_first_three = time()
            for first_index, second_index in combinations(range(hash_length), 2):
                circum_hash_key = list(hash_key)
                circum_hash_key[first_index] = '1' if hash_key[first_index] == '0' else '0'
                circum_hash_key[second_index] = '1' if hash_key[second_index] == '0' else '0'
                circum_hash_key = "".join(circum_hash_key)
                if circum_hash_key in bidword_hash_dict_list[i]:
                    bidword_index_set |= bidword_hash_dict_list[i][circum_hash_key]
            profiler_first_three += time() - time_flag_first_three
            ## circum hash with hamming distance 3
            #time_flag_first_four = time()
            #for first_index, second_index, third_index in combinations(range(hash_length), 3):
            #    circum_hash_key = list(hash_key)
            #    circum_hash_key[first_index] = '1' if hash_key[first_index] == '0' else '0'
            #    circum_hash_key[second_index] = '1' if hash_key[second_index] == '0' else '0'
            #    circum_hash_key[third_index] = '1' if hash_key[third_index] == '0' else '0'
            #    circum_hash_key = "".join(circum_hash_key)
            #    if circum_hash_key in bidword_hash_dict_list[i]:
            #        bidword_index_set |= bidword_hash_dict_list[i][circum_hash_key]
            #profiler_first_four += time() - time_flag_first_four
        # computing sim between query_index_list and bidword_index_list
        profiler_first += time() - time_flag_first
        
        query_index_list = list(query_index_set)
        bidword_index_list = list(bidword_index_set)

        partition_length = 1e8
        if DEBUG_FLAG or True:
            print "### profile ### matrix shape:", query_matrix[query_index_list, :].shape, bidword_matrix[bidword_index_list, :].transpose().shape, len(query_index_list) * len(bidword_index_list)
        if len(bidword_index_list) > partition_length:
            raise Exception("bidword_index_list too long: %d" % len(query_index_list))
        
        step = int(partition_length / len(bidword_index_list))
        partition_begin = 0
        partition_end = 0
        while partition_end < len(query_index_list):
            partition_end = len(query_index_list) if partition_begin + step > len(query_index_list) else partition_begin + step
            if DEBUG_FLAG or True:
                print "### profile ### partition_begin:", partition_begin, "partition_end:", partition_end
            time_flag_second = time()
            sim_matrix = dot(
                CUDAMatrix(query_matrix[query_index_list[partition_begin:partition_end], :]),
                CUDAMatrix(bidword_matrix[bidword_index_list, :].transpose())
            ).asarray().tolist()
            profiler_second += time() - time_flag_second
            profiler_third += sort_matrix(sim_matrix, query_list, query_index_list[partition_begin:partition_end], bidword_list, bidword_index_list)
            partition_begin = partition_end
            
        profiler_total += time() - time_flag_total
        if DEBUG_FLAG or True:
            print "### profile ### total=%f first=%f(%f)[%f(%f)%f(%f)%f(%f)%f(%f)%f(%f)] second=%f(%f) third=%f(%f) %s(%f)" % (
                profiler_total,
                profiler_first, profiler_first / profiler_total,
                profiler_first_zero, profiler_first_zero / profiler_first,
                profiler_first_one, profiler_first_one / profiler_first,
                profiler_first_two, profiler_first_two / profiler_first,
                profiler_first_three, profiler_first_three / profiler_first,
                profiler_first_four, profiler_first_four / profiler_first,
                profiler_second, profiler_second / profiler_total,
                profiler_third, profiler_third / profiler_total,
                duration(timer, time()), time() - timer
            )
Exemple #20
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    # Part 1: Whitening
    numpy.random.seed(420)
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    logger = logging.getLogger('circus.whitening')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    dist_peaks = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    file_out_suff = params.get('data', 'file_out_suff')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    matched_filter = params.getboolean('detection', 'matched-filter')
    matched_thresh = params.getfloat('detection', 'matched_thresh')
    fudge = params.getfloat('whitening', 'fudge')
    sign_peaks = params.get('detection', 'peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    ignore_spikes = params.getboolean('whitening', 'ignore_spikes')
    chunk_size = detect_memory(params, whitening=True)
    plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('whitening', 'safety_time')
    safety_space = params.getboolean('whitening', 'safety_space')
    sort_waveforms = params.getboolean('whitening', 'sort_waveforms')
    nb_temp_white = min(max(20, comm.size), N_e)
    max_silence_1 = int(20 * params.rate // comm.size)
    max_silence_2 = 5000
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    jitter_range = params.getint('detection', 'jitter_range')
    template_shift_2 = template_shift + jitter_range
    use_hanning = params.getboolean('detection', 'hanning')
    rejection_threshold = params.getfloat('detection', 'rejection_threshold')
    noise_window = params.getint('detection', 'noise_time')
    data_file.open()
    #################################################################

    if use_hanning:
        hanning_filter = numpy.hanning(N_t)

    if comm.rank == 0:
        print_and_log(
            ["Analyzing data to get whitening matrices and thresholds..."],
            'default', logger)

    nodes_indices = {}
    for elec in numpy.arange(N_e):
        nodes_indices[elec] = inv_nodes[edges[nodes[elec]]]

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings.
    if nb_chunks > comm.size:
        all_chunks = numpy.random.permutation(
            numpy.arange(nb_chunks - 1, dtype=numpy.int32))
    else:
        all_chunks = numpy.random.permutation(
            numpy.arange(nb_chunks, dtype=numpy.int32))

    all_electrodes = numpy.random.permutation(N_e)

    numpy.random.seed(comm.rank)

    for gidx in [all_chunks[comm.rank]]:

        # print "Node", comm.rank, "is analyzing chunk", gidx,  "/", nb_chunks, " ..."
        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   nodes=nodes)
        local_shape = len(local_chunk)

        # print "Node", comm.rank, "computes the median absolute deviations in a random chunk"
        thresholds = numpy.zeros(N_e, dtype=numpy.float32)
        for i in range(N_e):
            u = numpy.median(local_chunk[:, i], 0)
            thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0)
        gdata = gather_array(thresholds, comm)
        if comm.rank == 0:
            gdata = gdata.reshape((comm.size, N_e))
            thresholds = numpy.mean(gdata, 0)
            bfile = h5py.File(file_out_suff + '.basis.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(bfile, ['thresholds'],
                              {'thresholds': thresholds},
                              compression=hdf5_compress)
            bfile.close()
        comm.Barrier()
        thresholds = io.load_data(params, 'thresholds')

        local_borders = (template_shift, local_shape - template_shift)
        found_peaktimes = []

        if ignore_spikes:
            # Extracting the peaks.
            local_peaktimes = [np.empty(0, dtype=numpy.uint32)]
            for i in range(N_e):
                peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:,
                                                                          i]),
                                                    height=thresholds[i],
                                                    width=spike_width,
                                                    wlen=N_t)[0]
                peaktimes = peaktimes.astype(numpy.uint32)

                # print "Removing the useless borders..."
                idx = (peaktimes >= local_borders[0]) & (peaktimes <
                                                         local_borders[1])
                peaktimes = numpy.compress(idx, peaktimes)

                found_peaktimes.append(peaktimes)
        else:
            for i in range(N_e):
                found_peaktimes.append(numpy.zeros(0, dtype=numpy.uint32))

        all_peaktimes = numpy.concatenate(found_peaktimes)
        local_peaktimes = numpy.unique(all_peaktimes)

        if len(local_peaktimes) > 0:

            diff_times = local_peaktimes[-1] - local_peaktimes[0]
            all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool)
            padded_peaks = (local_peaktimes - local_peaktimes[0]).astype(
                numpy.int32)
            min_times = numpy.maximum(padded_peaks - safety_time, 0)
            max_times = numpy.minimum(padded_peaks + safety_time + 1,
                                      diff_times + 1)

            test_extremas = numpy.zeros((N_e, diff_times + 1),
                                        dtype=numpy.bool)
            for i in range(N_e):
                test_extremas[i,
                              found_peaktimes[i] - local_peaktimes[0]] = True

            argmax_peak = numpy.random.permutation(
                numpy.arange(len(local_peaktimes)))
            all_idx = numpy.take(local_peaktimes, argmax_peak)

            # print "Selection of the peaks with spatio-temporal masks..."
            for idx, peak in zip(argmax_peak, all_idx):

                all_elecs = numpy.where(test_extremas[:, peak -
                                                      local_peaktimes[0]])[0]
                data = local_chunk[peak, all_elecs]
                elec = all_elecs[numpy.argmax(numpy.abs(data))]
                indices = nodes_indices[elec]
                if safety_space:
                    all_times[indices, min_times[idx]:max_times[idx]] = True
                else:
                    all_times[elec, min_times[idx]:max_times[idx]] = True
        else:
            all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool)

    if do_temporal_whitening:

        local_res_temp = []

        for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white,
                                                comm.size)]:
            res = numpy.zeros((0, N_t), dtype=numpy.float32)
            scount = 0
            indices = nodes_indices[elec]
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            bound = len(esubset) - N_t
            while (scount < bound) and (len(res) < max_silence_2):
                myslice = esubset[scount:scount + N_t]
                if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)):
                    scount += N_t
                    res = numpy.vstack((res, local_chunk[myslice, elec]))
                else:
                    scount += 1
            if len(res) > 5:
                local_res_temp += [numpy.cov(res.T)]

        nb_elecs = numpy.array([len(local_res_temp)], dtype=numpy.float32)
        local_res_temp = numpy.array(local_res_temp, dtype=numpy.float32)
        if len(local_res_temp) == 0:
            local_res_temp = numpy.zeros(0, dtype=numpy.float32)
        else:
            local_res_temp = numpy.sum(local_res_temp, 0)
        all_res_temp = gather_array(local_res_temp.ravel(), comm, 0, 1)
        all_elecs = gather_array(nb_elecs, comm, 0, 1)

    if do_spatial_whitening:

        local_res_spac = numpy.zeros((N_e, N_e), dtype=numpy.float32)
        local_silences = []

        for elec in numpy.arange(comm.rank, N_e, comm.size):
            indices = nodes_indices[elec]
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            local_data = local_chunk[esubset][:, indices]
            local_whitening = get_whitening_matrix(
                local_data, fudge=fudge).astype(numpy.float32)
            pos = numpy.where(elec == indices)[0]
            local_res_spac[elec, indices] = local_whitening[pos]
            local_silences += [len(esubset)]

        all_res_spac = gather_array(local_res_spac.ravel(), comm, 0, 1)
        all_silences = gather_array(
            numpy.array(local_silences, dtype=numpy.int32), comm, 0, 1,
            'uint32')

    if comm.rank == 0:

        to_write = {}

        if do_temporal_whitening:
            try:
                nb_silences = numpy.sum(all_elecs > 0)
                all_res_temp = all_res_temp.reshape((nb_silences, N_t**2))
            except Exception:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            all_res_temp = numpy.sum(all_res_temp, 0)
            all_res_temp = all_res_temp.reshape(
                (N_t, N_t)) / numpy.sum(all_elecs)
            temporal_whitening = get_whitening_matrix(
                all_res_temp.astype(numpy.double),
                fudge=1e-3)[template_shift].astype(numpy.float32)
            temporal_whitening /= temporal_whitening.sum()
            to_write['temporal'] = temporal_whitening
            have_nans = numpy.sum(numpy.isnan(temporal_whitening))

            if have_nans > 0:
                temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32)
                temporal_whitening[N_t // 2] = 1
                to_write['temporal'] = temporal_whitening
                print_and_log(
                    ["Disabling temporal whitening because of NaNs found"],
                    'info', logger)

        if do_spatial_whitening:
            all_res_spac = all_res_spac.reshape(comm.size, N_e, N_e)
            spatial_whitening = numpy.sum(all_res_spac, 0)
            to_write['spatial'] = spatial_whitening

            if ignore_spikes:
                print_and_log([
                    "Found %gs without spikes to compute the whitening matrix..."
                    % (numpy.mean(all_silences) / params.rate)
                ], 'default', logger)
            else:
                print_and_log([
                    "Found %gs to compute the whitening matrix..." %
                    (numpy.mean(all_silences) / params.rate)
                ], 'default', logger)

            have_nans = numpy.sum(numpy.isnan(spatial_whitening))

            if have_nans > 0:
                spatial_whitening = numpy.eye(spatial_whitening.shape[0],
                                              dtype=numpy.float32)
                to_write['spatial'] = spatial_whitening
                print_and_log(
                    ["Disabling spatial whitening because of NaNs found"],
                    'info', logger)

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          list(to_write.keys()),
                          to_write,
                          compression=hdf5_compress)
        bfile.close()

    comm.Barrier()

    if do_spatial_whitening or do_temporal_whitening:

        if comm.rank == 0:
            print_and_log(
                ["Because of whitening, need to recompute the thresholds..."],
                'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            thresholds = numpy.zeros(N_e, dtype=numpy.float32)
            for i in range(N_e):
                u = numpy.median(local_chunk[:, i], 0)
                thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u),
                                             0)
            gdata = gather_array(thresholds, comm)
            if comm.rank == 0:
                gdata = gdata.reshape((comm.size, N_e))
                thresholds = numpy.mean(gdata, 0)
                bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                  'r+',
                                  libver='earliest')
                bfile.pop('thresholds')
                io.write_datasets(bfile, ['thresholds'],
                                  {'thresholds': thresholds},
                                  compression=hdf5_compress)
                bfile.close()
            comm.Barrier()

    # if comm.rank == 0:
    #     if not os.path.exists(plot_path):
    #         os.makedirs(plot_path)
    #     N_elec = min(int(numpy.sqrt(data_file.N_e)), 5)
    #     plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True,
    #                   n_elec=N_elec, save=[plot_path, 'electrodes'])

    # Part 2: Basis
    numpy.random.seed(422)

    SHARED_MEMORY = get_shared_memory_flag(params)
    #################################################################
    file_out = params.get('data', 'file_out')
    alignment = params.getboolean('detection', 'alignment')
    over_factor = params.getint('detection', 'oversampling_factor')
    nb_jitter = params.getint('detection', 'nb_jitter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    nodes, edges = get_nodes_and_edges(params)
    _, positions = get_nodes_and_positions(params)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    use_barycenter = params.getboolean('detection', 'use_barycenter')
    if matched_filter:
        chunk_size = detect_memory(params, whitening=True)
    else:
        chunk_size = detect_memory(params)
    safety_time = params.getint('whitening', 'safety_time')
    max_elts_elec = params.getint('whitening', 'max_elts')
    output_dim = params.getfloat('whitening', 'output_dim')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    smoothing_factor = params.getfloat('detection', 'smoothing_factor')
    if sign_peaks == 'both':
        max_elts_elec *= 2
    nb_elts = int(
        params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec)

    weird_thresh = params.get('detection', 'weird_thresh')
    if weird_thresh != '':
        ignore_artefacts = True
        weird_thresh = io.load_data(params, 'weird-thresholds')
    else:
        ignore_artefacts = False

    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    if ignore_dead_times:
        if SHARED_MEMORY:
            all_dead_times, mpi_memory_3 = get_dead_times(params)
        else:
            all_dead_times = get_dead_times(params)
    data_file.open()
    #################################################################

    if comm.rank == 0:
        print_and_log(["Searching spikes to construct the PCA basis..."],
                      'default', logger)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    groups = {}
    for i in range(N_e):
        groups[i] = 0

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    max_elts_elec //= comm.size
    nb_elts //= comm.size

    elt_count_pos = 0
    elt_count_neg = 0

    if sign_peaks in ['positive', 'both']:
        times_pos = numpy.zeros(nb_elts, dtype=numpy.int32)
        electrodes_pos = numpy.zeros(nb_elts, dtype=numpy.int32)
        extremum_pos = numpy.zeros(nb_elts, dtype=numpy.float32)
        elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)
    if sign_peaks in ['negative', 'both']:
        times_neg = numpy.zeros(nb_elts, dtype=numpy.int32)
        electrodes_neg = numpy.zeros(nb_elts, dtype=numpy.int32)
        extremum_neg = numpy.zeros(nb_elts, dtype=numpy.float32)
        elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)

    thresholds = io.load_data(params, 'thresholds')
    mads = io.load_data(params, 'mads')
    stds = io.load_data(params, 'stds')

    if alignment:
        cdata = numpy.linspace(-jitter_range, +jitter_range, nb_jitter)
        xdata = numpy.arange(-template_shift_2, template_shift_2 + 1)
        xoff = len(cdata) / 2.0
        snippet_duration = template_shift_2
        m_size = 2 * template_shift_2 + 1
        align_factor = m_size
        local_factors = align_factor * ((smoothing_factor * mads)**2)
    else:
        snippet_duration = template_shift
        xdata = numpy.arange(-template_shift, template_shift + 1)

    if rejection_threshold > 0:
        reject_noise = True
        noise_levels = stds * (2 * noise_window + 1)
    else:
        reject_noise = False

    to_explore = all_chunks[comm.rank::comm.size]

    upper_bounds = max_elts_elec

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    for gcount, gidx in enumerate(to_explore):

        if (elt_count_pos + elt_count_neg) < nb_elts:
            # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            local_borders = (snippet_duration, local_shape - snippet_duration)

            if ignore_dead_times:
                dead_indices = numpy.searchsorted(
                    all_dead_times, [t_offset, t_offset + local_shape])

            # Extracting the peaks.
            all_peaktimes = [numpy.empty(0, dtype=numpy.uint32)]

            found_peaktimes = []
            found_peak_amplitudes = []
            for i in range(N_e):
                height = thresholds[i]
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i],
                                                        height=height,
                                                        distance=dist_peaks)[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i],
                                                        height=height,
                                                        distance=dist_peaks)[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(
                        local_chunk[:, i]),
                                                        height=height,
                                                        distance=dist_peaks)[0]
                else:
                    peaktimes = numpy.empty(0, dtype=numpy.uint32)

                if ignore_artefacts:
                    artetimes = scipy.signal.find_peaks(
                        numpy.abs(local_chunk[:,
                                              i]), height=weird_thresh[i])[0]
                    to_keep = numpy.logical_not(
                        numpy.in1d(peaktimes, artetimes))
                    peaktimes = peaktimes[to_keep]

                idx = (peaktimes >= local_borders[0]) & (peaktimes <
                                                         local_borders[1])
                peaktimes = peaktimes[idx]

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        is_included = numpy.in1d(
                            peaktimes + t_offset,
                            all_dead_times[dead_indices[0]:dead_indices[1]])
                        peaktimes = peaktimes[~is_included]

                peaktimes = peaktimes.astype(numpy.uint32)
                found_peaktimes.append(peaktimes)

                peak_amplitudes = local_chunk[peaktimes, i]
                found_peak_amplitudes.append(peak_amplitudes)

            all_peaktimes = numpy.concatenate(
                found_peaktimes)  # i.e. concatenate once for efficiency
            all_peak_amplitudes = numpy.concatenate(found_peak_amplitudes)
            local_peaktimes, local_indices = numpy.unique(all_peaktimes,
                                                          return_inverse=True)

            if len(local_peaktimes) > 0:

                diff_times = (local_peaktimes[-1] - local_peaktimes[0]) + 1
                all_times = numpy.zeros((N_e, diff_times), dtype=numpy.bool)

                padded_peaks = (local_peaktimes - local_peaktimes[0]).astype(
                    numpy.int32)
                min_times = numpy.maximum(padded_peaks - safety_time, 0)
                max_times = numpy.minimum(padded_peaks + safety_time + 1,
                                          diff_times + 1)
                test_extremas = numpy.zeros((N_e, diff_times + 1),
                                            dtype=numpy.bool)
                for i in range(N_e):
                    test_extremas[i, found_peaktimes[i] -
                                  local_peaktimes[0]] = True

                # Consider the peaks by decreasing extremum.
                if sort_waveforms:
                    order = numpy.argsort(-np.abs(all_peak_amplitudes))
                    all_idx = numpy.take(all_peaktimes, order)
                    argmax_peak = local_indices[order]
                else:
                    n_times = len(all_peaktimes)
                    shuffling = numpy.random.permutation(numpy.arange(n_times))
                    all_idx = numpy.take(all_peaktimes, shuffling)
                    argmax_peak = local_indices[shuffling]

                # print "Selection of the peaks with spatio-temporal masks..."
                for midx, peak in zip(argmax_peak, all_idx):
                    if (elt_count_neg + elt_count_pos) == nb_elts:
                        break

                    all_elecs = numpy.where(
                        test_extremas[:, peak - local_peaktimes[0]])[0]
                    data = local_chunk[peak, all_elecs]

                    #target_area = test_extremas[:, min_times[midx]:max_times[midx]].sum(1)
                    #all_elecs = numpy.where(target_area)[0]
                    #data = local_chunk[peak, all_elecs]

                    if sign_peaks == 'negative':
                        if N_e > 1:
                            if use_barycenter:
                                weighed_position = data[:, numpy.
                                                        newaxis] * positions[
                                                            all_elecs]
                                barycenter = weighed_position.sum(
                                    0) / data.sum()
                                elec = numpy.argmin(
                                    numpy.linalg.norm(barycenter -
                                                      positions[all_elecs],
                                                      axis=1))
                            else:
                                elec = numpy.argmin(data)
                        else:
                            elec = 0
                        negative_peak = True
                    elif sign_peaks == 'positive':
                        if N_e > 1:
                            if use_barycenter:
                                weighed_position = data[:, numpy.
                                                        newaxis] * positions[
                                                            all_elecs]
                                barycenter = weighed_position.sum(
                                    0) / data.sum()
                                elec = numpy.argmax(
                                    numpy.linalg.norm(barycenter -
                                                      positions[all_elecs],
                                                      axis=1))
                            else:
                                elec = numpy.argmax(data)
                        else:
                            elec = 0
                        negative_peak = False
                    elif sign_peaks == 'both':
                        if N_e == 1:
                            if data < 0:
                                negative_peak = True
                            elif data > 0:
                                negative_peak = False
                            elec = 0
                        else:
                            if numpy.abs(numpy.max(data)) > numpy.abs(
                                    numpy.min(data)):
                                elec = numpy.argmax(data)
                                negative_peak = False
                            else:
                                elec = numpy.argmin(data)
                                negative_peak = True

                    elec = all_elecs[elec]

                    if groups[elec] < upper_bounds:

                        indices = nodes_indices[elec]
                        myslice = all_times[indices,
                                            min_times[midx]:max_times[midx]]

                        if not myslice.any():

                            sub_mat = local_chunk[peak -
                                                  snippet_duration:peak +
                                                  snippet_duration + 1, elec]

                            if reject_noise:
                                slice_window = sub_mat[
                                    snippet_duration -
                                    noise_window:snippet_duration +
                                    noise_window + 1]
                                value = numpy.linalg.norm(
                                    slice_window) / noise_levels[elec]
                                is_noise = value < rejection_threshold
                            else:
                                is_noise = False

                            if not is_noise:

                                extrema = sub_mat[snippet_duration]

                                if alignment:
                                    smoothed = True
                                    try:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata,
                                            sub_mat,
                                            s=local_factors[elec],
                                            k=3)
                                    except Exception:
                                        smoothed = False
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata, sub_mat, k=3, s=0)

                                    if negative_peak:
                                        rmin = (numpy.argmin(f(cdata)) -
                                                xoff) / over_factor
                                    else:
                                        rmin = (numpy.argmax(f(cdata)) -
                                                xoff) / over_factor
                                    ddata = numpy.linspace(
                                        rmin - template_shift,
                                        rmin + template_shift, N_t)

                                    if smoothed:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata,
                                            sub_mat,
                                            s=local_factors[elec],
                                            k=3)
                                    else:
                                        f = scipy.interpolate.UnivariateSpline(
                                            xdata, sub_mat, s=0, k=3)

                                    sub_mat = f(ddata).astype(numpy.float32)

                                if negative_peak:
                                    times_neg[elt_count_neg] = peak + t_offset
                                    electrodes_neg[elt_count_neg] = elec
                                    extremum_neg[elt_count_neg] = extrema
                                    elts_neg[:, elt_count_neg] = sub_mat
                                    elt_count_neg += 1
                                else:
                                    times_pos[elt_count_pos] = peak + t_offset
                                    electrodes_pos[elt_count_pos] = elec
                                    extremum_pos[elt_count_pos] = extrema
                                    elts_pos[:, elt_count_pos] = sub_mat
                                    elt_count_pos += 1

                                groups[elec] += 1
                                all_times[
                                    indices,
                                    min_times[midx]:max_times[midx]] = True
                                test_extremas[elec, peak -
                                              local_peaktimes[0]] = False

    sys.stderr.flush()

    print_and_log([
        "Node %d has collected %d waveforms" %
        (comm.rank, elt_count_pos + elt_count_neg)
    ], 'debug', logger)

    if sign_peaks in ['negative', 'both']:
        times_neg = gather_array(times_neg[:elt_count_neg],
                                 comm,
                                 0,
                                 1,
                                 dtype='int32')
        electrodes_neg = gather_array(electrodes_neg[:elt_count_neg],
                                      comm,
                                      0,
                                      1,
                                      dtype='int32')
        extremum_neg = gather_array(extremum_neg[:elt_count_neg], comm, 0, 1)
        gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1)
    if sign_peaks in ['positive', 'both']:
        times_pos = gather_array(times_pos[:elt_count_pos],
                                 comm,
                                 0,
                                 1,
                                 dtype='int32')
        electrodes_pos = gather_array(electrodes_pos[:elt_count_pos],
                                      comm,
                                      0,
                                      1,
                                      dtype='int32')
        extremum_pos = gather_array(extremum_pos[:elt_count_pos], comm, 0, 1)
        gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1)

    nb_waveforms = 0

    if comm.rank == 0:
        # DO PCA on elts and store the basis obtained.

        if sign_peaks in ['negative', 'both']:
            nb_waveforms += gdata_neg.shape[0]
        if sign_peaks in ['positive', 'both']:
            nb_waveforms += gdata_pos.shape[0]

    nb_waveforms = all_gather_array(
        numpy.array([nb_waveforms], dtype=numpy.float32), comm, 0)[0]

    if comm.rank == 0:
        print_and_log([
            "Found %d waveforms over %d requested" %
            (nb_waveforms, int(nb_elts * comm.size))
        ], 'default', logger)

        if nb_waveforms == 0:
            print_and_log(
                ['No waveforms found! Are the data properly loaded??'],
                'error', logger)

    if nb_waveforms == 0:
        sys.exit(0)

    if comm.rank == 0:
        res = {}
        pca = None
        pca_pos = None
        pca_neg = None
        warning_n_t = False
        if sign_peaks in ['negative', 'both']:
            res['times'] = times_neg
            res['electrodes'] = electrodes_neg
            res['extremum'] = extremum_neg
            if len(gdata_neg) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_neg * hanning_filter)
                else:
                    pca.fit(gdata_neg)
                res['proj'] = pca.components_.T.astype(numpy.float32)
                pca_neg = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj'] = numpy.identity(int(output_dim),
                                             dtype=numpy.float32)
            res['rec'] = res['proj'].T
            res['waveform'] = numpy.median(gdata_neg, 0)
            # dispersion = numpy.std(gdata_neg, 0) / numpy.median(stds)
            # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion))
            # if ratio < 0.25:
            #     print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger)
            #     warning_n_t = True
            # elif ratio == 1:
            #     print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger)
            #     warning_n_t = True
            idx = numpy.random.permutation(numpy.arange(
                gdata_neg.shape[0]))[:2500]
            res['waveforms'] = gdata_neg[idx, :]
        if sign_peaks in ['positive', 'both']:
            res['times_pos'] = times_pos
            res['electrodes_pos'] = electrodes_pos
            res['extremum_pos'] = extremum_pos
            if len(gdata_pos) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_pos * hanning_filter)
                else:
                    pca.fit(gdata_pos)
                res['proj_pos'] = pca.components_.T.astype(numpy.float32)
                pca_pos = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj_pos'] = numpy.identity(int(output_dim),
                                                 dtype=numpy.float32)
            res['rec_pos'] = res['proj_pos'].T
            res['waveform_pos'] = numpy.median(gdata_pos, 0)
            # dispersion = numpy.std(gdata_pos, 0) / numpy.median(stds)
            # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion))
            # if ratio < 0.25 and not warning_n_t:
            #     print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger)
            # elif ratio == 1 and not warning_n_t:
            #     print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger)
            idx = numpy.random.permutation(numpy.arange(
                gdata_pos.shape[0]))[:2500]
            res['waveforms_pos'] = gdata_pos[idx, :]

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          list(res.keys()),
                          res,
                          compression=hdf5_compress)
        if sign_peaks == 'positive':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj_pos'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'negative':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'both':
            print_and_log([
                "Two basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'debug', logger)
        if pca_pos is not None:
            print_and_log([
                "The percentage of variance explained is %s for positive spikes"
                % pca_pos
            ], 'debug', logger)
        if pca_neg is not None:
            print_and_log([
                "The percentage of variance explained is %s for negative spikes"
                % pca_neg
            ], 'debug', logger)

        bfile.close()

    comm.Barrier()

    if matched_filter:

        if comm.rank == 0:
            print_and_log([
                "Because of matched filters, need to recompute the thresholds..."
            ], 'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            local_chunk /= thresholds

            if sign_peaks in ['negative', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_neg,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in range(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds'],
                                      {'matched_thresholds': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

            if sign_peaks in ['positive', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_pos,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in range(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds_pos'],
                                      {'matched_thresholds_pos': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

    data_file.close()

    if SHARED_MEMORY and ignore_dead_times:
        mpi_memory_3.Free()
Exemple #21
0
    We generate minibatches of data and delayed data of the appropriate size transposed for use on the GPU.

    If we can't fill the last minibatch, we discard that data.
    """
    numCases, numDims = data.shape
    numBatches = numCases / mbs
    for i in range(numBatches):
        if transpose:
            yield (data[i * mbs:(i + 1) * mbs, :].transpose(),
                   [p[i * mbs:(i + 1) * mbs, :].transpose() for p in past])
        else:
            yield (data[i * mbs:(i + 1) * mbs, :],
                   [p[i * mbs:(i + 1) * mbs, :] for p in past])


def main():
    pass


if __name__ == "__main__":
    print "export LD_LIBRARY_PATH=/u/gdahl/cudaLearn/"
    print "export CUDAMATDIR=/u/gdahl/cudaLearn"

    devId = cm.cuda_get_free_device()
    cm.cuda_set_device(devId)

    cm.cublas_init()
    cm.CUDAMatrix.init_random(1)
    main()
    cm.cublas_shutdown()
Exemple #22
0
import sys
import numpy as np
import socket
import struct

import cudamat as cm

cuda_devise = 0

cm.cuda_set_device(cuda_devise)
cm.cublas_init()
cm.CUDAMatrix.init_random(1)

adinserver_host = 'localhost'
adinserver_port = 5532
julius_host = 'localhost'
julius_port = 5531
num_raw = 120
num_input = 1320
num_hid = 2048
num_output = 2004
num_context = 11  # 1320 / 120
batchsize = 32

w_filename = [
    "dnn_sample/W_l1.npy", "dnn_sample/W_l2.npy", "dnn_sample/W_l3.npy",
    "dnn_sample/W_l4.npy", "dnn_sample/W_l5.npy", "dnn_sample/W_output.npy"
]
b_filename = [
    "dnn_sample/bias_l1.npy", "dnn_sample/bias_l2.npy",
    "dnn_sample/bias_l3.npy", "dnn_sample/bias_l4.npy",
Exemple #23
0
        W = None
    else:
        W = loadMatrix(args.matrixW).astype(np.float32)

    if args.transpose:
        X = X.T.copy(
        )  # copy is needed to keep strides in order (or in column major order)

    if args.save is None:
        args.save = args.filename

    # chose the prefered lib:
    if args.lib == "C":
        print "Using GPU %i" % args.gpuID
        import cudamat as cm
        cm.cuda_set_device(args.gpuID)
        cm.cublas_init()

        import NMFcudamat as nmf
        print "Chose NMFcudamat library"
    elif args.lib == "P":
        import NMFskcuda as nmf
        print "Chose NMFskcuda library"
    else:
        import NMFnumpy as nmf
        print "Chose NMFnumpy library"

    try:
        if args.type == "P":
            print "Running sparse NMF with sparseness constraints %f for H and %f for W" \
                % (args.sparseH, args.sparseW)
Exemple #24
0
def LockGPU(max_retries=10, board=-1):

    # Assuming you already got GPU lock
    cm.cuda_set_device(board)
    cm.cublas_init()
    return board
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    dist_peaks = params.getint('detection', 'dist_peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = detect_memory(params)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    tmp_limits = map(float, tmp_limits)
    amp_auto = params.getboolean('fitting', 'amp_auto')
    nb_chances = params.getint('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    #################################################################

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))
            matched_tresholds_pos = io.load_data(params,
                                                 'matched-thresholds-pos')

    if ignore_dead_times:
        all_dead_times = get_dead_times(params)

    thresholds = io.load_data(params, 'thresholds')

    comm.Barrier()

    if comm.rank == 0:
        print_and_log(["Extracting MUA activity..."], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    else:
        spatial_whitening = None  # default assignment (PyCharm code inspection)
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    else:
        temporal_whitening = None  # default assignment (PyCharm code inspection)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.mua-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    electrodes_file = open(file_out_suff + '.elec-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amp-%d.data' % comm.rank, 'wb')
    comm.Barrier()

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    to_explore = range(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):
        # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift].

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if is_last:
            padding = (-dist_peaks, 0)
        elif is_first:
            padding = (0, dist_peaks)
        else:
            padding = (-dist_peaks, dist_peaks)

        result = {'spiketimes': [], 'amplitudes': [], 'templates': []}

        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   padding,
                                                   nodes=nodes)
        len_chunk = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                           temporal_whitening,
                                                           axis=0,
                                                           mode='constant')

        # print "Extracting the peaks..."

        local_peaktimes = [numpy.zeros(0, dtype=numpy.uint32)]
        local_elecs = [numpy.zeros(0, dtype=numpy.uint32)]
        local_amps = [numpy.zeros(0, dtype=numpy.float32)]

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_pos, axis=0, mode='constant')
                for i in range(N_e):
                    peaktimes = scipy.signal.find_peaks(
                        filter_chunk[:, i],
                        height=matched_tresholds_pos[i],
                        width=spike_width,
                        distance=dist_peaks,
                        wlen=N_t)[0]
                    local_peaktimes.append(peaktimes)
                    local_elecs.append(
                        i * numpy.ones(len(peaktimes), dtype='uint32'))
                    local_amps.append(filter_chunk[peaktimes, i])
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_neg, axis=0, mode='constant')
                for i in range(N_e):
                    peaktimes = scipy.signal.find_peaks(
                        filter_chunk[:, i],
                        height=matched_tresholds_neg[i],
                        width=spike_width,
                        distance=dist_peaks,
                        wlen=N_t)[0]
                    local_peaktimes.append(peaktimes)
                    local_elecs.append(
                        i * numpy.ones(len(peaktimes), dtype='uint32'))
                    local_amps.append(filter_chunk[peaktimes, i])
        else:
            for i in range(N_e):
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(
                        local_chunk[:, i]),
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                local_peaktimes.append(peaktimes)
                local_elecs.append(i *
                                   numpy.ones(len(peaktimes), dtype='uint32'))
                local_amps.append(local_chunk[peaktimes, i])

        local_peaktimes = numpy.concatenate(local_peaktimes)
        local_elecs = numpy.concatenate(local_elecs)
        local_amps = numpy.concatenate(local_amps)

        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(
                all_dead_times, [t_offset, t_offset + chunk_size])
            if dead_indices[0] != dead_indices[1]:
                is_included = numpy.in1d(
                    local_peaktimes + g_offset,
                    all_dead_times[dead_indices[0]:dead_indices[1]])
                local_peaktimes = local_peaktimes[~is_included]
                local_elecs = local_elecs[~is_included]
                local_amps = local_amps[~is_included]

        # print "Removing the useless borders..."
        local_borders = (dist_peaks, len_chunk - dist_peaks)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes) + g_offset
        local_elecs = numpy.compress(idx, local_elecs)
        local_amps = numpy.compress(idx, local_amps)

        spiketimes_file.write(local_peaktimes.astype(numpy.uint32).tostring())
        electrodes_file.write(local_elecs.tostring())
        amplitudes_file.write(local_amps.tostring())

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    electrodes_file.flush()
    os.fsync(electrodes_file.fileno())
    electrodes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    comm.Barrier()

    if comm.rank == 0:
        io.collect_mua(comm.size, params, erase=True)

    data_file.close()
Exemple #26
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    # Part 1: Whitening
    numpy.random.seed(420)

    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.whitening')
    #################################################################
    data_file = params.data_file
    data_file.open()
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    dist_peaks = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    file_out_suff = params.get('data', 'file_out_suff')
    file_out = params.get('data', 'file_out')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    matched_filter = params.getboolean('detection', 'matched-filter')
    matched_thresh = params.getfloat('detection', 'matched_thresh')
    sign_peaks = params.get('detection', 'peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = params.getint('whitening', 'chunk_size')
    plot_path = os.path.join(params.get('data', 'data_file_noext'), 'plots')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('whitening', 'safety_time')
    nb_temp_white = min(max(20, comm.size), N_e)
    max_silence_1 = int(20 * params.rate // comm.size)
    max_silence_2 = 5000
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    #################################################################

    if comm.rank == 0:
        print_and_log(
            ["Analyzing data to get whitening matrices and thresholds..."],
            'default', logger)

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    all_electrodes = numpy.random.permutation(N_e)

    for gidx in [all_chunks[comm.rank]]:

        #print "Node", comm.rank, "is analyzing chunk", gidx,  "/", nb_chunks, " ..."
        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   nodes=nodes)
        local_shape = len(local_chunk)

        #print "Node", comm.rank, "computes the median absolute deviations in a random chunk"
        thresholds = numpy.zeros(N_e, dtype=numpy.float32)
        for i in xrange(N_e):
            u = numpy.median(local_chunk[:, i], 0)
            thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0)
        gdata = gather_array(thresholds, comm)
        if comm.rank == 0:
            gdata = gdata.reshape((comm.size, N_e))
            thresholds = numpy.mean(gdata, 0)
            bfile = h5py.File(file_out_suff + '.basis.hdf5',
                              'w',
                              libver='latest')
            io.write_datasets(bfile, ['thresholds'],
                              {'thresholds': thresholds})
            bfile.close()
        comm.Barrier()
        thresholds = io.load_data(params, 'thresholds')

        #print "Extracting the peaks..."
        local_peaktimes = numpy.zeros(0, dtype=numpy.int32)
        for i in xrange(N_e):
            peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                          thresholds[i],
                                          valley=False,
                                          mpd=dist_peaks)
            local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))

        local_peaktimes = numpy.unique(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, local_shape - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if len(local_peaktimes) > 0:

            diff_times = local_peaktimes[-1] - local_peaktimes[0]
            all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool)
            min_times = numpy.maximum(
                local_peaktimes - local_peaktimes[0] - safety_time, 0)
            max_times = numpy.minimum(
                local_peaktimes - local_peaktimes[0] + safety_time + 1,
                diff_times)
            argmax_peak = numpy.random.permutation(
                numpy.arange(len(local_peaktimes)))
            all_idx = numpy.take(local_peaktimes, argmax_peak)

            #print "Selection of the peaks with spatio-temporal masks..."
            for idx, peak in zip(argmax_peak, all_idx):
                elec = numpy.argmax(numpy.abs(local_chunk[peak]))
                indices = numpy.take(inv_nodes, edges[nodes[elec]])
                all_times[indices, min_times[idx]:max_times[idx]] = True
        else:
            all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool)

    all_times_Ne = numpy.any(all_times, 0)
    subset = numpy.where(all_times_Ne == False)[0]
    all_silences = []

    if do_spatial_whitening:
        local_silences = numpy.take(local_chunk, subset,
                                    axis=0)[:max_silence_1]
        all_silences = gather_array(local_silences, comm, 0, 1)

    local_res = []

    if do_temporal_whitening:

        for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white,
                                                comm.size)]:
            res = numpy.zeros((0, N_t), dtype=numpy.float32)
            scount = 0
            indices = numpy.take(inv_nodes, edges[nodes[elec]])
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            bound = len(esubset) - N_t
            while (scount < bound) and (len(res) < max_silence_2):
                myslice = esubset[scount:scount + N_t]
                if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)):
                    scount += N_t
                    res = numpy.vstack((res, local_chunk[myslice, elec]))
                else:
                    scount += 1
            if len(res) > 5:
                local_res += [numpy.cov(res.T)]

        nb_elecs = numpy.array([len(local_res)], dtype=numpy.float32)
        local_res = numpy.array(local_res, dtype=numpy.float32)
        if len(local_res) == 0:
            local_res = numpy.zeros(0, dtype=numpy.float32)
        else:
            local_res = numpy.sum(local_res, 0)
        all_res = gather_array(local_res.ravel(), comm, 0, 1)
        all_elecs = gather_array(nb_elecs, comm, 0, 1)

    if comm.rank == 0:

        to_write = {}

        if do_temporal_whitening:
            try:
                nb_silences = numpy.sum(all_elecs > 0)
                all_res = all_res.reshape((nb_silences, N_t**2))
            except Exception:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            all_res = numpy.sum(all_res, 0)
            all_res = all_res.reshape((N_t, N_t)) / numpy.sum(all_elecs)
            temporal_whitening = get_whitening_matrix(
                all_res.astype(numpy.double),
                fudge=1e-3)[template_shift].astype(numpy.float32)
            temporal_whitening /= temporal_whitening.sum()
            to_write['temporal'] = temporal_whitening
            have_nans = numpy.sum(numpy.isnan(temporal_whitening))

            if have_nans > 0:
                temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32)
                temporal_whitening[N_t // 2] = 1
                to_write['temporal'] = temporal_whitening
                print_and_log(
                    ["Disabling temporal whitening because of NaNs found"],
                    'info', logger)

        if do_spatial_whitening:
            if len(all_silences) / params.rate == 0:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            spatial_whitening = get_whitening_matrix(
                all_silences.astype(numpy.double)).astype(numpy.float32)
            to_write['spatial'] = spatial_whitening
            print_and_log([
                "Found %gs without spikes for whitening matrices..." %
                (len(all_silences) / params.rate)
            ], 'default', logger)

            have_nans = numpy.sum(numpy.isnan(spatial_whitening))

            if have_nans > 0:
                spatial_whitening = numpy.eye(spatial_whitening.shape[0],
                                              dtype=numpy.float32)
                to_write['spatial'] = spatial_whitening
                print_and_log(
                    ["Disabling spatial whitening because of NaNs found"],
                    'info', logger)

        bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='latest')
        io.write_datasets(bfile, to_write.keys(), to_write)
        bfile.close()

    del all_silences
    comm.Barrier()

    if do_spatial_whitening or do_temporal_whitening:

        if comm.rank == 0:
            print_and_log(
                ["Because of whitening, need to recompute the thresholds..."],
                'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            thresholds = numpy.zeros(N_e, dtype=numpy.float32)
            for i in xrange(N_e):
                u = numpy.median(local_chunk[:, i], 0)
                thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u),
                                             0)
            gdata = gather_array(thresholds, comm)
            if comm.rank == 0:
                gdata = gdata.reshape((comm.size, N_e))
                thresholds = numpy.mean(gdata, 0)
                bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                  'r+',
                                  libver='latest')
                bfile.pop('thresholds')
                io.write_datasets(bfile, ['thresholds'],
                                  {'thresholds': thresholds})
                bfile.close()
            comm.Barrier()

    #if comm.rank == 0:
    #if not os.path.exists(plot_path):
    #    os.makedirs(plot_path)
    #N_elec = min(int(numpy.sqrt(data_file.N_e)), 5)
    #plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True,
    #              n_elec=N_elec, save=[plot_path, 'electrodes'])

    # Part 2: Basis
    numpy.random.seed(422)

    #################################################################
    file_out = params.get('data', 'file_out')
    alignment = params.getboolean('detection', 'alignment')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    nodes, edges = get_nodes_and_edges(params)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = params.getint('data', 'chunk_size')
    safety_time = params.getint('whitening', 'safety_time')
    max_elts_elec = params.getint('whitening', 'max_elts')
    nb_elts = int(
        params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec)
    output_dim = params.getfloat('whitening', 'output_dim')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    if sign_peaks == 'both':
        max_elts_elec *= 2
    #################################################################

    if comm.rank == 0:
        print_and_log(["Searching spikes to construct the PCA basis..."],
                      'default', logger)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    groups = {}
    for i in xrange(N_e):
        groups[i] = 0

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    max_elts_elec //= comm.size
    nb_elts //= comm.size

    elt_count_pos = 0
    elt_count_neg = 0

    if sign_peaks in ['positive', 'both']:
        elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)
    if sign_peaks in ['negative', 'both']:
        elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)

    chunks_to_load = all_chunks[comm.rank::comm.size]

    thresholds = io.load_data(params, 'thresholds')

    if comm.rank == 0:
        pbar = get_progressbar(nb_elts)

    if alignment:
        cdata = numpy.linspace(-template_shift, template_shift, 5 * N_t)
        xdata = numpy.arange(-2 * template_shift, 2 * template_shift + 1)

    for gcount, gidx in enumerate(chunks_to_load):

        if ((elt_count_pos + elt_count_neg) < nb_elts):
            #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            #print "Extracting the peaks..."
            all_peaktimes = numpy.zeros(0, dtype=numpy.int32)
            all_extremas = numpy.zeros(0, dtype=numpy.int32)

            for i in xrange(N_e):

                if sign_peaks == 'negative':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=True,
                                                  mpd=dist_peaks)
                elif sign_peaks == 'positive':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=False,
                                                  mpd=dist_peaks)
                elif sign_peaks == 'both':
                    peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                                  thresholds[i],
                                                  valley=False,
                                                  mpd=dist_peaks)
                all_peaktimes = numpy.concatenate((all_peaktimes, peaktimes))
                all_extremas = numpy.concatenate(
                    (all_extremas,
                     i * numpy.ones(len(peaktimes), dtype=numpy.int32)))

            #print "Removing the useless borders..."
            if alignment:
                local_borders = (2 * template_shift,
                                 local_shape - 2 * template_shift)
            else:
                local_borders = (template_shift, local_shape - template_shift)
            idx = (all_peaktimes >= local_borders[0]) & (all_peaktimes <
                                                         local_borders[1])
            all_peaktimes = numpy.compress(idx, all_peaktimes)
            all_extremas = numpy.compress(idx, all_extremas)

            local_peaktimes = numpy.unique(all_peaktimes)

            if len(local_peaktimes) > 0:

                diff_times = local_peaktimes[-1] - local_peaktimes[0]
                all_times = numpy.zeros((N_e, diff_times + 1),
                                        dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    diff_times)

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                all_idx = numpy.take(local_peaktimes, argmax_peak)

                #print "Selection of the peaks with spatio-temporal masks..."
                for midx, peak in zip(argmax_peak, all_idx):
                    if (elt_count_neg + elt_count_pos) == nb_elts:
                        break

                    if sign_peaks == 'negative':
                        elec = numpy.argmin(local_chunk[peak])
                        negative_peak = True
                    elif sign_peaks == 'positive':
                        elec = numpy.argmax(local_chunk[peak])
                        negative_peak = False
                    elif sign_peaks == 'both':
                        if numpy.abs(numpy.max(local_chunk[peak])) > numpy.abs(
                                numpy.min(local_chunk[peak])):
                            elec = numpy.argmax(local_chunk[peak])
                            negative_peak = False
                        else:
                            elec = numpy.argmin(local_chunk[peak])
                            negative_peak = True

                    indices = numpy.take(inv_nodes, edges[nodes[elec]])
                    myslice = all_times[indices,
                                        min_times[midx]:max_times[midx]]
                    is_local_extrema = elec in all_extremas[all_peaktimes ==
                                                            peak]
                    if is_local_extrema and not myslice.any():
                        upper_bounds = max_elts_elec

                        if groups[elec] < upper_bounds:

                            if negative_peak:
                                elts_neg[:, elt_count_neg] = local_chunk[
                                    peak - template_shift:peak +
                                    template_shift + 1, elec]
                            else:
                                elts_pos[:, elt_count_pos] = local_chunk[
                                    peak - template_shift:peak +
                                    template_shift + 1, elec]
                            if alignment:
                                ydata = local_chunk[peak -
                                                    2 * template_shift:peak +
                                                    2 * template_shift + 1,
                                                    elec]
                                f = scipy.interpolate.UnivariateSpline(xdata,
                                                                       ydata,
                                                                       s=0)
                                if negative_peak:
                                    rmin = (numpy.argmin(f(cdata)) -
                                            len(cdata) / 2.) / 5.
                                else:
                                    rmin = (numpy.argmax(f(cdata)) -
                                            len(cdata) / 2.) / 5.
                                ddata = numpy.linspace(rmin - template_shift,
                                                       rmin + template_shift,
                                                       N_t)

                                if negative_peak:
                                    elts_neg[:,
                                             elt_count_neg] = f(ddata).astype(
                                                 numpy.float32)
                                else:
                                    elts_pos[:,
                                             elt_count_pos] = f(ddata).astype(
                                                 numpy.float32)

                            if negative_peak:
                                elt_count_neg += 1
                            else:
                                elt_count_pos += 1

                        groups[elec] += 1
                        all_times[indices,
                                  min_times[midx]:max_times[midx]] = True

            if comm.rank == 0:
                pbar.update(elt_count_pos + elt_count_neg)

        if comm.rank == 0:
            if (elt_count_pos + elt_count_neg <
                (gcount + 1) * max_elts_elec // len(chunks_to_load)):
                pbar.update(
                    (gcount + 1) * max_elts_elec // len(chunks_to_load))

    if comm.rank == 0:
        pbar.finish()

    print_and_log([
        "Node %d has collected %d waveforms" %
        (comm.rank, elt_count_pos + elt_count_neg)
    ], 'debug', logger)

    if sign_peaks in ['negative', 'both']:
        gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1)
    if sign_peaks in ['positive', 'both']:
        gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1)

    if comm.rank == 0:
        #DO PCA on elts and store the basis obtained.

        nb_waveforms = 0
        if sign_peaks in ['negative', 'both']:
            nb_waveforms += gdata_neg.shape[0]
        if sign_peaks in ['positive', 'both']:
            nb_waveforms += gdata_pos.shape[0]

        print_and_log([
            "Found %d waveforms over %d requested" %
            (nb_waveforms, int(nb_elts * comm.size))
        ], 'default', logger)
        pca = PCA(output_dim, copy=False)
        res = {}
        if sign_peaks in ['negative', 'both']:
            if len(gdata_neg) > 0:
                res_pca = pca.fit_transform(gdata_neg.astype(
                    numpy.double)).astype(numpy.float32)
                res['proj'] = pca.components_.T.astype(numpy.float32)
            else:
                res['proj'] = numpy.identity(N_t, dtype=numpy.float32)
            res['rec'] = res['proj'].T
            res['waveform'] = numpy.median(gdata_neg, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_neg.shape[0]))[:1000]
            res['waveforms'] = gdata_neg[idx, :]
        if sign_peaks in ['positive', 'both']:
            if len(gdata_pos) > 0:
                res_pca = pca.fit_transform(gdata_pos.astype(
                    numpy.double)).astype(numpy.float32)
                res['proj_pos'] = pca.components_.T.astype(numpy.float32)
            else:
                res['proj_pos'] = numpy.identity(N_t, dtype=numpy.float32)
            res['rec_pos'] = res['proj_pos'].T
            res['waveform_pos'] = numpy.median(gdata_pos, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_pos.shape[0]))[:1000]
            res['waveforms_pos'] = gdata_pos[idx, :]

        bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='latest')
        io.write_datasets(bfile, res.keys(), res)
        if sign_peaks == 'positive':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj_pos'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'negative':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'both':
            print_and_log([
                "Two basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)

        bfile.close()

    comm.Barrier()

    if matched_filter:

        if comm.rank == 0:
            print_and_log([
                "Because of matched filters, need to recompute the thresholds..."
            ], 'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            if sign_peaks in ['negative', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_neg,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out + '.basis.hdf5',
                                      'r+',
                                      libver='latest')
                    io.write_datasets(bfile, ['matched_thresholds'],
                                      {'matched_thresholds': thresholds})
                    bfile.close()
                comm.Barrier()

            if sign_peaks in ['positive', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_pos,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out + '.basis.hdf5',
                                      'r+',
                                      libver='latest')
                    io.write_datasets(bfile, ['matched_thresholds_pos'],
                                      {'matched_thresholds_pos': thresholds})
                    bfile.close()
                comm.Barrier()

    data_file.close()
Exemple #27
0
def main(filename, params, nb_cpu, nb_gpu, use_gpu):

    try:
        SHARED_MEMORY = True
        MPI.Win.Allocate_shared(1, 1, MPI.INFO_NULL, MPI.COMM_SELF).Free()
    except NotImplementedError:
        SHARED_MEMORY = False

    #################################################################
    sampling_rate = params.getint('data', 'sampling_rate')
    N_e = params.getint('data', 'N_e')
    N_t = params.getint('data', 'N_t')
    N_total = params.getint('data', 'N_total')
    template_shift = params.getint('data', 'template_shift')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = int(params.getfloat('fitting', 'chunk') * sampling_rate)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = io.get_nodes_and_edges(params)
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    tmp_limits = map(float, tmp_limits)
    amp_auto = params.getboolean('fitting', 'amp_auto')
    space_explo = params.getfloat('fitting', 'space_explo')
    nb_chances = params.getint('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    #################################################################

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates = io.load_data_memshared(params,
                                           comm,
                                           'templates',
                                           normalize=True,
                                           transpose=True)
        N_tm, x = templates.shape
    else:
        templates = io.load_data(params, 'templates')
        x, N_tm = templates.shape

    N_e = params.getint('data', 'N_e')
    N_t = params.getint('data', 'N_t')
    template_shift = int((N_t - 1) // 2)
    temp_2_shift = 2 * template_shift
    full_gpu = use_gpu and gpu_only
    n_tm = N_tm // 2
    n_scalar = N_e * N_t
    last_spikes = numpy.zeros((n_tm, 1), dtype=numpy.int32)
    temp_window = numpy.arange(-template_shift, template_shift + 1)

    if not amp_auto:
        amp_limits = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')

    if not SHARED_MEMORY:
        for idx in xrange(templates.shape[1]):
            myslice = numpy.arange(templates.indptr[idx],
                                   templates.indptr[idx + 1])
            templates.data[myslice] /= norm_templates[idx]
        templates = templates.T

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates)

    info_string = ''

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))
            matched_tresholds_pos = io.load_data(params,
                                                 'matched-thresholds-pos')

    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" % (comm.size)
        else:
            info_string = "using %d CPUs" % (comm.size)

    comm.Barrier()

    thresholds = io.load_data(params, 'thresholds')

    if SHARED_MEMORY:
        c_overs = io.load_data_memshared(params,
                                         comm,
                                         'overlaps',
                                         nb_cpu=nb_cpu,
                                         nb_gpu=nb_gpu,
                                         use_gpu=use_gpu)
        c_overlap = io.get_overlaps(comm,
                                    params,
                                    nb_cpu=nb_cpu,
                                    nb_gpu=nb_gpu,
                                    use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        N_over = int(numpy.sqrt(over_shape[0]))
        S_over = over_shape[1]
    else:
        c_overlap = io.get_overlaps(comm,
                                    params,
                                    nb_cpu=nb_cpu,
                                    nb_gpu=nb_gpu,
                                    use_gpu=use_gpu)
        over_x = c_overlap.get('over_x')[:]
        over_y = c_overlap.get('over_y')[:]
        over_data = c_overlap.get('over_data')[:]
        over_shape = c_overlap.get('over_shape')[:]
        N_over = int(numpy.sqrt(over_shape[0]))
        S_over = over_shape[1]
        c_overlap.close()

        # To be faster, we rearrange the overlaps into a dictionnary. This has a cost: twice the memory usage for
        # a short period of time
        c_overs = {}
        overlaps = scipy.sparse.csr_matrix(
            (over_data, (over_x, over_y)),
            shape=(over_shape[0], over_shape[1]))
        del over_x, over_y, over_data

        for i in xrange(N_over):
            c_overs[i] = overlaps[i * N_over:(i + 1) * N_over]

        del overlaps

    comm.Barrier()

    if comm.rank == 0:
        io.print_and_log([
            "Here comes the SpyKING CIRCUS %s and %d templates..." %
            (info_string, n_tm)
        ], 'default', params)
        io.purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in xrange(N_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i])
        except Exception:
            if comm.rank == 0:
                io.print_and_log([
                    "Not enough memory on GPUs: GPUs are used for projection only"
                ], 'info', params)
            for i in xrange(N_over):
                if c_overs.has_key(i):
                    del c_overs[i]
            full_gpu = False

    borders, nb_chunks, chunk_len, last_chunk_len = io.analyze_data(
        params, chunk_size)
    nb_chunks = int(min(nb_chunks, max_chunk))

    if comm.rank == 0:
        pbar = get_progressbar(int(nb_chunks // comm.size))

    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank,
                           'wb')
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank,
                           'wb')
    templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank,
                          'wb')

    comm.Barrier()

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    last_chunk_size = 0

    for gcount, gidx in enumerate(xrange(comm.rank, nb_chunks, comm.size)):
        #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift]
        if gidx == (nb_chunks - 1):
            padding = (-2 * borders, 0)
        elif gidx == 0:
            padding = (0, 2 * borders)
        else:
            padding = (-2 * borders, 2 * borders)

        result = {'spiketimes': [], 'amplitudes': [], 'templates': []}

        local_chunk, local_shape = io.load_chunk(params,
                                                 gidx,
                                                 chunk_len,
                                                 chunk_size,
                                                 padding,
                                                 nodes=nodes)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                           temporal_whitening,
                                                           axis=0,
                                                           mode='constant')

        #print "Extracting the peaks..."
        local_peaktimes = numpy.zeros(0, dtype=numpy.int32)

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_pos, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i],
                                                  matched_tresholds_pos[i])
                    local_peaktimes = numpy.concatenate(
                        (local_peaktimes, peaktimes))
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_neg, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i],
                                                  matched_tresholds_neg[i])
                    local_peaktimes = numpy.concatenate(
                        (local_peaktimes, peaktimes))
        else:
            for i in xrange(N_e):
                if sign_peaks == 'negative':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=True)
                elif sign_peaks == 'positive':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=False)
                elif sign_peaks == 'both':
                    peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                                  thresholds[i],
                                                  valley=False)
                local_peaktimes = numpy.concatenate(
                    (local_peaktimes, peaktimes))

        local_peaktimes = numpy.unique(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, local_shape - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)
        n_t = len(local_peaktimes)
        len_chunk = local_chunk.shape[0]
        all_indices = numpy.arange(n_t)

        if full_gpu:
            #   all_indices = cmt.CUDAMatrix(all_indices)
            tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)),
                                     copy_on_host=False)

        if n_t > 0:
            #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."

            local_chunk = local_chunk.T.ravel()
            sub_mat = numpy.zeros((N_e * (2 * template_shift + 1), n_t),
                                  dtype=numpy.float32)

            if len_chunk != last_chunk_size:
                slice_indices = numpy.zeros(0, dtype=numpy.int32)
                for idx in xrange(N_e):
                    slice_indices = numpy.concatenate(
                        (slice_indices, len_chunk * idx + temp_window))
                last_chunk_size = len_chunk

            for count, idx in enumerate(local_peaktimes):
                sub_mat[:, count] = numpy.take(local_chunk,
                                               slice_indices + idx)

            del local_chunk

            if use_gpu:
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b = cmt.sparse_dot(templates, sub_mat)
            else:
                b = templates.dot(sub_mat)

            del sub_mat

            local_offset = gidx * chunk_size + padding[0] // N_total
            local_bounds = (temp_2_shift, local_shape - temp_2_shift)
            all_spikes = local_peaktimes + local_offset
            penalty = numpy.ones((n_tm, n_t), dtype=numpy.float32)

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            #b           = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()

            failure = numpy.zeros(n_t, dtype=numpy.int32)

            if full_gpu:
                mask = cmt.CUDAMatrix(penalty, copy_on_host=False)
                data = cmt.empty(mask.shape)
                cm_zeros = cmt.CUDAMatrix(numpy.zeros(mask.shape),
                                          copy_on_host=False)
                patch_gpu = b.shape[1] == 1
            else:
                mask = penalty
                sub_b = b[:n_tm, :]

            min_time = local_peaktimes.min()
            max_time = local_peaktimes.max()
            local_len = max_time - min_time + 1
            min_times = numpy.maximum(
                local_peaktimes - min_time - temp_2_shift, 0)
            max_times = numpy.minimum(
                local_peaktimes - min_time + temp_2_shift + 1,
                max_time - min_time)
            max_n_t = int(space_explo * (max_time - min_time + 1) //
                          (2 * temp_2_shift + 1))

            while (numpy.mean(failure) < nb_chances):

                if full_gpu:
                    sub_b = b.get_row_slice(0, n_tm)
                    sub_b.mult(mask, data)
                    tmp_mat = data.max(0)
                    argmax_bi = numpy.argsort(tmp_mat.asarray()[0, :])[::-1]
                    del tmp_mat, sub_b
                else:
                    data = sub_b * mask
                    argmax_bi = numpy.argsort(numpy.max(data, 0))[::-1]

                while (len(argmax_bi) > 0):

                    subset = []
                    indices = []
                    all_times = numpy.zeros(local_len, dtype=numpy.bool)

                    for count, idx in enumerate(argmax_bi):
                        myslice = all_times[min_times[idx]:max_times[idx]]
                        if not myslice.any():
                            subset += [idx]
                            indices += [count]
                            all_times[min_times[idx]:max_times[idx]] = True
                        if len(subset) > max_n_t:
                            break

                    subset = numpy.array(subset, dtype=numpy.int32)
                    argmax_bi = numpy.delete(argmax_bi, indices)

                    if full_gpu:
                        sub_b = b.get_row_slice(0, n_tm)
                        tmp_mat = sub_b.argmax(0)
                        inds_t, inds_temp = subset, tmp_mat.asarray()[
                            0, :][subset].astype(numpy.int32)
                        del tmp_mat
                    else:
                        inds_t, inds_temp = subset, numpy.argmax(
                            numpy.take(sub_b, subset, axis=1), 0)

                    if full_gpu:
                        best_amp = sub_b.asarray()[inds_temp,
                                                   inds_t] / n_scalar
                        best_amp2 = b.asarray()[inds_temp + n_tm,
                                                inds_t] / n_scalar
                        sub_mask = numpy.ones((sub_b.shape),
                                              dtype=numpy.float32)
                        sub_mask[inds_temp, inds_t] = 0
                        sub_mask = cmt.CUDAMatrix(sub_mask, copy_on_host=False)
                        mask.mult(sub_mask)
                        del sub_mask
                    else:
                        mask[inds_temp, inds_t] = 0
                        best_amp = sub_b[inds_temp, inds_t] / n_scalar
                        best_amp2 = b[inds_temp + n_tm, inds_t] / n_scalar

                    best_amp_n = best_amp / numpy.take(norm_templates,
                                                       inds_temp)
                    best_amp2_n = best_amp2 / numpy.take(
                        norm_templates, inds_temp + n_tm)

                    all_idx = ((best_amp_n >= amp_limits[inds_temp, 0]) &
                               (best_amp_n <= amp_limits[inds_temp, 1]))
                    to_keep = numpy.where(all_idx == True)[0]
                    to_reject = numpy.where(all_idx == False)[0]
                    ts = numpy.take(local_peaktimes, inds_t[to_keep])
                    good = (ts >= local_bounds[0]) & (ts < local_bounds[1])

                    # We reduce to only the good times that will be kept
                    #to_keep      = to_keep[good]
                    #ts           = ts[good]

                    if len(ts) > 0:
                        if full_gpu:
                            tmp = cmt.CUDAMatrix(numpy.ones((len(ts), 1)),
                                                 copy_on_host=False)
                            tmp3 = cmt.CUDAMatrix(-ts.reshape((len(ts), 1)),
                                                  copy_on_host=False)
                            tmp = tmp.dot(tmp_gpu)
                            tmp.add_col_vec(tmp3)
                            condition = cmt.empty(tmp.shape)
                            cmt.abs(tmp, condition).less_than(temp_2_shift + 1)
                            condition = condition.asarray().astype(numpy.bool)
                            tmp = tmp.asarray().astype(numpy.int32)
                        else:
                            tmp = numpy.dot(
                                numpy.ones((len(ts), 1), dtype=numpy.int32),
                                local_peaktimes.reshape((1, n_t)))
                            tmp -= ts.reshape((len(ts), 1))
                            condition = numpy.abs(tmp) <= temp_2_shift

                        for count, keep in enumerate(to_keep):

                            idx_b = numpy.compress(condition[count, :],
                                                   all_indices)
                            ytmp = tmp[count,
                                       condition[count, :]] + temp_2_shift

                            indices = numpy.zeros((S_over, len(ytmp)),
                                                  dtype=numpy.float32)
                            indices[ytmp, numpy.arange(len(ytmp))] = 1

                            if full_gpu:
                                indices = cmt.CUDAMatrix(indices,
                                                         copy_on_host=False)
                                if patch_gpu:
                                    b_lines = b.get_col_slice(0, b.shape[0])
                                else:
                                    b_lines = b.get_col_slice(
                                        idx_b[0], idx_b[-1] + 1)

                                tmp1 = cmt.sparse_dot(c_overs[inds_temp[keep]],
                                                      indices,
                                                      mult=-best_amp[keep])
                                tmp2 = cmt.sparse_dot(c_overs[inds_temp[keep] +
                                                              n_tm],
                                                      indices,
                                                      mult=-best_amp2[keep])
                                b_lines.add(tmp1)
                                b_lines.add(tmp2)
                                del tmp1, tmp2
                            else:
                                tmp1 = c_overs[inds_temp[keep]].multiply(
                                    -best_amp[keep]).dot(indices)
                                tmp2 = c_overs[inds_temp[keep] +
                                               n_tm].multiply(-best_amp2[keep]
                                                              ).dot(indices)
                                b[:, idx_b] += tmp1 + tmp2

                            if good[count]:

                                t_spike = ts[count] + local_offset
                                result['spiketimes'] += [t_spike]
                                result['amplitudes'] += [(best_amp_n[keep],
                                                          best_amp2_n[keep])]
                                result['templates'] += [inds_temp[keep]]

                    myslice = numpy.take(inds_t, to_reject)
                    failure[myslice] += 1
                    sub_idx = (numpy.take(failure, myslice) >= nb_chances)
                    if full_gpu:
                        N = numpy.sum(sub_idx)
                        if N > 0:
                            cu_slice = cmt.CUDAMatrix(numpy.compress(
                                sub_idx, myslice).reshape(1, N),
                                                      copy_on_host=False)
                            mask.set_selected_columns(cu_slice, cm_zeros)
                            del cu_slice
                    else:
                        mask[:, numpy.compress(sub_idx, myslice)] = 0

                    if full_gpu:
                        del sub_b

            spikes_to_write = numpy.array(result['spiketimes'],
                                          dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'],
                                              dtype=numpy.float32)
            templates_to_write = numpy.array(result['templates'],
                                             dtype=numpy.int32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if full_gpu:
                del mask, b, cm_zeros, data

        if comm.rank == 0:
            pbar.update(gcount)

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    comm.Barrier()

    if comm.rank == 0:
        pbar.finish()

    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)
Exemple #28
0
def main(params, nb_cpu, nb_gpu, use_gpu):
    # Part 1: Whitening
    numpy.random.seed(420)
    #params         = detect_memory(params)
    logger = init_logging(params.logfile)
    logger = logging.getLogger('circus.whitening')
    #################################################################
    data_file = params.data_file
    N_e = params.getint('data', 'N_e')
    hdf5_compress = params.getboolean('data', 'hdf5_compress')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    dist_peaks = params.getint('detection', 'dist_peaks')
    template_shift = params.getint('detection', 'template_shift')
    file_out_suff = params.get('data', 'file_out_suff')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    spike_width = params.getfloat('detection', 'spike_width')
    matched_filter = params.getboolean('detection', 'matched-filter')
    matched_thresh = params.getfloat('detection', 'matched_thresh')
    sign_peaks = params.get('detection', 'peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = detect_memory(params, whitening=True)
    plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots')
    nodes, edges = get_nodes_and_edges(params)
    safety_time = params.getint('whitening', 'safety_time')
    safety_space = params.getboolean('whitening', 'safety_space')
    nb_temp_white = min(max(20, comm.size), N_e)
    max_silence_1 = int(20 * params.rate // comm.size)
    max_silence_2 = 5000
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    jitter_range = params.getint('detection', 'jitter_range')
    template_shift_2 = template_shift + jitter_range
    use_hanning = params.getboolean('detection', 'hanning')
    data_file.open()
    #################################################################

    if use_hanning:
        hanning_filter = numpy.hanning(N_t)

    if comm.rank == 0:
        print_and_log(
            ["Analyzing data to get whitening matrices and thresholds..."],
            'default', logger)

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    all_electrodes = numpy.random.permutation(N_e)

    for gidx in [all_chunks[comm.rank]]:

        #print "Node", comm.rank, "is analyzing chunk", gidx,  "/", nb_chunks, " ..."
        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   nodes=nodes)
        local_shape = len(local_chunk)

        #print "Node", comm.rank, "computes the median absolute deviations in a random chunk"
        thresholds = numpy.zeros(N_e, dtype=numpy.float32)
        for i in xrange(N_e):
            u = numpy.median(local_chunk[:, i], 0)
            thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0)
        gdata = gather_array(thresholds, comm)
        if comm.rank == 0:
            gdata = gdata.reshape((comm.size, N_e))
            thresholds = numpy.mean(gdata, 0)
            bfile = h5py.File(file_out_suff + '.basis.hdf5',
                              'w',
                              libver='earliest')
            io.write_datasets(bfile, ['thresholds'],
                              {'thresholds': thresholds},
                              compression=hdf5_compress)
            bfile.close()
        comm.Barrier()
        thresholds = io.load_data(params, 'thresholds')

        #print "Extracting the peaks..."
        local_peaktimes = numpy.zeros(0, dtype=numpy.uint32)
        for i in xrange(N_e):
            peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]),
                                                height=thresholds[i],
                                                width=spike_width,
                                                wlen=N_t)[0]
            local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes))

        local_peaktimes = numpy.unique(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, local_shape - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if len(local_peaktimes) > 0:

            diff_times = local_peaktimes[-1] - local_peaktimes[0]
            all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool)
            min_times = numpy.maximum(
                local_peaktimes - local_peaktimes[0] - safety_time, 0)
            max_times = numpy.minimum(
                local_peaktimes - local_peaktimes[0] + safety_time + 1,
                diff_times)
            argmax_peak = numpy.random.permutation(
                numpy.arange(len(local_peaktimes)))
            all_idx = numpy.take(local_peaktimes, argmax_peak)

            #print "Selection of the peaks with spatio-temporal masks..."
            for idx, peak in zip(argmax_peak, all_idx):
                elec = numpy.argmax(numpy.abs(local_chunk[peak]))
                indices = numpy.take(inv_nodes, edges[nodes[elec]])
                if safety_space:
                    all_times[indices, min_times[idx]:max_times[idx]] = True
                else:
                    all_times[elec, min_times[idx]:max_times[idx]] = True
        else:
            all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool)

    if do_temporal_whitening:

        local_res_temp = []

        for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white,
                                                comm.size)]:
            res = numpy.zeros((0, N_t), dtype=numpy.float32)
            scount = 0
            indices = numpy.take(inv_nodes, edges[nodes[elec]])
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            bound = len(esubset) - N_t
            while (scount < bound) and (len(res) < max_silence_2):
                myslice = esubset[scount:scount + N_t]
                if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)):
                    scount += N_t
                    res = numpy.vstack((res, local_chunk[myslice, elec]))
                else:
                    scount += 1
            if len(res) > 5:
                local_res_temp += [numpy.cov(res.T)]

        nb_elecs = numpy.array([len(local_res_temp)], dtype=numpy.float32)
        local_res_temp = numpy.array(local_res_temp, dtype=numpy.float32)
        if len(local_res_temp) == 0:
            local_res_temp = numpy.zeros(0, dtype=numpy.float32)
        else:
            local_res_temp = numpy.sum(local_res_temp, 0)
        all_res_temp = gather_array(local_res_temp.ravel(), comm, 0, 1)
        all_elecs = gather_array(nb_elecs, comm, 0, 1)

    if do_spatial_whitening:

        local_res_spac = numpy.zeros((N_e, N_e), dtype=numpy.float32)
        local_silences = []

        for elec in numpy.arange(comm.rank, N_e, comm.size):
            indices = numpy.take(inv_nodes, edges[nodes[elec]])
            all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0),
                                       0)
            esubset = numpy.where(all_times_elec == False)[0]
            local_data = local_chunk[esubset][:, indices]
            local_whitening = get_whitening_matrix(local_data).astype(
                numpy.float32)
            pos = numpy.where(elec == indices)[0]
            local_res_spac[elec, indices] = local_whitening[pos]
            local_silences += [len(esubset)]

        all_res_spac = gather_array(local_res_spac.ravel(), comm, 0, 1)
        all_silences = gather_array(
            numpy.array(local_silences, dtype=numpy.int32), comm, 0, 1,
            'uint32')

    if comm.rank == 0:

        to_write = {}

        if do_temporal_whitening:
            try:
                nb_silences = numpy.sum(all_elecs > 0)
                all_res_temp = all_res_temp.reshape((nb_silences, N_t**2))
            except Exception:
                print_and_log([
                    "No silent periods detected: something wrong with the parameters?"
                ], 'error', logger)
            all_res_temp = numpy.sum(all_res_temp, 0)
            all_res_temp = all_res_temp.reshape(
                (N_t, N_t)) / numpy.sum(all_elecs)
            temporal_whitening = get_whitening_matrix(
                all_res_temp.astype(numpy.double),
                fudge=1e-3)[template_shift].astype(numpy.float32)
            temporal_whitening /= temporal_whitening.sum()
            to_write['temporal'] = temporal_whitening
            have_nans = numpy.sum(numpy.isnan(temporal_whitening))

            if have_nans > 0:
                temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32)
                temporal_whitening[N_t // 2] = 1
                to_write['temporal'] = temporal_whitening
                print_and_log(
                    ["Disabling temporal whitening because of NaNs found"],
                    'info', logger)

        if do_spatial_whitening:
            all_res_spac = all_res_spac.reshape(comm.size, N_e, N_e)
            spatial_whitening = numpy.sum(all_res_spac, 0)
            to_write['spatial'] = spatial_whitening

            print_and_log([
                "Found %gs without spikes for whitening matrices..." %
                (numpy.mean(all_silences) / params.rate)
            ], 'default', logger)

            have_nans = numpy.sum(numpy.isnan(spatial_whitening))

            if have_nans > 0:
                spatial_whitening = numpy.eye(spatial_whitening.shape[0],
                                              dtype=numpy.float32)
                to_write['spatial'] = spatial_whitening
                print_and_log(
                    ["Disabling spatial whitening because of NaNs found"],
                    'info', logger)

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile,
                          to_write.keys(),
                          to_write,
                          compression=hdf5_compress)
        bfile.close()

    comm.Barrier()

    if do_spatial_whitening or do_temporal_whitening:

        if comm.rank == 0:
            print_and_log(
                ["Because of whitening, need to recompute the thresholds..."],
                'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            thresholds = numpy.zeros(N_e, dtype=numpy.float32)
            for i in xrange(N_e):
                u = numpy.median(local_chunk[:, i], 0)
                thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u),
                                             0)
            gdata = gather_array(thresholds, comm)
            if comm.rank == 0:
                gdata = gdata.reshape((comm.size, N_e))
                thresholds = numpy.mean(gdata, 0)
                bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                  'r+',
                                  libver='earliest')
                bfile.pop('thresholds')
                io.write_datasets(bfile, ['thresholds'],
                                  {'thresholds': thresholds},
                                  compression=hdf5_compress)
                bfile.close()
            comm.Barrier()

    #if comm.rank == 0:
    #if not os.path.exists(plot_path):
    #    os.makedirs(plot_path)
    #N_elec = min(int(numpy.sqrt(data_file.N_e)), 5)
    #plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True,
    #              n_elec=N_elec, save=[plot_path, 'electrodes'])

    # Part 2: Basis
    numpy.random.seed(422)

    #################################################################
    file_out = params.get('data', 'file_out')
    alignment = params.getboolean('detection', 'alignment')
    smoothing = params.getboolean('detection', 'smoothing')
    isolation = params.getboolean('detection', 'isolation')
    over_factor = params.getint('detection', 'oversampling_factor')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    nodes, edges = get_nodes_and_edges(params)
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    if matched_filter:
        chunk_size = detect_memory(params, whitening=True)
    else:
        chunk_size = detect_memory(params)
    safety_time = params.getint('whitening', 'safety_time')
    max_elts_elec = params.getint('whitening', 'max_elts')
    output_dim = params.getfloat('whitening', 'output_dim')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    smoothing_factor = params.getfloat(
        'detection', 'smoothing_factor') * (1. / spike_thresh)**2
    if sign_peaks == 'both':
        max_elts_elec *= 2
    nb_elts = int(
        params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec)

    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    if ignore_dead_times:
        all_dead_times = get_dead_times(params)
    data_file.open()
    #################################################################

    if comm.rank == 0:
        print_and_log(["Searching spikes to construct the PCA basis..."],
                      'default', logger)

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    if nb_chunks < comm.size:

        res = io.data_stats(params, show=False)
        chunk_size = int(res * params.rate // comm.size)
        if comm.rank == 0:
            print_and_log(
                ["Too much cores, automatically resizing the data chunks"],
                'debug', logger)

        nb_chunks, last_chunk_len = data_file.analyze(chunk_size)

    groups = {}
    for i in xrange(N_e):
        groups[i] = 0

    # I guess this is more relevant, to take signals from all over the recordings
    all_chunks = numpy.random.permutation(
        numpy.arange(nb_chunks, dtype=numpy.int32))
    max_elts_elec //= comm.size
    nb_elts //= comm.size

    elt_count_pos = 0
    elt_count_neg = 0

    if sign_peaks in ['positive', 'both']:
        elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)
    if sign_peaks in ['negative', 'both']:
        elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32)

    chunks_to_load = all_chunks[comm.rank::comm.size]

    thresholds = io.load_data(params, 'thresholds')
    mads = io.load_data(params, 'mads')

    if alignment:
        cdata = numpy.linspace(-jitter_range, jitter_range,
                               int(over_factor * 2 * jitter_range))
        xdata = numpy.arange(-template_shift_2, template_shift_2 + 1)
        xoff = len(cdata) / 2.

    if isolation:
        yoff = numpy.array(range(0, N_t // 4) + range(3 * N_t // 4, N_t))

    to_explore = xrange(comm.rank, nb_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):

        gidx = all_chunks[gidx]

        if ((elt_count_pos + elt_count_neg) < nb_elts):
            #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            #print "Extracting the peaks..."
            all_peaktimes = numpy.zeros(0, dtype=numpy.uint32)
            all_extremas = numpy.zeros(0, dtype=numpy.uint32)

            for i in xrange(N_e):

                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        distance=dist_peaks,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i],
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        wlen=N_t)[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(
                        local_chunk[:, i]),
                                                        height=thresholds[i],
                                                        width=spike_width,
                                                        wlen=N_t)[0]
                all_peaktimes = numpy.concatenate((all_peaktimes, peaktimes))
                all_extremas = numpy.concatenate(
                    (all_extremas,
                     i * numpy.ones(len(peaktimes), dtype=numpy.uint32)))

            #print "Removing the useless borders..."
            if alignment:
                local_borders = (template_shift_2,
                                 local_shape - template_shift_2)
            else:
                local_borders = (template_shift, local_shape - template_shift)
            idx = (all_peaktimes >= local_borders[0]) & (all_peaktimes <
                                                         local_borders[1])
            all_peaktimes = numpy.compress(idx, all_peaktimes)
            all_extremas = numpy.compress(idx, all_extremas)

            local_peaktimes = numpy.unique(all_peaktimes)

            if ignore_dead_times:
                dead_indices = numpy.searchsorted(
                    all_dead_times, [t_offset, t_offset + local_shape])
                if dead_indices[0] != dead_indices[1]:
                    is_included = numpy.in1d(
                        local_peaktimes + t_offset,
                        all_dead_times[dead_indices[0]:dead_indices[1]])
                    local_peaktimes = local_peaktimes[~is_included]
                    local_peaktimes = numpy.sort(local_peaktimes)

            if len(local_peaktimes) > 0:

                diff_times = local_peaktimes[-1] - local_peaktimes[0]
                all_times = numpy.zeros((N_e, diff_times + 1),
                                        dtype=numpy.bool)
                min_times = numpy.maximum(
                    local_peaktimes - local_peaktimes[0] - safety_time, 0)
                max_times = numpy.minimum(
                    local_peaktimes - local_peaktimes[0] + safety_time + 1,
                    diff_times)

                n_times = len(local_peaktimes)
                argmax_peak = numpy.random.permutation(numpy.arange(n_times))
                all_idx = numpy.take(local_peaktimes, argmax_peak)

                #print "Selection of the peaks with spatio-temporal masks..."
                for midx, peak in zip(argmax_peak, all_idx):
                    if (elt_count_neg + elt_count_pos) == nb_elts:
                        break

                    if sign_peaks == 'negative':
                        elec = numpy.argmin(local_chunk[peak])
                        negative_peak = True
                    elif sign_peaks == 'positive':
                        elec = numpy.argmax(local_chunk[peak])
                        negative_peak = False
                    elif sign_peaks == 'both':
                        if N_e == 1:
                            if local_chunk[peak] < 0:
                                negative_peak = True
                            elif local_chunk[peak] > 0:
                                negative_peak = False
                            elec = 0
                        else:
                            if numpy.abs(numpy.max(
                                    local_chunk[peak])) > numpy.abs(
                                        numpy.min(local_chunk[peak])):
                                elec = numpy.argmax(local_chunk[peak])
                                negative_peak = False
                            else:
                                elec = numpy.argmin(local_chunk[peak])
                                negative_peak = True

                    indices = numpy.take(inv_nodes, edges[nodes[elec]])
                    myslice = all_times[indices,
                                        min_times[midx]:max_times[midx]]
                    is_local_extrema = elec in all_extremas[all_peaktimes ==
                                                            peak]
                    if is_local_extrema and not myslice.any():
                        upper_bounds = max_elts_elec

                        if groups[elec] < upper_bounds:

                            if not alignment:
                                sub_mat = local_chunk[peak -
                                                      template_shift:peak +
                                                      template_shift + 1, elec]

                            elif alignment:
                                ydata = local_chunk[peak -
                                                    template_shift_2:peak +
                                                    template_shift_2 + 1, elec]

                                if smoothing:
                                    factor = smoothing_factor * xdata.size
                                    f = scipy.interpolate.UnivariateSpline(
                                        xdata, ydata, s=factor, k=3)
                                else:
                                    f = scipy.interpolate.UnivariateSpline(
                                        xdata, ydata, k=3, s=0)
                                if negative_peak:
                                    rmin = (numpy.argmin(f(cdata)) -
                                            xoff) / over_factor
                                else:
                                    rmin = (numpy.argmax(f(cdata)) -
                                            xoff) / over_factor
                                ddata = numpy.linspace(rmin - template_shift,
                                                       rmin + template_shift,
                                                       N_t)
                                sub_mat = f(ddata).astype(numpy.float32)

                            if alignment:
                                if negative_peak:
                                    if numpy.min(sub_mat) >= -thresholds[elec]:
                                        to_accept = False
                                else:
                                    if numpy.max(sub_mat) <= thresholds[elec]:
                                        to_accept = False

                            if isolation:
                                to_accept = numpy.all(
                                    numpy.max(numpy.abs(sub_mat[yoff])) <=
                                    thresholds[elec])
                            else:
                                to_accept = True

                            if to_accept:
                                if negative_peak:
                                    elts_neg[:, elt_count_neg] = sub_mat
                                else:
                                    elts_pos[:, elt_count_pos] = sub_mat

                                if negative_peak:
                                    elt_count_neg += 1
                                else:
                                    elt_count_pos += 1

                        groups[elec] += 1
                        all_times[indices,
                                  min_times[midx]:max_times[midx]] = True

    sys.stderr.flush()

    if isolation:
        print_and_log([
            "Node %d has collected %d isolated waveforms" %
            (comm.rank, elt_count_pos + elt_count_neg)
        ], 'debug', logger)
    else:
        print_and_log([
            "Node %d has collected %d waveforms" %
            (comm.rank, elt_count_pos + elt_count_neg)
        ], 'debug', logger)

    if sign_peaks in ['negative', 'both']:
        gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1)
    if sign_peaks in ['positive', 'both']:
        gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1)

    nb_waveforms = 0

    if comm.rank == 0:
        #DO PCA on elts and store the basis obtained.

        if sign_peaks in ['negative', 'both']:
            nb_waveforms += gdata_neg.shape[0]
        if sign_peaks in ['positive', 'both']:
            nb_waveforms += gdata_pos.shape[0]

    nb_waveforms = all_gather_array(
        numpy.array([nb_waveforms], dtype=numpy.float32), comm, 0)[0]

    if comm.rank == 0:
        if isolation:
            print_and_log([
                "Found %d isolated waveforms over %d requested" %
                (nb_waveforms, int(nb_elts * comm.size))
            ], 'default', logger)
        else:
            print_and_log([
                "Found %d waveforms over %d requested" %
                (nb_waveforms, int(nb_elts * comm.size))
            ], 'default', logger)

        if nb_waveforms == 0:
            print_and_log(
                ['No waveforms found! Are the data properly loaded??'],
                'error', logger)

    if nb_waveforms == 0:
        sys.exit(0)

    if comm.rank == 0:
        res = {}
        pca = None
        pca_pos = None
        pca_neg = None
        if sign_peaks in ['negative', 'both']:
            if len(gdata_neg) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_neg * hanning_filter)
                else:
                    pca.fit(gdata_neg)
                res['proj'] = pca.components_.T.astype(numpy.float32)
                pca_neg = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj'] = numpy.identity(int(output_dim),
                                             dtype=numpy.float32)
            res['rec'] = res['proj'].T
            res['waveform'] = numpy.median(gdata_neg, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_neg.shape[0]))[:1000]
            res['waveforms'] = gdata_neg[idx, :]
        if sign_peaks in ['positive', 'both']:
            if len(gdata_pos) > 0:
                pca = PCA(output_dim)
                if use_hanning:
                    pca.fit(gdata_pos * hanning_filter)
                else:
                    pca.fit(gdata_pos)
                res['proj_pos'] = pca.components_.T.astype(numpy.float32)
                pca_pos = numpy.sum(pca.explained_variance_ratio_)
            else:
                res['proj_pos'] = numpy.identity(int(output_dim),
                                                 dtype=numpy.float32)
            res['rec_pos'] = res['proj_pos'].T
            res['waveform_pos'] = numpy.median(gdata_pos, 0)
            idx = numpy.random.permutation(numpy.arange(
                gdata_pos.shape[0]))[:1000]
            res['waveforms_pos'] = gdata_pos[idx, :]

        bfile = h5py.File(file_out_suff + '.basis.hdf5',
                          'r+',
                          libver='earliest')
        io.write_datasets(bfile, res.keys(), res, compression=hdf5_compress)
        if sign_peaks == 'positive':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj_pos'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'negative':
            print_and_log([
                "A basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'info', logger)
        elif sign_peaks == 'both':
            print_and_log([
                "Two basis with %s dimensions has been built" %
                res['proj'].shape[1]
            ], 'debug', logger)
        if pca_pos is not None:
            print_and_log([
                "The percentage of variance explained is %s for positive spikes"
                % pca_pos
            ], 'debug', logger)
        if pca_neg is not None:
            print_and_log([
                "The percentage of variance explained is %s for negative spikes"
                % pca_neg
            ], 'debug', logger)

        bfile.close()

    comm.Barrier()

    if matched_filter:

        if comm.rank == 0:
            print_and_log([
                "Because of matched filters, need to recompute the thresholds..."
            ], 'default', logger)

        if do_spatial_whitening:
            spatial_whitening = io.load_data(params, 'spatial_whitening')
            if use_gpu:
                spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                                   copy_on_host=False)
        if do_temporal_whitening:
            temporal_whitening = io.load_data(params, 'temporal_whitening')

        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))

        for gidx in [all_chunks[comm.rank]]:
            local_chunk, t_offset = data_file.get_data(gidx,
                                                       chunk_size,
                                                       nodes=nodes)
            local_shape = len(local_chunk)

            if do_spatial_whitening:
                if use_gpu:
                    local_chunk = cmt.CUDAMatrix(local_chunk,
                                                 copy_on_host=False)
                    local_chunk = local_chunk.dot(spatial_whitening).asarray()
                else:
                    local_chunk = numpy.dot(local_chunk, spatial_whitening)
            if do_temporal_whitening:
                local_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, temporal_whitening, axis=0, mode='constant')

            local_chunk /= thresholds

            if sign_peaks in ['negative', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_neg,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds'],
                                      {'matched_thresholds': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

            if sign_peaks in ['positive', 'both']:
                tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                             waveform_pos,
                                                             axis=0,
                                                             mode='constant')
                thresholds = numpy.zeros(N_e, dtype=numpy.float32)
                for i in xrange(N_e):
                    u = numpy.median(tmp_chunk[:, i], 0)
                    thresholds[i] = numpy.median(
                        numpy.abs(tmp_chunk[:, i] - u), 0)
                gdata = gather_array(thresholds, comm)
                if comm.rank == 0:
                    gdata = gdata.reshape((comm.size, N_e))
                    thresholds = numpy.mean(gdata, 0)
                    bfile = h5py.File(file_out_suff + '.basis.hdf5',
                                      'r+',
                                      libver='earliest')
                    io.write_datasets(bfile, ['matched_thresholds_pos'],
                                      {'matched_thresholds_pos': thresholds},
                                      compression=hdf5_compress)
                    bfile.close()
                comm.Barrier()

    data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    # params = detect_memory(params)
    _ = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    n_e = params.getint('data', 'N_e')
    n_total = params.nb_channels
    n_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    # file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    # spike_thresh = params.getfloat('detection', 'spike_thresh')
    ratio_thresh = params.getfloat('fitting', 'ratio_thresh')
    two_components = params.getboolean('fitting', 'two_components')
    # spike_width = params.getfloat('detection', 'spike_width')
    # dist_peaks = params.getint('detection', 'dist_peaks')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    templates_normalization = params.getboolean('clustering', 'templates_normalization')  # TODO test, switch, test!
    chunk_size = detect_memory(params, fitting=True)
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',')
    tmp_limits = [float(v) for v in tmp_limits]
    amp_auto = params.getboolean('fitting', 'amp_auto')
    auto_nb_chances = params.getboolean('fitting', 'auto_nb_chances')
    if auto_nb_chances:
        nb_chances = io.load_data(params, 'nb_chances')
        max_nb_chances = params.getint('fitting', 'max_nb_chances')
        percent_nb_chances = params.getfloat('fitting', 'percent_nb_chances')
        total_nb_chances = max(1, numpy.nanpercentile(nb_chances, percent_nb_chances))
        total_nb_chances = min(total_nb_chances, max_nb_chances)
        if comm.rank == 0:
            print_and_log(['nb_chances set automatically to %g' %total_nb_chances], 'debug', logger)
    else:
        total_nb_chances = params.getfloat('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    # noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    min_second_component = params.getfloat('fitting', 'min_second_component')
    debug = params.getboolean('fitting', 'debug')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(n_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.arange(len(nodes))
    data_file.open()
    #################################################################

    if use_gpu:
        import cudamat as cmt
        # # Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates, _ = io.load_data_memshared(params, 'templates', normalize=templates_normalization, transpose=True)
        N_tm, x = templates.shape
    else:
        templates = io.load_data(params, 'templates')
        x, N_tm = templates.shape

    temp_2_shift = 2 * template_shift
    temp_3_shift = 3 * template_shift
    full_gpu = use_gpu and gpu_only
    n_tm = N_tm // 2
    n_scalar = n_e * n_t

    temp_window = numpy.arange(-template_shift, template_shift + 1)
    size_window = n_e * (2 * template_shift + 1)

    if not amp_auto:
        amp_limits = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')
    if not templates_normalization:
        norm_templates_2 = (norm_templates ** 2.0) * n_scalar

    if not SHARED_MEMORY:
        # Normalize templates (if necessary).
        if templates_normalization:
            for idx in range(templates.shape[1]):
                myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1])
                templates.data[myslice] /= norm_templates[idx]
        # Transpose templates.
        templates = templates.T

    waveform_neg = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_neg = None  # default assignment (for PyCharm code inspection)
    waveform_pos = numpy.empty(0)  # default assignment (for PyCharm code inspection)
    matched_thresholds_pos = None  # default assignment (for PyCharm code inspection)
    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')[::-1]
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg))
            matched_thresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')[::-1]
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos))
            matched_thresholds_pos = io.load_data(params, 'matched-thresholds-pos')

    if ignore_dead_times:
        all_dead_times = get_dead_times(params)
    else:
        all_dead_times = None  # default assignment (for PyCharm code inspection)

    thresholds = io.get_accurate_thresholds(params, ratio_thresh)

    neighbors = {}
    if collect_all:
        for i in range(0, n_tm):
            tmp = templates[i, :].toarray().reshape(n_e, n_t)
            if templates_normalization:
                tmp = tmp * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, axis=1) != 0.0)[0]

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False)

    info_string = ''

    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" % comm.size
        else:
            info_string = "using %d CPUs" % comm.size

    comm.Barrier()

    c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    n_over = int(numpy.sqrt(over_shape[0]))
    s_over = over_shape[1]
    # # If the number of overlaps is different from templates, we need to recompute them.
    if n_over != N_tm:
        if comm.rank == 0:
            print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger)
        c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        n_over = int(numpy.sqrt(over_shape[0]))
        s_over = over_shape[1]

    if SHARED_MEMORY:
        c_overs, _ = io.load_data_memshared(params, 'overlaps')
    else:
        c_overs = io.load_data(params, 'overlaps')

    comm.Barrier()

    if n_tm == 0:
        if comm.rank == 0:
            print_and_log(["No templates present. Redo clustering?"], 'default', logger)

        sys.exit(0)

    if comm.rank == 0:
        print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm)], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    else:
        spatial_whitening = None  # default assignment (for PyCharm code inspection)
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')
    else:
        temporal_whitening = None  # default assignment (for PyCharm code inspection)

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in range(n_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False)
        except Exception:
            if comm.rank == 0:
                print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger)
            for i in range(n_over):
                if i in c_overs:
                    del c_overs[i]
            full_gpu = False

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb')
    comm.Barrier()
    templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb')
    comm.Barrier()

    if collect_all:
        garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb')
        comm.Barrier()
    else:
        garbage_times_file = None  # default assignment (for PyCharm code inspection)
        garbage_temp_file = None  # default assignment (for PyCharm code inspection)

    if debug:
        # Open debug files.
        chunk_nbs_debug_file = open(file_out_suff + '.chunk_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        iteration_nbs_debug_file = open(file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_nbs_debug_file = open(file_out_suff + '.peak_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_local_time_steps_debug_file = open(
            file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_time_steps_debug_file = open(file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        peak_scalar_products_debug_file = open(
            file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank, mode='wb'
        )
        comm.Barrier()
        peak_solved_flags_debug_file = open(file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        template_nbs_debug_file = open(file_out_suff + '.template_nbs_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
        success_flags_debug_file = open(file_out_suff + '.success_flags_debug_%d.data' % comm.rank, mode='wb')
        comm.Barrier()
    else:
        chunk_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        iteration_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_local_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_time_steps_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_scalar_products_debug_file = None  # default assignment (for PyCharm code inspection)
        peak_solved_flags_debug_file = None  # default assignment (for PyCharm code inspection)
        template_nbs_debug_file = None  # default assignment (for PyCharm code inspection)
        success_flags_debug_file = None  # default assignment (for PyCharm code inspection)

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False)

    last_chunk_size = 0
    slice_indices = numpy.zeros(0, dtype=numpy.int32)

    to_explore = range(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(params, to_explore)

    for gcount, gidx in enumerate(to_explore):
        # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift].

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if not (is_first and is_last):
            if is_last:
                padding = (-temp_3_shift, 0)
            elif is_first:
                padding = (0, temp_3_shift)
            else:
                padding = (-temp_3_shift, temp_3_shift)
        else:
            padding = (0, 0)

        result = {
            'spiketimes': [],
            'amplitudes': [],
            'templates': [],
        }
        result_debug = {
            'chunk_nbs': [],
            'iteration_nbs': [],
            'peak_nbs': [],
            'peak_local_time_steps': [],
            'peak_time_steps': [],
            'peak_scalar_products': [],
            'peak_solved_flags': [],
            'template_nbs': [],
            'success_flags': [],
        }

        local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes)           
        len_chunk = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant')

        # Extracting peaks.

        all_found_spikes = {}
        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = []

        local_peaktimes = [numpy.empty(0, dtype=numpy.uint32)]

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_pos[i])[0]
                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant')
                for i in range(n_e):
                    peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_neg[i])[0]
                    local_peaktimes.append(peaktimes)
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)
        else:
            for i in range(n_e):
                if sign_peaks == 'negative':
                    peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'positive':
                    peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i])[0]
                elif sign_peaks == 'both':
                    peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i])[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)
                local_peaktimes.append(peaktimes)
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()
            local_peaktimes = numpy.concatenate(local_peaktimes)

        local_peaktimes = numpy.unique(local_peaktimes)

        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(all_dead_times, [t_offset, t_offset + chunk_size])
            if dead_indices[0] != dead_indices[1]:
                is_included = numpy.in1d(local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]])
                local_peaktimes = local_peaktimes[~is_included]
                local_peaktimes = numpy.sort(local_peaktimes)
        else:
            dead_indices = None  # default assignment (for PyCharm code inspection)

        # print "Removing the useless borders..."
        local_borders = (template_shift, len_chunk - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if collect_all:
            for i in range(n_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32)

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        is_included = numpy.in1d(
                            all_found_spikes[i] + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]
                        )
                        all_found_spikes[i] = all_found_spikes[i][~is_included]
                        all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        nb_local_peak_times = len(local_peaktimes)

        if full_gpu:
            # all_indices = cmt.CUDAMatrix(all_indices)
            # tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False)
            _ = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False)

        if nb_local_peak_times > 0:
            # print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."

            if collect_all:
                c_local_chunk = local_chunk.copy()
            else:
                c_local_chunk = None  # default assignment (for PyCharm code inspection)

            sub_mat = local_chunk[local_peaktimes[:, None] + temp_window]
            sub_mat = sub_mat.transpose(2, 1, 0).reshape(size_window, nb_local_peak_times)

            del local_chunk

            if use_gpu:
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b = cmt.sparse_dot(templates, sub_mat)
            else:
                b = templates.dot(sub_mat)

            del sub_mat

            local_restriction = (t_offset, t_offset + chunk_size)
            all_spikes = local_peaktimes + g_offset

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            # b = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()           

            failure = numpy.zeros(nb_local_peak_times, dtype=numpy.int32)

            if full_gpu:
                mask = numpy.zeros((2 * n_tm, nb_local_peak_times), dtype=numpy.float32)
                mask[:n_tm, :] = 1
                # data = cmt.empty(mask.shape)
                _ = cmt.empty(mask.shape)
                patch_gpu = b.shape[1] == 1
            else:
                patch_gpu = None

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, n_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in range(n_e):
                    c_all_times[all_found_spikes[i], i] = True
            else:
                c_all_times = None  # default assignment (for PyCharm code inspection)
                c_min_times = None  # default assignment (for PyCharm code inspection)
                c_max_times = None  # default assignment (for PyCharm code inspection)

            iteration_nb = 0
            local_max = 0
            numerous_argmax = False
            nb_argmax = n_tm
            best_indices = numpy.zeros(0, dtype=numpy.int32)

            data = b[:n_tm, :]
            flatten_data = data.ravel()

            while numpy.mean(failure) < total_nb_chances:

                # Is there a way to update sub_b * mask at the same time?
                if full_gpu:
                    b_array = b.asarray()
                else:
                    b_array = None

                if numerous_argmax:
                    if len(best_indices) == 0:
                        best_indices = largest_indices(flatten_data, nb_argmax)
                    best_template_index, peak_index = numpy.unravel_index(best_indices[0], data.shape)
                else:
                    best_template_index, peak_index = numpy.unravel_index(data.argmax(), data.shape)

                peak_scalar_product = data[best_template_index, peak_index]
                best_template2_index = best_template_index + n_tm

                if templates_normalization:
                    if full_gpu:
                        best_amp = b_array[best_template_index, peak_index] / n_scalar
                        best_amp2 = b_array[best_template2_index, peak_index] / n_scalar
                    else:
                        best_amp = b[best_template_index, peak_index] / n_scalar
                        if two_components:
                            best_amp2 = b[best_template2_index, peak_index] / n_scalar
                        else:
                            best_amp2 = 0.0
                    best_amp_n = best_amp / norm_templates[best_template_index]
                    best_amp2_n = best_amp2 / norm_templates[best_template2_index]
                else:
                    if full_gpu:
                        best_amp = b_array[best_template_index, peak_index]
                        best_amp = best_amp / norm_templates_2[best_template_index]
                        # TODO is `best_amp` value correct?
                        best_amp2 = b_array[best_template2_index, peak_index]
                        best_amp2 = best_amp2 / norm_templates_2[best_template2_index]
                        # TODO is `best_amp2` value correct?
                    else:
                        best_amp = b[best_template_index, peak_index]
                        best_amp = best_amp / norm_templates_2[best_template_index]
                        # TODO is `best_amp` value correct?
                        if two_components:
                            best_amp2 = b[best_template2_index, peak_index]
                            best_amp2 = best_amp2 / norm_templates_2[best_template2_index]
                            # TODO is `best_amp2` value correct?
                        else:
                            best_amp2 = 0.0

                    best_amp_n = best_amp
                    best_amp2_n = best_amp2

                # Verify amplitude constraint.
                a_min, a_max = amp_limits[best_template_index, :]

                if (a_min <= best_amp_n) & (best_amp_n <= a_max):
                    # Keep the matching.
                    peak_time_step = local_peaktimes[peak_index]

                    peak_data = (local_peaktimes - peak_time_step).astype(np.int32)
                    is_neighbor = np.where(np.abs(peak_data) <= temp_2_shift)[0]
                    idx_neighbor = peak_data[is_neighbor] + temp_2_shift
                    nb_neighbors = len(is_neighbor)
                    indices = np.zeros((s_over, nb_neighbors), dtype=np.int32)
                    indices[idx_neighbor, np.arange(nb_neighbors)] = 1

                    if full_gpu:
                        indices = cmt.CUDAMatrix(indices, copy_on_host=False)
                        if patch_gpu:
                            b_lines = b.get_col_slice(0, b.shape[0])
                        else:
                            b_lines = b.get_col_slice(is_neighbor[0], is_neighbor[-1]+1)
                        tmp1 = cmt.sparse_dot(c_overs[best_template_index], indices, mult=-best_amp)
                        tmp2 = cmt.sparse_dot(c_overs[best_template2_index], indices, mult=-best_amp2)
                        b_lines.add(tmp1.add(tmp2))
                        del tmp1, tmp2
                    else:
                        tmp1 = c_overs[best_template_index].multiply(-best_amp)
                        if numpy.abs(best_amp2) > min_second_component:
                            tmp1 += c_overs[best_template2_index].multiply(-best_amp2)
                        b[:, is_neighbor] += tmp1.dot(indices)

                    numerous_argmax = False

                    # Add matching to the result.
                    t_spike = all_spikes[peak_index]

                    if (t_spike >= local_restriction[0]) and (t_spike < local_restriction[1]):
                        result['spiketimes'] += [t_spike]
                        result['amplitudes'] += [(best_amp_n, best_amp2_n)]
                        result['templates'] += [best_template_index]
                    # Mark current matching as tried.
                    b[best_template_index, peak_index] = -numpy.inf
                    # Save debug data.
                    if debug:
                        result_debug['chunk_nbs'] += [gidx]
                        result_debug['iteration_nbs'] += [iteration_nb]
                        result_debug['peak_nbs'] += [peak_index]
                        result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]]
                        result_debug['peak_time_steps'] += [all_spikes[peak_index]]
                        result_debug['peak_scalar_products'] += [peak_scalar_product]
                        result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]]
                        result_debug['template_nbs'] += [best_template_index]
                        result_debug['success_flags'] += [True]
                else:
                    # Reject the matching.
                    numerous_argmax = True
                    # Update failure counter of the peak.
                    failure[peak_index] += 1
                    # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted).
                    if failure[peak_index] >= total_nb_chances:
                        # Mark all the matching associated to the current peak as tried.
                        b[:, peak_index] = -numpy.inf
                        index = numpy.arange(n_tm) * nb_local_peak_times + peak_index
                    else:
                        # Mark current matching as tried.
                        b[best_template_index, peak_index] = -numpy.inf
                        index = best_template_index * nb_local_peak_times + peak_index

                    if numerous_argmax:
                        best_indices = best_indices[~numpy.in1d(best_indices, index)]

                    # Save debug data.
                    if debug:
                        result_debug['chunk_nbs'] += [gidx]
                        result_debug['iteration_nbs'] += [iteration_nb]
                        result_debug['peak_nbs'] += [peak_index]
                        result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]]
                        result_debug['peak_time_steps'] += [all_spikes[peak_index]]
                        result_debug['peak_scalar_products'] += [peak_scalar_product]
                        result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]]
                        result_debug['template_nbs'] += [best_template_index]
                        result_debug['success_flags'] += [False]

                iteration_nb += 1

            spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32)
            templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write, spikes_to_write - g_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False

                gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times                

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_thresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_thresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(matched_thresholds_neg[bestlecs], matched_thresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                else:
                    raise ValueError("Unexpected value %s" % sign_peaks)

                gspikes = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32)
                gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())

            if debug:
                # Write debug data to debug files.
                for field_label, field_dtype, field_file in [
                    ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file),
                    ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file),
                    ('peak_nbs', numpy.uint32, peak_nbs_debug_file),
                    ('peak_local_time_steps', numpy.uint32, peak_local_time_steps_debug_file),
                    ('peak_time_steps', numpy.uint32, peak_time_steps_debug_file),
                    ('peak_scalar_products', numpy.float32, peak_scalar_products_debug_file),
                    ('peak_solved_flags', numpy.float32, peak_solved_flags_debug_file),
                    ('template_nbs', numpy.uint32, template_nbs_debug_file),
                    ('success_flags', numpy.bool, success_flags_debug_file),
                ]:
                    field_to_write = numpy.array(result_debug[field_label], dtype=field_dtype)
                    field_file.write(field_to_write.tostring())

            if full_gpu:
                del b, data

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()
        
        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()

    if debug:
        # Close debug files.
        for field_file in [
            chunk_nbs_debug_file,
            iteration_nbs_debug_file,
            peak_nbs_debug_file,
            peak_local_time_steps_debug_file,
            peak_time_steps_debug_file,
            peak_scalar_products_debug_file,
            peak_solved_flags_debug_file,
            template_nbs_debug_file,
            success_flags_debug_file,
        ]:
            field_file.flush()
            os.fsync(field_file.fileno())
            field_file.close()

    comm.Barrier()

    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

    data_file.close()
Exemple #30
0
	def __init__(self,layer,step_size=None,dropout=None):
		#TODO: should probably put cudamat initialization elsewhere
		#in case it is used by more than one network
		cm.cuda_set_device(0)
		cm.init()
		super(net_cuda,self).__init__(layer,step_size,dropout)
Exemple #31
0
def main(params, nb_cpu, nb_gpu, use_gpu):

    #################################################################
    #params         = detect_memory(params)
    logger = init_logging(params.logfile)
    SHARED_MEMORY = get_shared_memory_flag(params)
    logger = logging.getLogger('circus.fitting')
    data_file = params.data_file
    data_file.open()
    N_e = params.getint('data', 'N_e')
    N_total = params.nb_channels
    N_t = params.getint('detection', 'N_t')
    template_shift = params.getint('detection', 'template_shift')
    file_out = params.get('data', 'file_out')
    file_out_suff = params.get('data', 'file_out_suff')
    sign_peaks = params.get('detection', 'peaks')
    matched_filter = params.getboolean('detection', 'matched-filter')
    spike_thresh = params.getfloat('detection', 'spike_thresh')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening = params.getboolean('whitening', 'spatial')
    chunk_size = params.getint('fitting', 'chunk_size')
    gpu_only = params.getboolean('fitting', 'gpu_only')
    nodes, edges = get_nodes_and_edges(params)
    tmp_limits = params.get('fitting',
                            'amp_limits').replace('(',
                                                  '').replace(')',
                                                              '').split(',')
    tmp_limits = map(float, tmp_limits)
    amp_auto = params.getboolean('fitting', 'amp_auto')
    nb_chances = params.getint('fitting', 'nb_chances')
    max_chunk = params.getfloat('fitting', 'max_chunk')
    noise_thr = params.getfloat('clustering', 'noise_thr')
    collect_all = params.getboolean('fitting', 'collect_all')
    ignore_dead_times = params.getboolean('triggers', 'ignore_times')
    inv_nodes = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    #################################################################

    if use_gpu:
        import cudamat as cmt
        ## Need to properly handle multi GPU per MPI nodes?
        if nb_gpu > nb_cpu:
            gpu_id = int(comm.rank // nb_cpu)
        else:
            gpu_id = 0
        cmt.cuda_set_device(gpu_id)
        cmt.init()
        cmt.cuda_sync_threads()

    if SHARED_MEMORY:
        templates = io.load_data_memshared(params,
                                           'templates',
                                           normalize=True,
                                           transpose=True)
        N_tm, x = templates.shape
    else:
        templates = io.load_data(params, 'templates')
        x, N_tm = templates.shape

    temp_2_shift = 2 * template_shift
    full_gpu = use_gpu and gpu_only
    n_tm = N_tm // 2
    n_scalar = N_e * N_t

    temp_window = numpy.arange(-template_shift, template_shift + 1)
    size_window = N_e * (2 * template_shift + 1)

    if not amp_auto:
        amp_limits = numpy.zeros((n_tm, 2))
        amp_limits[:, 0] = tmp_limits[0]
        amp_limits[:, 1] = tmp_limits[1]
    else:
        amp_limits = io.load_data(params, 'limits')

    norm_templates = io.load_data(params, 'norm-templates')

    if not SHARED_MEMORY:
        for idx in xrange(templates.shape[1]):
            myslice = numpy.arange(templates.indptr[idx],
                                   templates.indptr[idx + 1])
            templates.data[myslice] /= norm_templates[idx]
        templates = templates.T

    if matched_filter:
        if sign_peaks in ['negative', 'both']:
            waveform_neg = io.load_data(params, 'waveform')
            waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) *
                             len(waveform_neg))
            matched_tresholds_neg = io.load_data(params, 'matched-thresholds')
        if sign_peaks in ['positive', 'both']:
            waveform_pos = io.load_data(params, 'waveform-pos')
            waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) *
                             len(waveform_pos))
            matched_tresholds_pos = io.load_data(params,
                                                 'matched-thresholds-pos')

    if ignore_dead_times:
        all_dead_times = get_dead_times(params)

    thresholds = io.load_data(params, 'thresholds')

    if collect_all:
        neighbors = {}
        for i in xrange(n_tm):
            tmp = templates[i, :].toarray().reshape(N_e,
                                                    N_t) * norm_templates[i]
            neighbors[i] = numpy.where(numpy.sum(tmp, 1) != 0)[0]

    if use_gpu:
        templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False)

    info_string = ''

    if comm.rank == 0:
        if use_gpu:
            info_string = "using %d GPUs" % (comm.size)
        else:
            info_string = "using %d CPUs" % (comm.size)

    comm.Barrier()

    c_overlap = io.get_overlaps(params,
                                nb_cpu=nb_cpu,
                                nb_gpu=nb_gpu,
                                use_gpu=use_gpu)
    over_shape = c_overlap.get('over_shape')[:]
    N_over = int(numpy.sqrt(over_shape[0]))
    S_over = over_shape[1]
    ## If the number of overlaps is different from templates, we need to recompute them
    if N_over != N_tm:
        if comm.rank == 0:
            print_and_log(
                ['Templates have been modified, recomputing the overlaps...'],
                'default', logger)
        c_overlap = io.get_overlaps(params,
                                    erase=True,
                                    nb_cpu=nb_cpu,
                                    nb_gpu=nb_gpu,
                                    use_gpu=use_gpu)
        over_shape = c_overlap.get('over_shape')[:]
        N_over = int(numpy.sqrt(over_shape[0]))
        S_over = over_shape[1]

    if SHARED_MEMORY:
        c_overs = io.load_data_memshared(params, 'overlaps')
    else:
        c_overs = io.load_data(params, 'overlaps')

    comm.Barrier()

    if n_tm == 0:
        if comm.rank == 0:
            print_and_log(["No templates present. Redo clustering?"],
                          'default', logger)

        sys.exit(0)

    if comm.rank == 0:
        print_and_log([
            "Here comes the SpyKING CIRCUS %s and %d templates..." %
            (info_string, n_tm)
        ], 'default', logger)
        purge(file_out_suff, '.data')

    if do_spatial_whitening:
        spatial_whitening = io.load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = io.load_data(params, 'temporal_whitening')

    if full_gpu:
        try:
            # If memory on the GPU is large enough, we load the overlaps onto it
            for i in xrange(N_over):
                c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i],
                                                  copy_on_host=False)
        except Exception:
            if comm.rank == 0:
                print_and_log([
                    "Not enough memory on GPUs: GPUs are used for projection only"
                ], 'info', logger)
            for i in xrange(N_over):
                if c_overs.has_key(i):
                    del c_overs[i]
            full_gpu = False

    nb_chunks, last_chunk_len = data_file.analyze(chunk_size)
    processed_chunks = int(min(nb_chunks, max_chunk))

    comm.Barrier()
    spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank,
                           'wb')
    comm.Barrier()
    amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank,
                           'wb')
    comm.Barrier()
    templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank,
                          'wb')
    comm.Barrier()

    if collect_all:
        garbage_times_file = open(
            file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb')
        comm.Barrier()
        garbage_temp_file = open(
            file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb')
        comm.Barrier()

    if use_gpu and do_spatial_whitening:
        spatial_whitening = cmt.CUDAMatrix(spatial_whitening,
                                           copy_on_host=False)

    last_chunk_size = 0

    to_explore = xrange(comm.rank, processed_chunks, comm.size)

    if comm.rank == 0:
        to_explore = get_tqdm_progressbar(to_explore)

    for gcount, gidx in enumerate(to_explore):
        #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..."
        ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift]

        is_first = data_file.is_first_chunk(gidx, nb_chunks)
        is_last = data_file.is_last_chunk(gidx, nb_chunks)

        if is_last:
            padding = (-temp_2_shift, 0)
        elif is_first:
            padding = (0, temp_2_shift)
        else:
            padding = (-temp_2_shift, temp_2_shift)

        result = {'spiketimes': [], 'amplitudes': [], 'templates': []}

        local_chunk, t_offset = data_file.get_data(gidx,
                                                   chunk_size,
                                                   padding,
                                                   nodes=nodes)
        len_chunk = len(local_chunk)

        if do_spatial_whitening:
            if use_gpu:
                local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False)
                local_chunk = local_chunk.dot(spatial_whitening).asarray()
            else:
                local_chunk = numpy.dot(local_chunk, spatial_whitening)
        if do_temporal_whitening:
            local_chunk = scipy.ndimage.filters.convolve1d(local_chunk,
                                                           temporal_whitening,
                                                           axis=0,
                                                           mode='constant')

        #print "Extracting the peaks..."

        if collect_all:
            all_found_spikes = {}
            for i in xrange(N_e):
                all_found_spikes[i] = []

        local_peaktimes = numpy.zeros(0, dtype=numpy.uint32)

        if matched_filter:
            if sign_peaks in ['positive', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_pos, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i],
                                                  matched_tresholds_pos[i])
                    local_peaktimes = numpy.concatenate(
                        (local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
            if sign_peaks in ['negative', 'both']:
                filter_chunk = scipy.ndimage.filters.convolve1d(
                    local_chunk, waveform_neg, axis=0, mode='constant')
                for i in xrange(N_e):
                    peaktimes = algo.detect_peaks(filter_chunk[:, i],
                                                  matched_tresholds_neg[i])
                    local_peaktimes = numpy.concatenate(
                        (local_peaktimes, peaktimes))
                    if collect_all:
                        all_found_spikes[i] += peaktimes.tolist()
        else:
            for i in xrange(N_e):
                if sign_peaks == 'negative':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=True)
                elif sign_peaks == 'positive':
                    peaktimes = algo.detect_peaks(local_chunk[:, i],
                                                  thresholds[i],
                                                  valley=False)
                elif sign_peaks == 'both':
                    peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]),
                                                  thresholds[i],
                                                  valley=False)
                local_peaktimes = numpy.concatenate(
                    (local_peaktimes, peaktimes))
                if collect_all:
                    all_found_spikes[i] += peaktimes.tolist()

        local_peaktimes = numpy.unique(local_peaktimes)

        g_offset = t_offset + padding[0]

        if ignore_dead_times:
            dead_indices = numpy.searchsorted(
                all_dead_times, [t_offset, t_offset + chunk_size])
            if dead_indices[0] != dead_indices[1]:
                local_peaktimes = numpy.array(list(
                    set(local_peaktimes + g_offset).difference(
                        all_dead_times[dead_indices[0]:dead_indices[1]])),
                                              dtype=numpy.uint32) - g_offset
                local_peaktimes = numpy.sort(local_peaktimes)

        #print "Removing the useless borders..."
        local_borders = (template_shift, len_chunk - template_shift)
        idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes <
                                                       local_borders[1])
        local_peaktimes = numpy.compress(idx, local_peaktimes)

        if collect_all:
            for i in xrange(N_e):
                all_found_spikes[i] = numpy.array(all_found_spikes[i],
                                                  dtype=numpy.uint32)

                if ignore_dead_times:
                    if dead_indices[0] != dead_indices[1]:
                        all_found_spikes[i] = numpy.array(
                            list(
                                set(all_found_spikes[i] + g_offset).difference(
                                    all_dead_times[
                                        dead_indices[0]:dead_indices[1]])),
                            dtype=numpy.uint32) - g_offset
                        all_found_spikes[i] = numpy.sort(all_found_spikes[i])

                idx = (all_found_spikes[i] >= local_borders[0]) & (
                    all_found_spikes[i] < local_borders[1])
                all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i])

        n_t = len(local_peaktimes)

        if full_gpu:
            #   all_indices = cmt.CUDAMatrix(all_indices)
            tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)),
                                     copy_on_host=False)

        if n_t > 0:
            #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..."

            if collect_all:
                c_local_chunk = local_chunk.copy()

            local_chunk = local_chunk.T.ravel()
            sub_mat = numpy.zeros((size_window, n_t), dtype=numpy.float32)

            if len_chunk != last_chunk_size:
                slice_indices = numpy.zeros(0, dtype=numpy.int32)
                for idx in xrange(N_e):
                    slice_indices = numpy.concatenate(
                        (slice_indices, len_chunk * idx + temp_window))
                last_chunk_size = len_chunk

            for count, idx in enumerate(local_peaktimes):
                sub_mat[:, count] = numpy.take(local_chunk,
                                               slice_indices + idx)

            del local_chunk

            if use_gpu:
                sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False)
                b = cmt.sparse_dot(templates, sub_mat)
            else:
                b = templates.dot(sub_mat)

            del sub_mat

            local_restriction = (t_offset, t_offset + chunk_size)
            all_spikes = local_peaktimes + g_offset

            # Because for GPU, slicing by columns is more efficient, we need to transpose b
            #b           = b.transpose()
            if use_gpu and not full_gpu:
                b = b.asarray()

            failure = numpy.zeros(n_t, dtype=numpy.int32)

            if full_gpu:
                mask = numpy.zeros((2 * n_tm, n_t), dtype=numpy.float32)
                mask[:n_tm, :] = 1
                data = cmt.empty(mask.shape)
                patch_gpu = b.shape[1] == 1
            else:
                mask = numpy.ones((n_tm, n_t), dtype=numpy.float32)
                sub_b = b[:n_tm, :]

            if collect_all:
                c_all_times = numpy.zeros((len_chunk, N_e), dtype=numpy.bool)
                c_min_times = numpy.maximum(
                    numpy.arange(len_chunk) - template_shift, 0)
                c_max_times = numpy.minimum(
                    numpy.arange(len_chunk) + template_shift + 1, len_chunk)
                for i in xrange(N_e):
                    c_all_times[all_found_spikes[i], i] = True

            while (numpy.mean(failure) < nb_chances):

                if full_gpu:
                    gpu_mask = cmt.CUDAMatrix(mask, copy_on_host=False)
                    b.mult(gpu_mask, data)
                    tmp_mat = data.max(0)
                    argmax_bi = numpy.argsort(tmp_mat.asarray()[0, :])[::-1]
                    del tmp_mat
                else:
                    data = sub_b * mask
                    argmax_bi = numpy.argsort(numpy.max(data, 0))[::-1]

                for peak_index in argmax_bi:

                    if full_gpu:
                        b_array = b.asarray()
                        sub_b = b_array[:n_tm, :]

                    peak_scalar_products = np.take(sub_b, peak_index, axis=1)
                    best_template_index = np.argmax(peak_scalar_products,
                                                    axis=0)
                    best_template2_index = best_template_index + n_tm

                    if full_gpu:
                        best_amp = sub_b[best_template_index,
                                         peak_index] / n_scalar
                        best_amp2 = b_array[best_template_index,
                                            peak_index] / n_scalar
                    else:
                        best_amp = sub_b[best_template_index,
                                         peak_index] / n_scalar
                        best_amp2 = b[best_template2_index,
                                      peak_index] / n_scalar

                    best_amp_n = best_amp / norm_templates[best_template_index]
                    best_amp2_n = best_amp2 / norm_templates[
                        best_template2_index]

                    # Verify amplitude constraint.
                    a_min = amp_limits[best_template_index, 0]
                    a_max = amp_limits[best_template_index, 1]

                    if (a_min <= best_amp_n) & (best_amp_n <= a_max):
                        # Keep the matching.
                        peak_time_step = local_peaktimes[peak_index]

                        data = (local_peaktimes - peak_time_step).astype(
                            np.int32)
                        is_neighbor = np.where(np.abs(data) <= temp_2_shift)[0]
                        idx_neighbor = data[is_neighbor] + temp_2_shift
                        nb_neighbors = len(is_neighbor)
                        indices = np.zeros((S_over, nb_neighbors),
                                           dtype=np.int32)
                        indices[idx_neighbor, np.arange(nb_neighbors)] = 1

                        if full_gpu:
                            indices = cmt.CUDAMatrix(indices,
                                                     copy_on_host=False)
                            if patch_gpu:
                                b_lines = b.get_col_slice(0, b.shape[0])
                            else:
                                b_lines = b.get_col_slice(
                                    is_neighbor[0], is_neighbor[-1] + 1)

                            tmp1 = cmt.sparse_dot(c_overs[best_template_index],
                                                  indices,
                                                  mult=-best_amp[keep])
                            tmp2 = cmt.sparse_dot(
                                c_overs[best_template2_index],
                                indices,
                                mult=-best_amp2[keep])
                            b_lines.add(tmp1.add(tmp2))
                            del tmp1, tmp2
                        else:
                            tmp1 = c_overs[best_template_index].multiply(
                                -best_amp)
                            tmp2 = c_overs[best_template2_index].multiply(
                                -best_amp2)
                            b[:, is_neighbor] += (tmp1 + tmp2).dot(indices)
                        # Add matching to the result.
                        t_spike = all_spikes[peak_index]
                        if (t_spike >= local_restriction[0]) and (
                                t_spike < local_restriction[1]):
                            #print "Accept spikes", t_spike, local_restriction, type(t_spike), t_spike > local_restriction[0], t_spike < local_restriction[1]
                            result['spiketimes'] += [t_spike]
                            result['amplitudes'] += [(best_amp_n, best_amp2_n)]
                            result['templates'] += [best_template_index]
                        # Mark current matching as tried.
                        mask[best_template_index, peak_index] = 0
                    else:
                        # Reject the matching.
                        # Update failure counter of the peak.
                        failure[peak_index] += 1
                        # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted).
                        if failure[peak_index] == nb_chances:
                            mask[:, peak_index] = 0
                        else:
                            mask[best_template_index, peak_index] = 0

            spikes_to_write = numpy.array(result['spiketimes'],
                                          dtype=numpy.uint32)
            amplitudes_to_write = numpy.array(result['amplitudes'],
                                              dtype=numpy.float32)
            templates_to_write = numpy.array(result['templates'],
                                             dtype=numpy.uint32)

            spiketimes_file.write(spikes_to_write.tostring())
            amplitudes_file.write(amplitudes_to_write.tostring())
            templates_file.write(templates_to_write.tostring())

            if collect_all:

                for temp, spike in zip(templates_to_write,
                                       spikes_to_write - g_offset):
                    c_all_times[c_min_times[spike]:c_max_times[spike],
                                neighbors[temp]] = False

                gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0]
                c_all_times = numpy.take(c_all_times, gspikes, axis=0)
                c_local_chunk = numpy.take(c_local_chunk, gspikes,
                                           axis=0) * c_all_times

                if sign_peaks == 'negative':
                    bestlecs = numpy.argmin(c_local_chunk, 1)
                    if matched_filter:
                        threshs = -matched_tresholds_neg[bestlecs]
                    else:
                        threshs = -thresholds[bestlecs]
                    idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0]
                elif sign_peaks == 'positive':
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = matched_tresholds_pos[bestlecs]
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]
                elif sign_peaks == 'both':
                    c_local_chunk = numpy.abs(c_local_chunk)
                    bestlecs = numpy.argmax(c_local_chunk, 1)
                    if matched_filter:
                        threshs = numpy.minimum(
                            matched_tresholds_neg[bestlecs],
                            matched_tresholds_pos[bestlecs])
                    else:
                        threshs = thresholds[bestlecs]
                    idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0]

                gspikes = numpy.take(gspikes, idx)
                bestlecs = numpy.take(bestlecs, idx)
                gspikes_to_write = numpy.array(gspikes + g_offset,
                                               dtype=numpy.uint32)
                gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32)

                garbage_times_file.write(gspikes_to_write.tostring())
                garbage_temp_file.write(gtemplates_to_write.tostring())

            if full_gpu:
                del gpu_mask, b, data

    sys.stderr.flush()

    spiketimes_file.flush()
    os.fsync(spiketimes_file.fileno())
    spiketimes_file.close()

    amplitudes_file.flush()
    os.fsync(amplitudes_file.fileno())
    amplitudes_file.close()

    templates_file.flush()
    os.fsync(templates_file.fileno())
    templates_file.close()

    if collect_all:

        garbage_temp_file.flush()
        os.fsync(garbage_temp_file.fileno())
        garbage_temp_file.close()

        garbage_times_file.flush()
        os.fsync(garbage_times_file.fileno())
        garbage_times_file.close()

    comm.Barrier()

    if comm.rank == 0:
        io.collect_data(comm.size, params, erase=True)

    data_file.close()
Exemple #32
0
    singleSoftmax(xGPU, tempRow)
    xGPU.copy_to_host()
    diff = xGPU.numpy_array-r
    print num.sum(num.abs(diff))
    #testMaskedSM()

    col = cm.CUDAMatrix(reformat(num.random.rand(5,1)))
    print col.shape
    col.copy_to_host()
    print col.numpy_array
    col.reshape((1,5))
    print col.shape
    col.copy_to_host()
    print col.numpy_array
    garb = cm.CUDAMatrix(reformat(num.zeros((5,5))))
    garb.set_row_slice(2,3,col)
    garb.copy_to_host()
    print garb.numpy_array
    
if __name__ == "__main__":
    print "export LD_LIBRARY_PATH=/u/gdahl/cudaLearn/"
    print "export CUDAMATDIR=/u/gdahl/cudaLearn"
    
    devId = gpu_lock.obtain_lock_id()
    cm.cuda_set_device(devId)
    
    cm.cublas_init()
    cm.CUDAMatrix.init_random(1)
    main()
    cm.cublas_shutdown()
import sys
import numpy as np
import socket
import struct

import cudamat as cm

cuda_devise = 0

cm.cuda_set_device(cuda_devise)
cm.cublas_init()
cm.CUDAMatrix.init_random(1)

adinserver_host = 'localhost'
adinserver_port = 5532

julius_host = 'localhost'
julius_port = 5531

num_raw = 120
num_input = 1320
num_hid = 2048
num_output = 2004
num_context = 11 # 1320 / 120

batchsize = 32

w_filename = ["dnn_sample/W_l1.npy", "dnn_sample/W_l2.npy", "dnn_sample/W_l3.npy", "dnn_sample/W_l4.npy", "dnn_sample/W_l5.npy", "dnn_sample/W_output.npy"]
b_filename = ["dnn_sample/bias_l1.npy", "dnn_sample/bias_l2.npy", "dnn_sample/bias_l3.npy", "dnn_sample/bias_l4.npy", "dnn_sample/bias_l5.npy", "dnn_sample/bias_output.npy"]

prior_filename = "dnn_sample/seedhmm.cluster.prior"
Exemple #34
0
import numpy, sys, util
import cudamat as cm
cm.cuda_set_device(6)
cm.cublas_init()
cm.CUDAMatrix.init_random(42)

cmMat = cm.empty((20, 128))
for batch in range(10000):
    cmMat.fill_with_randn()
    if numpy.isnan(cmMat.euclid_norm()):
        util.save('test.dat', 'a', {'a': cmMat.asarray()})
        print "nan error in batch: ", batch
        sys.stdout.flush()
        sys.exit(1)

print "Ran without a problem"
Exemple #35
0
    vals=gp.garray(vals)
    #vals=np.transpose(vals)
   

    vals=vals[:,I]
   
    vals=vals.T

    return vals


input_img_file=sys.argv[1] #input image file
output_file=sys.argv[2]


cm.cuda_set_device(0)
cm.cublas_init()


if '/' in input_img_file:
   str_arr=input_img_file.split('/')
   l=len(str_arr)
   arr=str_arr[l-1].split('.')
   name=arr[0]
else:
   arr=input_img_file.split('.')
   name=arr[0]
  
print 'Processing image ' + str(name) + '...'

 def __init__(self, layer, step_size=None, dropout=None):
     # TODO: should probably put cudamat initialization elsewhere
     # in case it is used by more than one network
     cm.cuda_set_device(0)
     cm.init()
     super(net_cuda, self).__init__(layer, step_size, dropout)