Ejemplo n.º 1
0
        Args:
            l: Local feature map.
            m: Multiple globals feature map.
            measure: Type of f-divergence. For use with mode `fd`
            mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`.

        Returns:
            torch.Tensor: Loss.

        '''
        l_enc = self.nets.local_net(l)
        m_enc = self.nets.multi_net(m)
        N, local_units, dim_x, dim_y = l_enc.size()
        l_enc = l_enc.view(N, local_units, -1)
        m_enc = m_enc.view(N, local_units, -1)

        if mode == 'fd':
            loss = multi_fenchel_dual_loss(l_enc, m_enc, measure=measure)
        elif mode == 'nce':
            loss = multi_nce_loss(l_enc, m_enc)
        elif mode == 'dv':
            loss = multi_donsker_varadhan_loss(l_enc, m_enc)
        else:
            raise NotImplementedError(mode)

        return loss


if __name__ == '__main__':
    run(DIM())
Ejemplo n.º 2
0
        classifier = self.nets.classifier

        outputs = classifier(inputs)
        predicted = torch.max(F.log_softmax(outputs, dim=1).data, 1)[1]

        loss = criterion(outputs, targets)
        correct = 100. * predicted.eq(
            targets.data).cpu().sum() / targets.size(0)

        self.losses.classifier = loss
        self.results.accuracy = correct

    def predict(self, inputs):
        classifier = self.nets.classifier

        outputs = classifier(inputs)
        predicted = torch.max(F.log_softmax(outputs, dim=1).data, 1)[1]

        return predicted

    def visualize(self, images, inputs, targets):
        predicted = self.predict(inputs)
        self.add_image(images.data, labels=(targets.data, predicted.data),
                       name='gt_pred')


if __name__ == '__main__':
    classifier = MyClassifier()
    run(model=classifier)
Ejemplo n.º 3
0
        classifier = self.nets.classifier

        outputs = classifier(inputs)
        predicted = torch.max(F.log_softmax(outputs, dim=1).data, 1)[1]

        loss = criterion(outputs, targets)
        correct = 100. * predicted.eq(
            targets.data).cpu().sum() / targets.size(0)

        self.losses.classifier = loss
        self.results.accuracy = correct

    def predict(self, inputs):
        classifier = self.nets.classifier

        outputs = classifier(inputs)
        predicted = torch.max(F.log_softmax(outputs, dim=1).data, 1)[1]

        return predicted

    def visualize(self, images, inputs, targets):
        predicted = self.predict(inputs)
        self.add_image(images.data,
                       labels=(targets.data, predicted.data),
                       name='gt_pred')


if __name__ == '__main__':
    classifier = MyClassifier()
    run(model=classifier)
Ejemplo n.º 4
0
        '''

        self.nets.encoder = encoder

        X = self.inputs('data.images')
        self.task_idx = task_idx
        out = self.nets.encoder(X, return_all_activations=True)[self.task_idx]

        config = decoder_configs.get(config)
        config.update(**args)

        super().build(out.size()[1:], args=config)

    def routine(self, outs=None):
        X = self.inputs('data.images')
        if outs is None:
            outs = self.nets.encoder(X, return_all_activations=True)

        out = outs[self.task_idx]
        super().routine(X, out.detach())

    def visualize(self, inputs):
        out = self.nets.encoder(inputs,
                                return_all_activations=True)[self.task_idx]
        super().visualize(out)


if __name__ == '__main__':
    run(MSSSIMEval())
Ejemplo n.º 5
0
    infos = []
    for k in mode_dict.keys():
        mode = mode_dict[k]
        info = mode.__doc__.split('\n', 1)[0]  # Keep only first line of doctstring.
        infos.append('{}: {}'.format(k, info))
    infos = '\n\t'.join(infos)

    models = []
    parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter)
    parser.add_argument('models', nargs='+', choices=names,
                        help='Models used in Deep InfoMax. Choices are: \n\t{}'.format(infos))
    i = 1
    while True:
        arg = sys.argv[i]
        if arg in ('--help', '-h') and i == 1:
            i += 1
            break

        if arg.startswith('-'):
            break  # argument have begun
        i += 1
    args = parser.parse_args(sys.argv[1:i])
    models = args.models
    models = list(set(models))
    models = dict((k, mode_dict[k]) for k in models)
    sys.argv = [sys.argv[0]] + sys.argv[i:]
    controller = Controller(inputs=dict(inputs='data.images'), **models)

    run(controller)
Ejemplo n.º 6
0
            inputs: torch.Tensor
            ae_criterion: function

        Returns: None

        """
        encoded = self.nets.ae.encoder(inputs)
        outputs = self.nets.ae.decoder(encoded)
        r_loss = ae_criterion(
            outputs, inputs, size_average=False) / inputs.size(0)
        self.losses.ae = r_loss

    def visualize(self, inputs):
        """
        Adding generated images and base images to
        visualization.
        Args:
            inputs: torch.Tensor
        Returns: None

        """
        encoded = self.nets.ae.encoder(inputs)
        outputs = self.nets.ae.decoder(encoded)
        self.add_image(outputs, name='reconstruction')
        self.add_image(inputs, name='ground truth')


if __name__ == '__main__':
    autoencoder = AE()
    run(model=autoencoder)
Ejemplo n.º 7
0
            inputs: torch.Tensor
            ae_criterion: function

        Returns: None

        """
        encoded = self.nets.ae.encoder(inputs)
        outputs = self.nets.ae.decoder(encoded)
        r_loss = ae_criterion(outputs, inputs,
                              size_average=False) / inputs.size(0)
        self.losses.ae = r_loss

    def visualize(self, inputs):
        """
        Adding generated images and base images to
        visualization.
        Args:
            inputs: torch.Tensor
        Returns: None

        """
        encoded = self.nets.ae.encoder(inputs)
        outputs = self.nets.ae.decoder(encoded)
        self.add_image(outputs, name='reconstruction')
        self.add_image(inputs, name='ground truth')


if __name__ == '__main__':
    autoencoder = AE()
    run(model=autoencoder)
Ejemplo n.º 8
0
        for i, idx in enumerate(self.classification_idx):
            name = self.classifier_names[i]
            contract = dict(nets=dict(classifier='classifier_{}'.format(name)))
            setattr(self, 'classifier_{}'.format(name), Classifier(**contract))
            classifier = getattr(self, 'classifier_{}'.format(name))
            input_shape = outs[idx].size()[1:]
            kwargs = classifier_kwargs[i]
            classifier.build(input_shape, n_labels, **kwargs)

    def routine(self, inputs, targets):
        '''Classification routine.
        '''
        for i, idx in enumerate(self.classification_idx):
            name = self.classifier_names[i]
            classifier = getattr(self, 'classifier_{}'.format(name))
            output = self.nets.encoder(inputs)[idx]
            classifier.routine(output.detach(), targets)

    def visualize(self, inputs, targets):
        '''Visualization.
        '''
        for i, idx in enumerate(self.classification_idx):
            name = self.classifier_names[i]
            classifier = getattr(self, 'classifier_{}'.format(name))
            output = self.nets.encoder(inputs)[idx]
            classifier.visualize(inputs, output, targets)


if __name__ == '__main__':
    run(ClassificationEvaluator())
Ejemplo n.º 9
0
    def routine(self, outs=None, measure='KL', nonlinearity=''):
        '''

        Args:
            measure: Type of measure to use for NDM.
            nonlinearity: Nonlinearity to use on output of encoder.

        Returns:

        '''
        if outs is None:
            inputs = self.inputs('data.images')
            outs = self.nets.encoder(inputs, return_all_activations=True)

        out = outs[self.task_idx]
        if nonlinearity != '':
            out = getattr(nn, nonlinearity)()(out)
        super().routine(out.detach(), measure=measure)

    def visualize(self, inputs, nonlinearity=None, measure=None):
        out = self.nets.encoder(inputs,
                                return_all_activations=True)[self.task_idx]
        if nonlinearity != '':
            out = getattr(nn, nonlinearity)()(out)
        super().visualize(out, measure=measure)


if __name__ == '__main__':
    run(NDMEval())
Ejemplo n.º 10
0
import mlflow
import logging
from cortex.main import run

from src.models.fix_match.controller import FixMatchController
from src.data.dataset_plugins import SSLDatasetPlugin
from src import MLFLOW_SSL_URI

logger = logging.getLogger('ssl_evaluation')

if __name__ == '__main__':

    # if exp.ARGS
    mlflow.set_tracking_uri(MLFLOW_SSL_URI)
    controller = FixMatchController()

    run(model=controller)