Beispiel #1
0
def recovery_likelihood_training(energy,
                                 data,
                                 optimizer_kwargs=None,
                                 level_weight=1.0,
                                 loss_scale=1.0,
                                 level_distribution=NormalNoise(
                                     linear_noise(1e-3, 10.0)),
                                 noise_distribution=NormalNoise(1e-3),
                                 ema_weight=0.999,
                                 **kwargs):
    opt = filter_kwargs(kwargs, ctx=base_sm_training)
    ctx = base_sm_training(energy, data, **opt.ctx)

    ctx.add(recovery_likelihood_step=UpdateStep(
        partial(
            recovery_likelihood_step,
            ctx.energy,
            ctx.data,
            level_weight=level_weight,
            loss_scale=loss_scale,
            level_distribution=level_distribution,
            noise_distribution=noise_distribution,
        ),
        Update(
            [ctx.energy], optimizer=ctx.optimizer, **(optimizer_kwargs or {})),
        ctx=ctx))
    ctx.add(ema_step=partial(_ema_step, ema_weight=ema_weight, ctx=ctx))

    return ctx
Beispiel #2
0
def conditional_mle_training(model,
                             data,
                             valid_data=None,
                             optimizer=torch.optim.Adam,
                             optimizer_kwargs=None,
                             eval_no_grad=True,
                             **kwargs):
    opt = filter_kwargs(kwargs, ctx=TrainingContext)
    ctx = TrainingContext(**opt.ctx)
    ctx.optimizer = optimizer

    # networks to device
    ctx.register(data=to_device(data, ctx.device),
                 model=to_device(model, ctx.device))

    ctx.add(train_step=UpdateStep(
        partial(maximum_likelihood_step, ctx.model, ctx.data),
        Update(
            [ctx.model], optimizer=ctx.optimizer, **(optimizer_kwargs or {})),
        ctx=ctx))

    if valid_data is not None:
        ctx.register(valid_data=to_device(valid_data, ctx.device))
        ctx.add(valid_step=EvalStep(partial(maximum_likelihood_step, ctx.model,
                                            ctx.valid_data),
                                    modules=[ctx.model],
                                    no_grad=eval_no_grad,
                                    ctx=ctx),
                every=ctx.report_interval)
    return ctx
Beispiel #3
0
def supervised_training(net,
                        data,
                        valid_data=None,
                        losses=None,
                        optimizer=torch.optim.Adam,
                        optimizer_kwargs=None,
                        eval_no_grad=True,
                        **kwargs):
    opt = filter_kwargs(kwargs, ctx=TrainingContext)
    ctx = TrainingContext(**opt.ctx)
    ctx.optimizer = optimizer
    ctx.losses = losses

    # networks to device
    ctx.register(data=to_device(data, ctx.device),
                 net=to_device(net, ctx.device))

    ctx.add(train_step=UpdateStep(
        partial(supervised_step, ctx.net, ctx.data, losses=ctx.losses),
        Update([ctx.net], optimizer=ctx.optimizer, **(optimizer_kwargs or {})),
        ctx=ctx))

    if valid_data is not None:
        ctx.register(valid_data=to_device(valid_data, ctx.device))
        ctx.add(valid_step=EvalStep(partial(supervised_step,
                                            ctx.net,
                                            ctx.valid_data,
                                            losses=ctx.losses),
                                    modules=[ctx.net],
                                    no_grad=eval_no_grad,
                                    ctx=ctx),
                every=ctx.report_interval)
    return ctx
Beispiel #4
0
def base_sm_training(energy, data, optimizer=torch.optim.Adam, **kwargs):
    opt = filter_kwargs(kwargs, ctx=TrainingContext)
    ctx = TrainingContext(**opt.ctx)
    ctx.optimizer = optimizer

    # networks to device
    energy_target = deepcopy(energy)
    ctx.register(data=to_device(data, ctx.device),
                 energy=to_device(energy, ctx.device),
                 energy_target=to_device(energy_target, ctx.device))

    return ctx
Beispiel #5
0
def density_ratio_training(energy,
                           base,
                           data,
                           optimizer_kwargs=None,
                           density_ratio_step=density_ratio_step,
                           **kwargs):
    opt = filter_kwargs(kwargs, ctx=base_dre_training)
    ctx = base_dre_training(energy, base, data, **opt.ctx)

    ctx.add(dre_step=UpdateStep(
        partial(density_ratio_step, ctx.energy, ctx.base, ctx.data),
        Update(
            [ctx.energy], optimizer=ctx.optimizer, **(optimizer_kwargs or {})),
        ctx=ctx))

    return ctx
Beispiel #6
0
def telescoping_density_ratio_training(energy,
                                       base,
                                       data,
                                       mixing=None,
                                       optimizer_kwargs=None,
                                       telescoping_step=tdre_step,
                                       verbose=True,
                                       **kwargs):
    opt = filter_kwargs(kwargs, ctx=base_dre_training)
    ctx = base_dre_training(energy, base, data, **opt.ctx)

    ctx.add(tdre_step=UpdateStep(partial(telescoping_step,
                                         ctx.energy,
                                         ctx.base,
                                         ctx.data,
                                         mixing=mixing,
                                         verbose=verbose),
                                 Update([ctx.energy],
                                        optimizer=ctx.optimizer,
                                        **(optimizer_kwargs or {})),
                                 ctx=ctx))

    return ctx
Beispiel #7
0
def base_dre_training(energy,
                      base,
                      data,
                      train_base=True,
                      base_step=maximum_likelihood_step,
                      optimizer=torch.optim.Adam,
                      base_optimizer_kwargs=None,
                      **kwargs):
    opt = filter_kwargs(kwargs, ctx=TrainingContext)
    ctx = TrainingContext(**opt.ctx)
    ctx.optimizer = optimizer

    # networks to device
    ctx.register(data=to_device(data, ctx.device),
                 base=to_device(base, ctx.device),
                 energy=to_device(energy, ctx.device))

    if train_base:
        ctx.add(base_step=UpdateStep(partial(base_step, ctx.base, ctx.data),
                                     Update([ctx.base],
                                            optimizer=ctx.optimizer,
                                            **(base_optimizer_kwargs or {})),
                                     ctx=ctx))
    return ctx