Пример #1
0
    def montage_weights(self, ckpt_dir, save_dir, sorted_indices):
        weight_filenames = get_sorted_files(ckpt_dir, keyword='*_W.pvp', add_parent=True)
        #weight_filenames = [os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir) if '_W.pvp' in f]
        #weight_filenames.sort()
        save_dir = os.path.join(save_dir, 'Weights')
        os.mkdir(save_dir)
        gif = [] if self.weight_gif else None

        for i_filename, weight_filename in enumerate(weight_filenames):
            data = pv.readpvpfile(weight_filename)
            weights = data['values']
            weights = weights[0, 0, :, :, :, 0]
            f, h, w = weights.shape
            weights = weights[sorted_indices, ...]
            gridh, gridw = h * int(np.ceil(np.sqrt(f))), w * int(np.ceil(np.sqrt(f)))
            grid = np.zeros([gridh, gridw])
            count = 0

            for i_h in range(0, gridh, h):
                for i_w in range(0, gridw, w):
                    if count < f:
                        grid[i_h:i_h+h, i_w:i_w+w] = bytescale_patch_np(weights[count, ...])
                        count += 1

            grid[::h, :] = 255.
            grid[:, ::w] = 255.

            if not self.weight_gif:
                fig_name = os.path.split(weight_filename)[1][:-4]
                imwrite(os.path.join(save_dir, fig_name + '.png'), np.uint8(grid))
            else:
                gif.append(np.uint8(grid))
                mimsave(os.path.join(save_dir, 'weights.gif'), gif, fps=5)
Пример #2
0
 def __init__(self, imgList, gtList, inputShape, resizeMethod="crop", shuffle=True, skip=1, seed=None, getGT=True, rangeIdx=None):
     self.inputShape = inputShape
     #Read gt file
     #Read pvp file
     self.inData = readpvpfile(imgList, progressPeriod = 100)["values"]
     [numFrames, numVals] = self.inData.shape
     #Call superclass constructor
     super(pvpObj, self).__init__(gtList, resizeMethod, shuffle, skip, seed, getGT, rangeIdx)
     self.gtIdx = [int(fn.split('/')[-2]) for fn in self.imgFiles]
     assert(len(self.gtIdx) ==  numFrames)
Пример #3
0
    def plot_recs(self, ckpt_dir, save_dir):
        rec_paths = get_sorted_files(ckpt_dir, keyword='Frame*Recon_A.pvp', add_parent=True)
        if rec_paths == []: return
        save_dir = os.path.join(save_dir, 'Recons')
        gifs = {} if self.recon_gif else None

        if not os.path.isdir(save_dir):
            os.mkdir(save_dir)

        for i_rec_path, rec_path in enumerate(rec_paths):
            i_input_frame = int(''.join(filter(str.isdigit, os.path.split(rec_path)[1])))
            input_filename = 'Frame{}_A.pvp'.format(i_input_frame)
            input_path = os.path.join(ckpt_dir, input_filename)
            input_batch, rec_batch = pv.readpvpfile(input_path)['values'], pv.readpvpfile(rec_path)['values']
            n = input_batch.shape[0]
            frame_save_dir = os.path.join(save_dir, 'Frame{}'.format(i_input_frame)) if self.recon_gif else save_dir

            if not os.path.isdir(frame_save_dir) and not self.recon_gif:
                os.mkdir(frame_save_dir)

            for i_example, (input_ex, rec_ex) in enumerate(zip(input_batch, rec_batch)):
                if i_example not in list(gifs.keys()):
                    gifs[i_example] = []

                input_ex, rec_ex = input_ex[..., 0], rec_ex[..., 0]
                input_scaled, rec_scaled = bytescale_patch_np(input_ex), bytescale_patch_np(rec_ex)

                if np.sum(rec_scaled) == 0 and int(''.join([c for c in self.latest_analysis if c.isdigit()])) != 0:
                    print('[WARNING] BATCH {} EXPLODED'.format(os.path.split(rec_path)[1]))

                divider = np.zeros([input_scaled.shape[0], int(input_scaled.shape[1]*0.05)])
                pane = np.uint8(np.concatenate((input_scaled, divider, rec_scaled), 1))

                if not self.recon_gif:
                    imwrite(os.path.join(frame_save_dir, 'Example{}Input.png'.format(i_example)), pane)
                else:
                    gifs[i_example].append(pane)

        [mimsave(os.path.join(save_dir, 'Recon_{}.gif'.format(k)), gifs[k], fps=15) for k in list(gifs.keys())]
Пример #4
0
def get_fraction_active(filename):
    data = pv.readpvpfile(filename)
    data = np.array(data['values'])
    n, h, w, f = data.shape

    active = data != 0.0
    active_total = list(np.sum(active, (0, 1, 2)) / (n * h * w))

    feat_indices = list(range(len(active_total)))
    active_indices_sorted = [
        (x, y)
        for x, y in sorted(zip(active_total, feat_indices), reverse=True)
    ]
    active_sorted = [x[0] for x in active_indices_sorted]
    feat_indices_sorted = [x[1] for x in active_indices_sorted]

    return active_sorted, feat_indices_sorted
Пример #5
0
 def __init__(self,
              imgList,
              gtList,
              inputShape,
              resizeMethod="crop",
              shuffle=True,
              skip=1,
              seed=None,
              getGT=True,
              rangeIdx=None):
     self.inputShape = inputShape
     #Read gt file
     #Read pvp file
     self.inData = readpvpfile(imgList, progressPeriod=100)["values"]
     [numFrames, numVals] = self.inData.shape
     #Call superclass constructor
     super(pvpObj, self).__init__(gtList, resizeMethod, shuffle, skip, seed,
                                  getGT, rangeIdx)
     self.gtIdx = [int(fn.split('/')[-2]) for fn in self.imgFiles]
     assert (len(self.gtIdx) == numFrames)
Пример #6
0
def load_pvp_weights(filename):
    data = readpvpfile(filename)
    vals = data["values"]
    outVals = vals[0, 0, :, :, :, :].transpose((1, 2, 3, 0)).astype(np.float32)
    return outVals
Пример #7
0
def load_pvp_weights(filename):
    data = readpvpfile(filename)
    vals = data["values"]
    outVals = vals[0, 0, :, :, :, :].transpose((1, 2, 3, 0)).astype(np.float32)
    return outVals
Пример #8
0
if softmax:
    estLayer = "est.pvp"
else:
    estLayer = "GroundTruthReconS1.pvp"

gtLayer = "GroundTruthBin.pvp"
startFrame = 600


if not os.path.exists(plotsDir):
    os.makedirs(plotsDir)
if not os.path.exists(estDir):
    os.makedirs(estDir)

visEst = pv.readpvpfile(outputDir + visEstLayer)
visGt = pv.readpvpfile(outputDir + visGtLayer)

est = pv.readpvpfile(outputDir + estLayer)
gt = pv.readpvpfile(outputDir + gtLayer)

estVals = est["values"]
(numFrames, gtNy, gtNx, gtNf) = est["values"].shape
gtVals = np.array(gt["values"].todense()).reshape((numFrames, gtNy, gtNx, gtNf))

#Find valid indices
validIdx = np.nonzero(visGt["values"])

outPlotVals = np.zeros(numFrames)

for f in range(numFrames):
Пример #9
0
visGtLayer = "GroundTruthDownsample.pvp"

if softmax:
    estLayer = "est.pvp"
else:
    estLayer = "GroundTruthReconS1.pvp"

gtLayer = "GroundTruthBin.pvp"
startFrame = 600

if not os.path.exists(plotsDir):
    os.makedirs(plotsDir)
if not os.path.exists(estDir):
    os.makedirs(estDir)

visEst = pv.readpvpfile(outputDir + visEstLayer)
visGt = pv.readpvpfile(outputDir + visGtLayer)

est = pv.readpvpfile(outputDir + estLayer)
gt = pv.readpvpfile(outputDir + gtLayer)

estVals = est["values"]
(numFrames, gtNy, gtNx, gtNf) = est["values"].shape
gtVals = np.array(gt["values"].todense()).reshape(
    (numFrames, gtNy, gtNx, gtNf))

#Find valid indices
validIdx = np.nonzero(visGt["values"])

outPlotVals = np.zeros(numFrames)
Пример #10
0
                    type=float,
                    default=0.0,
                    help='Threshold if desired for the activity values.')
parser.add_argument('--thresh_type',
                    type=str,
                    choices=['hard', 'soft'],
                    default='hard',
                    help='Type of threshold.')
parser.add_argument('--n_batch', type=str, default='all')
parser.add_argument('--binary', default=False, action='store_true')
args = parser.parse_args()

name = os.path.splitext(os.path.split(args.act_file_path)[1])[0]
if not check_if_dir_exists(args.output_dir): os.mkdir(args.output_dir)

act_data = pv.readpvpfile(args.act_file_path)['values']
if args.n_batch != 'all': act_data = act_data[:int(args.n_batch), ...]
act_data = bytescale(act_data)
if args.thresh != 0:
    act_data = threshold(act_data, args.thresh, mode=args.thresh_type)
act_data = np.uint8(act_data)

pbar = ProgressBar()
for i_sample, sample in pbar(enumerate(act_data)):
    for i_feature in range(sample.shape[-1]):
        feat_map = sample[..., i_feature]
        if args.binary:
            feat_map[feat_map > np.mean(feat_map)] = 255
            feat_map[feat_map < np.mean(feat_map)] = 0

        save_path = os.path.join(
Пример #11
0
                    help='Threshold type.')
parser.add_argument('--downsample_b',
                    type=int,
                    default=1,
                    help='How much to downsample the batch.')
parser.add_argument('--downsample_h',
                    type=int,
                    default=1,
                    help='How much to downsample the height.')
parser.add_argument('--downsample_w',
                    type=int,
                    default=1,
                    help='How much to downsample the width.')
args = parser.parse_args()

acts = pv.readpvpfile(args.act_file)['values']
acts = acts[::args.downsample_b, ::args.downsample_h, ::args.downsample_w, :]

acts = bytescale(acts) / 255.

print('min and max are {} and {}.'.format(np.amin(acts), np.amax(acts)))

if args.thresh != 0:
    acts = threshold(acts, args.thresh, mode=args.thresh_type)

acts = acts.flatten()

print('sparsity is {}.'.format(get_sparsity(acts)))

write_header = False if os.path.isfile(args.txt_file_name) else True
Пример #12
0
# Imports
import os, sys

lib_path = os.path.abspath("/projects/pcsri/PetaVision/SpikingOpenPV/python/")
sys.path.append(lib_path)
import pvtools as pv
import numpy as np
from matplotlib import pyplot as plt

width = 1
height = 1

# Load and scale data
dictionary = pv.readpvpfile("./V1ToInputError_W.pvp")['values'][0, 0, ]
dictionary -= np.min(dictionary)
dictionary /= np.max(dictionary)

print "Shape:", dictionary[0].shape
print "Values min/max:", dictionary.min(), dictionary.max()

for i in range(0, 128):
    feature = dictionary[i]

    feature = feature.reshape(64, 64)

    plt.imshow(feature)
    plt.savefig("./figures/feature_" + str(i) + ".png")