def create_optimization(self): if self.args.loss_function == 'FocalLoss': self.loss = FocalLoss(gamma=self.args.gamma) else: self.loss = nn.CrossEntropyLoss() if self.args.cuda: self.loss.cuda() if self.args.classify: self.metric_fc = ArcMarginModel(self.args) self.optimizer = RMSprop(self.model.parameters(), self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay) else: self.optimizer = RMSprop([{ 'params': self.model.parameters() }, { 'params': self.metric_fc.parameters() }], self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
def train_net(args): torch.manual_seed(7) np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = float('-inf') writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) elif args.network == 'mobile': from mobilenet_v2 import MobileNetV2 model = MobileNetV2() else: raise TypeError('network {} is not supported.'.format( args.network)) metric_fc = ArcMarginModel(args) if args.optimizer == 'sgd': optimizer = torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] model = nn.DataParallel(model) metric_fc = nn.DataParallel(metric_fc) # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma).to(device) else: criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_dataset = ArcFaceDataset('train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers) # Epochs for epoch in range(start_epoch, args.end_epoch): # Decay learning rate if there is no improvement for 2 consecutive epochs, and terminate training after 10 if epochs_since_improvement == 10: break if epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0: checkpoint = 'BEST_checkpoint.tar' checkpoint = torch.load(checkpoint) model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] adjust_learning_rate(optimizer, 0.5) # One epoch's training train_loss, train_top1_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch) lr = optimizer.param_groups[0]['lr'] print('\nCurrent effective learning rate: {}\n'.format(lr)) # print('Step num: {}\n'.format(optimizer.step_num)) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_accuracy', train_top1_accs, epoch) writer.add_scalar('model/learning_rate', lr, epoch) if epoch % 5 == 0: # One epoch's validation megaface_acc = megaface_test(model) writer.add_scalar('model/megaface_accuracy', megaface_acc, epoch) # Check if there was an improvement is_best = megaface_acc > best_acc best_acc = max(megaface_acc, best_acc) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best)
def test(): filename = 'metric_fc.pt' metric_fc = ArcMarginModel() metric_fc.load_state_dict(torch.load(filename)) metric_fc = metric_fc.to('cpu') metric_fc.eval() train_batch_size = 30 eval_batch_size = 30 scripted_float_model_file = 'mobilenet_quantization_scripted.pth' scripted_quantized_model_file = 'mobilenet_quantization_scripted_quantized.pth' data_loader, data_loader_test = prepare_data_loaders() criterion = nn.CrossEntropyLoss() filename = 'image_matching_mobile.pt' print('loading {}...'.format(filename)) start = time.time() float_model = load_model(filename) print('elapsed {} sec'.format(time.time() - start)) print('\n Inverted Residual Block: Before fusion \n\n', float_model.features[1].conv) float_model.eval() # Fuses modules float_model.fuse_model() # Note fusion of Conv+BN+Relu and Conv+Relu print('\n Inverted Residual Block: After fusion\n\n', float_model.features[1].conv) num_eval_batches = 10 print("Size of baseline model") print_size_of_model(float_model) top1 = evaluate(float_model, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches) print('Evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg)) torch.jit.save(torch.jit.script(float_model), scripted_float_model_file) # 4. Post-training static quantization print(bcolors.HEADER + '\nPost-training static quantization' + bcolors.ENDC) num_calibration_batches = 10 myModel = load_model(filename).to('cpu') myModel.eval() # Fuse Conv, bn and relu myModel.fuse_model() # Specify quantization configuration # Start with simple min/max range estimation and per-tensor quantization of weights myModel.qconfig = torch.quantization.default_qconfig print(myModel.qconfig) torch.quantization.prepare(myModel, inplace=True) # Calibrate first print('Post Training Quantization Prepare: Inserting Observers') print('\n Inverted Residual Block:After observer insertion \n\n', myModel.features[1].conv) # Calibrate with the training set print('Calibrate with the training set') evaluate(myModel, metric_fc, criterion, data_loader, neval_batches=num_calibration_batches) print('Post Training Quantization: Calibration done') # Convert to quantized model torch.quantization.convert(myModel, inplace=True) print('Post Training Quantization: Convert done') print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n', myModel.features[1].conv) print("Size of model after quantization") print_size_of_model(myModel) top1 = evaluate(myModel, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches) print('Evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg)) per_channel_quantized_model = load_model(filename) per_channel_quantized_model.eval() per_channel_quantized_model.fuse_model() per_channel_quantized_model.qconfig = torch.quantization.get_default_qconfig('fbgemm') print(per_channel_quantized_model.qconfig) torch.quantization.prepare(per_channel_quantized_model, inplace=True) print('Calibrate with the training set') evaluate(per_channel_quantized_model, metric_fc, criterion, data_loader, num_calibration_batches) torch.quantization.convert(per_channel_quantized_model, inplace=True) top1 = evaluate(per_channel_quantized_model, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches) print('Evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg)) torch.jit.save(torch.jit.script(per_channel_quantized_model), scripted_quantized_model_file) # Speedup from quantization print(bcolors.HEADER + '\nSpeedup from quantization' + bcolors.ENDC) run_benchmark(scripted_float_model_file, data_loader_test) run_benchmark(scripted_quantized_model_file, data_loader_test) # 5. Quantization-aware training print(bcolors.HEADER + '\nQuantization-aware training' + bcolors.ENDC) qat_model = load_model(filename) qat_model.fuse_model() optimizer = torch.optim.SGD(qat_model.parameters(), lr=0.0001) qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') torch.quantization.prepare_qat(qat_model, inplace=True) print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n', qat_model.features[1].conv) num_train_batches = 20 # Train and check accuracy after each epoch for nepoch in range(8): train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'), num_train_batches) if nepoch > 3: # Freeze quantizer parameters qat_model.apply(torch.quantization.disable_observer) if nepoch > 2: # Freeze batch norm mean and variance estimates qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) # Check the accuracy after each epoch quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False) quantized_model.eval() top1 = evaluate(quantized_model, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches) print('Epoch %d :Evaluation accuracy on %d images, %2.2f' % ( nepoch, num_eval_batches * eval_batch_size, top1.avg))
def train_net(args): torch.manual_seed(7) np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = float('-inf') writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) elif args.network == 'mobile': model = MobileNetV2() else: raise TypeError('network {} is not supported.'.format( args.network)) # print(model) model = nn.DataParallel(model) metric_fc = ArcMarginModel(args) metric_fc = nn.DataParallel(metric_fc) if args.optimizer == 'sgd': optimizer = torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] logger = get_logger() # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma).to(device) else: criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_dataset = ArcFaceDataset('train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1) # Epochs for epoch in range(start_epoch, args.end_epoch): # One epoch's training train_loss, train_acc = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch, logger=logger) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_acc', train_acc, epoch) # One epoch's validation lfw_acc, threshold = lfw_test(model) writer.add_scalar('model/valid_acc', lfw_acc, epoch) writer.add_scalar('model/valid_thres', threshold, epoch) # Check if there was an improvement is_best = lfw_acc > best_acc best_acc = max(lfw_acc, best_acc) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best) scheduler.step(epoch)
def train_net(args): torch.manual_seed(7) np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = 0 writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) elif args.network == 'mobile': model = MobileNet(1.0) elif args.network == 'mr18': print("mr18") model = myResnet18() else: model = resnet_face18(args.use_se) model = nn.DataParallel(model) metric_fc = ArcMarginModel(args) metric_fc = nn.DataParallel(metric_fc) if args.optimizer == 'sgd': optimizer = torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] logger = get_logger() # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma).to(device) else: criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_dataset = ArcFaceDataset('train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1) # Epochs for epoch in range(start_epoch, args.end_epoch): scheduler.step() if args.full_log: lfw_acc, threshold = lfw_test(model) writer.add_scalar('LFW_Accuracy', lfw_acc, epoch) full_log(epoch) start = datetime.now() # One epoch's training train_loss, train_top5_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch, logger=logger, writer=writer) writer.add_scalar('Train_Loss', train_loss, epoch) writer.add_scalar('Train_Top5_Accuracy', train_top5_accs, epoch) end = datetime.now() delta = end - start print('{} seconds'.format(delta.seconds)) # One epoch's validation lfw_acc, threshold = lfw_test(model) writer.add_scalar('LFW Accuracy', lfw_acc, epoch) # Check if there was an improvement is_best = lfw_acc > best_acc best_acc = max(lfw_acc, best_acc) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best)
def train_net(args): torch.manual_seed(7) np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = float('-inf') writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) else: raise TypeError('network {} is not supported.'.format( args.network)) if args.pretrained: model.load_state_dict(torch.load('insight-face-v3.pt')) model = nn.DataParallel(model) metric_fc = ArcMarginModel(args) metric_fc = nn.DataParallel(metric_fc) if args.optimizer == 'sgd': optimizer = torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, nesterov=True, weight_decay=args.weight_decay) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma) else: criterion = nn.CrossEntropyLoss() # Custom dataloaders # train_dataset = ArcFaceDataset('train') # train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, # num_workers=num_workers) train_dataset = ArcFaceDatasetBatched('train', img_batch_size) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size // img_batch_size, shuffle=True, num_workers=num_workers, collate_fn=batched_collate_fn) scheduler = MultiStepLR(optimizer, milestones=[8, 16, 24, 32], gamma=0.1) # Epochs for epoch in range(start_epoch, args.end_epoch): lr = optimizer.param_groups[0]['lr'] logger.info('\nCurrent effective learning rate: {}\n'.format(lr)) # print('Step num: {}\n'.format(optimizer.step_num)) writer.add_scalar('model/learning_rate', lr, epoch) # One epoch's training train_loss, train_top1_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_accuracy', train_top1_accs, epoch) scheduler.step(epoch) if args.eval_ds == "LFW": from lfw_eval import lfw_test # One epochs's validata accuracy, threshold = lfw_test(model) elif args.eval_ds == "Megaface": from megaface_eval import megaface_test accuracy = megaface_test(model) else: accuracy = -1 writer.add_scalar('model/evaluation_accuracy', accuracy, epoch) # Check if there was an improvement is_best = accuracy > best_acc best_acc = max(accuracy, best_acc) if not is_best: epochs_since_improvement += 1 logger.info("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best, scheduler)
def train_net(args): torch.manual_seed(7) np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = 0 writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) elif args.network == 'mobile': model = MobileNet(1.0) else: model = resnet_face18(args.use_se) model = nn.DataParallel(model) metric_fc = ArcMarginModel(args) metric_fc = nn.DataParallel(metric_fc) if args.optimizer == 'sgd': # optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}], # lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) optimizer = InsightFaceOptimizer( torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay)) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] logger = get_logger() # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma).to(device) else: criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_dataset = ArcFaceDataset('train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) # Epochs for epoch in range(start_epoch, args.end_epoch): # One epoch's training train_loss, train_top1_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch, logger=logger) print('\nCurrent effective learning rate: {}\n'.format(optimizer.lr)) print('Step num: {}\n'.format(optimizer.step_num)) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_accuracy', train_top1_accs, epoch) writer.add_scalar('model/learning_rate', optimizer.lr, epoch) # One epoch's validation megaface_acc = megaface_test(model) writer.add_scalar('model/megaface_accuracy', megaface_acc, epoch) # Check if there was an improvement is_best = megaface_acc > best_acc best_acc = max(megaface_acc, best_acc) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best)
def train_net(args): torch.manual_seed(7) #torch的随机种子,在torch.randn使用 np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_acc = 0 writer = SummaryWriter() #tensorboard epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: if args.network == 'r18': model = resnet18(args) elif args.network == 'r34': model = resnet34(args) elif args.network == 'r50': model = resnet50(args) elif args.network == 'r101': model = resnet101(args) elif args.network == 'r152': model = resnet152(args) elif args.network == 'mobile': model = MobileNet(1.0) else: model = resnet_face18(args.use_se) model = nn.DataParallel(model) metric_fc = ArcMarginModel(args) metric_fc = nn.DataParallel(metric_fc) if args.optimizer == 'sgd': # optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}], # lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) optimizer = InsightFaceOptimizer( torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay)) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) #这里还需要自己加载进去 start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] logger = get_logger() # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function if args.focal_loss: criterion = FocalLoss(gamma=args.gamma).to(device) else: criterion = nn.CrossEntropyLoss().to(device) # Custom dataloaders train_dataset = Dataset(root=args.train_path, phase='train', input_shape=(3, 112, 112)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) # Epochs for epoch in range(start_epoch, args.end_epoch): # One epoch's training # 这里写一个训练函数十分简练,值得学习 train_loss, train_top1_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch, logger=logger) print('\nCurrent effective learning rate: {}\n'.format(optimizer.lr)) print('Step num: {}\n'.format(optimizer.step_num)) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_accuracy', train_top1_accs, epoch) writer.add_scalar('model/learning_rate', optimizer.lr, epoch) # Save checkpoint if epoch % 10 == 0: save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc)
def train_net(args): torch.manual_seed(9527) np.random.seed(9527) checkpoint = args.checkpoint start_epoch = 0 best_acc = 0 writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: model = MobileNetMatchModel() metric_fc = ArcMarginModel(args) if args.optimizer == 'sgd': optimizer = torch.optim.SGD([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay, nesterov=True) else: optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': metric_fc.parameters() }], lr=args.lr, weight_decay=args.weight_decay) model = nn.DataParallel(model) metric_fc = nn.DataParallel(metric_fc) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] metric_fc = checkpoint['metric_fc'] optimizer = checkpoint['optimizer'] logger = get_logger() # Move to GPU, if available model = model.to(device) metric_fc = metric_fc.to(device) # Loss function criterion = nn.CrossEntropyLoss().to(device) cudnn.benchmark = True # Custom dataloaders train_loader = torch.utils.data.DataLoader(FrameDataset('train'), batch_size=args.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) # scheduler = MultiStepLR(optimizer, milestones=[5, 10, 15, 20], gamma=0.1) scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) # Epochs for epoch in range(start_epoch, args.end_epoch): # One epoch's training train_loss, train_top5_accs = train(train_loader=train_loader, model=model, metric_fc=metric_fc, criterion=criterion, optimizer=optimizer, epoch=epoch, logger=logger) writer.add_scalar('model/train_loss', train_loss, epoch) writer.add_scalar('model/train_accuracy', train_top5_accs, epoch) lr = optimizer.param_groups[0]['lr'] print('\nLearning rate: {}'.format(lr)) writer.add_scalar('model/learning_rate', lr, epoch) # One epoch's validation val_acc, thres = test(model) writer.add_scalar('model/valid_accuracy', val_acc, epoch) writer.add_scalar('model/valid_threshold', thres, epoch) scheduler.step(epoch) # Check if there was an improvement is_best = val_acc > best_acc best_acc = max(val_acc, best_acc) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, best_acc, is_best)