Args: l: Local feature map. m: Multiple globals feature map. measure: Type of f-divergence. For use with mode `fd` mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`. Returns: torch.Tensor: Loss. ''' l_enc = self.nets.local_net(l) m_enc = self.nets.multi_net(m) N, local_units, dim_x, dim_y = l_enc.size() l_enc = l_enc.view(N, local_units, -1) m_enc = m_enc.view(N, local_units, -1) if mode == 'fd': loss = multi_fenchel_dual_loss(l_enc, m_enc, measure=measure) elif mode == 'nce': loss = multi_nce_loss(l_enc, m_enc) elif mode == 'dv': loss = multi_donsker_varadhan_loss(l_enc, m_enc) else: raise NotImplementedError(mode) return loss if __name__ == '__main__': run(DIM())
classifier = self.nets.classifier outputs = classifier(inputs) predicted = torch.max(F.log_softmax(outputs, dim=1).data, 1)[1] loss = criterion(outputs, targets) correct = 100. * predicted.eq( targets.data).cpu().sum() / targets.size(0) self.losses.classifier = loss self.results.accuracy = correct def predict(self, inputs): classifier = self.nets.classifier outputs = classifier(inputs) predicted = torch.max(F.log_softmax(outputs, dim=1).data, 1)[1] return predicted def visualize(self, images, inputs, targets): predicted = self.predict(inputs) self.add_image(images.data, labels=(targets.data, predicted.data), name='gt_pred') if __name__ == '__main__': classifier = MyClassifier() run(model=classifier)
''' self.nets.encoder = encoder X = self.inputs('data.images') self.task_idx = task_idx out = self.nets.encoder(X, return_all_activations=True)[self.task_idx] config = decoder_configs.get(config) config.update(**args) super().build(out.size()[1:], args=config) def routine(self, outs=None): X = self.inputs('data.images') if outs is None: outs = self.nets.encoder(X, return_all_activations=True) out = outs[self.task_idx] super().routine(X, out.detach()) def visualize(self, inputs): out = self.nets.encoder(inputs, return_all_activations=True)[self.task_idx] super().visualize(out) if __name__ == '__main__': run(MSSSIMEval())
infos = [] for k in mode_dict.keys(): mode = mode_dict[k] info = mode.__doc__.split('\n', 1)[0] # Keep only first line of doctstring. infos.append('{}: {}'.format(k, info)) infos = '\n\t'.join(infos) models = [] parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter) parser.add_argument('models', nargs='+', choices=names, help='Models used in Deep InfoMax. Choices are: \n\t{}'.format(infos)) i = 1 while True: arg = sys.argv[i] if arg in ('--help', '-h') and i == 1: i += 1 break if arg.startswith('-'): break # argument have begun i += 1 args = parser.parse_args(sys.argv[1:i]) models = args.models models = list(set(models)) models = dict((k, mode_dict[k]) for k in models) sys.argv = [sys.argv[0]] + sys.argv[i:] controller = Controller(inputs=dict(inputs='data.images'), **models) run(controller)
inputs: torch.Tensor ae_criterion: function Returns: None """ encoded = self.nets.ae.encoder(inputs) outputs = self.nets.ae.decoder(encoded) r_loss = ae_criterion( outputs, inputs, size_average=False) / inputs.size(0) self.losses.ae = r_loss def visualize(self, inputs): """ Adding generated images and base images to visualization. Args: inputs: torch.Tensor Returns: None """ encoded = self.nets.ae.encoder(inputs) outputs = self.nets.ae.decoder(encoded) self.add_image(outputs, name='reconstruction') self.add_image(inputs, name='ground truth') if __name__ == '__main__': autoencoder = AE() run(model=autoencoder)
inputs: torch.Tensor ae_criterion: function Returns: None """ encoded = self.nets.ae.encoder(inputs) outputs = self.nets.ae.decoder(encoded) r_loss = ae_criterion(outputs, inputs, size_average=False) / inputs.size(0) self.losses.ae = r_loss def visualize(self, inputs): """ Adding generated images and base images to visualization. Args: inputs: torch.Tensor Returns: None """ encoded = self.nets.ae.encoder(inputs) outputs = self.nets.ae.decoder(encoded) self.add_image(outputs, name='reconstruction') self.add_image(inputs, name='ground truth') if __name__ == '__main__': autoencoder = AE() run(model=autoencoder)
for i, idx in enumerate(self.classification_idx): name = self.classifier_names[i] contract = dict(nets=dict(classifier='classifier_{}'.format(name))) setattr(self, 'classifier_{}'.format(name), Classifier(**contract)) classifier = getattr(self, 'classifier_{}'.format(name)) input_shape = outs[idx].size()[1:] kwargs = classifier_kwargs[i] classifier.build(input_shape, n_labels, **kwargs) def routine(self, inputs, targets): '''Classification routine. ''' for i, idx in enumerate(self.classification_idx): name = self.classifier_names[i] classifier = getattr(self, 'classifier_{}'.format(name)) output = self.nets.encoder(inputs)[idx] classifier.routine(output.detach(), targets) def visualize(self, inputs, targets): '''Visualization. ''' for i, idx in enumerate(self.classification_idx): name = self.classifier_names[i] classifier = getattr(self, 'classifier_{}'.format(name)) output = self.nets.encoder(inputs)[idx] classifier.visualize(inputs, output, targets) if __name__ == '__main__': run(ClassificationEvaluator())
def routine(self, outs=None, measure='KL', nonlinearity=''): ''' Args: measure: Type of measure to use for NDM. nonlinearity: Nonlinearity to use on output of encoder. Returns: ''' if outs is None: inputs = self.inputs('data.images') outs = self.nets.encoder(inputs, return_all_activations=True) out = outs[self.task_idx] if nonlinearity != '': out = getattr(nn, nonlinearity)()(out) super().routine(out.detach(), measure=measure) def visualize(self, inputs, nonlinearity=None, measure=None): out = self.nets.encoder(inputs, return_all_activations=True)[self.task_idx] if nonlinearity != '': out = getattr(nn, nonlinearity)()(out) super().visualize(out, measure=measure) if __name__ == '__main__': run(NDMEval())
import mlflow import logging from cortex.main import run from src.models.fix_match.controller import FixMatchController from src.data.dataset_plugins import SSLDatasetPlugin from src import MLFLOW_SSL_URI logger = logging.getLogger('ssl_evaluation') if __name__ == '__main__': # if exp.ARGS mlflow.set_tracking_uri(MLFLOW_SSL_URI) controller = FixMatchController() run(model=controller)