def train(params, log): # specify dataset data = DatasetFactory.create(params) # specify model model = ModelFactory.create(params) # define loss function (criterion) criterion = set_loss_function(params) # optimizer & scheduler & load from checkpoint optimizer, scheduler, start_epoch, best_prec = set_optimizer_scheduler(params, model, log) # log details log_string = "\n" + "==== NET MODEL:\n" + str(model) log_string += "\n" + "==== OPTIMIZER:\n" + str(optimizer) + "\n" log_string += "\n" + "==== SCHEDULER:\n" + str(scheduler) + "\n" log_string += "\n" + "==== DATASET (TRAIN):\n" + str(data.dataset['train']) + "\n" log_string += "\n" + "==== DATASET (VAL):\n" + repr(data.dataset['val']) + "\n" log.log_global(log_string) # train for epoch in range(start_epoch, params['TRAIN']['epochs']): # train for one epoch _, _ = train_epoch(data.loader['train'], model, criterion, optimizer, scheduler, epoch, params['device'], log) # evaluate on train set acc_train, loss_train = validate(data.loader['train'], model, criterion, params['device']) # evaluate on validation set acc_val, loss_val = validate(data.loader['val'], model, criterion, params['device']) # remember best prec@1 is_best = acc_val > best_prec best_prec = max(acc_val, best_prec) # save checkpoint if params['LOG']['do_checkpoint']: save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec': best_prec, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler }, model, params, is_best) # logging results time_string = log.timers['global'].current2str() # get current time log.log_epoch(epoch + 1, acc_train, loss_train, acc_val, loss_val, is_best, time_string)
def train(params, log, time_keeper): # specify dataset data = DatasetFactory.create(params) # specify model model = ModelFactory.create(params) model = model.to(params['device']) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().to(params['device']) optimizer = torch.optim.SGD(model.parameters(), lr=params['TRAIN']['lr'], momentum=params['TRAIN']['momentum']) # resume from a checkpoint if len(params['TRAIN']['resume']) > 0: start_epoch, best_prec = load_checkpoint(log, model, params['TRAIN']['resume'], optimizer) else: start_epoch = 0 best_prec = 0 # scheduler (if any) if 'lr_schedule_step' in params['TRAIN']: scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=params['TRAIN']['lr_schedule_step'], gamma=params['TRAIN']['lr_schedule_gamma']) else: scheduler = None # log details log_string = "\n" + "==== NET MODEL:\n" + str(model) log_string += "\n" + "==== OPTIMIZER:\n" + str(optimizer) + "\n" log.log_global(log_string) time_keeper.start() # train for epoch in range(start_epoch, params['TRAIN']['epochs']): # adjust_learning_rate if scheduler: scheduler.step() # train for one epoch _, _ = train_epoch(data.loader['train'], model, criterion, optimizer, epoch, params['device'], log, timer) # evaluate on train set acc_train, loss_train = validate(data.loader['train'], model, criterion, params['device']) # evaluate on validation set acc_val, loss_val = validate(data.loader['val'], model, criterion, params['device']) # remember best prec@1 and save checkpoint is_best = acc_val > best_prec best_prec = max(acc_val, best_prec) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec': best_prec, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler }, model, params, is_best) # logging results time_string = time_keeper.get_current_str() # get current time log.log_epoch(epoch + 1, acc_train, loss_train, acc_val, loss_val, is_best, time_string)
def train(params, log, time_keeper, tboard_exp_path): writer = SummaryWriter(tboard_exp_path) # specify dataset data = DatasetFactory.create(params) # specify model model = ModelFactory.create(params) model = model.to(params['device']) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().to(params['device']) if params['MODEL']['name'] == 'resnet18': optimizer = torch.optim.SGD(model.fc.parameters(), lr=params['TRAIN']['lr'], momentum=params['TRAIN']['momentum']) elif params['MODEL']['name'] == 'vgg19': optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=params['TRAIN']['lr'], momentum=params['TRAIN']['momentum']) else: optimizer = torch.optim.SGD(model.parameters(), lr=params['TRAIN']['lr'], momentum=params['TRAIN']['momentum']) # DRAFT FOR EXPERIMENT exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1) # resume from a checkpoint if len(params['TRAIN']['resume']) > 0: start_epoch, best_prec = load_checkpoint(log, model, params['TRAIN']['resume'], optimizer) else: start_epoch = 0 best_prec = 0 # scheduler (if any) if 'lr_schedule_step' in params['TRAIN']: scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=params['TRAIN']['lr_schedule_step'], gamma=params['TRAIN']['lr_schedule_gamma']) else: scheduler = None # log details log_string = "\n" + "==== NET MODEL:\n" + str(model) log_string += "\n" + "==== OPTIMIZER:\n" + str(optimizer) + "\n" log.log_global(log_string) time_keeper.start() # train for epoch in range(start_epoch, params['TRAIN']['epochs']): # adjust_learning_rate if scheduler: scheduler.step() # train for one epoch _, _ = train_epoch(data.loader['train'], model, criterion, optimizer, epoch, params['device'], log, timer, writer, exp_lr_scheduler) # evaluate on train set acc_train, loss_train = validate(data.loader['train'], model, criterion, params['device']) # evaluate on validation set acc_val, loss_val = validate(data.loader['val'], model, criterion, params['device']) correct = 0 total = 0 class_correct = list(0. for i in range(len(classes))) class_total = list(0. for i in range(len(classes))) hoodie_not_correct = [] with torch.no_grad(): for (input_data, target) in data.loader['train']: images = input_data.to(params['device']) labels = target.to(params['device']) outputs = model(images) _, predicted = torch.max(outputs.data, 1) c = (predicted == labels).squeeze() total += labels.size(0) correct += (predicted == labels).sum().item() for j in range(4): label = labels[j] class_correct[label] += c[j].item() class_total[label] += 1 print('Accuracy of the network on the % epoch test images: %d %%' % (epoch, 100 * correct / total)) for k in range(len(classes)): print('Accuracy of %5s : %2d %%' % (classes[k], 100 * class_correct[k] / class_total[k])) writer.add_scalar('Accuracy of %5s' % classes[k], (100 * class_correct[k] / class_total[k]), epoch + 1) # remember best prec@1 and save checkpoint is_best = acc_val > best_prec best_prec = max(acc_val, best_prec) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec': best_prec, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler }, model, params, is_best) # logging results time_string = time_keeper.get_current_str() # get current time log.log_epoch(epoch + 1, acc_train, loss_train, acc_val, loss_val, is_best, time_string) writer.add_scalar("Train accuracy ", acc_train, epoch + 1) writer.add_scalar("Train Loss ", loss_train, epoch + 1) writer.add_scalar("Test accuracy ", acc_val, epoch + 1) writer.add_scalar("Test Loss ", loss_val, epoch + 1) exp_lr_scheduler.step()
def train(params, log, time_keeper): # specify dataset dataset = DatasetFactory.create(params) # specify model model = ModelFactory.create(params) model = model.to(params['device']) # optiimizer optimizer = SGD(model.parameters(), lr=params['TRAIN']['lr'], momentum=params['TRAIN']['momentum']) # scheduler scheduler = None # best accuracy(precision) best_prec = 0 # optionally resume from a checkpoint checkpoint_file = params['TRAIN']['resume'] start_epoch, best_prec = load_checkpoint(log, model, checkpoint_file, optimizer, scheduler) trainer = create_supervised_trainer(model, optimizer, F.cross_entropy, device=params['device']) # evaluator evaluator = create_supervised_evaluator(model, metrics={ 'accuracy': CategoricalAccuracy(), 'cross_entropy': Loss(F.cross_entropy) }, device=params['device']) # log details log_string = "\n" + "==== NET MODEL:\n" + str(model) log_string += "\n" + "==== OPTIMIZER:\n" + str(optimizer) + "\n" log.log_global(log_string) # end-of-iteration events @trainer.on(Events.ITERATION_COMPLETED) def on_iter(engine): iter_current = engine.state.iteration % len(dataset.loader['train']) epoch_current = engine.state.epoch num_iter = len(dataset.loader['train']) loss = engine.state.output # logging time_string = time_keeper.get_current_str() # get current time log.log_iter(iter_current, epoch_current - 1, num_iter, loss, time_string) # end-of-epoch events @trainer.on(Events.EPOCH_COMPLETED) def on_epoch(engine): nonlocal best_prec # current epoch epoch_current = engine.state.epoch # evaluation on train set evaluator.run(dataset.loader['train']) acc_train = evaluator.state.metrics['accuracy'] * 100 loss_train = evaluator.state.metrics['cross_entropy'] # evaluation on val set evaluator.run(dataset.loader['val']) acc_val = evaluator.state.metrics['accuracy'] * 100 loss_val = evaluator.state.metrics['cross_entropy'] is_best = acc_val > best_prec best_prec = max(acc_val, best_prec) save_checkpoint( { 'epoch': epoch_current + 1, 'state_dict': model.state_dict(), 'best_prec': best_prec, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler }, model, params, is_best) # logging results time_string = time_keeper.get_current_str() # get current time log.log_epoch(epoch_current, acc_train, loss_train, acc_val, loss_val, is_best, time_string) time_keeper.start() trainer.run(dataset.loader['train'], max_epochs=params['TRAIN']['epochs'])