Ejemplo n.º 1
0
 def __init__(self,
              encoder,
              bottleneck,
              bottleneck_latent_dim,
              latent_mixer,
              decoder,
              minimum_annealing_factor,
              annealing_epochs,
              reconstruction_field='features',
              **kwargs):
     super(RepresentationLearnerContinuousBN, self).__init__(**kwargs)
     self.encoder = utils.construct_from_kwargs(encoder)
     self.bottleneck = utils.construct_from_kwargs(
         bottleneck,
         additional_parameters={'latent_dim': bottleneck_latent_dim})
     self.latent_mixer = utils.construct_from_kwargs(
         latent_mixer,
         additional_parameters={'in_channels': bottleneck_latent_dim})
     self.latent_mixer.eval()
     mixer_out_channels = self.latent_mixer(
         torch.empty((1, 500, 1, bottleneck_latent_dim))).size(3)
     self.decoder = utils.construct_from_kwargs(
         decoder, additional_parameters={'in_channels': mixer_out_channels})
     self.minimum_annealing_factor = minimum_annealing_factor
     self.annealing_epochs = annealing_epochs
     self.num_mini_batches = len(self.dataloader)
     self.batch_id = 0
     self.start_of_training = True
     self.reconstruction_field = reconstruction_field
     self.pre_bottleneck = encoders.Identity()
     self.post_bottleneck = encoders.Identity()
     self.post_latent_mixer = encoders.Identity()
     self.input_layer = encoders.Identity()
     self.add_probes()
Ejemplo n.º 2
0
    def __init__(self,
                 encoder,
                 bottleneck,
                 bottleneck_latent_dim,
                 latent_mixer,
                 decoder,
                 minimum_annealing_factor,
                 annealing_epochs,
                 cpc=None,
                 gru_hidden_dim=64,
                 **kwargs):
        super(RepresentationLearnerContinuousBN_CPC, self).__init__(**kwargs)
        self.encoder = utils.construct_from_kwargs(encoder)
        self.bottleneck = utils.construct_from_kwargs(
            bottleneck,
            additional_parameters={'latent_dim': bottleneck_latent_dim})
        self.latent_mixer = utils.construct_from_kwargs(
            latent_mixer,
            additional_parameters={'in_channels': bottleneck_latent_dim})
        self.latent_mixer.eval()
        mixer_out_channels = self.latent_mixer(
            torch.empty((1, 100, 1, bottleneck_latent_dim))).size(3)

        self.gru_input_dim = 256
        self.gru_hidden_dim = gru_hidden_dim
        if cpc is not None:
            self.cpc = utils.construct_from_kwargs(
                cpc,
                additional_parameters={
                    "input_dim": self.gru_input_dim,
                    "output_dim": None,
                    "gru_hidden_dim": gru_hidden_dim,
                },
            )
        else:
            self.cpc = None

        # even if no CPC, do the GRU to evaluate the diff w/wo CPC
        self.CPCgru = cpc_module.GruVariableLength(self.gru_input_dim,
                                                   self.gru_hidden_dim)

        self.decoder = utils.construct_from_kwargs(
            decoder, additional_parameters={'in_channels': mixer_out_channels})
        self.minimum_annealing_factor = minimum_annealing_factor
        self.annealing_epochs = annealing_epochs
        self.num_mini_batches = len(self.dataloader)
        self.batch_id = 0
        self.start_of_training = True
        self.pre_bottleneck = encoders.Identity()
        self.post_bottleneck = encoders.Identity()
        self.post_latent_mixer = encoders.Identity()
        self.input_layer = encoders.Identity()
        self.add_probes()
        self.add_probes()
Ejemplo n.º 3
0
    def __init__(self,
                 cond_channels=(),
                 image_height=32,
                 hid_channels=128,
                 normalization=dict(class_name=nn.BatchNorm2d),
                 quantizer=dict(class_name=quantizers.L1Loss)):
        super(UpsamplingConv2d, self).__init__()

        assert len(cond_channels) <= 2
        if len(cond_channels) == 1:
            assert normalization['class_name'] is torch.nn.BatchNorm2d
        else:
            assert normalization['class_name'] is not torch.nn.BatchNorm2d, (
                "You can't use batchnorm with {} conditionings".format(
                    len(cond_channels)))
            normalization['cond_channels'] = cond_channels[1]['cond_dim']
        self.quantizer = utils.construct_from_kwargs(quantizer)

        scale = cond_channels[0]['reduction_factor']
        dim_vq = cond_channels[0]['cond_dim']

        self.conv = nn.Conv1d(dim_vq,
                              image_height // scale * 64,
                              kernel_size=1)
        num_strided = int(math.log2(scale))

        self.decoder = [
            nn.Conv2d(64, hid_channels, kernel_size=5, padding=2),
            utils.construct_from_kwargs({
                'num_features': hid_channels,
                **normalization
            }),
            nn.LeakyReLU(0.2, inplace=True),
        ]
        for i in range(num_strided):
            self.decoder += [
                nn.UpsamplingNearest2d(scale_factor=2),
                nn.Conv2d(hid_channels, hid_channels, kernel_size=5,
                          padding=2),
                utils.construct_from_kwargs({
                    'num_features': hid_channels,
                    **normalization
                }),
                nn.ReLU(inplace=True)
            ]
        self.decoder.append(
            nn.Conv2d(hid_channels, self.quantizer.num_levels, 1))
        self.decoder.append(nn.Sigmoid())
        self.decoder = nn.Sequential(*self.decoder)
Ejemplo n.º 4
0
 def __init__(self,
              image_height,
              reconstruction_channel=None,
              wave_net=dict(class_name=wavenet.WaveNet, ),
              quantizer=dict(class_name=quantizers.BinaryXEntropy),
              **kwargs):
     super(ColumnwiseWaveNet, self).__init__()
     self.quantizer = utils.construct_from_kwargs(quantizer)
     if 'in_channels' in wave_net or 'out_channels' in wave_net:
         raise ValueError('Channels for ColumnwiseWaveNet are auto set')
     wave_net['in_channels'] = image_height
     wave_net['out_channels'] = image_height * self.quantizer.num_levels
     wave_net.update(kwargs)
     self.wave_net = utils.construct_from_kwargs(wave_net)
     self.reconstruction_channel = reconstruction_channel
Ejemplo n.º 5
0
    def __init__(self, dataset, chunk_len, feature_field="features", every_n=1):
        self.dataset = utils.construct_from_kwargs(dataset)
        self.chunk_len = chunk_len
        self.feature_field = feature_field
        self.alphabet = self.dataset.alphabet
        self.every_n = every_n

        self.item_id_shift = []

        self.all_targets = []
        for i in range(len(self.dataset)):
            self.item_id_shift.append(len(self.all_targets))
            item = self.dataset[i]
            text = item["text"]
            alignment_rle = item["alignment_rle"]
            data = item[self.feature_field]

            for unit_id, (start, end) in zip(text, alignment_rle):
                for pos in range(int(start), int(end + 1)):
                    if (
                        pos - self.chunk_len // 2 >= 0
                        and pos - self.chunk_len // 2 + self.chunk_len - 1 < len(data)
                    ):
                        self.all_targets.append(unit_id)

        self.all_targets = torch.tensor(self.all_targets)
Ejemplo n.º 6
0
    def __init__(self,
                 encoder,
                 num_classes=-1,
                 allow_too_long_transcripts=False,
                 alignment_name="",
                 forced_alignment=False,
                 verbose=0,
                 **kwargs):
        super(CTCModel, self).__init__(**kwargs)

        # Determine number of classes for output layer
        alternativeNumClasses = len(self.dataset.alphabet)
        if alternativeNumClasses > 0 and alternativeNumClasses != num_classes:
            print ("CTCModel __init__() override yaml num_classes (" + str(num_classes) + \
                    ") with dataset alphabet num_classes (" + str(alternativeNumClasses) + ")" )
            num_classes = len(self.dataset.alphabet)
        assert num_classes > 0

        # Determine from .yaml if we need to produce an output path (alignment)
        # As we may need alignments for MNIST and ScribbleLens, we put the aligner in the model, not dataset
        self.aligner = None
        if alignment_name != "":
            if forced_alignment:
                self.aligner = aligner.ForcedAligner(alignment_name,
                                                     self.dataset.alphabet)
            else:
                self.aligner = aligner.Aligner(alignment_name,
                                               self.dataset.alphabet)

        self.encoder = utils.construct_from_kwargs(encoder)
        self.verbose = verbose
        self.projection = torch.nn.Linear(self.encoder.output_dim, num_classes)
        self.ctc = torch.nn.CTCLoss(blank=0,
                                    reduction='sum',
                                    zero_infinity=allow_too_long_transcripts)
Ejemplo n.º 7
0
    def __init__(
        self,
        dataset,
        varlen_fields=("features", "targets", "alignment"),
        rename_dict=None,
        collate_fn=None,
        ratio=None,
        **kwargs,
    ):
        assert not collate_fn
        if rename_dict is None:
            rename_dict = {}
        dataset = utils.construct_from_kwargs(dataset)
        collate_fn = PaddingCollater(
            varlen_fields=varlen_fields, rename_dict=rename_dict
        )
        sampler = None
        if ratio is not None:
            sampler = get_partial_sampler(dataset, ratio)
        super(PaddedDatasetLoader, self).__init__(
            dataset=dataset, collate_fn=collate_fn, sampler=sampler, **kwargs
        )

        self.metadata = {
            rename_dict.get(k, k): v for k, v in self.dataset.metadata.items()
        }
Ejemplo n.º 8
0
 def __init__(self,
              image_height,
              cond_channels,
              stack=dict(class_name=convolutional.ConvStack1D,
                         num_postproc=2),
              out_channels=1,
              reconstruction_channel=None,
              **kwargs):
     """
     Args:
         image_heigth: image_heigth of the reconstruction
         cond_channels: simensionality of conditioning
         stack: the rec stack to use
         out_channels: number of output channels
         reconstruction_channel: limit the reconsturction to only this
                                 channel. Forces out_channels = 1
     """
     super(StackReconstructor, self).__init__(**kwargs)
     self.reconstruction_channel = reconstruction_channel
     if self.reconstruction_channel is not None:
         assert out_channels == 1
     in_channels = sum([c['cond_dim'] for c in cond_channels])
     stack['in_channels'] = in_channels
     self.stack = utils.construct_from_kwargs(stack)
     self.stack.eval()
     stack_out_shape = self.stack(torch.empty(
         (1, 100, 1, in_channels))).size()
     self.proj = nn.Linear(
         stack_out_shape[-1],
         image_height * out_channels * self.quantizer.num_levels)
Ejemplo n.º 9
0
 def __init__(self,
              reconstruction_channel=None,
              quantizer=dict(class_name=quantizers.SoftmaxUniformQuantizer,
                             num_levels=4),
              **kwargs):
     super(BasePixelCNN, self).__init__(**kwargs)
     self.quantizer = utils.construct_from_kwargs(quantizer)
     self.reconstruction_channel = reconstruction_channel
Ejemplo n.º 10
0
 def _getitem(self, key, additional_parameters=None):
     if key not in self.cache:
         # make a copy since we may change the dict in the end
         opts = dict(get_val(self.objects_config, key, self.name))
         if 'class_name' not in opts:
             opts['class_name'] = self.default_class_dict[key]
         self.cache[key] = utils.construct_from_kwargs(
             opts, self.default_modules_dict.get(key),
             additional_parameters)
     return self.cache[key]
Ejemplo n.º 11
0
    def __init__(self,
                 input_dim,
                 image_height,
                 num_layers,
                 in_channels=64,
                 out_channels=1,
                 hid_channels=64,
                 use_sigmoid=True,
                 quantizer=dict(class_name=quantizers.L1Loss),
                 normalization=dict(class_name=torch.nn.BatchNorm2d),
                 **kwargs):
        super(DownsamplingDecoder2D, self).__init__(**kwargs)
        self.quantizer = utils.construct_from_kwargs(quantizer)

        self.image_height = image_height
        self.hid_channels = hid_channels
        self.conv_input_dim = input_dim // image_height
        self.conv = nn.Conv1d(input_dim, 64 * image_height, kernel_size=1)

        conv_stack = []
        for _ in range(num_layers - 1):
            conv_stack += [
                DownsamplingResBlock(hid_channels, hid_channels, hid_channels),
                utils.construct_from_kwargs({
                    'num_features': hid_channels,
                    **normalization
                }),
                nn.ReLU(inplace=True)
            ]

        conv_stack += [
            nn.Conv2d(hid_channels,
                      out_channels * self.quantizer.num_levels,
                      stride=1,
                      kernel_size=3,
                      padding=1)
        ]
        if use_sigmoid:
            conv_stack += [nn.Sigmoid()]

        self.conv_stack = nn.Sequential(*conv_stack)
        self.apply(utils.conv_weights_xavier_init)
Ejemplo n.º 12
0
    def __init__(
            self,
            condition_on_alignment=False,
            num_alignment_classes=11,  # For MNIST
            image_height=28,
            cond_mixer=dict(class_name=misc.IdentityForgetKWargs),
            reconstructor=[
                dict(class_name=reconstructors.ColumnGatedPixelCNN, ),
            ],
            **kwargs):
        super(Autoregressive2D, self).__init__(**kwargs)
        self.condition_on_alignment = condition_on_alignment
        if self.condition_on_alignment:
            self.num_alignment_classes = num_alignment_classes
            self.cond_mixer = utils.construct_from_kwargs(cond_mixer)
            self.cond_mixer.eval()
            mixer_out_channels = self.cond_mixer(
                torch.empty((1, 100, 1, num_alignment_classes))).size(3)
            self.cond_channels = ({
                'cond_dim': mixer_out_channels,
                'reduction_factor': 1
            }, )
        else:
            self.cond_channels = ()

        rec_params = {
            'image_height': image_height,
            'cond_channels': cond_channels_spec
        }
        # Compatibility with single-reconstructor checkpoints
        if 'class_name' in reconstructor:
            self.reconstructor = utils.construct_from_kwargs(
                reconstructor, additional_parameters=rec_params)
            self.reconstructors = {'': self.reconstructor}
        else:
            self.reconstructors = nn.ModuleDict({
                name:
                utils.construct_from_kwargs(rec,
                                            additional_parameters=rec_params)
                for name, rec in reconstructor.items()
            })
Ejemplo n.º 13
0
 def test_all(self):
     conf = {
         'class_name': 'distsup.data.CachedDataset',
         'real_class_name': 'distsup.tests.test_data.DummyDataset',
     }
     ds = utils.construct_from_kwargs(conf)
     assert ds[0]['num'] == 0
     assert len(ds) == len(ds._wrapped) == 10
     for i in range(3):
         ds[1]
     assert ds.counts[0] == ds.counts[1] == 1
     assert ds.counts[2] == ds.counts[3] == 0
     assert ds.get_counts(0) == 1
     assert ds.count_prop[0] == 1
Ejemplo n.º 14
0
 def __init__(
     self,
     dataset,
     chunk_len,
     varlen_fields=("features",),
     drop_fields=(),
     training=False,
     transform=None,
     oversample=1,
 ):
     self.dataset = utils.construct_from_kwargs(dataset)
     self.chunk_len = chunk_len
     self.varlen_fields = varlen_fields
     self.drop_fields = set(drop_fields)
     self.training = training
     if transform:
         self.transform = utils.construct_from_kwargs(transform)
     else:
         self.transform = None
     if not training:
         self.transform = None
     self.oversample = oversample
     assert self.training or self.oversample == 1
Ejemplo n.º 15
0
    def __init__(
        self, dataset, field_names=("features", "targets"), rename_dict=None,
        ratio=None, **kwargs
    ):
        self.field_names = field_names
        self.rename_dict = rename_dict or {}
        dataset = utils.construct_from_kwargs(dataset)
        sampler = None
        if ratio is not None:
            sampler = get_partial_sampler(dataset, ratio)
        super(FixedDatasetLoader, self).__init__(dataset=dataset,
                sampler=sampler, **kwargs)

        if hasattr(self.dataset, "metadata"):
            self.metadata = {
                self.rename_dict.get(k, k): v for k, v in self.dataset.metadata.items()
            }
Ejemplo n.º 16
0
    def __init__(
            self,
            embedding_dim,
            image_height=28,
            len_reduction=4,
            reconstructor={
                "class_name":
                "distsup.modules.reconstructors.ColumnGatedPixelCNN",
                "quantizer": {
                    "class_name":
                    "distsup.modules.quantizers.SoftmaxUniformQuantizer",
                    "num_levels": 16,
                },
            },
            device="cpu",
            ignore_alignment=False,
            count_blanks_alignment=True,
            advantage_digits=True,
            alignment_noise_pbb=None,
            **kwargs):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.image_height = image_height
        self.len_reduction = len_reduction
        cond_channels_spec = [{
            "cond_dim": embedding_dim,
            "reduction_factor": len_reduction
        }]
        rec_params = {
            "image_height": image_height,
            "cond_channels": cond_channels_spec
        }

        self.device = device
        self.reconstructor = utils.construct_from_kwargs(
            reconstructor, additional_parameters=rec_params).to(device)

        self.ignore_alignment = ignore_alignment
        self.count_blanks_alignment = count_blanks_alignment
        self.advantage_digits = advantage_digits
        self.alignment_noise_pbb = alignment_noise_pbb
        self.mask = torch.ones(embedding_dim).long().to(device)
        if not self.count_blanks_alignment:
            self.mask[0] = 0
Ejemplo n.º 17
0
    def __init__(self,
                 encoder,
                 num_classes=None,
                 allow_too_long_transcripts=False,
                 adv_size=0,
                 **kwargs):
        super(SimpleCTCModel, self).__init__(**kwargs)
        if num_classes is None:
            num_classes = self.dataloader.metadata['targets']['num_categories']
        self.encoder = utils.construct_from_kwargs(encoder)
        self.projection = torch.nn.Linear(self.encoder.output_dim, num_classes)
        self.ctc = torch.nn.CTCLoss(reduction='sum',
                                    zero_infinity=allow_too_long_transcripts)
        self.adversarial = None

        if adv_size != 0:
            self.adversarial = Adversarial(GlobalPredictor(
                self.encoder.output_dim, adv_size, aggreg=10),
                                           mode='maxent')
Ejemplo n.º 18
0
    def __init__(self,
                 classifier,
                 temperature,
                 frame_pair_batch_size=10000,
                 scoring_map=None,
                 lambda_fs=0.5,
                 verbose=0,
                 **kwargs):
        """
        Args:
            classifier: classifier that maps batches of feature chunks
                to output logits. See DistributionMatchingClassifier.
            temperature: temperature used in the distribution matching estimation
            frame_pair_batch_size: batch size used for the secondary 
                (frame similarity) loss.
            scoring_map: (optional) used for mapping output characters to
                other characters, for scoring purposes (needed for the
                stupid TIMIT evaluation scheme).      
        """
        super(DistributionMatchingModel, self).__init__(**kwargs)
        self.classifier = utils.construct_from_kwargs(classifier)
        self.alphabet = self.dataset.alphabet
        self.lm_order = self.dataset.order

        self.lm_probs = self.dataset.lm_probs
        self.lm_probs_ngrams = self.dataset.lm_probs_ngrams

        if Globals.cuda:
            self.lm_probs = self.lm_probs.to('cuda')
            self.lm_probs_ngrams = self.lm_probs_ngrams.to('cuda')

        self.temperature = temperature
        self.verbose = verbose
        self.lambda_fs = lambda_fs
        self.pair_batches = self.dataset.get_frame_pair_iter(
            frame_pair_batch_size)

        self.classifier.init_output_biases(
            self.dataset.output_frequencies.log())

        self.scoring_map = None
        if scoring_map is not None:
            self.scoring_map = read_scoring_map(scoring_map)
Ejemplo n.º 19
0
 def __init__(self, real_class_name, **kwargs):
     super(CachedDataset, self).__init__()
     self._wrapped = utils.construct_from_kwargs(
         {"class_name": real_class_name}, additional_parameters=kwargs
     )
     self._cache = [None] * len(self._wrapped)
Ejemplo n.º 20
0
    def __init__(self,
                 image_height,
                 cond_channels=(),
                 stride=2,
                 hid_channels=256,
                 use_sigmoid=True,
                 use_pixelshuffle=True,
                 resblocks=2,
                 out_channels=1,
                 quantizer=dict(class_name=quantizers.L1Loss),
                 normalization=dict(class_name=torch.nn.BatchNorm2d),
                 **kwargs):
        super(Decoder_2d, self).__init__(**kwargs)
        self.quantizer = utils.construct_from_kwargs(quantizer)

        assert len(cond_channels) <= 2
        if len(cond_channels) == 1:
            assert normalization['class_name'] is torch.nn.BatchNorm2d
        else:
            assert normalization['class_name'] is not torch.nn.BatchNorm2d
            normalization['cond_channels'] = cond_channels[1]['cond_dim']

        scale = cond_channels[0]['reduction_factor']
        dim_vq = cond_channels[0]['cond_dim']

        self.conv = nn.Conv1d(dim_vq, (image_height + scale - 1) // scale * 64,
                              kernel_size=1)

        num_strided = int(math.log(scale) / math.log(stride))
        padding, output_padding = [], []
        if stride == 3:
            for i in range(num_strided):
                padding.append(0)
                output_padding.append(0)
        elif stride == 2:
            for i in range(num_strided):
                if i % 2 == 0:
                    padding.append(1)
                    output_padding.append(0)
                else:
                    padding.append(0)
                    output_padding.append(1)
            if num_strided % 2 == 1:
                output_padding[-1] = 1
        if num_strided == 0:
            decoder = [
                dict(in_channels=64,
                     out_channels=out_channels * self.quantizer.num_levels,
                     kernel_size=1,
                     stride=1,
                     padding=0,
                     bias=True,
                     output_padding=0)
            ]
        else:
            decoder = [
                dict(in_channels=64,
                     out_channels=hid_channels,
                     padding=2,
                     kernel_size=5,
                     stride=1,
                     bias=False),
                *[
                    dict(in_channels=hid_channels,
                         out_channels=hid_channels if idx != num_strided - 1
                         else out_channels * self.quantizer.num_levels,
                         kernel_size=3,
                         padding=padding[idx],
                         output_padding=output_padding[idx],
                         stride=stride,
                         bias=False if idx != num_strided - 1 else True)
                    for idx in range(num_strided)
                ],
            ]

        self.decoder = []
        for conv in decoder[:-1]:
            #if len(self.decoder) > 0:
            #    self.decoder.extend([ResBlock(1, hid_channels, hid_channels)] * resblocks)
            if (not use_pixelshuffle):
                self.decoder += [
                    nn.ConvTranspose2d(**conv),
                    utils.construct_from_kwargs({
                        'num_features':
                        conv['out_channels'],
                        **normalization
                    }),
                    nn.ReLU(inplace=True)
                ]
            else:
                if conv['kernel_size'] == 5:
                    self.decoder += [
                        ResBlock(stride * stride, 64 // (stride * stride),
                                 hid_channels // (stride * stride)),
                        utils.construct_from_kwargs({
                            'num_features':
                            conv['out_channels'],
                            **normalization
                        }),
                        nn.ReLU(inplace=True)
                    ]
                else:
                    self.decoder += [
                        #nn.Conv2d(hid_channels, hid_channels * (stride * stride), kernel_size=3, padding=1),
                        ResBlock(stride * stride,
                                 hid_channels // (stride * stride),
                                 hid_channels),
                        nn.PixelShuffle(stride),
                        utils.construct_from_kwargs({
                            'num_features':
                            conv['out_channels'],
                            **normalization
                        }),
                        nn.ReLU(inplace=True)
                    ]
        if use_pixelshuffle:
            self.decoder.append(
                nn.Conv2d(hid_channels,
                          out_channels * self.quantizer.num_levels *
                          (stride * stride),
                          kernel_size=3,
                          padding=1))
            self.decoder.append(nn.PixelShuffle(stride))
        else:
            self.decoder.append(nn.ConvTranspose2d(**decoder[-1]))
        if use_sigmoid:
            self.decoder.append(nn.Sigmoid())
        self.decoder = nn.Sequential(*self.decoder)

        self.apply(utils.conv_weights_xavier_init)
Ejemplo n.º 21
0
 def __init__(self, quantizer=dict(class_name=quantizers.L1Loss), **kwargs):
     super(BaseReconstructor, self).__init__(**kwargs)
     self.quantizer = utils.construct_from_kwargs(quantizer)
Ejemplo n.º 22
0
    def add_probes(self):
        """Method to attach the probes to the models.

        :param probes: list of dict with the probes to attach. Keys are the probe names, values are dicts containing:
         - layer: str the name of the member variable the output on which to run the predictor,
         - target: str the field of the batch where to find the target that needs to be predicted
         - predictor: dict with the the predictor to be added
         - **kwargs: other kwargs of the probe class
        :return:
        """

        probes_dict = self.probes_dict
        self.probes_dict = {}
        self.probes = torch.nn.ModuleDict()

        def _register_output_shape(module, _, mod_output, name):
            if isinstance(mod_output, (list, tuple)):
                mod_output = mod_output[0]

            if isinstance(mod_output, torch.Tensor):
                module.output_shape = mod_output.shape

            else:
                mod_logger.warning(f'Could not gather shape of module {name} / output #0')

        probe_hooks = []
        for name, m in self.named_modules():
            probe_hooks.append(m.register_forward_hook(lambda mod, in_mod, out_mod, name=name:
                                                       _register_output_shape(mod, in_mod, out_mod, name)))

        test_batch, _ = self.test_batch_forward()

        for hook in probe_hooks:
            hook.remove()

        # Attach each probe to its place and build the predictors
        for probe_name, probe_dict in probes_dict.items():
            required_keys = {'layer', 'target', 'predictor'}
            if required_keys - probe_dict.keys():
                mod_logger.error(f"Expecting the following keys in the '{probe_name}' probe"
                                 f" dictionary: {required_keys}. Only found: {probe_dict.keys()}."
                                 f"Ignoring probe {probe_name}.")
            all_known_keys = required_keys | {'bp_to_main', 'learning_rate',
                                              'which_out', 'requires'}
            if set(probe_dict.keys()) - all_known_keys:
                raise Exception(f'Probe {probe_name} has unsupported arguments: '
                                f'{set(probe_dict.keys()) - all_known_keys}.')

            mod_logger.info(f'Adding probe {probe_name} to {probe_dict["layer"]}.')

            layer = self.get_named_module(probe_dict['layer'], probe_name)

            if probe_dict['target'] not in test_batch:
                mod_logger.error(f'Could not found target named {probe_dict["target"]} in test batch. '
                                 f'Available keys are: {test_batch.keys()}.'
                                 f'Ignoring probe {probe_name}')

            if len(layer.output_shape) != 4:
                mod_logger.error(f'Expecting data layout of the layer to be B x W x H x C.'
                                 f'The two last dimensions will be flatten when fed to the probe.'
                                 f'Currently obtained {layer.output_shape}')

            input_dim = layer.output_shape[-2] * layer.output_shape[-1]

            additional_predictor_parameters = {'input_dim': input_dim}

            if 'requires' in probe_dict:
                try:
                    # Recursively retrieve, e.g. self.bottleneck.num_tokens
                    val = reduce(lambda obj,attr: getattr(obj, attr),
                                 probe_dict['requires'].split('.'), self)
                    name = probe_dict['requires'].split('.')[-1]
                    additional_predictor_parameters.update({name: val})
                except AttributeError:
                    logger.warning(f"{probe_name} disabled; "
                                   f"{probe_dict['requires']} not available")
                    continue

            dataloader = self.dataloader
            if (hasattr(dataloader, 'metadata') and
                    (probe_dict['target'] in dataloader.metadata) and
                    (dataloader.metadata[probe_dict['target']]['type'] == 'categorical')):
                target_size = dataloader.metadata[probe_dict['target']]['num_categories']
                additional_predictor_parameters['output_dim'] = target_size

            predictor_dict = probe_dict['predictor']
            predictor = utils.construct_from_kwargs(predictor_dict,
                                                    additional_parameters=additional_predictor_parameters)
            probe = attach_auxiliary(layer,
                                     predictor,
                                     bp_to_main=probe_dict.get('bp_to_main', False),
                                     which_out=probe_dict.get('which_out', 0))

            self.probes_dict[probe_name] = probe_dict
            self.probes[probe_name] = probe
Ejemplo n.º 23
0
 def __init__(self, model, prob=0.12, **kwargs):
     super(Jitter, self).__init__()
     self.model = utils.construct_from_kwargs(model,
                                              additional_parameters=kwargs)
     self.prob = prob
Ejemplo n.º 24
0
    def run(self, save_dir, model, train_dataset, eval_datasets=None,
            saved_state=None, debug_skip_training=False,
            probe_train_data=None):
        if saved_state:
            model.load_state_dict(saved_state['state_dict'])
            for k in saved_state:
                if k.startswith('avg_state_dict'):
                    print("Loading poyak's ", k)
                    setattr(model, k, saved_state[k])
        if eval_datasets is None:
            eval_datasets = {}
        if Globals.cuda:
            model.cuda()
            GPUs = [f"{i}) {torch.cuda.get_device_name(i)}"
                    for i in range(torch.cuda.device_count())]
            print(f"Trainer using GPUs: {','.join(GPUs)}.")
        else:
            print("Trainer not using GPU.")
        proto = getattr(torch.optim, self.optimizer_name)

        if self.codebook_lr is not None:
            optimizer = proto(
                [{'params': model.bottleneck.embedding.parameters(),
                  'lr': self.codebook_lr},
                 {'params': model.get_parameters_for_optimizer(with_codebook=False),
                  'lr': self.learning_rate}],
                lr=self.learning_rate, **self.optimizer_kwargs)
        else:
            optimizer = proto(
                model.get_parameters_for_optimizer(with_codebook=True),
                lr=self.learning_rate, **self.optimizer_kwargs)

        self.lr_scheduler_params['optimizer'] = optimizer
        lr_scheduler = utils.construct_from_kwargs(self.lr_scheduler_params)
        if saved_state:
            optimizer.load_state_dict(saved_state['optimizer'])
            lr_scheduler.load_state_dict(saved_state['lr_scheduler'])
        print(f"Optimizer: {optimizer}")

        if saved_state:
            self.current_iteration = saved_state['current_iteration']
            start_epoch = saved_state['epoch'] + 1
        else:
            self.current_iteration = 0
            start_epoch = 1

        self.checkpointer.set_save_dir(save_dir)

        if self.log_layers_stats:
            self.dbg = DebugStats.attach(model, logger)

        for epoch in range(start_epoch, self.num_epochs+1):
            Globals.epoch = epoch
            self.iterate_epoch(
                epoch, save_dir, model, train_dataset, eval_datasets,
                optimizer, lr_scheduler,
                debug_skip_training=debug_skip_training)

        if probe_train_data is not None:
            print(f"re-train all probes")
            probe_parameters = []
            for p in model.parameters():
                p.requires_grad = False
            for _, probe in model.probes.items():
                print(probe)
                for name,layer in probe.named_modules():
                    has_parameters = False
                    for name, param in layer.named_parameters():
                        if not "." in name:
                            has_parameters = True
                    if not has_parameters:
                        continue
                    if hasattr(layer, 'reset_parameters'):
                        layer.reset_parameters()
                    else:
                        print("WARNING: skip layer {0}".format(name))
                probe_parameters.extend(list(probe.parameters()))
                for name, param in probe.named_parameters():
                    print("re-train {0}".format(name))
                    param.requires_grad = True
            optimizer = proto(probe_parameters,
                              lr=self.learning_rate, **self.optimizer_kwargs)
            print("learning rate set to {0}".format(self.learning_rate))
            self.lr_scheduler_params['optimizer'] = optimizer
            lr_scheduler = utils.construct_from_kwargs(self.lr_scheduler_params)
            tmp = probe_train_data
            self.checkpointer.enabled = False

            logger.end_log()
            for epoch in range(1, 11):
                Globals.epoch = epoch
                #with torch.backends.cudnn.flags(enabled=False):
                self.iterate_epoch(
                    epoch, save_dir + "/probe_train/", model, probe_train_data, eval_datasets,
                    optimizer, lr_scheduler,
                    debug_skip_training=debug_skip_training,
                    only_train_probes=True)
Ejemplo n.º 25
0
    def __init__(
            self,
            root='data/scribblelens.corpus.v1.zip',
            dataframe_filename='data/scribblelens.corpus.v1.pkl',
            alignment_root="",  # Default empty i.e. unused
            split=None,
            slice=None,  # tasman, kieft, brouwers
            slice_query=None,
            slice_filename=None,
            colormode='bw',
            vocabulary="",  # The alphabet filename in json format
            vocabulary_query=None,
            write_vocabulary=False,
            transcript_mode=2,
            target_height=32,
            target_width=-1,
            transform=None):
        """
        Args:
            root (string): Root directory of the dataset.
            alignmentRoot (string): Root directory of the path alignments. There should be one .ali file per image.
            split (string): The subset of data to provide.
                Choices are: train, test, supervised, unsupervised.
            slice_filename (string): Don't use existing slice and use a custom slice from a filename. The file
                should use the same format as in the dataset.
            colormode (string): The color of data to provide.
                Choices are: bw, color, gray.
            alphabet (dictionary): Pass in a pre-build alphabet from external source, or build during training if empty
            transcript_mode(int): Defines how we process space in target text, and blanks in targets [1..5]
            target_height (int, None): The height in pixels to which to resize the images.
                Use None for the original size, -1 for proportional scaling
            target_width (int, None): The width in pixels to which to resize the images.
                Use None for the original size, -1 for proportional scaling
            transform (callable, optional): Optional transform to be applied
                on a sample.
        Note:
            The alphabet or vocabulary needs to be generated with an extra tool like
                     generateAlphabet.py egs/scribblelens/yamls/tasman.yaml
        """
        self.root = root
        self.write_vocabulary = write_vocabulary
        self.vocabulary_query = vocabulary_query

        self.file = zipfile.ZipFile(root)

        self.target_width = target_width
        self.target_height = target_height
        if transform:
            self.pre_transform = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                construct_from_kwargs(
                    transform, additional_parameters={'scribblelens': True}),
                torchvision.transforms.ToPILImage(),
            ])
        else:
            self.pre_transform = None

        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Grayscale(),
            torchvision.transforms.ToTensor(),
        ])

        logging.debug(f"ScribbleLensDataset() constructor for split = {split}")

        self.dataframe_filename = dataframe_filename
        df = pd.read_pickle(self.dataframe_filename)
        df['alignment'] = np.nan
        df['alignment_rle'] = np.nan
        df['alignment_text'] = np.nan
        df['text'] = np.nan

        df['alignment'] = df['alignment'].astype(object)
        df['alignment_rle'] = df['alignment_rle'].astype(object)
        df['alignment_text'] = df['alignment_text'].astype(object)
        df['text'] = df['text'].astype(object)

        # 'vocabulary' Filename from .yaml. alphabet has the vocabulary as a dictionary for CTC output targets.
        self.transcriptMode = transcript_mode
        assert (1 <= self.transcriptMode <= 5)

        self.vocabulary = vocabulary
        if self.vocabulary != "" and not os.path.isfile(self.vocabulary):
            logging.error(
                f"You specified a vocabulary that does not exist: {self.vocabulary}"
            )
            sys.exit(4)

        self.alphabet = Alphabet(self.vocabulary)
        self.vocab_size = len(self.alphabet)
        self.must_create_alphabet = ((self.vocabulary == '')
                                     or self.write_vocabulary)
        self.nLines = 0

        authors = {
            'tasman', 'zeewijck', 'brouwer.chili', 'craen.de.vos.ijszee',
            'van.neck.tweede', 'van.neck.vierde', 'kieft'
        }

        assert (colormode in {'bw', 'color', 'gray'})
        assert (split is None or split
                in {'all', 'train', 'test', 'supervised', 'unsupervised'})
        assert (slice is None
                or slice in ({'empty', 'query', 'custom'} | authors))
        assert target_height != -1 or target_width != -1

        self.slice = slice
        self.split = split

        if self.must_create_alphabet:
            if self.vocabulary_query is None:
                self.vocabulary_query = "split == 'train'"
                logging.warning(
                    f'The vocabulary query has not been set. Setting it to "{self.vocabulary_query}"'
                )

            # Build the alphabet using the training data only
            self.build_alphabet(df.query(self.vocabulary_query), self.alphabet,
                                transcript_mode)

        # Select the data

        if slice_filename is not None and self.slice != 'custom':
            logging.error(f'Slice filename set to "{slice_filename}" '
                          f'yet slice is not "custom" but "{self.slice}".')
            sys.exit(1)

        if slice_query is not None and self.slice != 'query':
            logging.error(f'Slice query set to "{slice_filename}" '
                          f'yet slice is not "query" but "{self.slice}".')
            sys.exit(1)

        df_selection = df
        if self.slice == 'query':
            # Select with query
            df_selection = df.query(slice_query)

        elif self.slice == 'custom':
            assert slice_filename is not None
            # Select with query
            with open(slice_filename) as f:
                img_filenames = [
                    l.strip().split()[0] for l in f.read().split()
                    if l.strip()
                ]

            df_selection = df[df['img_filename'].isin(set(img_filenames))]

        elif self.slice in authors:
            df_selection = df[df['scribe'] == self.slice]

        elif self.slice is None:
            pass

        else:
            raise ValueError(f'Slice "{self.slice}" not available')

        if self.split is not None:
            if self.split == 'supervised':
                df_selection = df_selection[df_selection['transcribed']]

            elif self.split == 'unsupervised':
                df_selection = df_selection[~df_selection['transcribed']]

            elif self.split in {'train', 'test'}:
                df_selection = df_selection[df_selection['split'] ==
                                            self.split]

            else:
                raise ValueError(f'Split "{self.split}" not available.')

        self.get_transcriptions(df_selection, self.alphabet, transcript_mode)

        self.get_alignments(df_selection, alignment_root)

        self.metadata = {
            'alignment': {
                'type': 'categorical',
                'num_categories': len(self.alphabet)
            },
            'text': {
                'type': 'categorical',
                'num_categories': len(self.alphabet)
            },
        }

        self.file.close()
        self.file = None

        self.df = df_selection
Ejemplo n.º 26
0
    def __init__(
            self,
            image_height=28,
            in_channels=1,
            encoder=dict(class_name=convolutional.ConvStack1D,
                         hid_channels=64,
                         num_strided=2,
                         num_dilated=2,
                         num_postdil=3),
            bottleneck=dict(class_name=bottlenecks.VQBottleneck,
                            num_tokens=16),
            bottleneck_latent_dim=64,
            latent_mixer=dict(class_name=convolutional.ConvStack1D,
                              hid_channels=64,
                              num_dilated=2,
                              num_postdil=2),
            reconstructor={
                # name: dict
                '': dict(class_name=reconstructors.ColumnGatedPixelCNN, ),
            },
            reconstructor_field=None,
            side_info_encoder=None,
            bottleneck_cond=None,
            **kwargs):
        super(RepresentationLearner, self).__init__(**kwargs)
        self.encoder = utils.construct_from_kwargs(encoder,
                                                   additional_parameters={
                                                       'in_channels':
                                                       in_channels,
                                                       'image_height':
                                                       image_height
                                                   })
        # prevent affecting the encoder by the dummy minibatch
        self.encoder.eval()
        enc_out_shape = self.encoder(
            torch.empty((1, 500, image_height, in_channels))).size()

        self.bottleneck = utils.construct_from_kwargs(
            bottleneck,
            additional_parameters=dict(in_dim=enc_out_shape[2] *
                                       enc_out_shape[3],
                                       latent_dim=bottleneck_latent_dim))

        self.latent_mixer = utils.construct_from_kwargs(
            latent_mixer,
            additional_parameters={'in_channels': bottleneck_latent_dim})
        # prevent affecting the latent_mixer by the dummy minibatch
        self.latent_mixer.eval()
        mixer_out_channels = self.latent_mixer(
            torch.empty((1, 500, 1, bottleneck_latent_dim))).size(3)

        cond_channels_spec = [{
            'cond_dim': mixer_out_channels,
            'reduction_factor': self.encoder.length_reduction
        }]

        self.side_info_encoder = None
        if side_info_encoder is not None:
            self.side_info_encoder = utils.construct_from_kwargs(
                side_info_encoder)
            cond_channels_spec.append({
                'cond_dim':
                side_info_encoder['embedding_dim'],
                'reduction_factor':
                0
            })
        self.bottleneck_cond = lambda x: None
        if bottleneck_cond is not None:
            self.bottleneck_cond = utils.construct_from_kwargs(bottleneck_cond)

        rec_params = {
            'image_height': image_height,
            'cond_channels': cond_channels_spec
        }

        # Compatibility with single-reconstructor checkpoints
        if 'class_name' in reconstructor:
            self.reconstructor = utils.construct_from_kwargs(
                reconstructor, additional_parameters=rec_params)
            self.reconstructors = {'': self.reconstructor}
        else:
            self.reconstructors = nn.ModuleDict({
                name:
                utils.construct_from_kwargs(rec,
                                            additional_parameters=rec_params)
                for name, rec in reconstructor.items()
            })

        if reconstructor_field is None:
            self.reconstructors_fields = [
                'features' for _ in self.reconstructors
            ]

        elif isinstance(reconstructor_field, str):
            self.reconstructors_fields = [
                reconstructor_field for _ in self.reconstructors
            ]

        elif isinstance(reconstructor_field, list):
            self.reconstructors_fields = reconstructor_field

        else:
            raise ValueError(
                f"'reconstructor_field' must be a None, str, or a list. Currently {reconstructor_field}"
            )

        assert len(self.reconstructors_fields) == len(self.reconstructors), \
            'The reconstructor_field parameter should have as many elements as reconstructors there are.'

        self.input_layer = encoders.Identity()
        self.add_probes()
Ejemplo n.º 27
0
    def __init__(
            self,
            root='data/scribblelens.corpus.v1.2.zip',
            alignment_root="",  # Default empty i.e. unused
            split='supervised',
            slice='empty',  # tasman, kieft, brouwers
            slice_filename=None,
            colormode='bw',
            vocabulary="",  # The alphabet filename in json format
            transcript_mode=2,
            target_height=32,
            target_width=-1,
            transform=None):
        """
        Args:
            root (string): Root directory of the dataset.
            alignmentRoot (string): Root directory of the path alignments. There should be one .ali file per image.
            split (string): The subset of data to provide.
                Choices are: train, test, supervised, unsupervised.
            slice_filename (string): Don't use existing slice and use a custom slice from a filename. The file
                should use the same format as in the dataset.
            colormode (string): The color of data to provide.
                Choices are: bw, color, gray.
            alphabet (dictionary): Pass in a pre-build alphabet from external source, or build during training if empty
            transcript_mode(int): Defines how we process space in target text, and blanks in targets [1..5]
            target_height (int, None): The height in pixels to which to resize the images.
                Use None for the original size, -1 for proportional scaling
            target_width (int, None): The width in pixels to which to resize the images.
                Use None for the original size, -1 for proportional scaling
            transform (callable, optional): Optional transform to be applied
                on a sample.
        Note:
            The alphabet or vocabulary needs to be generated with an extra tool like
                     generateAlphabet.py egs/scribblelens/yamls/tasman.yaml
        """
        self.root = root

        self.file = zipfile.ZipFile(root)
        root = 'scribblelens.corpus.v1'

        self.target_width = target_width
        self.target_height = target_height
        if transform:
            self.pre_transform = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                construct_from_kwargs(
                    transform, additional_parameters={'scribblelens': True}),
                torchvision.transforms.ToPILImage(),
            ])
        else:
            self.pre_transform = None

        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Grayscale(),
            torchvision.transforms.ToTensor(),
        ])
        self.scribes = []

        self.trainingMode = (split == 'train')

        logging.debug(f"ScribbleLensDataset() constructor for split = {split}")

        # 'vocabulary' Filename from .yaml. alphabet has the vocabulary as a dictionary for CTC output targets.
        self.transcriptMode = transcript_mode
        assert (self.transcriptMode >= 1) and (self.transcriptMode <= 5)

        self.vocabulary = vocabulary
        if self.vocabulary != "" and not os.path.isfile(self.vocabulary):
            print("ERROR: You specified a vocabulary that does not exist: " +
                  str(self.vocabulary))
            sys.exit(4)
        self.alphabet = Alphabet(self.vocabulary)
        self.vocab_size = len(self.alphabet)
        self.must_create_alphabet = (self.vocabulary == '')
        self.nLines = 0

        self.alignmentFile = None  # Optional
        if alignment_root != "":
            self.alignmentFile = zipfile.ZipFile(alignment_root)
            alignment_root = os.path.basename(
                alignment_root)[:-4]  # Remove .zip ext
            self.pathAligner = distsup.aligner.Aligner(
                "none", self.alphabet)  # Path I/O

        assert (colormode in {'bw', 'color', 'gray'})
        assert (split in {'train', 'test', 'supervised', 'unsupervised'})
        assert (slice in {
            'empty', 'custom', 'tasman', 'zeewijck', 'brouwer.chili',
            'craen.de.vos.ijszee', 'van.neck.tweede', 'van.neck.vierde',
            'kieft'
        })
        assert target_height != -1 or target_width != -1

        fnm_pattern = r'(?P<item_dir>(.*?)/'\
                      r'(?P<scribe>[^/]*)/'\
                      r'((?P<page>\d+)(\.(?P<page_side>\d+))?/)?)'\
                      r'line(\.\D+)?(?P<line_index>\d+)\.jpg'
        fnm_matcher = re.compile(fnm_pattern)

        if slice_filename is not None and slice != 'custom':
            logging.error(
                'If you want a custom slice_filename you should use "custom" as slice.'
            )
            sys.exit(1)

        corpora_filenames = []
        if split == 'unsupervised':
            corpora_filenames = [
                os.path.join(root, 'corpora',
                             f'all.nl.{colormode}.lines.{split}.dat')
            ]
        elif split == 'supervised':
            for mode in ['train', 'test']:
                corpora_filenames.append(
                    os.path.join(root, 'corpora',
                                 f'nl.{colormode}.lines.{mode}.dat'))
        elif slice in {
                'tasman', 'kieft', 'zeewijck', 'brouwer.chili',
                'craen.de.vos.ijszee', 'van.neck.tweede', 'van.neck.vierde'
        }:
            corpora_filenames.append(
                os.path.join(root, 'corpora',
                             f'nl.{colormode}.lines.{slice}.{split}.dat'))
        elif slice != 'custom':
            corpora_filenames = [
                os.path.join(root, 'corpora',
                             f'nl.{colormode}.lines.{split}.dat')
            ]

        custom_corpora_filenames = []
        if slice == 'custom':
            custom_corpora_filenames.append(slice_filename)

        initialAlphabetSize = len(self.alphabet)

        self.data = []
        for corpora_filename in corpora_filenames:
            szBefore = len(self.alphabet)
            with self.file.open(corpora_filename) as f:
                self._read_corpora_filename(f, fnm_matcher, root,
                                            initialAlphabetSize, split)
                print("ScribbleLensDataset()  datafile: " +
                      str(corpora_filename) +
                      " and alphabet sizes before and after reading are " +
                      str(szBefore) + " and " + str(len(self.alphabet)))

        # FIXME: Ugly solution to using custom slices from external file
        for corpora_filename in custom_corpora_filenames:
            with open(corpora_filename, 'rb') as f:
                szBefore = len(self.alphabet)
                self._read_corpora_filename(f, fnm_matcher, "",
                                            initialAlphabetSize, split)
                print("ScribbleLensDataset() custom datafile: " +
                      str(corpora_filename) +
                      " and alphabet sizes before and after reading are " +
                      str(szBefore) + " and " + str(len(self.alphabet)))

        if self.vocabulary and self.must_create_alphabet:
            logging.warning(
                f'Vocabulary {self.vocabulary} not found. Serializing the newly generated alphabet...'
            )
            egs.scribblelens.utils.writeDictionary(self.vocabulary)

        self.data_frame = pd.DataFrame(self.data)
        self.data_frame['scribe_id'] = np.nan
        self.data_frame['scribe'] = np.nan

        # Obtain the scribe ID
        scribe_pats = pd.read_csv(io.StringIO(SCRIBE_RULES),
                                  delimiter=None,
                                  encoding='utf8',
                                  sep=r'\s+',
                                  comment='#')
        for index, row in scribe_pats.iterrows():
            selection = self.data_frame['img_filename'].str.match(
                row['directory'])
            self.data_frame.loc[selection, 'scribe_id'] = row['ID']
            self.data_frame.loc[selection, 'scribe'] = row['writer-name']

        # Special treatment of Roggeveen
        roggeveen_mapping = {}
        for l in SCRIBE_RULES_ROGGEVEEN.split('\n')[1:]:
            if not l.strip():
                continue

            directory, scribe_id = l.split()
            roggeveen_mapping.setdefault(scribe_id, []).append(directory)

        for k, v in roggeveen_mapping.items():
            roggeveen_selection = self.data_frame['item_dir'].str\
                .startswith('scribblelens.corpus.v1/nl/unsupervised/roggeveen/')
            selection = self.data_frame['item_dir'].apply(
                lambda x: x.rsplit('/')[-2]).isin(v)

            self.data_frame.loc[roggeveen_selection & selection,
                                'scribe_id'] = int(k)
            self.data_frame.loc[roggeveen_selection & selection,
                                'scribe'] = f'roggeveen.{k}'

        self.data_frame['scribe_id'] = self.data_frame['scribe_id'].astype(int)
        """
        df2 = self.data_frame.copy()
        df2['scribe_dir'] = df2['item_dir'].apply(lambda x: x.rsplit('/', 1)[-2])
        df2[['scribe', 'scribe_id', 'item_dir']].drop_duplicates().sort_values(['scribe_id', 'item_dir'])\
            .to_csv('scribblelens.scribes', index=False, sep='\t')
        """

        self.metadata = {
            'alignment': {
                'type': 'categorical',
                'num_categories': len(self.alphabet)
            },
            'text': {
                'type': 'categorical',
                'num_categories': len(self.alphabet)
            },
        }

        self.file.close()
        self.file = None
        if self.alignmentFile is not None:
            self.alignmentFile.close()
            self.alignmentFile = None
Ejemplo n.º 28
0
    def __init__(
        self,
        modalities_opts,
        transform=None,
        text_file=None,
        vocabulary_file=None,
        utt2spk_file=None,
        ali_file=None,
        split_by_space=False,
        cmvn_normalize_var=False,
    ):
        self.uttids = {}
        self.features = {}
        self.feature_delta_dim = {}
        self.cmvn = {}
        self.utt_centered = {}

        assert "features" in modalities_opts, f"Currently modalities_opts requires at least 'features' modality."

        for modality, modality_opts in modalities_opts.items():
            feature_file = modality_opts.pop('file')
            feature_delta_dim = modality_opts.pop('feature_delta_dim')
            cmvn_file = modality_opts.pop('cmvn_file')
            utt_centered = modality_opts.pop('utt_centered')

            self.uttids[modality], self.features[modality] = zip(*self._read_scp_file(feature_file))
            feature_dir = os.path.dirname(feature_file)
            self.features[modality] = OrderedDict(
                [
                    (uttid, self._to_absolute_path(feature_dir, f))
                    for uttid, f in self._read_scp_file(feature_file)
                ]
            )
            self.uttids[modality] = list(self.features[modality].keys())

            self.feature_delta_dim[modality] = feature_delta_dim
            self.utt_centered[modality] = utt_centered

            if cmvn_file:
                self.cmvn[modality] = self._load_cmvn(cmvn_file)
            else:
                self.cmvn[modality] = None

            assert not modality_opts, f"modality_opts keys {modality_opts.keys()} not used."

        self._keep_common_uttids()

        self.split_by_space = split_by_space

        self.utt2spk = None
        if utt2spk_file is not None:
            self.utt2spk = dict(self._read_scp_file(utt2spk_file))
            self._restrict_uttids(self.utt2spk, f"utt2spk ({utt2spk_file})")
            self.speakers = list(set([s for u, s in self.utt2spk.items()]))
            self.speakers_to_idx = {s: i for i, s in enumerate(self.speakers)}
            print(len(self.speakers), "speakers found")

        self.cmvn_normalize_var = cmvn_normalize_var

        if text_file:
            self.text = dict(self._read_scp_file(text_file))
            self._restrict_uttids(self.text, f"text ({text_file})")
            assert vocabulary_file is not None
            self.alphabet = self._read_vocabulary_file(vocabulary_file)
            self.text_int = self._tokenize_text(self.text, self.alphabet)
        else:
            self.text = None

        if ali_file:
            self.align = self._read_align_file(ali_file)
            self._restrict_uttids(self.align, f"align ({ali_file})")
        else:
            self.align = None
        super(SpeechDataset, self).__init__()

        self.metadata = {
            "alignment": {"type": "categorical", "num_categories": len(self.alphabet)},
            "targets": {"type": "categorical", "num_categories": len(self.alphabet)},
        }

        if transform:
            self.transform = utils.construct_from_kwargs(transform)
        else:
            self.transform = None
Ejemplo n.º 29
0
 def __init__(self, wrapped_class_name, **kwargs):
     super(RightToLeftReconstructor, self).__init__()
     self.reconstructor = utils.construct_from_kwargs(
         dict(class_name=wrapped_class_name), additional_parameters=kwargs)
Ejemplo n.º 30
0
    def __init__(
        self,
        dataset,
        chunk_len,
        order=5,
        lm_probs=None,
        num_ngrams=10000,
        feature_field="image",
    ):
        self.dataset = utils.construct_from_kwargs(dataset)
        self.chunk_len = chunk_len
        self.order = order
        self.feature_field = feature_field
        self.alphabet = self.dataset.alphabet

        # For distribution matching, we need to know the list of output
        # symbols, without the blank.
        # We create a mapping between outputs and alphabet items
        self.alphabet_to_output = {}
        self.outputs = []
        for char, idx in self.dataset.alphabet.chars.items():
            if (
                char not in self.dataset.alphabet.blank
                and char != self.dataset.alphabet.blank
            ):
                self.outputs.append(idx)
                self.alphabet_to_output[idx] = len(self.alphabet_to_output)

        self.output_frequencies = torch.zeros(len(self.outputs))

        #  if lm_probs is given, read it from there
        if lm_probs is not None:
            logging.info(f"Reading LM n-gram probabilities from {lm_probs}")

            self.lm_probs = []
            self.lm_probs_ngrams = []
            for l in open(lm_probs):
                ss = l.split()
                assert self.order == len(ss) - 1
                lm_prob = float(ss[0])
                self.lm_probs.append(lm_prob)
                self.lm_probs_ngrams.append(
                    [self.alphabet_to_output[self.alphabet.ch2idx(p)] for p in ss[1:]]
                )

                for p in ss[1:]:
                    self.output_frequencies[
                        self.alphabet_to_output[self.alphabet.ch2idx(p)]
                    ] += 1
            self.output_frequencies /= sum(self.output_frequencies)
            self.lm_probs = torch.tensor(np.array(self.lm_probs)).float()
            self.lm_probs_ngrams = torch.tensor(np.array(self.lm_probs_ngrams))

        else:
            self.lm_probs = None
            ngram_counter = Counter()

        self.segments = []
        self.frame_pairs = []
        logging.info("Iterating over the dataset to count and index segments")
        for i, item in enumerate(self.dataset):
            ali_rle = item["alignment_rle"]
            # we take overlapping segments
            for j in range(len(ali_rle) - order + 1):
                self.segments.append(
                    SegmentInfo(
                        i,
                        torch.tensor([int(a[0]) for a in ali_rle[j : j + order]]),
                        torch.tensor([int(a[1]) for a in ali_rle[j : j + order]]),
                    )
                )

            for j in range(len(ali_rle)):
                for k in range(
                    max(self.chunk_len // 2, int(ali_rle[j, 0])),
                    min(
                        len(item[self.feature_field]) - self.chunk_len // 2 - 1,
                        int(ali_rle[j, 1]),
                    ),
                ):
                    self.frame_pairs.append((i, k))

            # If LM probs are not given from outside, we use the 'text' field of the underlying dataset
            # and count the n-grams
            # FIXME: use a trie, if the amount of n-grams gets large
            if self.lm_probs is None:
                ngrams = zip(
                    *[
                        [self.alphabet_to_output[idx.item()] for idx in item["text"]][
                            i:
                        ]
                        for i in range(self.order)
                    ]
                )
                ngram_counter.update(ngrams)

                for output in [
                    self.alphabet_to_output[idx.item()] for idx in item["text"]
                ]:
                    self.output_frequencies[output] += 1

        if lm_probs is None:
            logging.info("Computing LM n-gram probabilities from reference text")
            self.lm_probs = []
            self.lm_probs_ngrams = []
            num_total_ngrams = sum(ngram_counter.values())
            for ngram, count in ngram_counter.most_common(num_ngrams):
                self.lm_probs.append(count / num_total_ngrams)
                self.lm_probs_ngrams.append(ngram)

            self.lm_probs = torch.tensor(np.array(self.lm_probs)).float()
            self.lm_probs_ngrams = torch.tensor(np.array(self.lm_probs_ngrams))

            self.output_frequencies /= sum(self.output_frequencies)

        if Globals.cuda:
            self.lm_probs = self.lm_probs.to("cuda")
            self.lm_probs_ngrams = self.lm_probs_ngrams.to("cuda")

        self.stddev = np.std(
            np.array(
                [
                    segment.ends[i] - segment.starts[i] + 1
                    for segment in self.segments
                    for i in range(len(segment.starts))
                ]
            )
        )