コード例 #1
0
    def setUp(self):
        # Build model
        net = self._make_test_model()

        # Build trainer
        self.trainer = Trainer(net)\
            .build_logger(TensorboardLogger(send_image_at_batch_indices=0,
                                            send_image_at_channel_indices='all',
                                            log_images_every=(20, 'iterations')),
                          log_directory=os.path.join(self.ROOT_DIR, 'logs'))\
            .build_criterion('CrossEntropyLoss')\
            .build_metric('CategoricalError')\
            .build_optimizer('Adam')\
            .validate_every((1, 'epochs'))\
            .save_every((2, 'epochs'), to_directory=os.path.join(self.ROOT_DIR, 'saves'))\
            .save_at_best_validation_score()\
            .set_max_num_epochs(2)\
            .cuda().set_precision(self.PRECISION)

        # Load CIFAR10 data
        train_loader, test_loader = \
            get_cifar10_loaders(root_directory=os.path.join(self.ROOT_DIR, 'data'),
                                download=self.DOWNLOAD_CIFAR)

        # Bind loaders
        self.trainer.bind_loader('train', train_loader).bind_loader('validate', test_loader)
コード例 #2
0
ファイル: basic.py プロジェクト: BenJamesbabala/inferno
 def test_cifar(self):
     from inferno.trainers.basic import Trainer
     from inferno.io.box.cifar10 import get_cifar10_loaders
     # Build cifar10 loaders
     trainloader, testloader = get_cifar10_loaders(
         root_directory=join(self.ROOT_DIR, 'data'),
         download=self.DOWNLOAD_CIFAR)
     # Make model
     net = self._make_test_model()
     tic = time.time()
     # Make trainer
     trainer = Trainer(model=net)\
         .build_optimizer('Adam')\
         .build_criterion('CrossEntropyLoss')\
         .build_metric('CategoricalError')\
         .validate_every((1, 'epochs'))\
         .save_every((1, 'epochs'), to_directory=join(self.ROOT_DIR, 'saves'))\
         .save_at_best_validation_score()\
         .set_max_num_epochs(2)
     # Bind trainer to datasets
     trainer.bind_loader('train',
                         trainloader).bind_loader('validate', testloader)
     # Check device and fit
     if self.CUDA:
         if self.HALF_PRECISION:
             trainer.cuda().set_precision('half').fit()
         else:
             trainer.cuda().fit()
     else:
         trainer.fit()
     toc = time.time()
     print("[*] Elapsed time: {} seconds.".format(toc - tic))
コード例 #3
0
ファイル: basic.py プロジェクト: BenJamesbabala/inferno
    def test_multi_gpu(self):
        import torch
        if not torch.cuda.is_available():
            return

        from inferno.trainers.basic import Trainer
        from inferno.io.box.cifar10 import get_cifar10_loaders
        import os

        # Make model
        net = self._make_test_model()
        # Make trainer
        trainer = Trainer(model=net) \
            .build_optimizer('Adam') \
            .build_criterion('CrossEntropyLoss') \
            .build_metric('CategoricalError') \
            .validate_every((1, 'epochs')) \
            .save_every((1, 'epochs'), to_directory=os.path.join(self.ROOT_DIR, 'saves')) \
            .save_at_best_validation_score() \
            .set_max_num_epochs(2)\
            .cuda(devices=[0, 1, 2, 3])

        train_loader, validate_loader = get_cifar10_loaders(
            root_directory=self.ROOT_DIR, download=True)
        trainer.bind_loader('train', train_loader)
        trainer.bind_loader('validate', validate_loader)

        trainer.fit()
コード例 #4
0
 def test_serialization(self):
     if not hasattr(self, 'trainer'):
         self.setUp()
     # Serialize
     self.trainer.save()
     # Unserialize
     trainer = Trainer().load(os.path.join(self.ROOT_DIR, 'saves'))
     train_loader, test_loader = \
         get_cifar10_loaders(root_directory=os.path.join(self.ROOT_DIR, 'data'),
                             download=self.DOWNLOAD_CIFAR)
     trainer.bind_loader('train', train_loader).bind_loader('validate', test_loader)
     trainer.fit()
     trainer.print("Inspect logs at: {}".format(self.trainer.log_directory))
コード例 #5
0
DATASET_DIRECTORY = 'data'
DOWNLOAD_CIFAR = True
USE_CUDA = False

# Build torch model
model = nn.Sequential(
    ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2),
    ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2),
    ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2), Flatten(),
    nn.Linear(in_features=(256 * 4 * 4), out_features=10), nn.Softmax())

# Load loaders
train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY,
                                                    download=DOWNLOAD_CIFAR)

# Build trainer
trainer = Trainer(model) \
  .build_criterion('CrossEntropyLoss') \
  .build_metric('CategoricalError') \
  .build_optimizer('Adam') \
  .validate_every((2, 'epochs')) \
  .save_every((5, 'epochs')) \
  .save_to_directory(SAVE_DIRECTORY) \
  .set_max_num_epochs(10) \
  .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
                                  log_images_every='never'),
                log_directory=LOG_DIRECTORY)

# Bind loaders
コード例 #6
0
model = CNN_2_EDropout(BATCH_SIZE)

model_save_path = Path('./' + MODEL_SAVE_DIRECTORY)
if not (model_save_path.exists() and model_save_path.is_dir()):
    model_save_path.mkdir()

# create directory for this run
run_directory_path = model_save_path / str(time.time()).replace('.', '')
run_directory_path.mkdir()

# logger directory
logger = Logger(LOGGER_SAVE_DIRECTORY)

train_loader, validate_loader = get_cifar10_loaders(
    DATASET_DIRECTORY,
    train_batch_size=BATCH_SIZE,
    test_batch_size=VALID_BATCH_SIZE,
    download=DOWNLOAD_CIFAR)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

Optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
criterion = nn.CrossEntropyLoss()

iterations = 0
for e in range(EPOCH):
    model.train()
    epoch_loss = 0
    e_start = timeit.default_timer()
    batch_idx = 0