예제 #1
0
    def make_bads(self, minimal_duration=None):
        '''Combine bads (epochs and channels into the format used for artifact clean files. 
		Clean sections should at least be minimal duration long.
		'''
        if minimal_duration == None:
            minimal_duration = self.minimal_clean_duration
        self.temp_bads = []
        artifact_channels = [
            ch for ch in self.bc.bad_channels
            if ch.channel not in self.remove_ch
        ]
        self.bad_channels = artifact_channels[:]
        self.temp = artifact_channels + self.be.bad_epochs
        self.incorrect_auto_annotations = []
        self.ignorably_short = []
        for b in self.temp:
            if b.annotation == 'clean': continue
            if b.correct == 'incorrect' and self.corrected:
                self.incorrect_auto_annotations.append(b)
                continue
            if b.duration <= 30:
                self.ignorably_short.append(b)
                continue
            if b.annotation != 'artifact':
                print(
                    b, 'annotation: ' + b.annotation, 'dur:', b.duration,
                    'this annotation could be unwanted, (clean annotations are ignored)'
                )
            start = bad_epoch.Boundary(b.st_sample,
                                       boundary_type='start',
                                       visible=False)
            end = bad_epoch.Boundary(b.et_sample,
                                     boundary_type='end',
                                     visible=False)
            be = bad_epoch.Bad_epoch(start, end, 'artifact', b.coder, 'blue',
                                     b.pp_id, b.exp_type, b.block_st_sample,
                                     b.bid, self.epoch_id, False, 'correct',
                                     -9, '',
                                     b.block_st_sample + self.block_duration)
            self.temp_bads.append(be)
            self.epoch_id += 1
        self.temp_bads.sort()
        if len(self.temp_bads) == 0:
            self.bads = []
            return

        temp_bads = copy.deepcopy(self.temp_bads)
        temp_bads = cal.combine_overlaps(temp_bads)
        self.stiches = cal.stitch_artifacts(temp_bads, minimal_duration)
        self.stiched_stiches = cal.stitch_stiches(self.stiches)
        temp_bads = cal.combine_artifacts(temp_bads, self.stiched_stiches)
        cal.check_artifacts(artifacts=temp_bads,
                            fo=self.fo,
                            default='clean',
                            minimal_duration=minimal_duration)
        self.bads = temp_bads
예제 #2
0
	def toggle_complete_channel(self):
		if self.channel_mode == 'off': return False
		channel = self.ch_names[self.channel_mode_index]
		if channel in self.complete_bad_channel:
			self.complete_bad_channel.pop(self.complete_bad_channel.index(channel))
			for bc in self.bad_channels:
				if bc.channel == channel and bc.annotation == 'all':
					self.delete_bad_epoch(bc.epoch_id)
					break
		else:
			sboundary = bad_epoch.Boundary(0,'start')
			eboundary = bad_epoch.Boundary(self.data.shape[1],'end')
			self.bad_channels.append(bad_channel.Bad_channel(channel,start_boundary = sboundary,end_boundary = eboundary, pp_id = self.pp_id, coder = self.coder,exp_type = self.exp_type, bid = self.bid, block_st_sample = self.block_st_sample, epoch_id = self.make_bad_epoch_id(),offset = self.channel_mode_index *self.offset_value,annotation = 'all'))
			self.complete_bad_channel.append(channel)

		self.handle_plot(force_redraw=True)
예제 #3
0
 def xml2bad_epochs(self, load_data=True, multiplier=1, remove_clean=False):
     '''Create a list of bad epochs from xml file.'''
     self.bad_epochs = []
     if load_data: self.load_xml()
     for be_xml in self.artifacts.iter('bad_epoch'):
         # fetch subelements
         element_names = 'st_sample,et_sample,block_st_sample,block_et_sample,pp_id,bid,annotation,color,exp_type,coder,epoch_ids,coder,correct'.split(
             ',')
         element_values = []
         for e in element_names:
             if not be_xml.find(e) == None:
                 element_values.append(be_xml.find(e).text)
             else:
                 element_values.append('NA')
         st_sample, et_sample, block_st_sample, block_et_sample, pp_id, bid, annotation, color, exp_type, coder, epoch_ids, coder, correct = element_values
         if remove_clean and annotation == 'clean': continue
         if st_sample == 'NA' or et_sample == 'NA':
             continue
         epoch_id = be_xml.attrib['id']
         #create start and end boundary
         # print(st_sample,et_sample)
         start = bad_epoch.Boundary(x=int(int(st_sample) * multiplier),
                                    boundary_type='start',
                                    visible=False)
         end = bad_epoch.Boundary(x=int(int(et_sample) * multiplier),
                                  boundary_type='end',
                                  visible=False)
         # create bad epoch
         be = bad_epoch.Bad_epoch(start_boundary=start,
                                  end_boundary=end,
                                  annotation=annotation,
                                  color=color,
                                  pp_id=pp_id,
                                  exp_type=exp_type,
                                  bid=bid,
                                  block_st_sample=block_st_sample,
                                  epoch_id=epoch_id,
                                  visible=False,
                                  epoch_ids=epoch_ids,
                                  block_et_sample=block_et_sample,
                                  coder=coder,
                                  correct=correct)
         self.bad_epochs.append(be)
     # print('N bad epoch:',len(self.bad_epochs))
     return self.bad_epochs
예제 #4
0
 def xml2bad_epochs(self):
     '''Create a list of bad epochs from xml file.'''
     self.artifacts = []
     # print('Starting with list of',len(self.artifacts),' bad epochs')
     for be_xml in self.cnn_result.iter('bad_epoch'):
         # fetch subelements
         element_names = 'st_sample,et_sample,block_st_sample,pp_id,bid,annotation,color,exp_type,coder,note,correct,corrector,perc_clean'.split(
             ',')
         element_values = []
         for e in element_names:
             if be_xml.find(e) != None:
                 element_values.append(be_xml.find(e).text)
             else:
                 element_values.append('NA')
         st_sample, et_sample, block_st_sample, pp_id, bid, annotation, color, exp_type, coder, note, correct, corrector, perc_clean = element_values
         epoch_id = be_xml.attrib['id']
         #create start and end boundary
         start = bad_epoch.Boundary(x=int(st_sample),
                                    boundary_type='start',
                                    visible=False)
         end = bad_epoch.Boundary(x=int(et_sample),
                                  boundary_type='end',
                                  visible=False)
         # create bad epoch
         be = bad_epoch.Bad_epoch(start_boundary=start,
                                  end_boundary=end,
                                  annotation=annotation,
                                  color=color,
                                  pp_id=pp_id,
                                  exp_type=exp_type,
                                  bid=bid,
                                  block_st_sample=block_st_sample,
                                  epoch_id=epoch_id,
                                  visible=False,
                                  correct=correct,
                                  perc_clean=perc_clean,
                                  coder=self.cnn_model_name)
         self.artifacts.append(be)
     # print('N bad epoch:',len(self.artifacts),'artifacts')
     return self.artifacts
예제 #5
0
	def handle_end(self):
		'''Create a end boundary, and either add this to closest start boundary or create new epoch.'''
		if self.event.xdata < self.data.shape[1]: x = self.event.xdata
		else: x = self.data.shape[1] - 1 
		boundary = bad_epoch.Boundary(x,'end')
		if self.channel_mode == 'on': self.channel_boundaries.append(boundary)
		else: self.boundaries.append(boundary)
		be = self.find_completion_bad_epoch(boundary_type = 'start')
		if be:
			print('combining boundaries')
			be.set_end(boundary)
			if self.channel_mode == 'on':
				i = self.ch_names.index(be.channel)
				be.plot(channel_data = self.data[i],offset = i * self.offset_value)
		else:
			if self.channel_mode == 'on':
				channel = self.ch_names[self.channel_mode_index]
				self.bad_channels.append(bad_channel.Bad_channel(channel,end_boundary = boundary, pp_id = self.pp_id, exp_type = self.exp_type, bid = self.bid, block_st_sample = self.block_st_sample, epoch_id = self.make_bad_epoch_id(),offset = self.channel_mode_index * self.offset_value,annotation = self.default_annotation_channel))
			else:
				self.bad_epochs.append(bad_epoch.Bad_epoch(end_boundary = boundary, pp_id = self.pp_id, exp_type = self.exp_type, bid = self.bid, block_st_sample = self.block_st_sample, epoch_id = self.make_bad_epoch_id(),annotation = self.default_annotation))
예제 #6
0
    def make_bad_epochs(self, minimal_duration=2000):
        '''Create a list of bad epochs from xml file.'''
        self.bad_epochs = []
        ws = self.w.windows['sf1000']
        previous_annotation = ''
        perc_clean = []
        for index in range(self.pred_class.shape[0]):
            annotation = 'clean' if self.pred_class_adj[
                index] == 0 else 'artifact'
            perc_clean.append(self.pred_perc[index][0])
            if previous_annotation == '':
                # first index
                st_sample = ws.start_snippets[index]
                previous_annotation = annotation

            elif previous_annotation != annotation or index == self.pred_class.shape[
                    0] - 1:
                # start of a new bad epoch, handle creation previous bad epoch
                b = self.b
                if index == self.pred_class.shape[0] - 1:
                    if b.start_marker_missing or b.end_marker_missing:
                        et_sample = self.pred_class.shape[0] * 10
                    else:
                        et_sample = b.duration_sample
                else:
                    et_sample = (ws.start_snippets[index - 1] +
                                 ws.end_snippets[index - 1]) / 2
                start = bad_epoch.Boundary(x=int(st_sample),
                                           boundary_type='start',
                                           visible=False)
                end = bad_epoch.Boundary(x=int(et_sample),
                                         boundary_type='end',
                                         visible=False)
                epoch_id = get_cnn_epoch_id(increment=True)
                perc_clean_last = perc_clean.pop(-1)
                perc_clean = ' '.join([
                    str(np.mean(np.array(perc_clean))),
                    str(np.std(np.array(perc_clean)))
                ])

                # create bad epoch
                be = bad_epoch.Bad_epoch(start_boundary=start,
                                         end_boundary=end,
                                         annotation=previous_annotation,
                                         pp_id=b.pp_id,
                                         exp_type=b.exp_type,
                                         bid=b.bid,
                                         block_st_sample=b.st_sample,
                                         epoch_id=epoch_id,
                                         coder=self.cnn_model_name,
                                         correct='unk',
                                         visible=False,
                                         perc_clean=perc_clean)

                self.bad_epochs.append(be)
                st_sample = et_sample + 1
                previous_annotation = annotation
                perc_clean = [perc_clean_last]
        if len(self.bad_epochs) > 1: self.combine_bad_epochs(minimal_duration)
        else: self.artifacts = copy.deepcopy(self.bad_epochs)

        self.calc_clean_artifact_samples()
        self.artifacts2indices()