예제 #1
0
class Experiment():
    def __init__(self, app, params):
        self.app = app
        self.params = params
        self.main_timer = None
        self.stream = None
        self.thread = None
        self.catch_channels_trouble = True
        self.mock_signals_buffer = None
        self.activate_trouble_catching = False
        self.main = None
        self.restart()
        pass

    def update(self):
        """
        Experiment main update action
        :return: None
        """
        # get next chunk
        # self.stream is a ChannelsSelector instance!
        chunk, other_chunk, timestamp = self.stream.get_next_chunk() if self.stream is not None else (None, None)
        if chunk is not None and self.main is not None:

            # update and collect current samples
            for i, signal in enumerate(self.signals):
                signal.update(chunk)
                # self.current_samples[i] = signal.current_sample

            # push current samples
            sample = np.vstack([np.array(signal.current_chunk) for signal in self.signals]).T.tolist()
            self.signals_outlet.push_chunk(sample)

            # record data
            if self.main.player_panel.start.isChecked():
                if self.params['bShowSubjectWindow']:
                    self.subject.figure.update_reward(self.reward.get_score())
                if self.samples_counter < self.experiment_n_samples:
                    chunk_slice = slice(self.samples_counter, self.samples_counter + chunk.shape[0])
                    self.raw_recorder[chunk_slice] = chunk[:, :self.n_channels]
                    self.raw_recorder_other[chunk_slice] = other_chunk
                    self.timestamp_recorder[chunk_slice] = timestamp
                    # for s, sample in enumerate(self.current_samples):
                    self.signals_recorder[chunk_slice] = sample
                    self.samples_counter += chunk.shape[0]

                    # catch channels trouble

                    if self.activate_trouble_catching:
                        if self.samples_counter > self.seconds:
                            self.seconds += 2 * self.freq
                            raw_std_new = np.std(self.raw_recorder[int(self.samples_counter - self.freq):
                                                                   self.samples_counter], 0)
                            if self.raw_std is None:
                                self.raw_std = raw_std_new
                            else:
                                if self.catch_channels_trouble and any(raw_std_new > 7 * self.raw_std):
                                    w = ChannelTroubleWarning(parent=self.main)
                                    w.pause_clicked.connect(self.handle_channels_trouble_pause)
                                    w.closed.connect(
                                        lambda: self.enable_trouble_catching(w)
                                    )
                                    w.show()
                                    self.catch_channels_trouble = False
                                self.raw_std = 0.5 * raw_std_new + 0.5 * self.raw_std

            # redraw signals and raw data
            self.main.redraw_signals(sample, chunk, self.samples_counter, self.current_protocol_n_samples)
            if self.params['bPlotSourceSpace']:
                self.source_space_window.update_protocol_state(chunk)

            # redraw protocols
            is_half_time = self.samples_counter >= self.current_protocol_n_samples // 2
            current_protocol = self.protocols_sequence[self.current_protocol_index]
            if current_protocol.mock_previous > 0:
                samples = [signal.current_chunk[-1] for signal in current_protocol.mock]
            elif current_protocol.mock_samples_file_path is not None:
                samples = self.mock_signals_buffer[self.samples_counter % self.mock_signals_buffer.shape[0]]
            else:
                samples = sample[-1]

            # self.reward.update(samples[self.reward.signal_ind], chunk.shape[0])
            if (self.main.player_panel.start.isChecked() and
                    self.samples_counter - chunk.shape[0] < self.experiment_n_samples):
                self.reward_recorder[
                self.samples_counter - chunk.shape[0]:self.samples_counter] = self.reward.get_score()

            if self.main.player_panel.start.isChecked():
                # subject update
                if self.params['bShowSubjectWindow']:
                    mark = self.subject.update_protocol_state(samples, self.reward, chunk_size=chunk.shape[0],
                                                              is_half_time=is_half_time)
                else:
                    mark = None
                self.mark_recorder[self.samples_counter - chunk.shape[0]:self.samples_counter] = 0
                self.mark_recorder[self.samples_counter - 1] = int(mark or 0)

            # change protocol if current_protocol_n_samples has been reached
            if self.samples_counter >= self.current_protocol_n_samples and not self.test_mode:
                self.next_protocol()

    def enable_trouble_catching(self, widget):
        self.catch_channels_trouble = not widget.ignore_flag

    def start_test_protocol(self, protocol):
        print('Experiment: test')
        if not self.main_timer.isActive():
            self.main_timer.start(1000 * 1. / self.freq)
        self.samples_counter = 0
        self.main.signals_buffer *= 0
        self.test_mode = True

        self.subject.change_protocol(protocol)

    def close_test_protocol(self):
        if self.main_timer.isActive():
            self.main_timer.stop()
        self.samples_counter = 0
        self.main.signals_buffer *= 0
        self.test_mode = False

    def handle_channels_trouble_pause(self):
        print('pause clicked')
        if self.main.player_panel.start.isChecked():
            self.main.player_panel.start.click()

    def handle_channels_trouble_continue(self, pause_enabled):
        print('continue clicked')
        if not pause_enabled and not self.main.player_panel.start.isChecked():
            self.main.player_panel.start.click()

    def next_protocol(self):
        """
        Change protocol
        :return: None
        """
        # save raw and signals samples asynchronously
        protocol_number_str = 'protocol' + str(self.current_protocol_index + 1)

        # descale signals:
        signals_recordings = np.array([signal.descale_recording(data)
                                       for signal, data in
                                       zip(self.signals, self.signals_recorder[:self.samples_counter].T)]).T

        # close previous protocol
        self.protocols_sequence[self.current_protocol_index].close_protocol(
            raw=self.raw_recorder[:self.samples_counter],
            signals=signals_recordings,
            protocols=self.protocols,
            protocols_seq=[protocol.name for protocol in self.protocols_sequence[:self.current_protocol_index + 1]],
            raw_file=self.dir_name + 'experiment_data.h5',
            marks=self.mark_recorder[:self.samples_counter])

        save_signals(self.dir_name + 'experiment_data.h5', self.signals, protocol_number_str,
                     raw_data=self.raw_recorder[:self.samples_counter],
                     timestamp_data=self.timestamp_recorder[:self.samples_counter],
                     raw_other_data=self.raw_recorder_other[:self.samples_counter],
                     signals_data=signals_recordings,
                     reward_data=self.reward_recorder[:self.samples_counter],
                     protocol_name=self.protocols_sequence[self.current_protocol_index].name,
                     mock_previous=self.protocols_sequence[self.current_protocol_index].mock_previous,
                     mark_data=self.mark_recorder[:self.samples_counter])

        # reset samples counter
        previous_counter = self.samples_counter
        self.samples_counter = 0
        if self.protocols_sequence[self.current_protocol_index].update_statistics_in_the_end:
            self.main.time_counter1 = 0
            self.main.signals_viewer.reset_buffer()
        self.seconds = self.freq

        # list of real fb protocols (number in protocol sequence)
        if isinstance(self.protocols_sequence[self.current_protocol_index], FeedbackProtocol):
            if self.protocols_sequence[self.current_protocol_index].mock_previous == 0:
                self.real_fb_number_list += [self.current_protocol_index + 1]
        elif self.protocols_sequence[self.current_protocol_index].as_mock:
            self.real_fb_number_list += [self.current_protocol_index + 1]

        if self.current_protocol_index < len(self.protocols_sequence) - 1:

            # update current protocol index and n_samples
            self.current_protocol_index += 1
            current_protocol = self.protocols_sequence[self.current_protocol_index]
            self.current_protocol_n_samples = self.freq * (
                    self.protocols_sequence[self.current_protocol_index].duration +
                    np.random.uniform(0, self.protocols_sequence[self.current_protocol_index].random_over_time))

            # prepare mock from raw if necessary
            if current_protocol.mock_previous:
                random_previos_fb = None
                if len(self.real_fb_number_list) > 0:
                    random_previos_fb = self.real_fb_number_list[np.random.randint(0, len(self.real_fb_number_list))]
                if current_protocol.shuffle_mock_previous:
                    current_protocol.mock_previous = random_previos_fb
                print('MOCK from protocol # current_protocol.mock_previous')
                if current_protocol.mock_previous == self.current_protocol_index:
                    mock_raw = self.raw_recorder[:previous_counter]
                    mock_signals = self.signals_recorder[:previous_counter]
                else:
                    mock_raw = load_h5py(self.dir_name + 'experiment_data.h5',
                                         'protocol{}/raw_data'.format(current_protocol.mock_previous))
                    mock_signals = load_h5py(self.dir_name + 'experiment_data.h5',
                                             'protocol{}/signals_data'.format(current_protocol.mock_previous))
                # print(self.real_fb_number_list)

                current_protocol.prepare_raw_mock_if_necessary(mock_raw, random_previos_fb, mock_signals)

            # change protocol widget
            self.subject.change_protocol(current_protocol)
            if current_protocol.mock_samples_file_path is not None:
                self.mock_signals_buffer = load_h5py_protocol_signals(
                    current_protocol.mock_samples_file_path,
                    current_protocol.mock_samples_protocol)
            self.main.status.update()

            self.reward.threshold = current_protocol.reward_threshold
            reward_signal_id = current_protocol.reward_signal_id
            self.reward.signal = self.signals[reward_signal_id]  # TODO: REward for MOCK
            self.reward.set_enabled(isinstance(current_protocol, FeedbackProtocol))

        else:
            # status
            self.main.status.finish()
            # action in the end of protocols sequence
            self.current_protocol_n_samples = np.inf
            self.is_finished = True
            if self.params['bShowSubjectWindow']:
                self.subject.close()
            if self.params['bPlotSourceSpace']:
                self.source_space_window.close()
            # plot_fb_dynamic(self.dir_name + 'experiment_data.h5', self.dir_name)
            # np.save('results/raw', self.main.raw_recorder)
            # np.save('results/signals', self.main.signals_recorder)

            # save_h5py(self.dir_name + 'raw.h5', self.main.raw_recorder)
            # save_h5py(self.dir_name + 'signals.h5', self.main.signals_recorder)

    def restart(self):

        timestamp_str = datetime.strftime(datetime.now(), '%m-%d_%H-%M-%S')
        self.dir_name = 'results/{}_{}/'.format(self.params['sExperimentName'], timestamp_str)
        os.makedirs(self.dir_name)

        wait_bar = WaitMessage(WAIT_BAR_MESSAGES['EXPERIMENT_START']).show_and_return()

        self.test_mode = False
        if self.main_timer is not None:
            self.main_timer.stop()
        if self.stream is not None:
            self.stream.disconnect()
        if self.thread is not None:
            self.thread.terminate()

        # timer
        self.main_timer = QtCore.QTimer(self.app)

        self.is_finished = False

        # current protocol index
        self.current_protocol_index = 0

        # samples counter for protocol sequence
        self.samples_counter = 0

        # run file lsl stream in a thread
        self.thread = None
        if self.params['sInletType'] == 'lsl_from_file':
            self.restart_lsl_from_file()

        # run simulated eeg lsl stream in a thread
        elif self.params['sInletType'] == 'lsl_generator':
            self.thread = stream_generator_in_a_thread(self.params['sStreamName'])

        # use FTB inlet
        aux_streams = None
        if self.params['sInletType'] == 'ftbuffer':
            hostname, port = self.params['sFTHostnamePort'].split(':')
            port = int(port)
            stream = FieldTripBufferInlet(hostname, port)

        # use LSL inlet
        else:
            stream_names = re.split(r"[,; ]+", self.params['sStreamName'])
            streams = [LSLInlet(name=name) for name in stream_names]
            stream = streams[0]
            aux_streams = streams[1:] if len(streams) > 1 else None

        # setup events stream by name
        events_stream_name = self.params['sEventsStreamName']
        events_stream = LSLInlet(events_stream_name) if events_stream_name else None

        # setup main stream
        self.stream = ChannelsSelector(stream, exclude=self.params['sReference'],
                                       subtractive_channel=self.params['sReferenceSub'],
                                       dc=self.params['bDC'], events_inlet=events_stream, aux_inlets=aux_streams,
                                       prefilter_band=self.params['sPrefilterBand'])
        self.stream.save_info(self.dir_name + 'stream_info.xml')
        save_channels_and_fs(self.dir_name + 'experiment_data.h5', self.stream.get_channels_labels(),
                             self.stream.get_frequency())

        save_xml_str_to_hdf5_dataset(self.dir_name + 'experiment_data.h5', self.stream.info_as_xml(), 'stream_info.xml')
        self.freq = self.stream.get_frequency()
        self.n_channels = self.stream.get_n_channels()
        self.n_channels_other = self.stream.get_n_channels_other()
        channels_labels = self.stream.get_channels_labels()
        montage = Montage(channels_labels)
        print(montage)
        self.seconds = 2 * self.freq
        self.raw_std = None

        # signals
        self.signals = [DerivedSignal.from_params(ind, self.freq, self.n_channels, channels_labels, signal)
                        for ind, signal in enumerate(self.params['vSignals']['DerivedSignal']) if
                        not signal['bBCIMode']]

        # composite signals
        self.composite_signals = [CompositeSignal([s for s in self.signals],
                                                  signal['sExpression'],
                                                  signal['sSignalName'],
                                                  ind + len(self.signals), self.freq)
                                  for ind, signal in enumerate(self.params['vSignals']['CompositeSignal'])]

        # bci signals
        self.bci_signals = [BCISignal(self.freq, channels_labels, signal['sSignalName'], ind)
                            for ind, signal in enumerate(self.params['vSignals']['DerivedSignal']) if
                            signal['bBCIMode']]

        self.signals += self.composite_signals
        self.signals += self.bci_signals
        # self.current_samples = np.zeros_like(self.signals)

        # signals outlet
        self.signals_outlet = SignalsOutlet([signal.name for signal in self.signals], fs=self.freq)

        # protocols
        self.protocols = []
        signal_names = [signal.name for signal in self.signals]

        for protocol in self.params['vProtocols']:
            # some general protocol arguments
            source_signal_id = None if protocol['fbSource'] == 'All' else signal_names.index(protocol['fbSource'])
            reward_signal_id = signal_names.index(protocol['sRewardSignal']) if protocol['sRewardSignal'] != '' else 0
            mock_path = (protocol['sMockSignalFilePath'] if protocol['sMockSignalFilePath'] != '' else None,
                         protocol['sMockSignalFileDataset'])
            m_signal = protocol['sMSignal']
            m_signal_index = None if m_signal not in signal_names else signal_names.index(m_signal)

            # general protocol arguments dictionary
            kwargs = dict(
                source_signal_id=source_signal_id,
                name=protocol['sProtocolName'],
                duration=protocol['fDuration'],
                random_over_time=protocol['fRandomOverTime'],
                update_statistics_in_the_end=bool(protocol['bUpdateStatistics']),
                stats_type=protocol['sStatisticsType'],
                mock_samples_path=mock_path,
                show_reward=bool(protocol['bShowReward']),
                reward_signal_id=reward_signal_id,
                reward_threshold=protocol['bRewardThreshold'],
                ssd_in_the_end=protocol['bSSDInTheEnd'],
                timer=self.main_timer,
                freq=self.freq,
                mock_previous=int(protocol['iMockPrevious']),
                drop_outliers=int(protocol['iDropOutliers']),
                experiment=self,
                pause_after=bool(protocol['bPauseAfter']),
                beep_after=bool(protocol['bBeepAfter']),
                reverse_mock_previous=bool(protocol['bReverseMockPrevious']),
                m_signal_index=m_signal_index,
                shuffle_mock_previous=bool(protocol['bRandomMockPrevious']),
                as_mock=bool(protocol['bMockSource']),
                auto_bci_fit=bool(protocol['bAutoBCIFit']),
                montage=montage
            )

            # type specific arguments
            if protocol['sFb_type'] == 'Baseline':
                self.protocols.append(
                    BaselineProtocol(
                        self.signals,
                        text=protocol['cString'] if protocol['cString'] != '' else 'Relax',
                        half_time_text=protocol['cString2'] if bool(protocol['bUseExtraMessage']) else None,
                        **kwargs
                    ))
            elif protocol['sFb_type'] in ['Feedback', 'CircleFeedback']:
                self.protocols.append(
                    FeedbackProtocol(
                        self.signals,
                        circle_border=protocol['iRandomBound'],
                        m_threshold=protocol['fMSignalThreshold'],
                        **kwargs))
            elif protocol['sFb_type'] == 'ThresholdBlink':
                self.protocols.append(
                    ThresholdBlinkFeedbackProtocol(
                        self.signals,
                        threshold=protocol['fBlinkThreshold'],
                        time_ms=protocol['fBlinkDurationMs'],
                        **kwargs))
            elif protocol['sFb_type'] == 'Video':
                self.protocols.append(
                    VideoProtocol(
                        self.signals,
                        video_path=protocol['sVideoPath'],
                        **kwargs))
            else:
                raise TypeError('Undefined protocol type \"{}\"'.format(protocol['sFb_type']))

        # protocols sequence
        names = [protocol.name for protocol in self.protocols]
        group_names = [p['sName'] for p in self.params['vPGroups']['PGroup']]
        print(group_names)
        self.protocols_sequence = []
        for name in self.params['vPSequence']:
            if name in names:
                self.protocols_sequence.append(self.protocols[names.index(name)])
            if name in group_names:
                group = self.params['vPGroups']['PGroup'][group_names.index(name)]
                subgroup = []
                if len(group['sList'].split(' ')) == 1:
                    subgroup.append([group['sList']] * int(group['sNumberList']))
                else:
                    for s_name, s_n in zip(group['sList'].split(' '), list(map(int, group['sNumberList'].split(' ')))):
                        subgroup.append([s_name] * s_n)
                if group['bShuffle']:
                    subgroup = np.concatenate(subgroup)
                    subgroup = list(subgroup[np.random.permutation(len(subgroup))])
                else:
                    subgroup = [k for k in chain(*zip_longest(*subgroup)) if k is not None]
                print(subgroup)
                for subname in subgroup:
                    self.protocols_sequence.append(self.protocols[names.index(subname)])
                    if len(group['sSplitBy']):
                        self.protocols_sequence.append(self.protocols[names.index(group['sSplitBy'])])

        # reward
        from pynfb.reward import Reward
        self.reward = Reward(self.protocols[0].reward_signal_id,
                             threshold=self.protocols[0].reward_threshold,
                             rate_of_increase=self.params['fRewardPeriodS'],
                             fs=self.freq)

        self.reward.set_enabled(isinstance(self.protocols_sequence[0], FeedbackProtocol))

        # timer
        # self.main_timer = QtCore.QTimer(self.app)
        self.main_timer.timeout.connect(self.update)
        self.main_timer.start(1000 * 1. / self.freq)

        # current protocol number of samples ('frequency' * 'protocol duration')
        self.current_protocol_n_samples = self.freq * (self.protocols_sequence[self.current_protocol_index].duration +
                                                       np.random.uniform(0, self.protocols_sequence[
                                                           self.current_protocol_index].random_over_time))

        # experiment number of samples
        max_protocol_n_samples = int(
            max([self.freq * (p.duration + p.random_over_time) for p in self.protocols_sequence]))

        # data recorders
        self.experiment_n_samples = max_protocol_n_samples
        self.samples_counter = 0
        self.raw_recorder = np.zeros((max_protocol_n_samples * 110 // 100, self.n_channels)) * np.nan
        self.timestamp_recorder = np.zeros((max_protocol_n_samples * 110 // 100)) * np.nan
        self.raw_recorder_other = np.zeros((max_protocol_n_samples * 110 // 100, self.n_channels_other)) * np.nan
        self.signals_recorder = np.zeros((max_protocol_n_samples * 110 // 100, len(self.signals))) * np.nan
        self.reward_recorder = np.zeros((max_protocol_n_samples * 110 // 100)) * np.nan
        self.mark_recorder = np.zeros((max_protocol_n_samples * 110 // 100)) * np.nan

        # save init signals
        save_signals(self.dir_name + 'experiment_data.h5', self.signals,
                     group_name='protocol0')

        # save settings
        params_to_xml_file(self.params, self.dir_name + 'settings.xml')
        save_xml_str_to_hdf5_dataset(self.dir_name + 'experiment_data.h5', params_to_xml(self.params), 'settings.xml')

        # windows
        self.main = MainWindow(signals=self.signals,
                               protocols=self.protocols_sequence,
                               parent=None,
                               experiment=self,
                               current_protocol=self.protocols_sequence[self.current_protocol_index],
                               n_signals=len(self.signals),
                               max_protocol_n_samples=max_protocol_n_samples,
                               freq=self.freq,
                               n_channels=self.n_channels,
                               plot_raw_flag=self.params['bPlotRaw'],
                               plot_signals_flag=self.params['bPlotSignals'],
                               plot_source_space_flag=self.params['bPlotSourceSpace'],
                               show_subject_window=self.params['bShowSubjectWindow'],
                               channels_labels=channels_labels,
                               photo_rect=self.params['bShowPhotoRectangle'])
        self.subject = self.main.subject_window
        if self.params['bPlotSourceSpace']:
            self.source_space_window = self.main.source_space_window

        if self.params['sInletType'] == 'lsl_from_file':
            self.main.player_panel.start_clicked.connect(self.restart_lsl_from_file)

        # create real fb list
        self.real_fb_number_list = []

        wait_bar.close()

    def restart_lsl_from_file(self):
        if self.thread is not None:
            self.thread.terminate()

        file_path = self.params['sRawDataFilePath']
        reference = self.params['sReference']
        stream_name = self.params['sStreamName']

        self.thread = stream_file_in_a_thread(file_path, reference, stream_name)

    def destroy(self):
        if self.thread is not None:
            self.thread.terminate()
        self.main_timer.stop()
        del self.stream
        self.stream = None
예제 #2
0
    def restart(self):

        timestamp_str = datetime.strftime(datetime.now(), '%m-%d_%H-%M-%S')
        self.dir_name = 'results/{}_{}/'.format(self.params['sExperimentName'], timestamp_str)
        os.makedirs(self.dir_name)

        wait_bar = WaitMessage(WAIT_BAR_MESSAGES['EXPERIMENT_START']).show_and_return()

        self.test_mode = False
        if self.main_timer is not None:
            self.main_timer.stop()
        if self.stream is not None:
            self.stream.disconnect()
        if self.thread is not None:
            self.thread.terminate()

        # timer
        self.main_timer = QtCore.QTimer(self.app)

        self.is_finished = False

        # current protocol index
        self.current_protocol_index = 0

        # samples counter for protocol sequence
        self.samples_counter = 0

        # run file lsl stream in a thread
        self.thread = None
        if self.params['sInletType'] == 'lsl_from_file':
            self.restart_lsl_from_file()

        # run simulated eeg lsl stream in a thread
        elif self.params['sInletType'] == 'lsl_generator':
            self.thread = stream_generator_in_a_thread(self.params['sStreamName'])

        # use FTB inlet
        aux_streams = None
        if self.params['sInletType'] == 'ftbuffer':
            hostname, port = self.params['sFTHostnamePort'].split(':')
            port = int(port)
            stream = FieldTripBufferInlet(hostname, port)

        # use LSL inlet
        else:
            stream_names = re.split(r"[,; ]+", self.params['sStreamName'])
            streams = [LSLInlet(name=name) for name in stream_names]
            stream = streams[0]
            aux_streams = streams[1:] if len(streams) > 1 else None

        # setup events stream by name
        events_stream_name = self.params['sEventsStreamName']
        events_stream = LSLInlet(events_stream_name) if events_stream_name else None

        # setup main stream
        self.stream = ChannelsSelector(stream, exclude=self.params['sReference'],
                                       subtractive_channel=self.params['sReferenceSub'],
                                       dc=self.params['bDC'], events_inlet=events_stream, aux_inlets=aux_streams,
                                       prefilter_band=self.params['sPrefilterBand'])
        self.stream.save_info(self.dir_name + 'stream_info.xml')
        save_channels_and_fs(self.dir_name + 'experiment_data.h5', self.stream.get_channels_labels(),
                             self.stream.get_frequency())

        save_xml_str_to_hdf5_dataset(self.dir_name + 'experiment_data.h5', self.stream.info_as_xml(), 'stream_info.xml')
        self.freq = self.stream.get_frequency()
        self.n_channels = self.stream.get_n_channels()
        self.n_channels_other = self.stream.get_n_channels_other()
        channels_labels = self.stream.get_channels_labels()
        montage = Montage(channels_labels)
        print(montage)
        self.seconds = 2 * self.freq
        self.raw_std = None

        # signals
        self.signals = [DerivedSignal.from_params(ind, self.freq, self.n_channels, channels_labels, signal)
                        for ind, signal in enumerate(self.params['vSignals']['DerivedSignal']) if
                        not signal['bBCIMode']]

        # composite signals
        self.composite_signals = [CompositeSignal([s for s in self.signals],
                                                  signal['sExpression'],
                                                  signal['sSignalName'],
                                                  ind + len(self.signals), self.freq)
                                  for ind, signal in enumerate(self.params['vSignals']['CompositeSignal'])]

        # bci signals
        self.bci_signals = [BCISignal(self.freq, channels_labels, signal['sSignalName'], ind)
                            for ind, signal in enumerate(self.params['vSignals']['DerivedSignal']) if
                            signal['bBCIMode']]

        self.signals += self.composite_signals
        self.signals += self.bci_signals
        # self.current_samples = np.zeros_like(self.signals)

        # signals outlet
        self.signals_outlet = SignalsOutlet([signal.name for signal in self.signals], fs=self.freq)

        # protocols
        self.protocols = []
        signal_names = [signal.name for signal in self.signals]

        for protocol in self.params['vProtocols']:
            # some general protocol arguments
            source_signal_id = None if protocol['fbSource'] == 'All' else signal_names.index(protocol['fbSource'])
            reward_signal_id = signal_names.index(protocol['sRewardSignal']) if protocol['sRewardSignal'] != '' else 0
            mock_path = (protocol['sMockSignalFilePath'] if protocol['sMockSignalFilePath'] != '' else None,
                         protocol['sMockSignalFileDataset'])
            m_signal = protocol['sMSignal']
            m_signal_index = None if m_signal not in signal_names else signal_names.index(m_signal)

            # general protocol arguments dictionary
            kwargs = dict(
                source_signal_id=source_signal_id,
                name=protocol['sProtocolName'],
                duration=protocol['fDuration'],
                random_over_time=protocol['fRandomOverTime'],
                update_statistics_in_the_end=bool(protocol['bUpdateStatistics']),
                stats_type=protocol['sStatisticsType'],
                mock_samples_path=mock_path,
                show_reward=bool(protocol['bShowReward']),
                reward_signal_id=reward_signal_id,
                reward_threshold=protocol['bRewardThreshold'],
                ssd_in_the_end=protocol['bSSDInTheEnd'],
                timer=self.main_timer,
                freq=self.freq,
                mock_previous=int(protocol['iMockPrevious']),
                drop_outliers=int(protocol['iDropOutliers']),
                experiment=self,
                pause_after=bool(protocol['bPauseAfter']),
                beep_after=bool(protocol['bBeepAfter']),
                reverse_mock_previous=bool(protocol['bReverseMockPrevious']),
                m_signal_index=m_signal_index,
                shuffle_mock_previous=bool(protocol['bRandomMockPrevious']),
                as_mock=bool(protocol['bMockSource']),
                auto_bci_fit=bool(protocol['bAutoBCIFit']),
                montage=montage
            )

            # type specific arguments
            if protocol['sFb_type'] == 'Baseline':
                self.protocols.append(
                    BaselineProtocol(
                        self.signals,
                        text=protocol['cString'] if protocol['cString'] != '' else 'Relax',
                        half_time_text=protocol['cString2'] if bool(protocol['bUseExtraMessage']) else None,
                        **kwargs
                    ))
            elif protocol['sFb_type'] in ['Feedback', 'CircleFeedback']:
                self.protocols.append(
                    FeedbackProtocol(
                        self.signals,
                        circle_border=protocol['iRandomBound'],
                        m_threshold=protocol['fMSignalThreshold'],
                        **kwargs))
            elif protocol['sFb_type'] == 'ThresholdBlink':
                self.protocols.append(
                    ThresholdBlinkFeedbackProtocol(
                        self.signals,
                        threshold=protocol['fBlinkThreshold'],
                        time_ms=protocol['fBlinkDurationMs'],
                        **kwargs))
            elif protocol['sFb_type'] == 'Video':
                self.protocols.append(
                    VideoProtocol(
                        self.signals,
                        video_path=protocol['sVideoPath'],
                        **kwargs))
            else:
                raise TypeError('Undefined protocol type \"{}\"'.format(protocol['sFb_type']))

        # protocols sequence
        names = [protocol.name for protocol in self.protocols]
        group_names = [p['sName'] for p in self.params['vPGroups']['PGroup']]
        print(group_names)
        self.protocols_sequence = []
        for name in self.params['vPSequence']:
            if name in names:
                self.protocols_sequence.append(self.protocols[names.index(name)])
            if name in group_names:
                group = self.params['vPGroups']['PGroup'][group_names.index(name)]
                subgroup = []
                if len(group['sList'].split(' ')) == 1:
                    subgroup.append([group['sList']] * int(group['sNumberList']))
                else:
                    for s_name, s_n in zip(group['sList'].split(' '), list(map(int, group['sNumberList'].split(' ')))):
                        subgroup.append([s_name] * s_n)
                if group['bShuffle']:
                    subgroup = np.concatenate(subgroup)
                    subgroup = list(subgroup[np.random.permutation(len(subgroup))])
                else:
                    subgroup = [k for k in chain(*zip_longest(*subgroup)) if k is not None]
                print(subgroup)
                for subname in subgroup:
                    self.protocols_sequence.append(self.protocols[names.index(subname)])
                    if len(group['sSplitBy']):
                        self.protocols_sequence.append(self.protocols[names.index(group['sSplitBy'])])

        # reward
        from pynfb.reward import Reward
        self.reward = Reward(self.protocols[0].reward_signal_id,
                             threshold=self.protocols[0].reward_threshold,
                             rate_of_increase=self.params['fRewardPeriodS'],
                             fs=self.freq)

        self.reward.set_enabled(isinstance(self.protocols_sequence[0], FeedbackProtocol))

        # timer
        # self.main_timer = QtCore.QTimer(self.app)
        self.main_timer.timeout.connect(self.update)
        self.main_timer.start(1000 * 1. / self.freq)

        # current protocol number of samples ('frequency' * 'protocol duration')
        self.current_protocol_n_samples = self.freq * (self.protocols_sequence[self.current_protocol_index].duration +
                                                       np.random.uniform(0, self.protocols_sequence[
                                                           self.current_protocol_index].random_over_time))

        # experiment number of samples
        max_protocol_n_samples = int(
            max([self.freq * (p.duration + p.random_over_time) for p in self.protocols_sequence]))

        # data recorders
        self.experiment_n_samples = max_protocol_n_samples
        self.samples_counter = 0
        self.raw_recorder = np.zeros((max_protocol_n_samples * 110 // 100, self.n_channels)) * np.nan
        self.timestamp_recorder = np.zeros((max_protocol_n_samples * 110 // 100)) * np.nan
        self.raw_recorder_other = np.zeros((max_protocol_n_samples * 110 // 100, self.n_channels_other)) * np.nan
        self.signals_recorder = np.zeros((max_protocol_n_samples * 110 // 100, len(self.signals))) * np.nan
        self.reward_recorder = np.zeros((max_protocol_n_samples * 110 // 100)) * np.nan
        self.mark_recorder = np.zeros((max_protocol_n_samples * 110 // 100)) * np.nan

        # save init signals
        save_signals(self.dir_name + 'experiment_data.h5', self.signals,
                     group_name='protocol0')

        # save settings
        params_to_xml_file(self.params, self.dir_name + 'settings.xml')
        save_xml_str_to_hdf5_dataset(self.dir_name + 'experiment_data.h5', params_to_xml(self.params), 'settings.xml')

        # windows
        self.main = MainWindow(signals=self.signals,
                               protocols=self.protocols_sequence,
                               parent=None,
                               experiment=self,
                               current_protocol=self.protocols_sequence[self.current_protocol_index],
                               n_signals=len(self.signals),
                               max_protocol_n_samples=max_protocol_n_samples,
                               freq=self.freq,
                               n_channels=self.n_channels,
                               plot_raw_flag=self.params['bPlotRaw'],
                               plot_signals_flag=self.params['bPlotSignals'],
                               plot_source_space_flag=self.params['bPlotSourceSpace'],
                               show_subject_window=self.params['bShowSubjectWindow'],
                               channels_labels=channels_labels,
                               photo_rect=self.params['bShowPhotoRectangle'])
        self.subject = self.main.subject_window
        if self.params['bPlotSourceSpace']:
            self.source_space_window = self.main.source_space_window

        if self.params['sInletType'] == 'lsl_from_file':
            self.main.player_panel.start_clicked.connect(self.restart_lsl_from_file)

        # create real fb list
        self.real_fb_number_list = []

        wait_bar.close()