def create_optimization(self): if self.args.loss_function == 'FocalLoss': self.loss = FocalLoss(gamma=self.args.gamma) else: self.loss = nn.CrossEntropyLoss() if self.args.cuda: self.loss.cuda() if self.args.classify: self.metric_fc = ArcMarginModel(self.args) self.optimizer = RMSprop(self.model.parameters(), self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay) else: self.optimizer = RMSprop([{ 'params': self.model.parameters() }, { 'params': self.metric_fc.parameters() }], self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
def train_net(args): torch.manual_seed(7) np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = float('-inf') writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) elif args.network == 'mobile': from mobilenet_v2 import MobileNetV2 model = MobileNetV2() else: raise TypeError('network {} is not supported.'.format( args.network)) metric_fc = ArcMarginModel(args) if args.optimizer == 'sgd': optimizer = torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] model = nn.DataParallel(model) metric_fc = nn.DataParallel(metric_fc) # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma).to(device) else: criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_dataset = ArcFaceDataset('train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers) # Epochs for epoch in range(start_epoch, args.end_epoch): # Decay learning rate if there is no improvement for 2 consecutive epochs, and terminate training after 10 if epochs_since_improvement == 10: break if epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0: checkpoint = 'BEST_checkpoint.tar' checkpoint = torch.load(checkpoint) model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] adjust_learning_rate(optimizer, 0.5) # One epoch's training train_loss, train_top1_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch) lr = optimizer.param_groups[0]['lr'] print('\nCurrent effective learning rate: {}\n'.format(lr)) # print('Step num: {}\n'.format(optimizer.step_num)) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_accuracy', train_top1_accs, epoch) writer.add_scalar('model/learning_rate', lr, epoch) if epoch % 5 == 0: # One epoch's validation megaface_acc = megaface_test(model) writer.add_scalar('model/megaface_accuracy', megaface_acc, epoch) # Check if there was an improvement is_best = megaface_acc > best_acc best_acc = max(megaface_acc, best_acc) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best)
def test(): filename = 'metric_fc.pt' metric_fc = ArcMarginModel() metric_fc.load_state_dict(torch.load(filename)) metric_fc = metric_fc.to('cpu') metric_fc.eval() train_batch_size = 30 eval_batch_size = 30 scripted_float_model_file = 'mobilenet_quantization_scripted.pth' scripted_quantized_model_file = 'mobilenet_quantization_scripted_quantized.pth' data_loader, data_loader_test = prepare_data_loaders() criterion = nn.CrossEntropyLoss() filename = 'image_matching_mobile.pt' print('loading {}...'.format(filename)) start = time.time() float_model = load_model(filename) print('elapsed {} sec'.format(time.time() - start)) print('\n Inverted Residual Block: Before fusion \n\n', float_model.features[1].conv) float_model.eval() # Fuses modules float_model.fuse_model() # Note fusion of Conv+BN+Relu and Conv+Relu print('\n Inverted Residual Block: After fusion\n\n', float_model.features[1].conv) num_eval_batches = 10 print("Size of baseline model") print_size_of_model(float_model) top1 = evaluate(float_model, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches) print('Evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg)) torch.jit.save(torch.jit.script(float_model), scripted_float_model_file) # 4. Post-training static quantization print(bcolors.HEADER + '\nPost-training static quantization' + bcolors.ENDC) num_calibration_batches = 10 myModel = load_model(filename).to('cpu') myModel.eval() # Fuse Conv, bn and relu myModel.fuse_model() # Specify quantization configuration # Start with simple min/max range estimation and per-tensor quantization of weights myModel.qconfig = torch.quantization.default_qconfig print(myModel.qconfig) torch.quantization.prepare(myModel, inplace=True) # Calibrate first print('Post Training Quantization Prepare: Inserting Observers') print('\n Inverted Residual Block:After observer insertion \n\n', myModel.features[1].conv) # Calibrate with the training set print('Calibrate with the training set') evaluate(myModel, metric_fc, criterion, data_loader, neval_batches=num_calibration_batches) print('Post Training Quantization: Calibration done') # Convert to quantized model torch.quantization.convert(myModel, inplace=True) print('Post Training Quantization: Convert done') print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n', myModel.features[1].conv) print("Size of model after quantization") print_size_of_model(myModel) top1 = evaluate(myModel, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches) print('Evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg)) per_channel_quantized_model = load_model(filename) per_channel_quantized_model.eval() per_channel_quantized_model.fuse_model() per_channel_quantized_model.qconfig = torch.quantization.get_default_qconfig('fbgemm') print(per_channel_quantized_model.qconfig) torch.quantization.prepare(per_channel_quantized_model, inplace=True) print('Calibrate with the training set') evaluate(per_channel_quantized_model, metric_fc, criterion, data_loader, num_calibration_batches) torch.quantization.convert(per_channel_quantized_model, inplace=True) top1 = evaluate(per_channel_quantized_model, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches) print('Evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg)) torch.jit.save(torch.jit.script(per_channel_quantized_model), scripted_quantized_model_file) # Speedup from quantization print(bcolors.HEADER + '\nSpeedup from quantization' + bcolors.ENDC) run_benchmark(scripted_float_model_file, data_loader_test) run_benchmark(scripted_quantized_model_file, data_loader_test) # 5. Quantization-aware training print(bcolors.HEADER + '\nQuantization-aware training' + bcolors.ENDC) qat_model = load_model(filename) qat_model.fuse_model() optimizer = torch.optim.SGD(qat_model.parameters(), lr=0.0001) qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') torch.quantization.prepare_qat(qat_model, inplace=True) print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n', qat_model.features[1].conv) num_train_batches = 20 # Train and check accuracy after each epoch for nepoch in range(8): train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'), num_train_batches) if nepoch > 3: # Freeze quantizer parameters qat_model.apply(torch.quantization.disable_observer) if nepoch > 2: # Freeze batch norm mean and variance estimates qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) # Check the accuracy after each epoch quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False) quantized_model.eval() top1 = evaluate(quantized_model, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches) print('Epoch %d :Evaluation accuracy on %d images, %2.2f' % ( nepoch, num_eval_batches * eval_batch_size, top1.avg))
def train_net(args): torch.manual_seed(7) np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = float('-inf') writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) elif args.network == 'mobile': model = MobileNetV2() else: raise TypeError('network {} is not supported.'.format( args.network)) # print(model) model = nn.DataParallel(model) metric_fc = ArcMarginModel(args) metric_fc = nn.DataParallel(metric_fc) if args.optimizer == 'sgd': optimizer = torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] logger = get_logger() # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma).to(device) else: criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_dataset = ArcFaceDataset('train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1) # Epochs for epoch in range(start_epoch, args.end_epoch): # One epoch's training train_loss, train_acc = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch, logger=logger) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_acc', train_acc, epoch) # One epoch's validation lfw_acc, threshold = lfw_test(model) writer.add_scalar('model/valid_acc', lfw_acc, epoch) writer.add_scalar('model/valid_thres', threshold, epoch) # Check if there was an improvement is_best = lfw_acc > best_acc best_acc = max(lfw_acc, best_acc) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best) scheduler.step(epoch)
def train_net(args): torch.manual_seed(7) np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = 0 writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) elif args.network == 'mobile': model = MobileNet(1.0) elif args.network == 'mr18': print("mr18") model = myResnet18() else: model = resnet_face18(args.use_se) model = nn.DataParallel(model) metric_fc = ArcMarginModel(args) metric_fc = nn.DataParallel(metric_fc) if args.optimizer == 'sgd': optimizer = torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] logger = get_logger() # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma).to(device) else: criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_dataset = ArcFaceDataset('train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1) # Epochs for epoch in range(start_epoch, args.end_epoch): scheduler.step() if args.full_log: lfw_acc, threshold = lfw_test(model) writer.add_scalar('LFW_Accuracy', lfw_acc, epoch) full_log(epoch) start = datetime.now() # One epoch's training train_loss, train_top5_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch, logger=logger, writer=writer) writer.add_scalar('Train_Loss', train_loss, epoch) writer.add_scalar('Train_Top5_Accuracy', train_top5_accs, epoch) end = datetime.now() delta = end - start print('{} seconds'.format(delta.seconds)) # One epoch's validation lfw_acc, threshold = lfw_test(model) writer.add_scalar('LFW Accuracy', lfw_acc, epoch) # Check if there was an improvement is_best = lfw_acc > best_acc best_acc = max(lfw_acc, best_acc) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best)
def train_net(args): torch.manual_seed(7) np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = float('-inf') writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) else: raise TypeError('network {} is not supported.'.format( args.network)) if args.pretrained: model.load_state_dict(torch.load('insight-face-v3.pt')) model = nn.DataParallel(model) metric_fc = ArcMarginModel(args) metric_fc = nn.DataParallel(metric_fc) if args.optimizer == 'sgd': optimizer = torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, nesterov=True, weight_decay=args.weight_decay) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma) else: criterion = nn.CrossEntropyLoss() # Custom dataloaders # train_dataset = ArcFaceDataset('train') # train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, # num_workers=num_workers) train_dataset = ArcFaceDatasetBatched('train', img_batch_size) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size // img_batch_size, shuffle=True, num_workers=num_workers, collate_fn=batched_collate_fn) scheduler = MultiStepLR(optimizer, milestones=[8, 16, 24, 32], gamma=0.1) # Epochs for epoch in range(start_epoch, args.end_epoch): lr = optimizer.param_groups[0]['lr'] logger.info('\nCurrent effective learning rate: {}\n'.format(lr)) # print('Step num: {}\n'.format(optimizer.step_num)) writer.add_scalar('model/learning_rate', lr, epoch) # One epoch's training train_loss, train_top1_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_accuracy', train_top1_accs, epoch) scheduler.step(epoch) if args.eval_ds == "LFW": from lfw_eval import lfw_test # One epochs's validata accuracy, threshold = lfw_test(model) elif args.eval_ds == "Megaface": from megaface_eval import megaface_test accuracy = megaface_test(model) else: accuracy = -1 writer.add_scalar('model/evaluation_accuracy', accuracy, epoch) # Check if there was an improvement is_best = accuracy > best_acc best_acc = max(accuracy, best_acc) if not is_best: epochs_since_improvement += 1 logger.info("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best, scheduler)
def train_net(args): torch.manual_seed(7) np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = 0 writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) elif args.network == 'mobile': model = MobileNet(1.0) else: model = resnet_face18(args.use_se) model = nn.DataParallel(model) metric_fc = ArcMarginModel(args) metric_fc = nn.DataParallel(metric_fc) if args.optimizer == 'sgd': # optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}], # lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) optimizer = InsightFaceOptimizer( torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay)) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] logger = get_logger() # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma).to(device) else: criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_dataset = ArcFaceDataset('train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) # Epochs for epoch in range(start_epoch, args.end_epoch): # One epoch's training train_loss, train_top1_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch, logger=logger) print('\nCurrent effective learning rate: {}\n'.format(optimizer.lr)) print('Step num: {}\n'.format(optimizer.step_num)) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_accuracy', train_top1_accs, epoch) writer.add_scalar('model/learning_rate', optimizer.lr, epoch) # One epoch's validation megaface_acc = megaface_test(model) writer.add_scalar('model/megaface_accuracy', megaface_acc, epoch) # Check if there was an improvement is_best = megaface_acc > best_acc best_acc = max(megaface_acc, best_acc) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best)
def train_net(args): torch.manual_seed(7) #torch的随机种子,在torch.randn使用 np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = 0 writer = SummaryWriter() #tensorboard epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) elif args.network == 'mobile': model = MobileNet(1.0) else: model = resnet_face18(args.use_se) model = nn.DataParallel(model) metric_fc = ArcMarginModel(args) metric_fc = nn.DataParallel(metric_fc) if args.optimizer == 'sgd': # optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}], # lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) optimizer = InsightFaceOptimizer( torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay)) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) #这里还需要自己加载进去 start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] logger = get_logger() # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma).to(device) else: criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_dataset = Dataset(root=args.train_path, phase='train', input_shape=(3, 112, 112)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) # Epochs for epoch in range(start_epoch, args.end_epoch): # One epoch's training # 这里写一个训练函数十分简练,值得学习 train_loss, train_top1_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch, logger=logger) print('\nCurrent effective learning rate: {}\n'.format(optimizer.lr)) print('Step num: {}\n'.format(optimizer.step_num)) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_accuracy', train_top1_accs, epoch) writer.add_scalar('model/learning_rate', optimizer.lr, epoch) # Save checkpoint if epoch % 10 == 0: save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc)
def train_net(args): torch.manual_seed(9527) np.random.seed(9527) checkpoint = args.checkpoint start_epoch = 0 best_acc = 0 writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: model = MobileNetMatchModel() metric_fc = ArcMarginModel(args) if args.optimizer == 'sgd': optimizer = torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay, nesterov=True) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) model = nn.DataParallel(model) metric_fc = nn.DataParallel(metric_fc) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] logger = get_logger() # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function criterion = nn.CrossEntropyLoss().to(device) cudnn.benchmark = True # Custom dataloaders train_loader = torch.utils.data.DataLoader(FrameDataset('train'), batch_size=args.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) # scheduler = MultiStepLR(optimizer, milestones=[5, 10, 15, 20], gamma=0.1) scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) # Epochs for epoch in range(start_epoch, args.end_epoch): # One epoch's training train_loss, train_top5_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch, logger=logger) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_accuracy', train_top5_accs, epoch) lr = optimizer.param_groups[0]['lr'] print('\nLearning rate: {}'.format(lr)) writer.add_scalar('model/learning_rate', lr, epoch) # One epoch's validation val_acc, thres = test(model) writer.add_scalar('model/valid_accuracy', val_acc, epoch) writer.add_scalar('model/valid_threshold', thres, epoch) scheduler.step(epoch) # Check if there was an improvement is_best = val_acc > best_acc best_acc = max(val_acc, best_acc) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best)
class Train: def __init__(self, model, trainloader, valloader, args): self.model = model self.model_dict = self.model.state_dict() self.trainloader = trainloader self.valloader = valloader self.args = args self.start_epoch = 0 self.best_top1 = 0.0 # Loss function and Optimizer self.loss = None self.optimizer = None self.create_optimization() # Model Loading self.load_pretrained_model() self.load_checkpoint(self.args.resume_from) # Tensorboard Writer self.summary_writer = SummaryWriter() def train(self): for cur_epoch in range(self.start_epoch, self.args.num_epochs): # Initialize tqdm tqdm_batch = tqdm(self.trainloader, desc="Epoch-" + str(cur_epoch) + "-") # Learning rate adjustment self.adjust_learning_rate(self.optimizer, cur_epoch) # Meters for tracking the average values loss, top1, top5 = AverageTracker(), AverageTracker( ), AverageTracker() # Set the model to be in training mode (for dropout and batchnorm) self.model.train() for data, target in tqdm_batch: if self.args.cuda: data, target = data.cuda( async=self.args.async_loading), target.cuda( async=self.args.async_loading) data_var, target_var = Variable(data), Variable(target) # Forward pass output = self.model(data_var) cur_loss = self.loss(output, target_var) # Optimization step self.optimizer.zero_grad() cur_loss.backward() self.optimizer.step() # Top-1 and Top-5 Accuracy Calculation cur_acc1, cur_acc5 = self.compute_accuracy(output.data, target, topk=(1, 5)) loss.update(cur_loss.item()) top1.update(cur_acc1.item()) top5.update(cur_acc5.item()) # Summary Writing self.summary_writer.add_scalar("epoch-loss", loss.avg, cur_epoch) self.summary_writer.add_scalar("epoch-top-1-acc", top1.avg, cur_epoch) self.summary_writer.add_scalar("epoch-top-5-acc", top5.avg, cur_epoch) # Print in console tqdm_batch.close() print("Epoch-" + str(cur_epoch) + " | " + "loss: " + str(loss.avg) + " - acc-top1: " + str(top1.avg)[:7] + "- acc-top5: " + str(top5.avg)[:7]) # Evaluate on Validation Set if cur_epoch % self.args.test_every == 0 and self.valloader: self.test(self.valloader, cur_epoch) # Checkpointing is_best = top1.avg > self.best_top1 self.best_top1 = max(top1.avg, self.best_top1) self.save_checkpoint( { 'epoch': cur_epoch + 1, 'state_dict': self.model.state_dict(), 'best_top1': self.best_top1, 'optimizer': self.optimizer.state_dict(), }, is_best) def test(self, testloader, cur_epoch=-1): loss, top1, top5 = AverageTracker(), AverageTracker(), AverageTracker() # Set the model to be in testing mode (for dropout and batchnorm) self.model.eval() for data, target in testloader: if self.args.cuda: data, target = data.cuda( async=self.args.async_loading), target.cuda( async=self.args.async_loading) data_var, target_var = Variable(data), Variable(target) # Forward pass with torch.no_grad(): output = self.model(data_var) cur_loss = self.loss(output, target_var) # Top-1 and Top-5 Accuracy Calculation cur_acc1, cur_acc5 = self.compute_accuracy(output.data, target, topk=(1, 5)) loss.update(cur_loss.item()) top1.update(cur_acc1.item()) top5.update(cur_acc5.item()) if cur_epoch != -1: # Summary Writing self.summary_writer.add_scalar("test-loss", loss.avg, cur_epoch) self.summary_writer.add_scalar("test-top-1-acc", top1.avg, cur_epoch) self.summary_writer.add_scalar("test-top-5-acc", top5.avg, cur_epoch) print("Test Results" + " | " + "loss: " + str(loss.avg) + " - acc-top1: " + str(top1.avg)[:7] + "- acc-top5: " + str(top5.avg)[:7]) def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): torch.save(state, self.args.checkpoint_dir + filename) if is_best: shutil.copyfile(self.args.checkpoint_dir + filename, self.args.checkpoint_dir + 'model_best.pth.tar') def compute_accuracy(self, output, target, topk=(1, )): """Computes the accuracy@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, idx = output.topk(maxk, 1, True, True) idx = idx.t() correct = idx.eq(target.view(1, -1).expand_as(idx)) acc_arr = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) acc_arr.append(correct_k.mul_(1.0 / batch_size)) return acc_arr def adjust_learning_rate(self, optimizer, epoch): """Sets the learning rate to the initial LR multiplied by 0.98 every epoch""" learning_rate = self.args.learning_rate * ( self.args.learning_rate_decay**epoch) for param_group in optimizer.param_groups: param_group['lr'] = learning_rate def create_optimization(self): if self.args.loss_function == 'FocalLoss': self.loss = FocalLoss(gamma=self.args.gamma) else: self.loss = nn.CrossEntropyLoss() if self.args.cuda: self.loss.cuda() if self.args.classify: self.metric_fc = ArcMarginModel(self.args) self.optimizer = RMSprop(self.model.parameters(), self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay) else: self.optimizer = RMSprop([{ 'params': self.model.parameters() }, { 'params': self.metric_fc.parameters() }], self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay) def load_pretrained_model(self): try: print("Loading ImageNet pretrained weights...") pretrained_dict = torch.load(self.args.pretrained_path) #self.model.load_state_dict(pretrained_dict) for params_name in pretrained_dict: if params_name in self.model_dict and pretrained_dict[ params_name].size( ) == self.model_dict[params_name].size(): self.model.state_dict()[params_name].copy_( pretrained_dict[params_name]) print("ImageNet pretrained weights loaded successfully.\n") except: print("No ImageNet pretrained weights exist. Skipping...\n") def load_checkpoint(self, filename): filename = self.args.checkpoint_dir + filename try: print("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.start_epoch = checkpoint['epoch'] self.best_top1 = checkpoint['best_top1'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) print("Checkpoint loaded successfully from '{}' at (epoch {})\n". format(self.args.checkpoint_dir, checkpoint['epoch'])) except: print("No checkpoint exists from '{}'. Skipping...\n".format( self.args.checkpoint_dir))