コード例 #1
0
ファイル: trainer_test.py プロジェクト: abhaikollara/flare
 def test_predict(self):
     model = SingleInputModel()
     trainer = Trainer(model,
                       nn.CrossEntropyLoss(),
                       Adam(model.parameters()),
                       metrics=[Accuracy()])
     trainer.predict(train_x)
コード例 #2
0
ファイル: trainer_test.py プロジェクト: abhaikollara/flare
 def test_evaluate(self):
     model = SingleInputModel()
     trainer = Trainer(model,
                       nn.CrossEntropyLoss(),
                       Adam(model.parameters()),
                       metrics=[Accuracy()])
     trainer.evaluate(train_x, train_y)
コード例 #3
0
ファイル: trainer_test.py プロジェクト: abhaikollara/flare
 def test_train(self):
     model = SingleInputModel()
     trainer = Trainer(model,
                       nn.CrossEntropyLoss(),
                       Adam(model.parameters()),
                       metrics=[Accuracy()])
     history = trainer.train(train_x, train_y, epochs=2)
     train_logs = history.train_logs
     assert train_logs['Loss'][0] > train_logs['Loss'][1]
コード例 #4
0
ファイル: trainer_test.py プロジェクト: abhaikollara/flare
    def test_train_with_validation_split(self):
        model = SingleInputModel()
        trainer = Trainer(model,
                          nn.CrossEntropyLoss(),
                          Adam(model.parameters()),
                          metrics=[Accuracy()])
        history = trainer.train(train_x,
                                train_y,
                                epochs=2,
                                validation_split=0.2)
        train_logs, test_logs = history.train_logs, history.test_logs

        assert train_logs['Loss'][0] > train_logs['Loss'][1]
        assert test_logs['Loss'][0] > test_logs['Loss'][1]
コード例 #5
0
ファイル: trainer_test.py プロジェクト: abhaikollara/flare
 def test_train_generator(self):
     model = SingleInputModel()
     trainer = Trainer(model,
                       nn.CrossEntropyLoss(),
                       Adam(model.parameters()),
                       metrics=[Accuracy()])
コード例 #6
0
ファイル: mnist_flare.py プロジェクト: abhaikollara/flare
        return F.log_softmax(x, dim=1)


train_data = np.load(os.path.abspath("./data/mnist_train.npz"))
train_x, train_y = train_data['X_train'], train_data['Y_train']

train_x = np.expand_dims(train_x, 1).astype('float32')
train_y = train_y.reshape(-1)

test_data = np.load(os.path.abspath("./data/mnist_test.npz"))
test_x, test_y = test_data['X_test'], test_data['Y_test']

test_x = np.expand_dims(test_x, 1).astype('float32')
test_y = test_y.reshape(-1)

train_x /= 255.0
test_x /= 255.0

model = Net()
trainer = Trainer(model,
                  F.nll_loss,
                  Adam(model.parameters()),
                  metrics=[Accuracy()])
history = trainer.train(train_x,
                        train_y,
                        batch_size=64,
                        epochs=2,
                        validation_split=0.2)
print(history.train_logs)
print(history.test_logs)