示例#1
0
def get_metric_fn(conf,
                  metric_name,
                  cuda,
                  mode,
                  pred_key='pred',
                  target_key='target'):
    assert mode in ('train', 'test')
    assert metric_name in _METRICS, 'Unknown metric {}'.format(metric_name)

    metric_info = _METRICS[metric_name]
    metric_type = metric_info[1]

    if isinstance(metric_info[0], str):
        # Generic metric function
        metric_path = metric_info[0]
        if len(metric_info) <= 2:
            metric_fn = _get_generic_metric_fn(metric_path)
        else:
            metric_fn = _get_generic_metric_fn(metric_path, **metric_info[2])
    else:
        # Metric requires specialized handling
        metric_constructor = metric_info[0]
        metric_fn = metric_constructor(conf, metric_name, cuda)

    metric_conf = conf.get_attr('{}_metric'.format(metric_name), default={})
    if 'pred_key' in metric_conf:
        pred_key = metric_conf['pred_key']
    if 'target_key' in metric_conf:
        target_key = metric_conf['target_key']

    if 'transform' in metric_conf:
        transform = metric_conf['transform']
        if transform == 'none':
            transform = None
        else:
            transform = get_output_transform(conf, transform, mode)
    else:
        transform = get_output_transform(conf, conf.application, mode)

    return MetricFunction(metric_fn, metric_type, transform, pred_key,
                          target_key)
示例#2
0
def build_runner(conf, cuda, mode, resume=False):
  gen_model_conf = Configuration.from_dict(conf.generator_model)
  gen_model = construct_model(gen_model_conf, gen_model_conf.name)

  val_metric_transform = get_output_transform(conf, conf.application, 'test')
  val_metric_fns = {name: get_metric_fn(name)
                    for name in conf.get_attr('validation_metrics',
                                              default=[])}
  output_transform = get_output_transform(conf, conf.application, 'output')

  if mode == 'train':
    disc_model_conf = Configuration.from_dict(conf.discriminator_model)
    disc_model = construct_model(disc_model_conf, disc_model_conf.name)

    gen_adv_criteria = {loss_name: get_criterion(conf, loss_name, cuda, 'gen')
                        for loss_name in conf.generator_adversarial_losses}
    gen_criteria = {loss_name: get_criterion(conf, loss_name, cuda)
                    for loss_name in conf.generator_losses}
    disc_adv_criteria = {loss_name: get_criterion(conf, loss_name, cuda,
                                                  'disc')
                         for loss_name in conf.discriminator_losses}

    if cuda != '':
      utils.cudaify([gen_model, disc_model] +
                    list(gen_adv_criteria.values()) +
                    list(gen_criteria.values()) +
                    list(disc_adv_criteria.values()))

    # Important: construct optimizers after moving model to GPU!
    gen_opt_conf = Configuration.from_dict(conf.generator_optimizer)
    gen_optimizer = get_optimizer(gen_opt_conf, gen_opt_conf.name,
                                  gen_model.parameters())
    gen_lr_scheduler = None
    if gen_opt_conf.has_attr('lr_scheduler'):
      gen_lr_scheduler = get_lr_scheduler(gen_opt_conf,
                                          gen_opt_conf.lr_scheduler,
                                          gen_optimizer)

    disc_opt_conf = Configuration.from_dict(conf.discriminator_optimizer)
    disc_optimizer = get_optimizer(disc_opt_conf, disc_opt_conf.name,
                                   disc_model.parameters())
    disc_lr_scheduler = None
    if disc_opt_conf.has_attr('lr_scheduler'):
      disc_lr_scheduler = get_lr_scheduler(disc_opt_conf,
                                           disc_opt_conf.lr_scheduler,
                                           disc_optimizer)

    train_disc_metrics = conf.get_attr('train_discriminator_metrics',
                                       default=[])
    train_disc_metric_fns = {name: get_metric_fn(name)
                             for name in train_disc_metrics}

    train_gen_metric_transform = get_output_transform(conf, conf.application,
                                                      'train')
    train_gen_metrics = conf.get_attr('train_generator_metrics', default=[])
    train_gen_metric_fns = {name: get_metric_fn(name)
                            for name in train_gen_metrics}

    input_method = disc_model_conf.get_attr('input_method',
                                            default=DEFAULT_INPUT_METHOD)

    runner = AdversarialRunner(gen_model, disc_model,
                               gen_optimizer, disc_optimizer,
                               gen_lr_scheduler, disc_lr_scheduler,
                               gen_adv_criteria, gen_criteria,
                               disc_adv_criteria,
                               conf.get_attr('generator_loss_weights', {}),
                               conf.get_attr('discriminator_loss_weights', {}),
                               cuda,
                               train_gen_metric_fns,
                               train_gen_metric_transform,
                               train_disc_metric_fns,
                               val_metric_fns,
                               val_metric_transform,
                               output_transform,
                               input_method)
    if gen_model_conf.has_attr('pretrained_weights') and not resume:
      runner.initialize_pretrained_model(gen_model_conf, runner.gen,
                                         cuda, conf.file)

    if disc_model_conf.has_attr('pretrained_weights') and not resume:
      runner.initialize_pretrained_model(disc_model_conf, runner.disc,
                                         cuda, conf.file)
  else:
    if cuda != '':
      utils.cudaify(gen_model)
    runner = AdversarialRunner(gen_model,
                               cuda=cuda,
                               val_metric_fns=val_metric_fns,
                               val_metric_transform=val_metric_transform,
                               output_transform=output_transform)

  return runner
示例#3
0
def build_runner(conf, cuda, mode='train', resume=False):
    model_conf = Configuration.from_dict(conf.model)

    model = construct_model(model_conf, model_conf.name)

    val_metric_transform = get_output_transform(conf, conf.application, 'test')
    val_metric_fns = {
        name: get_metric_fn(name)
        for name in conf.get_attr('validation_metrics', default=[])
    }
    output_transform = get_output_transform(conf, conf.application, 'output')

    if mode == 'train':
        criteria = {}
        if conf.has_attr('loss_name'):
            criteria[conf.loss_name] = get_criterion(conf, conf.loss_name,
                                                     cuda)
        else:
            for loss_name in conf.losses:
                criteria[loss_name] = get_criterion(conf, loss_name, cuda)

        assert len(
            criteria) > 0, 'Need at least one loss to optimize something!'

        if cuda != '':
            utils.cudaify([model] + list(criteria.values()))

        # Important: construct optimizer after moving model to GPU!
        opt_conf = Configuration.from_dict(conf.optimizer)
        optimizer = get_optimizer(opt_conf, opt_conf.name, model.parameters())

        lr_scheduler = None
        if opt_conf.has_attr('lr_scheduler'):
            lr_scheduler = get_lr_scheduler(opt_conf, opt_conf.lr_scheduler,
                                            optimizer)

        train_metric_transform = get_output_transform(conf, conf.application,
                                                      'train')
        train_metric_fns = {
            name: get_metric_fn(name)
            for name in conf.get_attr('train_metrics', default=[])
        }

        runner = Runner(model, criteria, conf.get_attr('loss_weights', {}),
                        optimizer, lr_scheduler, cuda, train_metric_fns,
                        train_metric_transform, val_metric_fns,
                        val_metric_transform, output_transform)

        if model_conf.has_attr('pretrained_weights') and not resume:
            runner.initialize_pretrained_model(model_conf, runner.model, cuda,
                                               conf.file)
    else:
        if cuda != '':
            utils.cudaify(model)
        runner = Runner(model,
                        cuda=cuda,
                        val_metric_fns=val_metric_fns,
                        val_metric_transform=val_metric_transform,
                        output_transform=output_transform)

    return runner
def build_runner(conf, cuda, mode):
    gen_model_conf = Configuration.from_dict(conf.generator_model, conf)
    gen_model = construct_model(gen_model_conf, gen_model_conf.name, cuda)

    val_metric_fns = {
        name: get_metric_fn(conf, name, cuda, 'test')
        for name in conf.get_attr('validation_metrics', default=[])
    }
    output_transform = get_output_transform(conf, conf.application,
                                            'inference')
    test_input_batch_transform = get_input_batch_transform(
        conf, conf.application, 'test')

    if mode == 'train':
        disc_model_conf = Configuration.from_dict(conf.discriminator_model,
                                                  conf)
        disc_model = construct_model(disc_model_conf, disc_model_conf.name,
                                     cuda)

        gen_adv_criteria = {
            loss_name: get_criterion(conf, loss_name, cuda, loss_type='gen')
            for loss_name in conf.generator_adversarial_losses
        }
        gen_criteria = {
            loss_name: get_criterion(conf, loss_name, cuda)
            for loss_name in conf.generator_losses
        }
        disc_adv_criteria = {
            loss_name: get_criterion(conf, loss_name, cuda, loss_type='disc')
            for loss_name in conf.discriminator_losses
        }

        if cuda != '':
            # Potentially split models over GPUs
            gen_model, disc_model = utils.cudaify([gen_model, disc_model],
                                                  cuda)
            utils.cudaify(
                list(gen_adv_criteria.values()) + list(gen_criteria.values()) +
                list(disc_adv_criteria.values()))

        # Important: construct optimizers after moving model to GPU!
        gen_opt_conf = Configuration.from_dict(conf.generator_optimizer, conf)
        gen_optimizer = get_optimizer(gen_opt_conf, gen_opt_conf.name,
                                      gen_model.parameters())
        gen_lr_scheduler = None
        if gen_opt_conf.has_attr('lr_scheduler'):
            gen_lr_scheduler = get_lr_scheduler(gen_opt_conf,
                                                gen_opt_conf.lr_scheduler,
                                                gen_optimizer)

        disc_opt_conf = Configuration.from_dict(conf.discriminator_optimizer,
                                                conf)
        disc_optimizer = get_optimizer(disc_opt_conf, disc_opt_conf.name,
                                       disc_model.parameters())
        disc_lr_scheduler = None
        if disc_opt_conf.has_attr('lr_scheduler'):
            disc_lr_scheduler = get_lr_scheduler(disc_opt_conf,
                                                 disc_opt_conf.lr_scheduler,
                                                 disc_optimizer)

        train_input_batch_transform = get_input_batch_transform(
            conf, conf.application, 'train')
        train_disc_metrics = conf.get_attr('train_discriminator_metrics',
                                           default=[])
        train_disc_metric_fns = {
            name: get_metric_fn(conf, name, cuda, 'train')
            for name in train_disc_metrics
        }
        val_disc_metric_key = 'validation_discriminator_metrics'
        val_disc_metric_fns = {
            name: get_metric_fn(conf, name, cuda, 'test')
            for name in conf.get_attr(val_disc_metric_key, default=[])
        }

        train_gen_metrics = conf.get_attr('train_generator_metrics',
                                          default=[])
        train_gen_metric_fns = {
            name: get_metric_fn(conf, name, cuda, 'train')
            for name in train_gen_metrics
        }

        disc_input_fn = get_discriminator_input_fn(conf, disc_model_conf)
        val_disc_input_fn = get_discriminator_input_fn(conf,
                                                       disc_model_conf,
                                                       no_pool=True)

        pretr_generator_epochs = conf.get_attr('pretrain_generator_epochs')
        pretr_discriminator_epochs = conf.get_attr(
            'pretrain_discriminator_epochs')

        runner = AdversarialRunner(
            gen_model, disc_model, gen_optimizer, disc_optimizer,
            gen_lr_scheduler, disc_lr_scheduler, gen_adv_criteria,
            gen_criteria, disc_adv_criteria,
            conf.get_attr('generator_loss_weights', {}),
            conf.get_attr('discriminator_loss_weights', {}), cuda,
            train_gen_metric_fns, train_disc_metric_fns, val_metric_fns,
            val_disc_metric_fns, output_transform, train_input_batch_transform,
            test_input_batch_transform,
            gen_opt_conf.get_attr('updates_per_step', 1),
            disc_opt_conf.get_attr('updates_per_step',
                                   1), disc_input_fn, val_disc_input_fn,
            pretr_generator_epochs, pretr_discriminator_epochs)
        if gen_model_conf.has_attr('pretrained_weights'):
            initialize_pretrained_model(gen_model_conf, runner.gen, cuda,
                                        conf.file)

        if disc_model_conf.has_attr('pretrained_weights'):
            initialize_pretrained_model(disc_model_conf, runner.disc, cuda,
                                        conf.file)
    else:
        if cuda != '':
            utils.cudaify(gen_model)
        runner = AdversarialRunner(
            gen_model,
            cuda=cuda,
            val_metric_fns=val_metric_fns,
            output_transform=output_transform,
            test_input_batch_transform=test_input_batch_transform)

    return runner