def __init__(self,
                 experiment_manifest_name,
                 subject_ids,
                 checkpoint_dir='',
                 restore_epoch=None,
                 SN_kwargs=(),
                 DG_kwargs=(),
                 RP_kwargs=(),
                 ES_kwargs=(),
                 VERBOSE=True,
                 **kwargs):

        # load the experiment_manifest
        with open(os.path.join(text_dir, experiment_manifest_name)) as file:
            self.experiment_manifest = yaml.full_load(file)

        # checks
        token_type = self.experiment_manifest[subject_ids[-1]]['token_type']
        assert token_type in TOKEN_TYPES, 'Unrecognized token_type!! -- jgm'

        # attribute
        self._token_type = token_type  # NB: changes will not propagate
        self._RP_kwargs = dict(RP_kwargs)

        # create ECoG subjects
        self.ecog_subjects = [
            ECoGSubject(self.experiment_manifest[subject_id],
                        subject_id,
                        pretrain_all_blocks=(subject_id != subject_ids[-1]),
                        **dict(ES_kwargs),
                        _DG_kwargs=dict(DG_kwargs)
                        #####
                        # target_specs=target_specs
                        #####
                        ) for subject_id in subject_ids
        ]

        # create the SequenceNetwork according to the experiment_manifest
        self.net = sequence_networks.SequenceNetwork(
            self.experiment_manifest[subject_ids[-1]],
            EOS_token=EOS_token,
            pad_token=pad_token,
            OOV_token=OOV_token,
            training_GPUs=[0],
            TARGETS_ARE_SEQUENCES='sequence' in token_type,
            VERBOSE=VERBOSE,
            **dict(SN_kwargs))

        # invoke some setters
        # NB: these attributes adjust self.ecog_subjects and self.net, so they
        #  must be invoked *after* those are created.  Hence no auto_attribute!
        self.VERBOSE = VERBOSE
        self.checkpoint_dir = checkpoint_dir
        self.restore_epoch = restore_epoch

        # update the data_manifests for our case
        for subject in self.ecog_subjects:
            for data_key, data_manifest in subject.data_manifests.items():
                try:
                    data_manifest.penalty_scale = self.experiment_manifest[
                        subject.subnet_id][data_key + '_penalty_scale']
                except KeyError:
                    pass
        self.set_feature_lists(**kwargs)
示例#2
0
    def __init__(
        self,
        experiment_manifest_name,
        subject_ids,
        checkpoint_dir='',
        restore_epoch=None,
        SN_kwargs=(),
        DG_kwargs=(),
        RP_kwargs=(),
        ES_kwargs=(),
        VERBOSE=True,
        **kwargs
    ):

        # load the experiment_manifest
        with open(os.path.join(text_dir, experiment_manifest_name)) as file:
            self.experiment_manifest = yaml.full_load(file)

        # checks
        token_type = self.experiment_manifest[subject_ids[-1]]['token_type']
        assert token_type in TOKEN_TYPES, 'Unrecognized token_type!! -- jgm'

        # attribute
        self._token_type = token_type   # NB: changes will not propagate
        self._RP_kwargs = dict(RP_kwargs)

        # create ECoG subjects
        self.ecog_subjects = [
            ECoGSubject(
                self.experiment_manifest[subject_id],
                subject_id,
                pretrain_all_blocks=(subject_id != subject_ids[-1]),
                **dict(ES_kwargs),
                _DG_kwargs=dict(DG_kwargs)
                #####
                # target_specs=target_specs
                #####
            ) for subject_id in subject_ids]

        # invoke some setters
        # NB: these attributes adjust self.ecog_subjects, so they must be
        #  invoked *after* those are created (hence no auto_attribute).  But
        #  the changes to the ecog_subjects below in turn depend on the
        #  self.checkpoint_dir, so they have to be set after these lines.
        self.VERBOSE = VERBOSE
        self.checkpoint_dir = checkpoint_dir
        self.restore_epoch = restore_epoch

        # update the data_manifests for our case
        for subject in self.ecog_subjects:
            for data_key, data_manifest in subject.data_manifests.items():
                if data_key == 'decoder_targets' and 'sequence' in token_type:
                    data_manifest.APPEND_EOS = True
                try:
                    data_manifest.penalty_scale = self.experiment_manifest[
                        subject.subnet_id][data_key + '_penalty_scale']
                except KeyError:
                    pass
        self.set_feature_lists(**kwargs)

        # create the SequenceNetwork according to the experiment_manifest
        if int(tf.__version__.split('.')[0]) == 2:
            self.net = NeuralNetwork(
                self.experiment_manifest[subject_ids[-1]],
                Seq2Seq,
                self.ecog_subjects[-1],  # temporary hack
                EOS_token=EOS_token,
                pad_token=pad_token,
                OOV_token=OOV_token,
                **dict(SN_kwargs)
            )
        else:
            self.net = sequence_networks.SequenceNetwork(
                self.experiment_manifest[subject_ids[-1]],
                EOS_token=EOS_token,
                pad_token=pad_token,
                OOV_token=OOV_token,
                training_GPUs=[0],
                TARGETS_ARE_SEQUENCES='sequence' in token_type,
                VERBOSE=VERBOSE,
                **dict(SN_kwargs)
            )

        # re-run to set the net's checkpoint_path
        self.checkpoint_dir = checkpoint_dir