コード例 #1
0
ファイル: cifar_cnce.py プロジェクト: mjendrusch/torchsupport
def valid_callback(args, ctx: TrainingContext = None):
    ctx.log(images=LogImage(args.sample))
    labels = args.prediction.argmax(dim=1)
    for idx in range(10):
        positive = args.sample[labels == idx]
        if positive.size(0) != 0:
            ctx.log(**{f"classified {idx}": LogImage(positive)})
コード例 #2
0
def valid_callback(args, ctx: TrainingContext = None):
    ctx.log(images=LogImage(args.condition))
    labels = args.distribution.logits.argmax(dim=1)
    for idx in range(10):
        positive = args.condition[labels == idx]
        if positive.size(0) != 0:
            ctx.log(**{f"classified {idx}": LogImage(positive)})
コード例 #3
0
def base_sm_training(energy, data, optimizer=torch.optim.Adam, **kwargs):
    opt = filter_kwargs(kwargs, ctx=TrainingContext)
    ctx = TrainingContext(**opt.ctx)
    ctx.optimizer = optimizer

    # networks to device
    energy_target = deepcopy(energy)
    ctx.register(data=to_device(data, ctx.device),
                 energy=to_device(energy, ctx.device),
                 energy_target=to_device(energy_target, ctx.device))

    return ctx
コード例 #4
0
def SupervisedTraining(net, data, valid_data, losses=None, **kwargs):
    ctx = TrainingContext(kwargs["network_name"], **kwargs)
    net = to_device(net, ctx.device)
    data = DataDistribution(data,
                            batch_size=ctx.batch_size,
                            num_workers=ctx.num_workers)
    valid_data = DataDistribution(valid_data,
                                  batch_size=ctx.batch_size,
                                  num_workers=ctx.num_workers)
    ctx.checkpoint.add_checkpoint(net=net)
    ctx.loop \
    .add(train=UpdateStep(
      canned_supervised(ctx, net.train(), data, losses),
      Update(net, optimizer=torch.optim.Adam)
    )) \
    .add(valid=Step(
      canned_supervised(ctx, net.eval(), valid_data, losses)
    ))
    return ctx
コード例 #5
0
def base_dre_training(energy,
                      base,
                      data,
                      train_base=True,
                      base_step=maximum_likelihood_step,
                      optimizer=torch.optim.Adam,
                      base_optimizer_kwargs=None,
                      **kwargs):
    opt = filter_kwargs(kwargs, ctx=TrainingContext)
    ctx = TrainingContext(**opt.ctx)
    ctx.optimizer = optimizer

    # networks to device
    ctx.register(data=to_device(data, ctx.device),
                 base=to_device(base, ctx.device),
                 energy=to_device(energy, ctx.device))

    if train_base:
        ctx.add(base_step=UpdateStep(partial(base_step, ctx.base, ctx.data),
                                     Update([ctx.base],
                                            optimizer=ctx.optimizer,
                                            **(base_optimizer_kwargs or {})),
                                     ctx=ctx))
    return ctx
コード例 #6
0
def conditional_mle_training(model,
                             data,
                             valid_data=None,
                             optimizer=torch.optim.Adam,
                             optimizer_kwargs=None,
                             eval_no_grad=True,
                             **kwargs):
    opt = filter_kwargs(kwargs, ctx=TrainingContext)
    ctx = TrainingContext(**opt.ctx)
    ctx.optimizer = optimizer

    # networks to device
    ctx.register(data=to_device(data, ctx.device),
                 model=to_device(model, ctx.device))

    ctx.add(train_step=UpdateStep(
        partial(maximum_likelihood_step, ctx.model, ctx.data),
        Update(
            [ctx.model], optimizer=ctx.optimizer, **(optimizer_kwargs or {})),
        ctx=ctx))

    if valid_data is not None:
        ctx.register(valid_data=to_device(valid_data, ctx.device))
        ctx.add(valid_step=EvalStep(partial(maximum_likelihood_step, ctx.model,
                                            ctx.valid_data),
                                    modules=[ctx.model],
                                    no_grad=eval_no_grad,
                                    ctx=ctx),
                every=ctx.report_interval)
    return ctx
コード例 #7
0
ファイル: supervised.py プロジェクト: mjendrusch/torchsupport
def supervised_training(net,
                        data,
                        valid_data=None,
                        losses=None,
                        optimizer=torch.optim.Adam,
                        optimizer_kwargs=None,
                        eval_no_grad=True,
                        **kwargs):
    opt = filter_kwargs(kwargs, ctx=TrainingContext)
    ctx = TrainingContext(**opt.ctx)
    ctx.optimizer = optimizer
    ctx.losses = losses

    # networks to device
    ctx.register(data=to_device(data, ctx.device),
                 net=to_device(net, ctx.device))

    ctx.add(train_step=UpdateStep(
        partial(supervised_step, ctx.net, ctx.data, losses=ctx.losses),
        Update([ctx.net], optimizer=ctx.optimizer, **(optimizer_kwargs or {})),
        ctx=ctx))

    if valid_data is not None:
        ctx.register(valid_data=to_device(valid_data, ctx.device))
        ctx.add(valid_step=EvalStep(partial(supervised_step,
                                            ctx.net,
                                            ctx.valid_data,
                                            losses=ctx.losses),
                                    modules=[ctx.net],
                                    no_grad=eval_no_grad,
                                    ctx=ctx),
                every=ctx.report_interval)
    return ctx