def recovery_likelihood_training(energy, data, optimizer_kwargs=None, level_weight=1.0, loss_scale=1.0, level_distribution=NormalNoise( linear_noise(1e-3, 10.0)), noise_distribution=NormalNoise(1e-3), ema_weight=0.999, **kwargs): opt = filter_kwargs(kwargs, ctx=base_sm_training) ctx = base_sm_training(energy, data, **opt.ctx) ctx.add(recovery_likelihood_step=UpdateStep( partial( recovery_likelihood_step, ctx.energy, ctx.data, level_weight=level_weight, loss_scale=loss_scale, level_distribution=level_distribution, noise_distribution=noise_distribution, ), Update( [ctx.energy], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), ctx=ctx)) ctx.add(ema_step=partial(_ema_step, ema_weight=ema_weight, ctx=ctx)) return ctx
def conditional_mle_training(model, data, valid_data=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, eval_no_grad=True, **kwargs): opt = filter_kwargs(kwargs, ctx=TrainingContext) ctx = TrainingContext(**opt.ctx) ctx.optimizer = optimizer # networks to device ctx.register(data=to_device(data, ctx.device), model=to_device(model, ctx.device)) ctx.add(train_step=UpdateStep( partial(maximum_likelihood_step, ctx.model, ctx.data), Update( [ctx.model], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), ctx=ctx)) if valid_data is not None: ctx.register(valid_data=to_device(valid_data, ctx.device)) ctx.add(valid_step=EvalStep(partial(maximum_likelihood_step, ctx.model, ctx.valid_data), modules=[ctx.model], no_grad=eval_no_grad, ctx=ctx), every=ctx.report_interval) return ctx
def supervised_training(net, data, valid_data=None, losses=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, eval_no_grad=True, **kwargs): opt = filter_kwargs(kwargs, ctx=TrainingContext) ctx = TrainingContext(**opt.ctx) ctx.optimizer = optimizer ctx.losses = losses # networks to device ctx.register(data=to_device(data, ctx.device), net=to_device(net, ctx.device)) ctx.add(train_step=UpdateStep( partial(supervised_step, ctx.net, ctx.data, losses=ctx.losses), Update([ctx.net], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), ctx=ctx)) if valid_data is not None: ctx.register(valid_data=to_device(valid_data, ctx.device)) ctx.add(valid_step=EvalStep(partial(supervised_step, ctx.net, ctx.valid_data, losses=ctx.losses), modules=[ctx.net], no_grad=eval_no_grad, ctx=ctx), every=ctx.report_interval) return ctx
def base_sm_training(energy, data, optimizer=torch.optim.Adam, **kwargs): opt = filter_kwargs(kwargs, ctx=TrainingContext) ctx = TrainingContext(**opt.ctx) ctx.optimizer = optimizer # networks to device energy_target = deepcopy(energy) ctx.register(data=to_device(data, ctx.device), energy=to_device(energy, ctx.device), energy_target=to_device(energy_target, ctx.device)) return ctx
def density_ratio_training(energy, base, data, optimizer_kwargs=None, density_ratio_step=density_ratio_step, **kwargs): opt = filter_kwargs(kwargs, ctx=base_dre_training) ctx = base_dre_training(energy, base, data, **opt.ctx) ctx.add(dre_step=UpdateStep( partial(density_ratio_step, ctx.energy, ctx.base, ctx.data), Update( [ctx.energy], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), ctx=ctx)) return ctx
def telescoping_density_ratio_training(energy, base, data, mixing=None, optimizer_kwargs=None, telescoping_step=tdre_step, verbose=True, **kwargs): opt = filter_kwargs(kwargs, ctx=base_dre_training) ctx = base_dre_training(energy, base, data, **opt.ctx) ctx.add(tdre_step=UpdateStep(partial(telescoping_step, ctx.energy, ctx.base, ctx.data, mixing=mixing, verbose=verbose), Update([ctx.energy], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), ctx=ctx)) return ctx
def base_dre_training(energy, base, data, train_base=True, base_step=maximum_likelihood_step, optimizer=torch.optim.Adam, base_optimizer_kwargs=None, **kwargs): opt = filter_kwargs(kwargs, ctx=TrainingContext) ctx = TrainingContext(**opt.ctx) ctx.optimizer = optimizer # networks to device ctx.register(data=to_device(data, ctx.device), base=to_device(base, ctx.device), energy=to_device(energy, ctx.device)) if train_base: ctx.add(base_step=UpdateStep(partial(base_step, ctx.base, ctx.data), Update([ctx.base], optimizer=ctx.optimizer, **(base_optimizer_kwargs or {})), ctx=ctx)) return ctx