Exemplo n.º 1
0
# loss_fn =
#loss_fn = torch.nn.BCELoss()
loss_fn = torch.nn.CrossEntropyLoss()
metrics = {'fps': training.BatchTimer(), 'acc': training.accuracy}

# Train
writer = SummaryWriter()
writer.iteration, writer.interval = 0, 10

print('\n\nInitial')
print('-' * 10)
net.eval()
training.pass_epoch(net,
                    loss_fn,
                    val_loader,
                    batch_metrics=metrics,
                    show_running=True,
                    device=device,
                    writer=writer)

for epoch in range(epochs):
    print('\nEpoch {}/{}'.format(epoch + 1, epochs))
    print('-' * 10)

    net.train()
    training.pass_epoch(net,
                        loss_fn,
                        train_loader,
                        optimizer,
                        scheduler,
                        batch_metrics=metrics,
Exemplo n.º 2
0
    def train(self, save_model=True):
        batch_size = 32
        epochs = 100
        workers = 0 if os.name == 'nt' else 8

        optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        scheduler = MultiStepLR(optimizer, [5, 10])

        dataset = self.get_train_dataset()
        img_inds = np.arange(len(dataset))
        np.random.shuffle(img_inds)
        train_inds = img_inds[:int(0.8 * len(img_inds))]
        val_inds = img_inds[int(0.8 * len(img_inds)):]

        train_loader = DataLoader(dataset,
                                  num_workers=workers,
                                  batch_size=batch_size,
                                  sampler=SubsetRandomSampler(train_inds))
        val_loader = DataLoader(dataset,
                                num_workers=workers,
                                batch_size=batch_size,
                                sampler=SubsetRandomSampler(val_inds))

        loss_fn = torch.nn.CrossEntropyLoss()
        metrics = {'fps': training.BatchTimer(), 'acc': training.accuracy}

        writer = SummaryWriter()
        writer.iteration, writer.interval = 0, 10

        print('\n\nInitial')
        print('-' * 10)
        self.model.eval()
        training.pass_epoch(self.model,
                            loss_fn,
                            val_loader,
                            batch_metrics=metrics,
                            show_running=True,
                            writer=writer)

        for epoch in tqdm(range(epochs)):
            print('\nEpoch {}/{}'.format(epoch + 1, epochs))
            print('-' * 10)

            self.model.train()
            training.pass_epoch(self.model,
                                loss_fn,
                                train_loader,
                                optimizer,
                                scheduler,
                                batch_metrics=metrics,
                                show_running=True,
                                writer=writer)

            self.model.eval()
            training.pass_epoch(self.model,
                                loss_fn,
                                val_loader,
                                batch_metrics=metrics,
                                show_running=True,
                                writer=writer)

            writer.close()

        if save_model:
            self.save_model()
Exemplo n.º 3
0
np.random.shuffle(img_inds)
train_inds = img_inds[:int(0.9 * len(img_inds))]
val_inds = img_inds[int(0.9 * len(img_inds)):]

train_loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(train_inds)
)
val_loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(val_inds)
)

loss_fn = torch.nn.CrossEntropyLoss()
metrics = {
    'fps': training.BatchTimer(),
    'acc': training.accuracy
}

print('\n\nInitial')
print('-' * 10)
resnet.eval()
training.pass_epoch(
    resnet, loss_fn, val_loader,
    batch_metrics=metrics, show_running=True, device=device
)
Exemplo n.º 4
0
def train_model(db_id):
    start_epoch = 0
    batch_size = 32
    epochs = 5
    workers = 2
    train_transform = transforms.Compose([
             transforms.ToPILImage(),
             transforms.RandomHorizontalFlip(p=0.5),
             np.float32,
             transforms.ToTensor(),
             fixed_image_standardization
    ])
    images, num_classes = get_dataset(db_id)
    dataset = MyCustomDataset(images, train_transform)
    train_loader = DataLoader(
                    dataset,
                    num_workers=workers,
                    batch_size=batch_size
                    )
    model = InceptionResnetV1(
                 classify=True,
                 num_classes=num_classes
            ).to(device)
    checkpoint_path, checkpoint_file, label_dict = get_saved_model(db_id)
    if checkpoint_path is not None and os.path.exists(checkpoint_path):
         checkpoint = torch.load(checkpoint_file)
         model.load_state_dict(checkpoint['net'])
         start_epoch = checkpoint['epoch']
    else:
        checkpoint_path = "./checkpoint"

    optimizer = optim.SGD(model.parameters(), lr=0.1)
    scheduler = MultiStepLR(optimizer, [60, 120, 180])
    loss_fn = torch.nn.CrossEntropyLoss()
    metrics = {
      'fps': training.BatchTimer(),
      'acc': training.accuracy
    }

    writer = SummaryWriter(log_dir=None, comment='', purge_step=None, max_queue=10, flush_secs=600, filename_suffix='face_rec_log_')
    writer.iteration, writer.interval = 1, 10

    checkpoint_save_name = 'face_rec_test'
    ckp_dir = checkpoint_path
    ckp_name = ''
    for epoch in range(epochs):
        training.pass_epoch(
              model, loss_fn, train_loader, optimizer, scheduler,
              batch_metrics=metrics, show_running=False, device=device,
              writer=writer
        )

        if (epoch+1) % 50 == 0:
            print('Saving..')
            state = {
               'net': model.state_dict(),
               'epoch': epoch,
               'is_final' : 0
            }
            ckp_name = checkpoint_save_name+'_'+str(epoch+1)
                       #if not os.path.isdir('checkpoint'):
            os.makedirs(ckp_dir, exist_ok=True)
            torch.save(state, ckp_dir+'/'+ckp_name+'.pth')
        writer.close()

    
    state = {
        'net': model.state_dict(),
        'epoch': epochs,
        'is_final' : 1
    }
    ckp_name = checkpoint_save_name+'_final'
    os.makedirs(ckp_dir, exist_ok=True)
    save_path = ckp_dir+'/'+ckp_name+'.pth'
    torch.save(state, save_path)
    update_model(db_id, save_path)