예제 #1
0
def main():
    args = get_args()
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    log_path = os.path.join(args.save_path, 'log')
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    writer = SummaryWriter(log_dir=log_path)

    data_path = args.data_path
    train_path = os.path.join(data_path, 'train/label.txt')
    val_path = os.path.join(data_path, 'val/label.txt')
    # dataset_train = TrainDataset(train_path,transform=transforms.Compose([RandomCroper(),RandomFlip()]))
    dataset_train = TrainDataset(train_path,
                                 transform=transforms.Compose(
                                     [Resizer(), PadToSquare()]))
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=8,
                                  batch_size=args.batch,
                                  collate_fn=collater,
                                  shuffle=True)
    # dataset_val = ValDataset(val_path,transform=transforms.Compose([RandomCroper()]))
    dataset_val = ValDataset(val_path,
                             transform=transforms.Compose(
                                 [Resizer(), PadToSquare()]))
    dataloader_val = DataLoader(dataset_val,
                                num_workers=8,
                                batch_size=args.batch,
                                collate_fn=collater)

    total_batch = len(dataloader_train)

    # Create the model
    # if args.depth == 18:
    #     retinaface = model.resnet18(num_classes=2, pretrained=True)
    # elif args.depth == 34:
    #     retinaface = model.resnet34(num_classes=2, pretrained=True)
    # elif args.depth == 50:
    #     retinaface = model.resnet50(num_classes=2, pretrained=True)
    # elif args.depth == 101:
    #     retinaface = model.resnet101(num_classes=2, pretrained=True)
    # elif args.depth == 152:
    #     retinaface = model.resnet152(num_classes=2, pretrained=True)
    # else:
    #     raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    # Create torchvision model
    return_layers = {'layer2': 1, 'layer3': 2, 'layer4': 3}
    retinaface = torchvision_model.create_retinaface(return_layers)

    retinaface = retinaface.cuda()
    retinaface = torch.nn.DataParallel(retinaface).cuda()
    retinaface.training = True

    optimizer = optim.Adam(retinaface.parameters(), lr=1e-3)
    # optimizer = optim.SGD(retinaface.parameters(), lr=1e-2, momentum=0.9, weight_decay=0.0005)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
    # scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    #scheduler  = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,30,60], gamma=0.1)

    print('Start to train.')

    epoch_loss = []
    iteration = 0

    for epoch in range(args.epochs):
        retinaface.train()
        #print('Current learning rate:',scheduler.get_lr()[0])
        # retinaface.module.freeze_bn()
        # retinaface.module.freeze_first_layer()

        # Training
        for iter_num, data in enumerate(dataloader_train):
            optimizer.zero_grad()
            classification_loss, bbox_regression_loss, ldm_regression_loss = retinaface(
                [data['img'].cuda().float(), data['annot']])
            classification_loss = classification_loss.mean()
            bbox_regression_loss = bbox_regression_loss.mean()
            ldm_regression_loss = ldm_regression_loss.mean()

            # loss = classification_loss + 1.0 * bbox_regression_loss + 0.5 * ldm_regression_loss
            loss = classification_loss + bbox_regression_loss + ldm_regression_loss

            loss.backward()
            optimizer.step()
            #epoch_loss.append(loss.item())

            if iter_num % args.verbose == 0:
                log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (
                    epoch, args.epochs, iter_num, total_batch)
                table_data = [['loss name', 'value'],
                              ['total_loss', str(loss.item())],
                              [
                                  'classification',
                                  str(classification_loss.item())
                              ], ['bbox',
                                  str(bbox_regression_loss.item())],
                              ['landmarks',
                               str(ldm_regression_loss.item())]]
                table = AsciiTable(table_data)
                #table = SingleTable(table_data)
                #table = DoubleTable(table_data)
                log_str += table.table
                print(log_str)
                # write the log to tensorboard
                writer.add_scalars(
                    'losses:', {
                        'total_loss': loss.item(),
                        'cls_loss': classification_loss.item(),
                        'bbox_loss': bbox_regression_loss.item(),
                        'ldm_loss': ldm_regression_loss.item()
                    }, iteration * args.verbose)
                iteration += 1

        #scheduler.step()
        #scheduler.step(np.mean(epoch_loss))

        # Eval
        if epoch % args.eval_step == 0:
            print('-------- RetinaFace Pytorch --------')
            print('Evaluating epoch {}'.format(epoch))
            recall, precision = eval_widerface.evaluate(
                dataloader_val, retinaface)
            print('Recall:', recall)
            print('Precision:', precision)

        # Save model
        if (epoch + 1) % args.save_step == 0:
            torch.save(retinaface.state_dict(),
                       args.save_path + '/model_epoch_{}.pt'.format(epoch + 1))
예제 #2
0
def main():
    precision_global = 0
    args = get_args()
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    log_path = os.path.join(args.save_path, 'log')
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    writer = SummaryWriter(log_dir=log_path)

    data_path = args.data_path
    train_path = os.path.join(
        data_path,
        'retina-train-splitTrain.txt')  #"train\\label.txt")#'train.txt')
    val_path = os.path.join(
        data_path, "retina-train-splitTest.txt"
    )  #"retina-train-splitTest.txt") #'retina-val.txt')##'val.txt')
    # train_path = os.path.join(data_path,'train\\label.txt')#"train\\label.txt")#'train.txt')
    # val_path = os.path.join(data_path,'val\\label.txt')#"val\\label.txt")#'val.txt')
    # dataset_train = TrainDataset(train_path,transform=transforms.Compose([RandomCroper(),RandomFlip()]))
    dataset_train = TrainDataset(train_path,
                                 transform=transforms.Compose(
                                     [Resizer(), PadToSquare()]))
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=6,
                                  batch_size=args.batch,
                                  collate_fn=collater,
                                  shuffle=True)
    # dataset_val = ValDataset(val_path,transform=transforms.Compose([RandomCroper()]))
    dataset_val = ValDataset(val_path,
                             transform=transforms.Compose(
                                 [Resizer(), PadToSquare()]))
    dataloader_val = DataLoader(dataset_val,
                                num_workers=8,
                                batch_size=args.batch,
                                collate_fn=collater)

    total_batch = len(dataloader_train)

    # Create the model
    # if args.depth == 18:
    #     retinaface = model.resnet18(num_classes=2, pretrained=True)
    # elif args.depth == 34:
    #     retinaface = model.resnet34(num_classes=2, pretrained=True)
    # elif args.depth == 50:
    #     retinaface = model.resnet50(num_classes=2, pretrained=True)
    # elif args.depth == 101:
    #     retinaface = model.resnet101(num_classes=2, pretrained=True)
    # elif args.depth == 152:
    #     retinaface = model.resnet152(num_classes=2, pretrained=True)
    # else:
    #     raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    # Create torchvision model
    return_layers = {'layer2': 1, 'layer3': 2, 'layer4': 3}
    retinaface = torchvision_model.create_retinaface(return_layers)

    # Load trained model
    if (args.model_path is not None):
        retina_dict = retinaface.state_dict()
        pre_state_dict = torch.load(args.model_path)
        pretrained_dict = {
            k[7:]: v
            for k, v in pre_state_dict.items() if k[7:] in retina_dict
        }
        retinaface.load_state_dict(pretrained_dict)

    retinaface = retinaface.cuda()
    retinaface = torch.nn.DataParallel(retinaface).cuda()
    retinaface.training = True

    optimizer = optim.Adam(retinaface.parameters(), lr=1e-3)
    # optimizer = optim.SGD(retinaface.parameters(), lr=1e-2, momentum=0.9, weight_decay=0.0005)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
    # scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    #scheduler  = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,30,60], gamma=0.1)

    #performance detect
    # print('-------- RetinaFace Pytorch --------')
    # recall, precision = eval_widerface.evaluate(dataloader_val, retinaface)
    # print('Recall:', recall)
    # print('Precision:', precision, "best Precision:", precision_global)

    print('Start to train.')

    epoch_loss = []
    iteration = 0

    for epoch in range(args.epochs):
        retinaface.train()

        # Training
        for iter_num, data in enumerate(dataloader_train):
            #ff = data["img"].numpy()
            #print(ff[0][1][320][320])
            optimizer.zero_grad()
            classification_loss, bbox_regression_loss, ldm_regression_loss = retinaface(
                [data['img'].cuda().float(), data['annot']])
            classification_loss = classification_loss.mean()
            bbox_regression_loss = bbox_regression_loss.mean()
            ldm_regression_loss = ldm_regression_loss.mean()

            # loss = classification_loss + 1.0 * bbox_regression_loss + 0.5 * ldm_regression_loss
            loss = classification_loss + bbox_regression_loss + ldm_regression_loss

            loss.backward()
            optimizer.step()

            if iter_num % args.verbose == 0:
                log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (
                    epoch, args.epochs, iter_num, total_batch)
                table_data = [['loss name', 'value'],
                              ['total_loss', str(loss.item())],
                              [
                                  'classification',
                                  str(classification_loss.item())
                              ], ['bbox',
                                  str(bbox_regression_loss.item())],
                              ['landmarks',
                               str(ldm_regression_loss.item())]]
                table = AsciiTable(table_data)
                log_str += table.table
                print(log_str)
                # write the log to tensorboard
                writer.add_scalar('losses:', loss.item(),
                                  iteration * args.verbose)
                writer.add_scalar('class losses:', classification_loss.item(),
                                  iteration * args.verbose)
                writer.add_scalar('box losses:', bbox_regression_loss.item(),
                                  iteration * args.verbose)
                writer.add_scalar('landmark losses:',
                                  ldm_regression_loss.item(),
                                  iteration * args.verbose)
                iteration += 1

        # Eval
        if epoch % args.eval_step == 0:
            print('-------- RetinaFace Pytorch --------')
            print('Evaluating epoch {}'.format(epoch))
            recall, precision = eval_widerface.evaluate(
                dataloader_val, retinaface)
            if (precision_global < precision):
                precision_global = precision
                torch.save(
                    retinaface.state_dict(), args.save_path +
                    '/model_Best_epoch_{}.pt'.format(epoch + 1))
            print('Recall:', recall)
            print('Precision:', precision, "best Precision:", precision_global)

            writer.add_scalar('Recall:', recall, epoch * args.eval_step)
            writer.add_scalar('Precision:', precision, epoch * args.eval_step)

        # Save model
        if (epoch + 1) % args.save_step == 0:
            torch.save(retinaface.state_dict(),
                       args.save_path + '/model_epoch_{}.pt'.format(epoch + 1))

    writer.close()
예제 #3
0
def main():
    args = get_args()
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    log_path = os.path.join(args.save_path,'log')
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    # # writer = SummaryWriter(log_dir=log_path)

    dataset_train = TrainDataset(transform=transforms.Compose([Rotate(),Resizer(),Color()]))
    len_train_set = int(len(dataset_train) * 0.7)
    len_val_set   = len(dataset_train) - len_train_set

    train_set, val_set = random_split(dataset_train, [len_train_set, len_val_set])
    dataloader_train = DataLoader(train_set, num_workers=8, batch_size=args.batch, collate_fn=collater,shuffle=True)
    dataloader_val = DataLoader(val_set, num_workers=8, batch_size=args.batch, collate_fn=collater)
    
    total_batch = len(dataloader_train)

    # Create torchvision model
    return_layers = {'layer2':1,'layer3':2,'layer4':3}
    retinaface = torchvision_model.create_retinaface(return_layers)
    retinaface = retinaface.cuda()

    base_lr=1e-4
    lr = base_lr
    optimizer = optim.Adam(retinaface.parameters(), lr=lr)

    retinaface = torch.nn.DataParallel(retinaface).cuda()
    retinaface.training = True
    # retinaface.load_state_dict(torch.load("./pretrained.torch"))
    retinaface.load_state_dict(torch.load("./out/mnas_epoch__ori111124.pt"))
    
    
    lr_cos = lambda n: 0.5 * (1 + np.cos((n) / (args.epochs) * np.pi)) * base_lr
    # optimizer = optim.SGD(retinaface.parameters(), lr=1e-2, momentum=0.9, weight_decay=0.0005)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
    # scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    #scheduler  = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,30,60], gamma=0.1)
    print('Start to train.')

    # ####
    # print("pretrained")
    # recall, precision, landmakr,miss= eval_widerface.evaluate(dataloader_val,retinaface)
    # print('Recall:',recall)
    # print('Precision:',precision)
    # print("landmark: ",str(landmakr))
    # print("miss: "+ str(miss))
    # sdfsdfsdf

    # ###



    epoch_loss = []
    iteration = 0

    for epoch in range(args.epochs):
        lr=lr_cos(epoch)
        print("Current lr is {}".format(lr))
        retinaface.train()
        #print('Current learning rate:',scheduler.get_lr()[0])
        # retinaface.module.freeze_bn()
        # retinaface.module.freeze_first_layer()

        # Training
        for iter_num,data in enumerate(dataloader_train):
            optimizer.zero_grad()
            classification_loss, bbox_regression_loss,ldm_regression_loss = retinaface([data['img'].cuda().float(), data['annot']])
            classification_loss = classification_loss.mean()
            bbox_regression_loss = bbox_regression_loss.mean()
            ldm_regression_loss = ldm_regression_loss.mean()

            loss = classification_loss+0.1*ldm_regression_loss
            # loss = classification_loss + bbox_regression_loss + ldm_regression_loss

            loss.backward()
            optimizer.step()
            #epoch_loss.append(loss.item())
            
            if iter_num % args.verbose == 0:
                log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (epoch, args.epochs, iter_num, total_batch)
                table_data = [
                    ['loss name','value'],
                    ['total_loss',str(loss.item())],
                    ['classification',str(classification_loss.item())],
                    ['bbox',str(bbox_regression_loss.item())],
                    ['landmarks',str(ldm_regression_loss.item())]
                    ]
                table = AsciiTable(table_data)
                #table = SingleTable(table_data)
                #table = DoubleTable(table_data)
                log_str +=table.table
                print(log_str)
                # write the log to tensorboard
                # writer.add_scalar('losses:',loss.item(),iteration*args.verbose)
                # writer.add_scalar('class losses:',classification_loss.item(),iteration*args.verbose)
                # writer.add_scalar('box losses:',bbox_regression_loss.item(),iteration*args.verbose)
                # writer.add_scalar('landmark losses:',ldm_regression_loss.item(),iteration*args.verbose)
                iteration +=1
        
        #scheduler.step()
        #scheduler.step(np.mean(epoch_loss))	

        # Eval
        if epoch % args.eval_step == 0:
            print('-------- RetinaFace Pytorch --------')
            print ('Evaluating epoch {}'.format(epoch))
            recall, precision, landmakr,miss= eval_widerface.evaluate(dataloader_val,retinaface)
            print('Recall:',recall)
            print('Precision:',precision)
            print("landmark: ",str(landmakr))
            print("miss: "+ str(miss))
            with open("bbb.txt", 'a') as f:
                f.write('-------- RetinaFace Pytorch --------(pretrain)'+'\n')
                f.write ('Evaluating epoch {}'.format(epoch)+'\n')
                f.write('Recall:'+str(recall)+'\n')
                f.write('Precision:'+str(precision)+'\n')
                f.write("landmark: "+str(landmakr)+'\n')
                f.write("miss: "+ str(miss)+'\n')
                f.close()
            # writer.add_scalar('Recall:', recall, epoch*args.eval_step)
            # writer.add_scalar('Precision:', precision, epoch*args.eval_step)

        # Save model
        if (epoch + 1) % args.save_step == 0:
            torch.save(retinaface.state_dict(), args.save_path + '/pretrain{}.pt'.format(epoch + 1+5+1112222211100))
def main():
    args = get_args()
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    log_path = os.path.join(args.save_path, 'log')
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    data_path = args.data_path
    # dataset_train = TrainDataset(train_path,transform=transforms.Compose([RandomCroper(),()]))
    dataset_train = TrainDataset('./widerface/train/label.txt',
                                 transform=transforms.Compose([
                                     RandomErasing(),
                                     RandomFlip(),
                                     Rotate(),
                                     Color(),
                                     Resizer(),
                                     PadToSquare()
                                 ]))
    # dataset_train = TrainDataset('./widerface/train/label.txt',transform=transforms.Compose([Resizer(),PadToSquare()]))
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=8,
                                  batch_size=args.batch,
                                  collate_fn=collater,
                                  shuffle=True)
    # dataset_val = ValDataset(val_path,transform=transforms.Compose([RandomCroper()]))
    dataset_val = TrainDataset('./widerface/train/label.txt',
                               transform=transforms.Compose(
                                   [Resizer(640), PadToSquare()]))
    dataloader_val = DataLoader(dataset_val,
                                num_workers=8,
                                batch_size=args.batch,
                                collate_fn=collater)

    total_batch = len(dataloader_train)

    # Create torchvision model
    return_layers = {'layer2': 1, 'layer3': 2, 'layer4': 3}
    retinaface = torchvision_model.create_retinaface(return_layers)
    retinaface_ = retinaface.cuda()
    retinaface = torch.nn.DataParallel(retinaface_).cuda()
    retinaface.training = True
    base_lr = 1e-7

    # pre_train = torch.load('network.torch')
    # cur=retinaface.state_dict()
    # for k, v in cur.items():
    #     if k[12:] in pre_train:
    #         print(k[12:])
    #         cur[k]=pre_train[k[12:]]
    # retinaface.load_state_dict(cur)
    retinaface.load_state_dict(
        torch.load(
            "/versa/elvishelvis/RetinaYang/out/stage_5_68_full_model_epoch_121.pt"
        ))
    lr = base_lr
    # optimizer=torch.optim.Adam(retinaface.parameters(),lr=lr)
    # fix encoder
    for name, value in retinaface.named_parameters():
        if 'Landmark' in name:
            value.requires_grad = False
    lr_cos = lambda n: 0.5 * (1 + np.cos((n) /
                                         (args.epochs) * np.pi)) * base_lr
    params = filter(lambda p: p.requires_grad == True, retinaface.parameters())
    body = filter(lambda p: p.requires_grad == False, retinaface.parameters())
    optimizer = torch.optim.Adam([{
        'params': body,
        'lr': lr * 3
    }, {
        'params': params,
        'lr': lr
    }])
    #evaluation the current model
    if (args.training == False):
        print("not pretrain")
        recall, precision, landmakr, miss = eval_widerface.evaluate(
            dataloader_val, retinaface)
        print('Recall:', recall)
        print('Precision:', precision)
        print("landmark: ", str(landmakr))
        print("miss: " + str(miss))
        return
    ##
    print('Start to train.')

    epoch_loss = []
    iteration = 0
    retinaface = retinaface.cuda()
    for epoch in range(args.epochs):
        lr = lr_cos(epoch)

        retinaface.train()

        # Training
        for iter_num, data in enumerate(dataloader_train):
            optimizer.zero_grad()
            classification_loss, bbox_regression_loss, ldm_regression_loss = retinaface(
                [data['img'].cuda().float(), data['annot']])
            classification_loss = classification_loss.mean()
            bbox_regression_loss = bbox_regression_loss.mean()
            ldm_regression_loss = ldm_regression_loss.mean()

            # loss = classification_loss + 1.0 * bbox_regression_loss + 0.5 * ldm_regression_loss
            loss = classification_loss + 0.15 * bbox_regression_loss + 0.25 * ldm_regression_loss

            loss.backward()
            optimizer.step()

            if iter_num % args.verbose == 0:
                log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (
                    epoch, args.epochs, iter_num, total_batch)
                table_data = [['loss name', 'value'],
                              ['total_loss', str(loss.item())],
                              [
                                  'classification',
                                  str(classification_loss.item())
                              ], ['bbox',
                                  str(bbox_regression_loss.item())],
                              ['landmarks',
                               str(ldm_regression_loss.item())]]
                table = AsciiTable(table_data)
                log_str += table.table
                print(log_str)
                iteration += 1

        # Eval
        if epoch % args.eval_step == 0:
            with open("aaa.txt", 'a') as f:
                f.write('-------- RetinaFace Pytorch --------' + '\n')
                f.write('Evaluating epoch {}'.format(epoch) + '\n')
                f.write('total_loss:' + str(loss.item()) + '\n')
                f.write('classification' + str(classification_loss.item()) +
                        '\n')
                f.write('bbox' + str(bbox_regression_loss.item()) + '\n')
                f.write('landmarks' + str(ldm_regression_loss.item()) + '\n')

                f.close()
            print('-------- RetinaFace Pytorch --------')
            print('Evaluating epoch {}'.format(epoch))
            recall, precision, landmakr, miss = eval_widerface.evaluate(
                dataloader_val, retinaface)
            print('Recall:', recall)
            print('Precision:', precision)
            print("landmark: ", str(landmakr))
            print("miss: " + str(miss))

            with open("aaa.txt", 'a') as f:
                f.write('-------- RetinaFace Pytorch --------(not pretrain)' +
                        '\n')
                f.write('Evaluating epoch {}'.format(epoch) + '\n')
                f.write('Recall:' + str(recall) + '\n')
                f.write('Precision:' + str(precision) + '\n')
                f.write("landmark: " + str(landmakr) + '\n')
                f.write("miss: " + str(miss) + '\n')
                f.close()
        # Save model
        if (epoch) % args.save_step == 0:
            torch.save(
                retinaface.state_dict(), args.save_path +
                '/stage_5_68_full_model_epoch_{}.pt'.format(epoch + 1))
예제 #5
0
def main():
    args = get_args()
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    log_path = os.path.join(args.save_path, 'log')
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    writer = SummaryWriter(log_dir=log_path)

    data_path = args.data_path
    train_path = os.path.join(data_path, 'train/label.txt')
    val_path = os.path.join(data_path, 'val/label.txt')
    dataloader_train, dataloader_test = load_data(train_path,
                                                  args.batch_size,
                                                  split_train_test=True)
    dataloader_val = load_data(val_path, args.batch_size)

    total_batch = len(dataloader_train)

    # Create torchvision model
    retinaface = torchvision_model.create_retinaface().cuda()
    retinaface = torch.nn.DataParallel(retinaface).cuda()
    retinaface.training = True

    optimizer = optim.Adam(retinaface.parameters(), lr=1e-3)

    print('Start to train.')

    epoch_loss = []
    iteration = 0

    for epoch in range(args.epochs):
        retinaface.train()

        # Training
        for iter_num, data in enumerate(dataloader_train):
            optimizer.zero_grad()
            classification_loss, bbox_regression_loss, ldm_regression_loss = retinaface(
                [data['img'].cuda().float(), data['annot']])
            classification_loss = classification_loss.mean()
            bbox_regression_loss = bbox_regression_loss.mean()
            ldm_regression_loss = ldm_regression_loss.mean()

            # loss = classification_loss + 1.0 * bbox_regression_loss + 0.5 * ldm_regression_loss
            loss = classification_loss + bbox_regression_loss + 0.5 * ldm_regression_loss

            loss.backward()
            optimizer.step()

            if iter_num % args.verbose == 0:
                log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (
                    epoch, args.epochs, iter_num, total_batch)
                table_data = [['loss name', 'value'],
                              ['total_loss', str(loss.item())],
                              [
                                  'classification',
                                  str(classification_loss.item())
                              ], ['bbox',
                                  str(bbox_regression_loss.item())],
                              ['landmarks',
                               str(ldm_regression_loss.item())]]
                table = AsciiTable(table_data)
                log_str += table.table
                print("train loses:")
                print(log_str)

                # write the log to tensorboard
                writer.add_scalar('losses:', loss.item(),
                                  iteration * args.verbose)
                writer.add_scalar('class losses:', classification_loss.item(),
                                  iteration * args.verbose)
                writer.add_scalar('box losses:', bbox_regression_loss.item(),
                                  iteration * args.verbose)
                writer.add_scalar('landmark losses:',
                                  ldm_regression_loss.item(),
                                  iteration * args.verbose)
                iteration += 1
                validate(dataloader_test, retinaface)

        # Eval
        if epoch % args.eval_step == 0:
            print('-------- RetinaFace --------')
            print('Evaluating epoch {}'.format(epoch))
            recall, precision = eval_widerface.evaluate(
                dataloader_val, retinaface)
            print('Recall:', recall)
            print('Precision:', precision)

            writer.add_scalar('Recall:', recall, epoch * args.eval_step)
            writer.add_scalar('Precision:', precision, epoch * args.eval_step)

        # Save model
        if (epoch + 1) % args.save_step == 0:
            torch.save(retinaface.state_dict(),
                       args.save_path + '/model_epoch_{}.pt'.format(epoch + 1))

    writer.close()