Exemple #1
0
    def output_and_loss(self, h_block, t_block):
        batch, units, length = h_block.shape
        # shape : (batch * sequence_length, num_classes)
        logits_flat = seq_func(self.affine,
                               h_block,
                               reconstruct_shape=False)
        rebatch, _ = logits_flat.shape
        concat_t_block = t_block.view(rebatch)
        weights = (concat_t_block >= 1).float()
        n_correct, n_total = utils.accuracy(logits_flat,
                                            concat_t_block,
                                            ignore_index=0)

        # shape : (batch * sequence_length, num_classes)
        log_probs_flat = F.log_softmax(logits_flat,
                                       dim=-1)
        # shape : (batch * max_len, 1)
        targets_flat = t_block.view(-1, 1).long()

        if self.label_smoothing is not None and self.label_smoothing > 0.0:
            num_classes = logits_flat.size(-1)
            smoothing_value = self.label_smoothing / (num_classes - 1)
            # Fill all the correct indices with 1 - smoothing value.
            one_hot_targets = input_like(log_probs_flat,
                                         smoothing_value)
            smoothed_targets = one_hot_targets.scatter_(-1,
                                                        targets_flat,
                                                        1.0 - self.label_smoothing)
            negative_log_likelihood_flat = - log_probs_flat * smoothed_targets
            negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1,
                                                                            keepdim=True)
        else:
            # Contribution to the negative log likelihood only comes from the exact indices
            # of the targets, as the target distributions are one-hot. Here we use torch.gather
            # to extract the indices of the num_classes dimension which contribute to the loss.
            # shape : (batch * sequence_length, 1)
            negative_log_likelihood_flat = - torch.gather(log_probs_flat,
                                                          dim=1,
                                                          index=targets_flat)

        # shape : (batch, sequence_length)
        negative_log_likelihood = negative_log_likelihood_flat.view(rebatch)
        negative_log_likelihood = negative_log_likelihood * weights
        # shape : (batch_size,)
        loss = negative_log_likelihood.sum() / (weights.sum() + 1e-13)
        stats = utils.Statistics(loss=utils.to_cpu(loss) * n_total,
                                 n_correct=utils.to_cpu(n_correct),
                                 n_words=n_total)
        return loss, stats
Exemple #2
0
    def run(self, need_onnx=False):
        utils.makedirs(self.model_dir)
        param_filename = os.path.join(self.model_dir, 'params.npz')
        params_loaded = self.is_up_to_date(param_filename)
        if params_loaded:
            chainer.serializers.load_npz(param_filename, self.model)

        chainer.config.train = False
        inputs = utils.as_list(self.model.inputs())
        if need_onnx:
            need_onnx = self.gen_onnx_model(inputs)

        self.model.to_gpu()
        gpu_inputs = utils.to_gpu(inputs)
        gpu_outputs = self.model(*gpu_inputs)
        gpu_outputs = utils.as_list(gpu_outputs)
        outputs = utils.to_cpu(gpu_outputs)
        self.inputs = inputs
        self.outputs = outputs

        if need_onnx:
            self.gen_onnx_test(inputs, outputs)

        if not params_loaded:
            chainer.serializers.save_npz(param_filename, self.model)

        return inputs, outputs
Exemple #3
0
 def forward(self, x):
     layer_outputs = []
     yolo_outputs = []
     x, backbone_outputs = self.backbone(x)
     layer_outputs.extend(backbone_outputs)
     for module in self.cnn_list:
         x = module(x)
         layer_outputs.append(x)
     x, layer_loss = self.first_yolo(x)
     layer_outputs.append(x)
     yolo_outputs.append(x)
     route_output = layer_outputs[-4]
     x = torch.cat([route_output], 1)
     layer_outputs.append(x)
     x = self.second_yolo_conv(x)
     layer_outputs.append(x)
     x = self.upsample(x)
     layer_outputs.append(x)
     x = torch.cat([x, layer_outputs[8]], 1)
     layer_outputs.append(x)
     x = self.second_yolo_conv2(x)
     layer_outputs.append(x)
     x = self.second_yolo_conv3(x)
     layer_outputs.append(x)
     x, layer2_loss = self.second_yolo(x)
     layer_outputs.append(x)
     yolo_outputs.append(x)
     return to_cpu(torch.cat(yolo_outputs, 1))
Exemple #4
0
    def run_first(self, task, inputs, sample_outputs):
        self.batch_size = inputs[0].shape[0]
        onnx_filename = task.get_onnx_file()
        with open(onnx_filename, 'rb') as f:
            onnx_proto = f.read()

        logger = tensorrt.Logger()
        # logger = tensorrt.Logger(tensorrt.Logger.Severity.INFO)
        builder = tensorrt.Builder(logger)
        builder.max_batch_size = self.batch_size
        network = builder.create_network()
        parser = tensorrt.OnnxParser(network, logger)
        parser.parse(onnx_proto)
        engine = builder.build_cuda_engine(network)
        self.context = engine.create_execution_context()

        assert len(inputs) + len(sample_outputs) == engine.num_bindings
        for i, input in enumerate(inputs):
            assert self.batch_size == input.shape[0]
            assert input.shape[1:] == engine.get_binding_shape(i)
        for i, output in enumerate(sample_outputs):
            assert self.batch_size == output.shape[0]
            i += len(inputs)
            assert output.shape[1:] == engine.get_binding_shape(i)

        self.inputs = utils.to_gpu(inputs)
        self.outputs = []
        for output in sample_outputs:
            self.outputs.append(cupy.zeros_like(output))
        self.bindings = [a.data.ptr for a in self.inputs]
        self.bindings += [a.data.ptr for a in self.outputs]
        self.run_task()
        return utils.to_cpu(self.outputs)
Exemple #5
0
 def run_first(self, task, inputs, sample_outputs):
     self.model = task.model
     self.model.to_gpu()
     self.inputs = utils.to_gpu(inputs)
     gpu_outputs = self.run_task()
     gpu_outputs = utils.as_list(gpu_outputs)
     outputs = utils.to_cpu(gpu_outputs)
     return outputs
def visualize_delta(i, var_dict, grad_dict):

	for n in [k for k in grad_dict.keys() if 'rnn' in k]:
		fig, ax = plt.subplots(1,2, figsize=[16,8])
		im = ax[0].imshow(to_cpu(par['learning_rate']*grad_dict[n]), aspect='auto')
		fig.colorbar(im, ax=ax[0])
		im = ax[1].imshow(to_cpu(var_dict[n]), aspect='auto')
		fig.colorbar(im, ax=ax[1])

		fig.suptitle(n)
		ax[0].set_title('Gradient')
		ax[1].set_title('Variable')

		plt.savefig('./savedir/{}_delta_{}_iter{:0>6}.png'.format(par['savefn'], n, i), bbox_inches='tight')
		if par['save_pdfs']:
			plt.savefig('./savedir/{}_delta_{}_iter{:0>6}.pdf'.format(par['savefn'], n, i), bbox_inches='tight')
		plt.clf()
		plt.close()
Exemple #7
0
 def forward(self, x):
     yolo_outputs = []
     layer_outputs = []
     x, darknet_outputs = self.darknet(x)
     layer_outputs.extend(darknet_outputs)
     for i, module in enumerate(self.first_path):
         x = module(x)
         layer_outputs.append(x)
     x, layer_loss = self.first_yolo(x)
     layer_outputs.append(x)
     yolo_outputs.append(x)
     print(f"Yolo dim {x.size()}")
     route_output = layer_outputs[-4]
     x = torch.cat([route_output], 1)
     print(f"First route dim {x.size()}")
     layer_outputs.append(x)
     x = self.conv1(x)
     layer_outputs.append(x)
     x = self.scale1(x)
     layer_outputs.append(x)
     x = torch.cat([x, layer_outputs[61]], 1)
     print(f"Second route dim {x.size()}")
     layer_outputs.append(x)
     for i, module in enumerate(self.second_yolo_conv_blocks):
         x = module(x)
         layer_outputs.append(x)
     x, layer2_loss = self.second_yolo(x)
     layer_outputs.append(x)
     yolo_outputs.append(x)
     print(f"Yolo 2 dim {x.size()}")
     route_output = layer_outputs[-4]
     x = torch.cat([route_output], 1)
     print(f"Third route dim {x.size()}")
     layer_outputs.append(x)
     x = self.conv2(x)
     layer_outputs.append(x)
     x = self.scale2(x)
     layer_outputs.append(x)
     x = torch.cat([x, layer_outputs[36]], 1)
     print(f"Fourth route dim {x.size()}")
     layer_outputs.append(x)
     for i, module in enumerate(self.third_yolo_conv_blocks):
         x = module(x)
         layer_outputs.append(x)
     x, layer3_loss = self.third_yolo(x)
     yolo_outputs.append(x)
     #print(f"Yolo output: {x.shape}")
     layer_outputs.append(x)
     #print(f"YOLO before: {yolo_outputs}")
     return to_cpu(torch.cat(yolo_outputs, 1))
def activity_plots(i, model):

	V_min = to_cpu(model.v[:,0,:,:].T.min())

	fig, ax = plt.subplots(4,1, figsize=(15,11), sharex=True)
	ax[0].imshow(to_cpu(model.input_data[:,0,:].T), aspect='auto')
	ax[0].set_title('Input Data')
	ax[1].imshow(to_cpu((model.input_data[:,0,:] @ model.eff_var['W_in']).T), aspect='auto')
	ax[1].set_title('Projected Inputs')
	ax[2].imshow(to_cpu(model.z[:,0,:].T), aspect='auto')
	ax[2].set_title('Spiking')
	ax[3].imshow(to_cpu(model.v[:,0,0,:].T), aspect='auto', clim=(V_min,0.))
	ax[3].set_title('Membrane Voltage ($(V_r = {:5.3f}), {:5.3f} \\leq V_j^t \\leq 0$)'.format(par[par['spike_model']]['V_r'].min(), V_min))

	ax[0].set_ylabel('Input Neuron')
	ax[1].set_ylabel('Hidden Neuron')
	ax[2].set_ylabel('Hidden Neuron')
	ax[3].set_ylabel('Hidden Neuron')

	plt.savefig('./savedir/{}_activity_iter{:0>6}.png'.format(par['savefn'], i), bbox_inches='tight')
	if par['save_pdfs']:
		plt.savefig('./savedir/{}_activity_iter{:0>6}.pdf'.format(par['savefn'], i), bbox_inches='tight')
	plt.clf()
	plt.close()
def clopath_update_plot(it, cl_in, cl_rnn, gr_in, gr_rnn):

	update_list = to_cpu([cl_in, cl_rnn, gr_in, gr_rnn])
	update_name = ['Clopath W_in', 'Clopath W_rnn', 'Grad W_in', 'Grad W_rnn']

	fig, ax = plt.subplots(2,2, figsize=[12,10])
	for i, j in itertools.product([0,1], [0,1]):
		im = ax[i,j].imshow(update_list[i+2*j], aspect='auto')
		ax[i,j].set_title(update_name[i+2*j])
		fig.colorbar(im, ax=ax[i,j])

	plt.savefig('./savedir/{}_clopath{:0>6}.png'.format(par['savefn'], it), bbox_inches='tight')
	if par['save_pdfs']:
		plt.savefig('./savedir/{}_clopath{:0>6}.pdf'.format(par['savefn'], it), bbox_inches='tight')
	plt.clf()
	plt.close()
Exemple #10
0
 def forward(self, x, targets=None):
     img_dim = x.shape[2]
     loss = 0
     layer_outputs, yolo_outputs = [], []
     for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
         if module_def["type"] in ["convolutional", "upsample", "maxpool"]:
             x = module(x)
         elif module_def["type"] == "route":
             x = torch.cat([layer_outputs[int(layer_i)] for layer_i in module_def["layers"].split(",")], 1)
         elif module_def["type"] == "shortcut":
             layer_i = int(module_def["from"])
             x = layer_outputs[-1] + layer_outputs[layer_i]
         elif module_def["type"] == "yolo":
             x, layer_loss = module[0](x, targets, img_dim)
             loss += layer_loss
             yolo_outputs.append(x)
         layer_outputs.append(x)
     yolo_outputs = to_cpu(torch.cat(yolo_outputs, 1))
     return yolo_outputs if targets is None else (loss, yolo_outputs)
def plot_grads_and_epsilons(it, trial_info, model, h, eps_v_rec, eps_w_rec, eps_ir_rec):

	h = to_cpu(h[:,0,:])
	eps_v_rec = to_cpu(eps_v_rec)
	eps_w_rec = to_cpu(eps_w_rec)
	eps_ir_rec = to_cpu(eps_ir_rec)

	V_min = to_cpu(model.v[:,0,:,:].T.min())

	fig, ax = plt.subplots(8, 1, figsize=[16,22], sharex=True)

	ax[0].imshow(trial_info['neural_input'][:,0,:].T, aspect='auto')
	ax[0].set_title('Input Data')
	ax[0].set_ylabel('Input Neuron')

	ax[1].imshow(to_cpu(model.z[:,0,:].T), aspect='auto')
	ax[1].set_title('Spiking')
	ax[1].set_ylabel('Hidden Neuron')

	ax[2].plot(to_cpu(model.z[:,0,0]), label='Spike')
	ax[2].plot(to_cpu(model.v[:,0,0,0]) * -10, label='- Voltage x 10')
	ax[2].plot(h[:,0], label='Gradient')
	ax[2].legend()
	ax[2].set_title('Single Neuron')

	ax[3].imshow(h.T, aspect='auto', clim=(0, par['gamma_psd']))
	ax[3].set_title('Pseudogradient (${} \\leq h \\leq {}$) | Sum: $h = {:6.3f}$'.format(0., par['gamma_psd'], np.sum(h)))
	ax[3].set_ylabel('Hidden Neuron')

	ax[4].imshow(to_cpu(model.v[:,0,0,:].T), aspect='auto')
	ax[4].set_title('Membrane Voltage ($(V_r = {:5.3f}), {:5.3f} \\leq V_j^t \\leq 0$)'.format(par[par['spike_model']]['V_r'].min(), V_min))
	ax[4].set_ylabel('Hidden Neuron')

	ax[5].imshow(eps_v_rec.T, aspect='auto')
	ax[5].set_title('Voltage Eligibility (${:6.3f} \\leq e_{{v,rec}} \\leq {:6.3f}$)'.format(eps_v_rec.min(), eps_v_rec.max()))
	ax[5].set_ylabel('Hidden Neuron')

	ax[6].imshow(eps_w_rec.T, aspect='auto')
	ax[6].set_title('Adaptation Eligibility (${:6.3f} \\leq e_{{w,rec}} \\leq {:6.3f}$)'.format(eps_w_rec.min(), eps_w_rec.max()))
	ax[6].set_ylabel('Hidden Neuron')

	ax[7].imshow(eps_ir_rec.T, aspect='auto')
	ax[7].set_title('Current Eligibility (${:6.3f} \\leq e_{{ir,rec}} \\leq {:6.3f}$)'.format(eps_ir_rec.min(), eps_ir_rec.max()))
	ax[7].set_ylabel('Hidden Neuron')

	# ax[0,1].imshow(trial_info['neural_input'][:,0,:].T, aspect='auto')
	# ax[1,1].imshow(to_cpu(model.z[:,0,:].T), aspect='auto')
	# ax[2,1].imshow(h.T, aspect='auto', clim=(0, par['gamma_psd']))
	# ax[3,1].imshow(to_cpu(model.v[:,0,0,:].T), aspect='auto')

	# ax[4,1].imshow(eps_v_rec.T, aspect='auto')
	# ax[4,1].set_xlabel('Time')

	# for i in range(4):
	# 	ax[i,0].set_xticks([])

	# for i in range(5):
	# 	ax[i,1].set_xlim(200,350)

	plt.savefig('./savedir/{}_epsilon_iter{:0>6}.png'.format(par['savefn'], it), bbox_inches='tight')
	if par['save_pdfs']:
		plt.savefig('./savedir/{}_epsilon_iter{:0>6}.pdf'.format(par['savefn'], it), bbox_inches='tight')
	plt.clf()
	plt.close()
Exemple #12
0
    runs = 8

    c_all = []
    d_all = []
    v_all = []
    s_all = []

    # Run a couple batches to generate sufficient data points
    for i in range(runs):
        print('R:{:>2}'.format(i), end='\r')
        trial_info = stim.make_batch(var_delay=False)
        model.run_model(trial_info, testing=True)

        c_all.append(trial_info['sample_cat'])
        d_all.append(trial_info['sample_dir'])
        v_all.append(to_cpu(model.v))
        s_all.append(to_cpu(model.s))

    del model
    del stim

    batch_size = runs * par['batch_size']

    c = np.concatenate(c_all, axis=0)
    d = np.concatenate(d_all, axis=0)
    v = np.concatenate(v_all, axis=1)
    s = np.concatenate(s_all, axis=1)

    print('Model run complete.')
    print('Performing SVM decoding on {} trials.\n'.format(batch_size))
Exemple #13
0
    def forward(self, x, targets=None, img_dim=None):

        # Tensors for cuda support
        FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
        LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
        ByteTensor = torch.cuda.ByteTensor if x.is_cuda else torch.ByteTensor

        self.img_dim = img_dim
        num_samples = x.size(0)
        grid_size = x.size(2)

        prediction = (x.view(num_samples, self.num_anchors,
                             self.num_classes + 5, grid_size,
                             grid_size).permute(0, 1, 3, 4, 2).contiguous())

        # Get outputs
        x = torch.sigmoid(prediction[..., 0])  # Center x
        y = torch.sigmoid(prediction[..., 1])  # Center y
        w = prediction[..., 2]  # Width
        h = prediction[..., 3]  # Height
        pred_conf = torch.sigmoid(prediction[..., 4])  # Conf
        pred_cls = torch.sigmoid(prediction[..., 5:])  # Cls pred.

        # If grid size does not match current we compute new offsets
        if grid_size != self.grid_size:
            self.compute_grid_offsets(grid_size, cuda=x.is_cuda)

        # Add offset and scale with anchors
        pred_boxes = FloatTensor(prediction[..., :4].shape)
        pred_boxes[..., 0] = x.data + self.grid_x
        pred_boxes[..., 1] = y.data + self.grid_y
        pred_boxes[..., 2] = torch.exp(w.data) * self.anchor_w
        pred_boxes[..., 3] = torch.exp(h.data) * self.anchor_h

        output = torch.cat(
            (
                pred_boxes.view(num_samples, -1, 4) * self.stride,
                pred_conf.view(num_samples, -1, 1),
                pred_cls.view(num_samples, -1, self.num_classes),
            ),
            -1,
        )

        if targets is None:
            return output, 0
        else:
            iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf = build_targets(
                pred_boxes=pred_boxes,
                pred_cls=pred_cls,
                target=targets,
                anchors=self.scaled_anchors,
                ignore_thres=self.ignore_thres,
            )

            # Loss : Mask outputs to ignore non-existing objects (except with conf. loss)
            loss_x = self.mse_loss(x[obj_mask], tx[obj_mask])
            loss_y = self.mse_loss(y[obj_mask], ty[obj_mask])
            loss_w = self.mse_loss(w[obj_mask], tw[obj_mask])
            loss_h = self.mse_loss(h[obj_mask], th[obj_mask])
            loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask])
            loss_conf_noobj = self.bce_loss(pred_conf[noobj_mask],
                                            tconf[noobj_mask])
            loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj
            loss_cls = self.bce_loss(pred_cls[obj_mask], tcls[obj_mask])
            total_loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls

            # Metrics
            cls_acc = 100 * class_mask[obj_mask].mean()
            conf_obj = pred_conf[obj_mask].mean()
            conf_noobj = pred_conf[noobj_mask].mean()
            conf50 = (pred_conf > 0.5).float()
            iou50 = (iou_scores > 0.5).float()
            iou75 = (iou_scores > 0.75).float()
            detected_mask = conf50 * class_mask * tconf
            precision = torch.sum(
                iou50 * detected_mask) / (conf50.sum() + 1e-16)
            recall50 = torch.sum(
                iou50 * detected_mask) / (obj_mask.sum() + 1e-16)
            recall75 = torch.sum(
                iou75 * detected_mask) / (obj_mask.sum() + 1e-16)

            self.metrics = {
                "loss": to_cpu(total_loss).item(),
                "x": to_cpu(loss_x).item(),
                "y": to_cpu(loss_y).item(),
                "w": to_cpu(loss_w).item(),
                "h": to_cpu(loss_h).item(),
                "conf": to_cpu(loss_conf).item(),
                "cls": to_cpu(loss_cls).item(),
                "cls_acc": to_cpu(cls_acc).item(),
                "recall50": to_cpu(recall50).item(),
                "recall75": to_cpu(recall75).item(),
                "precision": to_cpu(precision).item(),
                "conf_obj": to_cpu(conf_obj).item(),
                "conf_noobj": to_cpu(conf_noobj).item(),
                "grid_size": grid_size,
            }

            return output, total_loss
Exemple #14
0
def run_SVM_analysis():

    print('\nLoading and running model.')
    model = Model()
    stim = Stimulus()
    runs = 8

    m_all = []
    v_all = []
    s_all = []

    for i in range(runs):
        print('R:{:>2}'.format(i), end='\r')
        trial_info = stim.make_batch(var_delay=False)
        model.run_model(trial_info)

        m_all.append(trial_info['sample_cat'])
        v_all.append(to_cpu(model.v))
        s_all.append(to_cpu(model.s))

    del model
    del stim

    batch_size = runs * par['batch_size']

    m = np.concatenate(m_all, axis=0)
    v = np.concatenate(v_all, axis=1)
    s = np.concatenate(s_all, axis=1)

    print('Performing SVM decoding on {} trials.\n'.format(batch_size))
    # Initialize linear classifier
    args = {
        'kernel': 'linear',
        'decision_function_shape': 'ovr',
        'shrinking': False,
        'tol': 1e-3
    }
    lin_clf_v = SVC(**args)
    lin_clf_s = SVC(**args)

    score_v = np.zeros([par['num_time_steps']])
    score_s = np.zeros([par['num_time_steps']])

    # Choose training and testing indices
    train_pct = 0.75
    num_train_inds = int(batch_size * train_pct)

    shuffled = np.random.permutation(batch_size)
    train_inds = shuffled[:num_train_inds]
    test_inds = shuffled[num_train_inds:]

    for t in range(end_dead_time, par['num_time_steps']):
        print('T:{:>4}'.format(t), end='\r')

        lin_clf_v.fit(v[t, train_inds, :], m[train_inds])
        lin_clf_s.fit(s[t, train_inds, :], m[train_inds])

        dec_v = lin_clf_v.predict(v[t, test_inds, :])
        dec_s = lin_clf_s.predict(s[t, test_inds, :])

        score_v[t] = np.mean(m[test_inds] == dec_v)
        score_s[t] = np.mean(m[test_inds] == dec_s)

    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.plot(score_v, c=[241 / 255, 153 / 255, 1 / 255], label='Voltage')
    ax.plot(score_s, c=[58 / 255, 79 / 255, 65 / 255], label='Syn. Eff.')

    ax.axhline(0.5, c='k', ls='--')
    ax.axvline(trial_info['timings'][0, 0], c='k', ls='--')
    ax.axvline(trial_info['timings'][1, 0], c='k', ls='--')

    ax.set_title('SVM Decoding of Sample Category')
    ax.set_xlabel('Time')
    ax.set_ylabel('Decoding Accuracy')
    ax.set_yticks([0., 0.25, 0.5, 0.75, 1.])
    ax.grid()
    ax.set_xlim(0, par['num_time_steps'] - 1)

    ax.legend()
    plt.savefig('./analysis/svm_decoding.png', bbox_inches='tight')

    print('SVM decoding complete.')
Exemple #15
0
 def export(self):
     self.results = self._export()
     if self.label != "":
         return { self.label + "_" + k : utils.to_cpu(v) for k, v in self.results.items() }
     else:
         return { k : utils.to_cpu(v) for k, v in self.results.items() }
def output_behavior(it, trial_info, y):


	if par['task'] == 'dmswitch':
		task_info = trial_info['task']
		task_names = ['dms', 'dmc']
		num_tasks = 2
		height = 14
	else:
		task_names = [par['task']]
		num_tasks = 1
		height = 8

	match_info, timings = trial_info['match'], trial_info['timings']

	fig, ax = plt.subplots(2*num_tasks, 1, figsize=[16,height], sharex=True)

	for task in range(num_tasks):

		if par['task'] == 'dmswitch':
			task_mask = (task_info == task)
			match = np.where(np.logical_and(task_mask, match_info))[0]
			nonmatch = np.where(np.logical_and(task_mask, np.logical_not(match_info)))[0]

		else:
			match = np.where(match_info)[0]
			nonmatch = np.where(np.logical_not(match_info))[0]

		time = np.arange(par['num_time_steps'])

		y_match        = to_cpu(cp.mean(y[:,match,:], axis=1))
		y_nonmatch     = to_cpu(cp.mean(y[:,nonmatch,:], axis=1))

		y_match_err    = to_cpu(cp.std(y[:,match,:], axis=1))
		y_nonmatch_err = to_cpu(cp.std(y[:,nonmatch,:], axis=1))

		c_res = [[60/255, 21/255, 59/255, 1.0], [164/255, 14/255, 76/255, 1.0], [77/255, 126/255, 168/255, 1.0]]
		c_err = [[60/255, 21/255, 59/255, 0.5], [164/255, 14/255, 76/255, 0.5], [77/255, 126/255, 168/255, 0.5]]

		for i, (r, e) in enumerate(zip([y_match, y_nonmatch], [y_match_err, y_nonmatch_err])):
			j = 2*task + i

			err_low  = r - e
			err_high = r + e

			ax[j].fill_between(time, err_low[:,0], err_high[:,0], color=c_err[0])
			ax[j].fill_between(time, err_low[:,1], err_high[:,1], color=c_err[1])
			ax[j].fill_between(time, err_low[:,2], err_high[:,2], color=c_err[2])

			ax[j].plot(time, r[:,0], c=c_res[0], label='Fixation')
			ax[j].plot(time, r[:,1], c=c_res[1], label='Cat. 1 / Match')
			ax[j].plot(time, r[:,2], c=c_res[2], label='Cat. 2 / Non-Match')

			for t in range(timings.shape[0]):
				ax[j].axvline(timings[t,:].min(), c='k', ls='--')

	fig.suptitle('Output Neuron Behavior')
	for task in range(num_tasks):
		j = task*2
		ax[j].set_title('Task: {} | Cat. 1 / Match Trials'.format(task_names[task].upper()))
		ax[j+1].set_title('Task: {} | Cat. 2 / Non-Match Trials'.format(task_names[task].upper()))

	for j in range(2*num_tasks):
		ax[j].legend(loc="upper left")
		ax[j].set_ylabel('Mean Response')
	ax[0].set_xlim(time.min(), time.max())
	ax[2*num_tasks-1].set_xlabel('Time')

	plt.savefig('./savedir/{}_outputs_iter{:0>6}.png'.format(par['savefn'], it), bbox_inches='tight')
	if par['save_pdfs']:
		plt.savefig('./savedir/{}_outputs_iter{:0>6}.pdf'.format(par['savefn'], it), bbox_inches='tight')
	plt.clf()
	plt.close()
Exemple #17
0
    runs = 8
    z_bin = 20 // par['dt']

    c_all = []
    d_all = []
    z_all = []

    # Run a couple batches to generate sufficient data points
    for i in range(runs):
        print('R:{:>2}'.format(i), end='\r')
        trial_info = stim.make_batch(var_delay=False)
        model.run_model(trial_info, testing=True)

        c_all.append(trial_info['sample_cat'])
        d_all.append(trial_info['sample_dir'])
        z_all.append(to_cpu(model.z))

    del model
    del stim

    batch_size = runs * par['batch_size']

    c = np.concatenate(c_all, axis=0)
    d = np.concatenate(d_all, axis=0)
    z = np.concatenate(z_all, axis=1)

    print('Model run complete.')
    print('Performing ROC decoding on {} trials.\n'.format(batch_size))

    local_spikes = np.zeros([par['num_time_steps'], par['n_hidden']])
    for t in range(par['num_time_steps'] - z_bin):