def valid_callback(args, ctx: TrainingContext = None): ctx.log(images=LogImage(args.sample)) labels = args.prediction.argmax(dim=1) for idx in range(10): positive = args.sample[labels == idx] if positive.size(0) != 0: ctx.log(**{f"classified {idx}": LogImage(positive)})
def valid_callback(args, ctx: TrainingContext = None): ctx.log(images=LogImage(args.condition)) labels = args.distribution.logits.argmax(dim=1) for idx in range(10): positive = args.condition[labels == idx] if positive.size(0) != 0: ctx.log(**{f"classified {idx}": LogImage(positive)})
def base_sm_training(energy, data, optimizer=torch.optim.Adam, **kwargs): opt = filter_kwargs(kwargs, ctx=TrainingContext) ctx = TrainingContext(**opt.ctx) ctx.optimizer = optimizer # networks to device energy_target = deepcopy(energy) ctx.register(data=to_device(data, ctx.device), energy=to_device(energy, ctx.device), energy_target=to_device(energy_target, ctx.device)) return ctx
def SupervisedTraining(net, data, valid_data, losses=None, **kwargs): ctx = TrainingContext(kwargs["network_name"], **kwargs) net = to_device(net, ctx.device) data = DataDistribution(data, batch_size=ctx.batch_size, num_workers=ctx.num_workers) valid_data = DataDistribution(valid_data, batch_size=ctx.batch_size, num_workers=ctx.num_workers) ctx.checkpoint.add_checkpoint(net=net) ctx.loop \ .add(train=UpdateStep( canned_supervised(ctx, net.train(), data, losses), Update(net, optimizer=torch.optim.Adam) )) \ .add(valid=Step( canned_supervised(ctx, net.eval(), valid_data, losses) )) return ctx
def base_dre_training(energy, base, data, train_base=True, base_step=maximum_likelihood_step, optimizer=torch.optim.Adam, base_optimizer_kwargs=None, **kwargs): opt = filter_kwargs(kwargs, ctx=TrainingContext) ctx = TrainingContext(**opt.ctx) ctx.optimizer = optimizer # networks to device ctx.register(data=to_device(data, ctx.device), base=to_device(base, ctx.device), energy=to_device(energy, ctx.device)) if train_base: ctx.add(base_step=UpdateStep(partial(base_step, ctx.base, ctx.data), Update([ctx.base], optimizer=ctx.optimizer, **(base_optimizer_kwargs or {})), ctx=ctx)) return ctx
def conditional_mle_training(model, data, valid_data=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, eval_no_grad=True, **kwargs): opt = filter_kwargs(kwargs, ctx=TrainingContext) ctx = TrainingContext(**opt.ctx) ctx.optimizer = optimizer # networks to device ctx.register(data=to_device(data, ctx.device), model=to_device(model, ctx.device)) ctx.add(train_step=UpdateStep( partial(maximum_likelihood_step, ctx.model, ctx.data), Update( [ctx.model], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), ctx=ctx)) if valid_data is not None: ctx.register(valid_data=to_device(valid_data, ctx.device)) ctx.add(valid_step=EvalStep(partial(maximum_likelihood_step, ctx.model, ctx.valid_data), modules=[ctx.model], no_grad=eval_no_grad, ctx=ctx), every=ctx.report_interval) return ctx
def supervised_training(net, data, valid_data=None, losses=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, eval_no_grad=True, **kwargs): opt = filter_kwargs(kwargs, ctx=TrainingContext) ctx = TrainingContext(**opt.ctx) ctx.optimizer = optimizer ctx.losses = losses # networks to device ctx.register(data=to_device(data, ctx.device), net=to_device(net, ctx.device)) ctx.add(train_step=UpdateStep( partial(supervised_step, ctx.net, ctx.data, losses=ctx.losses), Update([ctx.net], optimizer=ctx.optimizer, **(optimizer_kwargs or {})), ctx=ctx)) if valid_data is not None: ctx.register(valid_data=to_device(valid_data, ctx.device)) ctx.add(valid_step=EvalStep(partial(supervised_step, ctx.net, ctx.valid_data, losses=ctx.losses), modules=[ctx.net], no_grad=eval_no_grad, ctx=ctx), every=ctx.report_interval) return ctx