Exemplo n.º 1
0
def test_trainer():
    trainer = MyTrainer(Params())

    trainer.params.eidx = 3
    fn = trainer.save_keypoint()
    trainer.train()
    assert trainer.params.eidx == trainer.params.epoch
    trainer.load_checkpoint(fn)
    assert trainer.params.eidx == 3
Exemplo n.º 2
0
        from torch.optim import SGD
        self.model = MyModel()
        self.optim = SGD(self.model.parameters(), lr=params.lr)
        self.cross = nn.CrossEntropyLoss()

    def train_batch(self, eidx, idx, global_step, batch_data, params, device):
        optim, cross = self.optim, self.cross
        meter = Meter()
        xs, ys = batch_data

        # 训练逻辑
        logits = self.model(xs)
        meter.loss = cross(logits, ys)

        # 反向传播
        meter.loss.backward()
        optim.step()
        optim.zero_grad()

        return meter


params = Params()
params.epoch = 5
params.lr = 0.1

params.dataset = 'svhn'

trainer = MyTrainer(params)
trainer.train()
Exemplo n.º 3
0
def get_params():
    p = Params()
    p.git_commit = False
    return p
Exemplo n.º 4
0
"""

"""

from thexp import Params
params = Params()
print(params)

params.epoch = 400
params.batch_size = 25
print(params)

from thexp import Params


class MyParams(Params):
    def __init__(self):
        super().__init__()
        self.batch_size = 50
        self.topk = (1, 2, 3, 4)
        self.optim = dict(lr=0.009, moment=0.9)


params = MyParams()
print(params)

from thexp import Params
params = Params()
params.choice("dataset", "mnist", "cifar10", "cifar100", "svhn")
params.arange("thresh", 5, 0, 20)
print(params)
Exemplo n.º 5
0
    For commercial projects that require the ability to distribute 
    the code of this program as part of a program that cannot be 
    distributed under the GNU General Public License, please contact 
            
            [email protected]
             
    to purchase a commercial license.
"""

from thexp import Trainer,Params
import random

class myTrainer(Trainer):
    pass

trainer = myTrainer(Params())

for i in range(50):
    trainer.logger.info(i)

for i in range(20):
    trainer.writter.add_scalar("test",random.random(),i)

# trainer.saver
# trainer.rnd


# ======================================

import torch.nn as nn
class MyModel(nn.Module):