예제 #1
0
    def _step(self, action):
        """
        _step receives an action and returns:
            a new observation, obs
            reward associated with the action, reward
            a boolean variable indicating whether the experiment has end, done
            a dictionary with extra information:
                ground truth correct response, info['gt']
                boolean indicating the end of the trial, info['new_trial']
        """
        # ---------------------------------------------------------------------
        # Reward and observations
        # ---------------------------------------------------------------------
        trial = self.trial
        info = {'new_trial': False}
        info['gt'] = np.zeros((3, ))
        # rewards
        reward = 0
        # observations
        obs = np.zeros((3, ))
        if self.in_epoch(self.t, 'fixation'):
            info['gt'][0] = 1
            obs[0] = 1
            if self.actions[action] != 0:
                info['new_trial'] = self.abort
                reward = self.R_ABORTED
        elif self.in_epoch(self.t, 'decision'):
            info['gt'][int((trial['ground_truth'] / 2 + 1.5))] = 1
            gt_sign = np.sign(trial['ground_truth'])
            action_sign = np.sign(self.actions[action])
            if gt_sign == action_sign:
                reward = self.R_CORRECT
            elif gt_sign == -action_sign:
                reward = self.R_FAIL
            info['new_trial'] = self.actions[action] != 0
        else:
            info['gt'][0] = 1

        # this is an 'if' to allow the stimulus and fixation periods to overlap
        if self.in_epoch(self.t, 'stimulus'):
            obs[0] = 1
            high = (trial['ground_truth'] > 0) + 1
            low = (trial['ground_truth'] < 0) + 1
            obs[high] = self.scale(trial['coh']) +\
                self.rng.gauss(mu=0, sigma=self.sigma)/np.sqrt(self.dt)
            obs[low] = self.scale(-trial['coh']) +\
                self.rng.gauss(mu=0, sigma=self.sigma)/np.sqrt(self.dt)

        # ---------------------------------------------------------------------
        # new trial?
        reward, info['new_trial'] = tasktools.new_trial(
            self.t, self.tmax, self.dt, info['new_trial'], self.R_MISS, reward)
        if info['new_trial']:
            self.t = 0
            self.num_tr += 1
        else:
            self.t += self.dt

        done = self.num_tr > self.num_tr_exp
        return obs, reward, done, info
예제 #2
0
    def _step(self, action):
        trial = self.trial
        info = {'new_trial': False, 'gt': np.zeros((3, ))}
        reward = 0

        obs = np.zeros((3, ))
        if self.t == 0:  # at stage 1, if action==fixate, abort
            if action == 0:
                reward = self.R_ABORTED
                info['new_trial'] = True
            else:
                state = trial['transition'][action]
                obs[int(state)] = 1
                reward = trial['reward'][int(state - 1)]
        elif self.t == self.dt:
            obs[0] = 1
            if action != 0:
                reward = self.R_ABORTED
            info['new_trial'] = True
        else:
            raise ValueError('t is not 0 or 1')

        # ---------------------------------------------------------------------
        # new trial?
        reward, info['new_trial'] = tasktools.new_trial(
            self.t, self.tmax, self.dt, info['new_trial'], self.R_MISS, reward)
        if info['new_trial']:
            self.t = 0
            self.num_tr += 1
        else:
            self.t += self.dt

        done = self.num_tr > self.num_tr_exp

        return obs, reward, done, info
예제 #3
0
    def _step(self, action):
        trial = self.trial
        # ---------------------------------------------------------------------
        # Reward and inputs
        # ---------------------------------------------------------------------
        trial = self.trial
        info = {'new_trial': False}
        # ground truth signal is not well defined in this task
        info['gt'] = np.zeros((4, ))
        # rewards
        reward = 0
        # observations
        obs = np.zeros((4, ))
        if self.in_epoch(self.t, 'fixation'):
            obs[0] = 1
            if self.actions[action] != 0:
                info['new_trial'] = self.abort
                reward = self.R_ABORTED
        elif self.in_epoch(self.t, 'decision'):
            if self.actions[action] == 2:
                if trial['wager']:
                    reward = self.R_SURE
                else:
                    reward = self.R_ABORTED
            else:
                gt_sign = np.sign(trial['ground_truth'])
                action_sign = np.sign(self.actions[action])
                if gt_sign == action_sign:
                    reward = self.R_CORRECT
                elif gt_sign == -action_sign:
                    reward = self.R_FAIL
            info['new_trial'] = self.actions[action] != 0

        if self.in_epoch(self.t, 'delay'):
            obs[0] = 1
        elif self.in_epoch(self.t, 'stimulus'):
            high = (trial['ground_truth'] > 0) + 1
            low = (trial['ground_truth'] < 0) + 1
            obs[high] = self.scale(+trial['coh']) +\
                self.rng.gauss(mu=0, sigma=self.sigma)/np.sqrt(self.dt)
            obs[low] = self.scale(-trial['coh']) +\
                self.rng.gauss(mu=0, sigma=self.sigma)/np.sqrt(self.dt)
        if trial['wager'] and self.in_epoch(self.t, 'sure'):
            obs[3] = 1

        # ---------------------------------------------------------------------
        # new trial?
        reward, info['new_trial'] = tasktools.new_trial(
            self.t, self.tmax, self.dt, info['new_trial'], self.R_MISS, reward)

        if info['new_trial']:
            self.t = 0
            self.num_tr += 1
        else:
            self.t += self.dt

        done = self.num_tr > self.num_tr_exp
        return obs, reward, done, info
예제 #4
0
    def _step(self, action):
        trial = self.trial

        # ---------------------------------------------------------------------
        # Reward and inputs
        # ---------------------------------------------------------------------
        # epochs = trial['epochs']
        info = {'new_trial': False}
        info['gt'] = np.zeros((2,))
        reward = 0
        obs = np.zeros((5,))
        if self.in_epoch(self.t, 'fixation'):
            info['gt'][0] = 1
            obs[0] = 1  # TODO: fixation cue only during fixation period?
            if self.actions[action] != -1:
                info['new_trial'] = self.abort
                reward = self.R_ABORTED
        elif self.in_epoch(self.t, 'decision'):
            info['gt'][int((trial['ground_truth']/2+.5))] = 1
            gt_sign = np.sign(trial['ground_truth'])
            action_sign = np.sign(self.actions[action])
            if (action_sign > 0):
                info['new_trial'] = True
                if (gt_sign > 0):
                    reward = self.R_CORRECT
                else:
                    reward = self.R_INCORRECT
        else:
            info['gt'][0] = 1

        if self.in_epoch(self.t, 'dpa1'):
            dpa1, _ = trial['pair']
            obs[dpa1] = 1
        if self.in_epoch(self.t, 'dpa2'):
            _, dpa2 = trial['pair']
            obs[dpa2] = 1
        # ---------------------------------------------------------------------
        # new trial?
        reward, info['new_trial'] = tasktools.new_trial(self.t, self.tmax,
                                                        self.dt,
                                                        info['new_trial'],
                                                        self.R_MISS, reward)

        if info['new_trial']:
            info['new_trial'] = True
            self.t = 0
            self.num_tr += 1
        else:
            self.t += self.dt
        done = self.num_tr > self.num_tr_exp
        return obs, reward, done, info
예제 #5
0
    def _step(self, action):
        # ---------------------------------------------------------------------
        # Reward and inputs
        # ---------------------------------------------------------------------
        trial = self.trial
        info = {'new_trial': False, 'gt': np.zeros((2, ))}
        reward = 0
        obs = np.zeros((3, ))
        if self.in_epoch(self.t, 'fixation'):
            obs[0] = 1
            info['gt'][0] = 1
            if self.actions[action] != -1:
                info['new_trial'] = self.abort
                reward = self.R_ABORTED
        if self.in_epoch(self.t, 'production'):
            t_prod = self.t - trial['durations']['measure'][1]
            eps = abs(t_prod - trial['production'])
            if eps < self.dt / 2 + 1:
                info['gt'][1] = 1
            else:
                info['gt'][0] = 1
            if action == 1:
                info['new_trial'] = True  # terminate
                # actual production time
                eps_threshold = 0.2 * trial['production'] + 25
                if eps > eps_threshold:
                    reward = self.R_FAIL
                else:
                    reward = (1. - eps / eps_threshold)**1.5
                    reward = min(reward, 0.1)
                    reward *= self.R_CORRECT
        else:
            info['gt'][0] = 1

        if self.in_epoch(self.t, 'ready'):
            obs[1] = 1
        if self.in_epoch(self.t, 'set'):
            obs[2] = 1

        # ---------------------------------------------------------------------
        # new trial?
        reward, info['new_trial'] = tasktools.new_trial(
            self.t, self.tmax, self.dt, info['new_trial'], self.R_MISS, reward)
        if info['new_trial']:
            self.t = 0
            self.num_tr += 1
        else:
            self.t += self.dt

        done = self.num_tr > self.num_tr_exp
        return obs, reward, done, info
예제 #6
0
    def _step(self, action):
        # ---------------------------------------------------------------------
        # Reward
        # ---------------------------------------------------------------------
        trial = self.trial
        info = {'new_trial': False, 'gt': np.zeros((3, ))}
        reward = 0
        obs = np.zeros((3, ))

        if self.in_epoch(self.t, 'fixation'):
            info['gt'][0] = 1
            obs[0] = 1
            if self.actions[action] != 0:
                info['new_trial'] = self.abort
                reward = self.R_ABORTED
        elif self.in_epoch(self.t, 'decision'):
            info['gt'][int((trial['ground_truth'] / 2 + 1.5))] = 1
            gt_sign = np.sign(trial['ground_truth'])
            action_sign = np.sign(self.actions[action])
            if gt_sign == action_sign:
                reward = self.R_CORRECT
            elif gt_sign == -action_sign:
                reward = self.R_FAIL
            info['new_trial'] = self.actions[action] != 0
        else:
            info['gt'][0] = 1

        # ---------------------------------------------------------------------
        # Inputs
        # ---------------------------------------------------------------------
        if self.in_epoch(self.t, 'sample'):
            obs[trial['sample'] + 1] = 1
        if self.in_epoch(self.t, 'test'):
            obs[trial['test'] + 1] = 1

        # ---------------------------------------------------------------------
        # new trial?
        reward, info['new_trial'] = tasktools.new_trial(
            self.t, self.tmax, self.dt, info['new_trial'], self.R_MISS, reward)

        if info['new_trial']:
            self.t = 0
            self.num_tr += 1
        else:
            self.t += self.dt

        done = self.num_tr > self.num_tr_exp
        return obs, reward, done, info
예제 #7
0
    def _step(self, action):
        # ---------------------------------------------------------------------
        # Reward and inputs
        # ---------------------------------------------------------------------
        trial = self.trial
        info = {'new_trial': False}
        # rewards
        reward = 0
        # observations
        obs = np.zeros((2, ))
        info['gt'] = np.zeros((3, ))
        if self.in_epoch(self.t, 'fixation'):
            info['gt'][0] = 1
            obs[0] = 1
            if self.actions[action] != 0:
                info['new_trial'] = self.abort
                reward = self.R_ABORTED
        elif self.in_epoch(self.t, 'decision'):
            info['gt'][int((trial['ground_truth'] / 2 + 1.5))] = 1
            gt_sign = np.sign(trial['ground_truth'])
            action_sign = np.sign(self.actions[action])
            if gt_sign == action_sign:
                reward = self.R_CORRECT
            elif gt_sign == -action_sign:
                reward = self.R_FAIL
            info['new_trial'] = self.actions[action] != 0
        else:
            info['gt'][0] = 1

        if self.in_epoch(self.t, 'f1'):
            obs[1] = self.scale_p(trial['f1']) +\
                self.rng.gauss(mu=0, sigma=self.sigma)/np.sqrt(self.dt)
        elif self.in_epoch(self.t, 'f2'):
            obs[1] = self.scale_p(trial['f2']) +\
                self.rng.gauss(mu=0, sigma=self.sigma)/np.sqrt(self.dt)

        # ---------------------------------------------------------------------
        # new trial?
        reward, info['new_trial'] = tasktools.new_trial(
            self.t, self.tmax, self.dt, info['new_trial'], self.R_MISS, reward)
        if info['new_trial']:
            self.t = 0
            self.num_tr += 1
        else:
            self.t += self.dt

        done = self.num_tr > self.num_tr_exp
        return obs, reward, done, info
예제 #8
0
    def _step(self, action):
        # ---------------------------------------------------------------------
        # Reward and inputs
        # ---------------------------------------------------------------------
        trial = self.trial
        info = {'new_trial': False}
        info['gt'] = np.zeros((2, ))
        reward = 0
        obs = np.zeros((3, ))
        if self.in_epoch(self.t, 'fixation'):
            info['gt'][0] = 1
            obs[0] = 1  # fixation cue only during fixation period
            if self.actions[action] != -1:
                info['new_trial'] = self.abort
                reward = self.R_ABORTED
        elif self.in_epoch(self.t, 'decision'):
            info['gt'][int((trial['ground_truth'] / 2 + .5))] = 1
            gt_sign = np.sign(trial['ground_truth'])
            action_sign = np.sign(self.actions[action])
            if (action_sign > 0):
                info['new_trial'] = True
                if (gt_sign > 0):
                    reward = self.R_CORRECT
                else:
                    reward = self.R_INCORRECT
        else:
            info['gt'][0] = 1

        if self.in_epoch(self.t, 'stimulus'):
            # observation
            stim = (trial['ground_truth'] > 0) + 1
            obs[stim] = 1

        # ---------------------------------------------------------------------
        # new trial?
        reward, info['new_trial'] = tasktools.new_trial(
            self.t, self.tmax, self.dt, info['new_trial'], self.R_MISS, reward)
        if info['new_trial']:
            self.t = 0
            self.num_tr += 1
        else:
            self.t += self.dt

        done = self.num_tr > self.num_tr_exp
        return obs, reward, done, info
예제 #9
0
    def _step(self, action):
        trial = self.trial

        # ---------------------------------------------------------------------
        # Reward
        # ---------------------------------------------------------------------

        # epochs = trial['epochs']
        info = {'new_trial': False}
        info['gt'] = np.zeros((3, ))
        reward = 0
        if (self.in_epoch(self.t, 'fixation')
                or self.in_epoch(self.t, 'offer-on')):
            if (action != self.actions['FIXATE']):
                info['new_trial'] = self.abort
                reward = self.R_ABORTED
        elif self.in_epoch(self.t, 'decision'):
            if action in [
                    self.actions['CHOOSE-LEFT'], self.actions['CHOOSE-RIGHT']
            ]:
                info['new_trial'] = True

                juiceL, juiceR = trial['juice']

                nB, nA = trial['offer']
                rA = nA * self.R_A
                rB = nB * self.R_B

                if juiceL == 'A':
                    rL, rR = rA, rB
                else:
                    rL, rR = rB, rA

                if action == self.actions['CHOOSE-LEFT']:
                    reward = rL
                elif action == self.actions['CHOOSE-RIGHT']:
                    reward = rR

        # ---------------------------------------------------------------------
        # Inputs
        # ---------------------------------------------------------------------
        obs = np.zeros(len(self.inputs))
        if not self.in_epoch(self.t, 'decision'):
            obs[self.inputs['FIXATION']] = 1
        if self.in_epoch(self.t, 'offer-on'):
            juiceL, juiceR = trial['juice']
            obs[self.inputs['L-' + juiceL]] = 1
            obs[self.inputs['R-' + juiceR]] = 1

            obs[self.inputs['N-L']] = self.scale(trial['nL']) +\
                self.rng.gauss(mu=0, sigma=self.sigma)/np.sqrt(self.dt)
            obs[self.inputs['N-R']] = self.scale(trial['nR']) +\
                self.rng.gauss(mu=0, sigma=self.sigma)/np.sqrt(self.dt)

        # ---------------------------------------------------------------------
        # new trial?
        reward, info['new_trial'] = tasktools.new_trial(
            self.t, self.tmax, self.dt, info['new_trial'], self.R_MISS, reward)

        if info['new_trial']:
            self.t = 0
            self.num_tr += 1
        else:
            self.t += self.dt

        done = self.num_tr > self.num_tr_exp
        return obs, reward, done, info
예제 #10
0
    def _step(self, action):
        # -----------------------------------------------------------------
        # Reward
        # -----------------------------------------------------------------
        trial = self.trial
        dt = self.dt
        rng = self.rng

        # epochs = trial['epochs']
        info = {'new_trial': False}
        info['gt'] = np.zeros((3, ))
        reward = 0
        if self.in_epoch(self.t, 'fixation'):
            info['gt'][0] = 1
            if (action != self.actions['FIXATE']):
                info['new_trial'] = self.abort
                reward = self.R_ABORTED
        elif self.in_epoch(self.t, 'decision'):
            info['gt'][int((trial['ground_truth'] / 2 + 1.5))] = 1
            if action == self.actions['left']:
                info['new_trial'] = True
                if trial['context'] == 'm':
                    correct = (trial['left_right_m'] < 0)
                else:
                    correct = (trial['left_right_c'] < 0)
                if correct:
                    reward = self.R_CORRECT
            elif action == self.actions['right']:
                info['new_trial'] = True
                if trial['context'] == 'm':
                    correct = (trial['left_right_m'] > 0)
                else:
                    correct = (trial['left_right_c'] > 0)
                if correct:
                    reward = self.R_CORRECT
        else:
            info['gt'][0] = 1

        # -------------------------------------------------------------------------------------
        # Inputs
        # -------------------------------------------------------------------------------------

        if trial['context'] == 'm':
            context = self.inputs['motion']
        else:
            context = self.inputs['color']

        if trial['left_right_m'] < 0:
            high_m = self.inputs['m-left']
            low_m = self.inputs['m-right']
        else:
            high_m = self.inputs['m-right']
            low_m = self.inputs['m-left']

        if trial['left_right_c'] < 0:
            high_c = self.inputs['c-left']
            low_c = self.inputs['c-right']
        else:
            high_c = self.inputs['c-right']
            low_c = self.inputs['c-left']

        obs = np.zeros(len(self.inputs))
        if (self.in_epoch(self.t, 'fixation')
                or self.in_epoch(self.t, 'stimulus')
                or self.in_epoch(self.t, 'delay')):
            obs[context] = 1
        if self.in_epoch(self.t, 'stimulus'):
            obs[high_m] = self.scale(+trial['coh_m']) +\
                rng.gauss(mu=0, sigma=self.sigma) / np.sqrt(dt)
            obs[low_m] = self.scale(-trial['coh_m']) +\
                rng.gauss(mu=0, sigma=self.sigma) / np.sqrt(dt)
            obs[high_c] = self.scale(+trial['coh_c']) +\
                rng.gauss(mu=0, sigma=self.sigma) / np.sqrt(dt)
            obs[low_c] = self.scale(-trial['coh_c']) +\
                rng.gauss(mu=0, sigma=self.sigma) / np.sqrt(dt)
        # ---------------------------------------------------------------------
        # new trial?
        reward, info['new_trial'] = tasktools.new_trial(
            self.t, self.tmax, self.dt, info['new_trial'], self.R_MISS, reward)

        if info['new_trial']:
            self.t = 0
            self.num_tr += 1
        else:
            self.t += self.dt

        done = self.num_tr > self.num_tr_exp
        return obs, reward, done, info