# loss_fn = #loss_fn = torch.nn.BCELoss() loss_fn = torch.nn.CrossEntropyLoss() metrics = {'fps': training.BatchTimer(), 'acc': training.accuracy} # Train writer = SummaryWriter() writer.iteration, writer.interval = 0, 10 print('\n\nInitial') print('-' * 10) net.eval() training.pass_epoch(net, loss_fn, val_loader, batch_metrics=metrics, show_running=True, device=device, writer=writer) for epoch in range(epochs): print('\nEpoch {}/{}'.format(epoch + 1, epochs)) print('-' * 10) net.train() training.pass_epoch(net, loss_fn, train_loader, optimizer, scheduler, batch_metrics=metrics,
def train(self, save_model=True): batch_size = 32 epochs = 100 workers = 0 if os.name == 'nt' else 8 optimizer = optim.Adam(self.model.parameters(), lr=0.001) scheduler = MultiStepLR(optimizer, [5, 10]) dataset = self.get_train_dataset() img_inds = np.arange(len(dataset)) np.random.shuffle(img_inds) train_inds = img_inds[:int(0.8 * len(img_inds))] val_inds = img_inds[int(0.8 * len(img_inds)):] train_loader = DataLoader(dataset, num_workers=workers, batch_size=batch_size, sampler=SubsetRandomSampler(train_inds)) val_loader = DataLoader(dataset, num_workers=workers, batch_size=batch_size, sampler=SubsetRandomSampler(val_inds)) loss_fn = torch.nn.CrossEntropyLoss() metrics = {'fps': training.BatchTimer(), 'acc': training.accuracy} writer = SummaryWriter() writer.iteration, writer.interval = 0, 10 print('\n\nInitial') print('-' * 10) self.model.eval() training.pass_epoch(self.model, loss_fn, val_loader, batch_metrics=metrics, show_running=True, writer=writer) for epoch in tqdm(range(epochs)): print('\nEpoch {}/{}'.format(epoch + 1, epochs)) print('-' * 10) self.model.train() training.pass_epoch(self.model, loss_fn, train_loader, optimizer, scheduler, batch_metrics=metrics, show_running=True, writer=writer) self.model.eval() training.pass_epoch(self.model, loss_fn, val_loader, batch_metrics=metrics, show_running=True, writer=writer) writer.close() if save_model: self.save_model()
np.random.shuffle(img_inds) train_inds = img_inds[:int(0.9 * len(img_inds))] val_inds = img_inds[int(0.9 * len(img_inds)):] train_loader = DataLoader( dataset, num_workers=workers, batch_size=batch_size, sampler=SubsetRandomSampler(train_inds) ) val_loader = DataLoader( dataset, num_workers=workers, batch_size=batch_size, sampler=SubsetRandomSampler(val_inds) ) loss_fn = torch.nn.CrossEntropyLoss() metrics = { 'fps': training.BatchTimer(), 'acc': training.accuracy } print('\n\nInitial') print('-' * 10) resnet.eval() training.pass_epoch( resnet, loss_fn, val_loader, batch_metrics=metrics, show_running=True, device=device )
def train_model(db_id): start_epoch = 0 batch_size = 32 epochs = 5 workers = 2 train_transform = transforms.Compose([ transforms.ToPILImage(), transforms.RandomHorizontalFlip(p=0.5), np.float32, transforms.ToTensor(), fixed_image_standardization ]) images, num_classes = get_dataset(db_id) dataset = MyCustomDataset(images, train_transform) train_loader = DataLoader( dataset, num_workers=workers, batch_size=batch_size ) model = InceptionResnetV1( classify=True, num_classes=num_classes ).to(device) checkpoint_path, checkpoint_file, label_dict = get_saved_model(db_id) if checkpoint_path is not None and os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_file) model.load_state_dict(checkpoint['net']) start_epoch = checkpoint['epoch'] else: checkpoint_path = "./checkpoint" optimizer = optim.SGD(model.parameters(), lr=0.1) scheduler = MultiStepLR(optimizer, [60, 120, 180]) loss_fn = torch.nn.CrossEntropyLoss() metrics = { 'fps': training.BatchTimer(), 'acc': training.accuracy } writer = SummaryWriter(log_dir=None, comment='', purge_step=None, max_queue=10, flush_secs=600, filename_suffix='face_rec_log_') writer.iteration, writer.interval = 1, 10 checkpoint_save_name = 'face_rec_test' ckp_dir = checkpoint_path ckp_name = '' for epoch in range(epochs): training.pass_epoch( model, loss_fn, train_loader, optimizer, scheduler, batch_metrics=metrics, show_running=False, device=device, writer=writer ) if (epoch+1) % 50 == 0: print('Saving..') state = { 'net': model.state_dict(), 'epoch': epoch, 'is_final' : 0 } ckp_name = checkpoint_save_name+'_'+str(epoch+1) #if not os.path.isdir('checkpoint'): os.makedirs(ckp_dir, exist_ok=True) torch.save(state, ckp_dir+'/'+ckp_name+'.pth') writer.close() state = { 'net': model.state_dict(), 'epoch': epochs, 'is_final' : 1 } ckp_name = checkpoint_save_name+'_final' os.makedirs(ckp_dir, exist_ok=True) save_path = ckp_dir+'/'+ckp_name+'.pth' torch.save(state, save_path) update_model(db_id, save_path)