Example #1
0
def validate(valloader, iteration):
    with torch.no_grad():
        net.eval()
        corrects1 = 0
        corrects2 = 0
        corrects3 = 0
        cnt = 0
        val_cls_losses = []
        val_apn_losses = []
        for val_images, val_labels in valloader:
            if args.cuda:
                val_images = val_images.cuda()
                val_labels = val_labels.cuda()
            cnt += val_labels.size(0)

            logits, _, _ = net(val_images)
            preds = []
            for i in range(len(val_labels)):
                pred = [logit[i][val_labels[i]] for logit in logits]
                preds.append(pred)

            val_cls_losses = multitask_loss(logits, val_labels)
            val_apn_loss = pairwise_ranking_loss(preds)

            val_cls_losses.append(sum(val_cls_losses))
            val_apn_losses.append(val_apn_loss)

            _, predicted1 = torch.max(logits[0], 1)
            correct1 = (predicted1 == val_labels).sum()
            corrects1 += correct1

            _, predicted2 = torch.max(logits[1], 1)
            correct2 = (predicted2 == val_labels).sum()
            corrects2 += correct2

            _, predicted3 = torch.max(logits[2], 1)
            correct3 = (predicted3 == val_labels).sum()
            corrects3 += correct3

        val_cls_losses = torch.stack(val_cls_losses).mean()
        val_apn_losses = torch.stack(val_apn_losses).mean()
        accuracy1 = corrects1.float() / cnt
        accuracy2 = corrects2.float() / cnt
        accuracy3 = corrects3.float() / cnt

        logger.scalar_summary('val_cls_loss', val_cls_losses.item(),
                              iteration + 1)
        logger.scalar_summary('val_rank_loss', val_apn_losses.item(),
                              iteration + 1)
        logger.scalar_summary('val_acc1', accuracy1.item(), iteration + 1)
        logger.scalar_summary('val_acc2', accuracy2.item(), iteration + 1)
        logger.scalar_summary('val_acc3', accuracy3.item(), iteration + 1)
        print(
            " [*] Iter %d || Val accuracy1: %.2f, Val accuracy2: %.2f, Val accuracy3: %.2f"
            %
            (iteration, accuracy1.item(), accuracy2.item(), accuracy3.item()))

    net.train()
Example #2
0
def train_cls(trainloader, iteration):
    all_iteration = len(trainloader)

    net.train()
    # 从训练集迭代器中获取训练数据
    for index, (images, labels) in enumerate(trainloader):
        # 评估图片读取耗时
        t0 = time.time()
        # 将图片和标签转化为tensor
        images.requires_grad_()
        if args.cuda:
            images, labels = images.cuda(), labels.cuda()
        # images = torch.tensor(images).cuda(async=True)
        # labels = torch.tensor(labels).cuda(async=True)

        # 将图片输入网络,前传,生成预测值
        logits, _, _ = net(images)
        opt1.zero_grad()

        # 计算cls loss
        new_cls_losses = multitask_loss(logits, labels)
        new_cls_loss = sum(new_cls_losses)
        new_cls_loss.backward()
        opt1.step()

        # 评估训练耗时
        t1 = time.time()
        # 可视化
        logger.scalar_summary('cls_loss', new_cls_loss.item(), iteration + 1)
        logger.scalar_summary('cls_loss1', new_cls_losses[0].item(),
                              iteration + 1)
        logger.scalar_summary('cls_loss12', new_cls_losses[1].item(),
                              iteration + 1)
        logger.scalar_summary('cls_loss123', new_cls_losses[2].item(),
                              iteration + 1)

        # 计算rank loss
        preds = []
        for i in range(len(labels)):
            pred = [logit[i][labels[i]] for logit in logits]
            preds.append(pred)
        new_apn_loss = pairwise_ranking_loss(preds)
        logger.scalar_summary('rank_loss', new_apn_loss.item(), iteration + 1)

        if (index % 10) == 0:
            print(
                " [*] cls_epoch[%d] || dataIter[%d] %d || cls_loss: %.4f || Timer: %.4fsec"
                % (iteration, all_iteration, index, new_cls_loss.item(),
                   (t1 - t0)))
            # 在日志文件中记录每个epoch的精度和loss
            with open('./result_racnn123/%s.txt' % 'cls', 'a') as acc_file:
                acc_file.write(
                    " [*] cls_epoch[%d] || dataIter[%d] %d || cls_loss: %.4f || Timer: %.4fsec"
                    % (iteration, all_iteration, index, new_cls_loss.item(),
                       (t1 - t0)))
Example #3
0
def validate(valloader, iteration, type):
    with torch.no_grad():
        net.eval()
        corrects1 = 0
        corrects2 = 0
        corrects3 = 0
        cnt = 0
        val_cls_losses = []
        val_apn_losses = []

        t0 = time.time()
        for j, (val_images, val_labels) in enumerate(valloader):
            accuracy_list = []
            if args.cuda:
                val_images = val_images.cuda()
                val_labels = val_labels.cuda()
            cnt += val_labels.size(0)

            logits, _, _ = net(val_images)
            preds = []
            for i in range(len(val_labels)):
                pred = [logit[i][val_labels[i]] for logit in logits]
                preds.append(pred)

            # val_cls_losses = multitask_loss(logits, val_labels)
            val_cls_loss = multitask_loss(logits, val_labels)
            val_apn_loss = pairwise_ranking_loss(preds)

            # val_cls_losses.append(sum(val_cls_los))
            val_cls_losses.append(sum(val_cls_loss))
            val_apn_losses.append(val_apn_loss)

            _, predicted1 = torch.max(logits[0], 1)
            correct1 = (predicted1 == val_labels).sum()
            corrects1 += correct1

            _, predicted2 = torch.max(logits[1], 1)
            correct2 = (predicted2 == val_labels).sum()
            corrects2 += correct2

            _, predicted3 = torch.max(logits[2], 1)
            correct3 = (predicted3 == val_labels).sum()
            corrects3 += correct3
            if j % 20 == 0:
                print(" || validate %s [*] Iter[%d] %d || " %
                      (type, len(valloader), j))
        t1 = time.time()
        val_cls_losses = torch.stack(val_cls_losses).mean()
        val_apn_losses = torch.stack(val_apn_losses).mean()
        accuracy1 = corrects1.float() / cnt
        accuracy2 = corrects2.float() / cnt
        accuracy3 = corrects3.float() / cnt
        accuracy_list.append(accuracy1.item())
        accuracy_list.append(accuracy2.item())
        accuracy_list.append(accuracy3.item())

        logger.scalar_summary('val_cls_loss', val_cls_losses.item(),
                              iteration + 1)
        logger.scalar_summary('val_rank_loss', val_apn_losses.item(),
                              iteration + 1)
        logger.scalar_summary('val_acc1', accuracy1.item(), iteration + 1)
        logger.scalar_summary('val_acc2', accuracy2.item(), iteration + 1)
        logger.scalar_summary('val_acc3', accuracy3.item(), iteration + 1)
        if type == 'cls':
            print(
                " cls: [*] Iter %d || Val accuracy1: %.2f || Val accuracy2: %.2f || Val accuracy3: %.2f || Timer: %.4fsec"
                % (iteration, accuracy1.item(), accuracy2.item(),
                   accuracy3.item(), (t1 - t0)))
            with open('./result_racnn/%s.txt' % 'cls_val_accuracy',
                      'a') as acc_file:
                acc_file.write(
                    " cls: [*] Iter %d || Val accuracy1: %.2f || Val accuracy2: %.2f || Val accuracy3: %.2f || Timer: %.4fsec\n"
                    % (iteration, accuracy1.item(), accuracy2.item(),
                       accuracy3.item(), (t1 - t0)))
        else:
            print(
                " apn: [*] Iter %d || Val accuracy1: %.2f, Val accuracy2: %.2f, Val accuracy3: %.2f || Timer: %.4fsec"
                % (iteration, accuracy1.item(), accuracy2.item(),
                   accuracy3.item(), (t1 - t0)))
            with open('./result_racnn/%s.txt' % 'apn_val_accuracy',
                      'a') as acc_file:
                acc_file.write(
                    " cls: [*] Iter %d || Val accuracy1: %.2f || Val accuracy2: %.2f || Val accuracy3: %.2f || Timer: %.4fsec\n"
                    % (iteration, accuracy1.item(), accuracy2.item(),
                       accuracy3.item(), (t1 - t0)))
    net.train()
    return accuracy_list
Example #4
0
def train(trainset, trainloader, valloader):
    net.train()

    conf_loss = 0
    loc_loss = 0

    print(" [*] Loading dataset...")
    batch_iterator = None

    resume_apn = True

    # pretrain APN
    if resume_apn:
        checkpoint_path = './racnn_models/checkpoint_pre_APN.pth.tar'
        if os.path.isfile(checkpoint_path):
            print("=> loading checkpoint '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path)
            apn_iter = checkpoint['apn_iter']
            apn_epoch = checkpoint['apn_epoch']
            apn_steps = checkpoint['apn_steps']
            net.load_state_dict(checkpoint['state_dict'])
    else:
        apn_iter, apn_epoch, apn_steps = pretrainAPN(trainset, trainloader)

    cls_iter, cls_epoch, cls_steps = 0, 0, 1
    switch_step = 0
    old_cls_loss, new_cls_loss = 2, 1
    old_apn_loss, new_apn_loss = 2, 1
    iteration = 0  # count the both of iteration
    epoch_size = len(trainset) // 4
    cls_tol = 0
    apn_tol = 0
    batch_iterator = iter(trainloader)

    # while ((old_cls_loss - new_cls_loss)**2 > 1e-7) and ((old_apn_loss - new_apn_loss)**2 > 1e-7) and (iteration < 500000):
    while ((old_cls_loss - new_cls_loss)**2 > 1e-7) and (
        (old_apn_loss - new_apn_loss)**2 > 1e-7) and (iteration < 50000):
        # until the two type of losses no longer change
        print(' [*] Swtich optimize parameters to Class')
        while ((cls_tol < 10) and (cls_iter % 500 != 0)):
            # print('cls_tol=%d, cls_iter=%d'%(cls_tol, cls_iter))
            if (not batch_iterator) or (iteration % epoch_size == 0):
                batch_iterator = iter(trainloader)

            if cls_iter % epoch_size == 0:
                cls_epoch += 1
                if cls_epoch in decay_steps:
                    cls_steps += 1
                    adjust_learning_rate(opt1, 0.1, cls_steps, args.lr)

            old_cls_loss = new_cls_loss

            images, labels = next(batch_iterator)
            # images, labels = Variable(images, requires_grad = True), Variable(labels)
            images.requires_grad_()
            if args.cuda:
                images, labels = images.cuda(), labels.cuda()

            t0 = time.time()
            logits, _, _ = net(images)

            opt1.zero_grad()
            new_cls_losses = multitask_loss(logits, labels)
            new_cls_loss = sum(new_cls_losses)
            #new_cls_loss = new_cls_losses[0]
            new_cls_loss.backward()
            opt1.step()
            t1 = time.time()

            if (old_cls_loss - new_cls_loss)**2 < 1e-6:
                cls_tol += 1
            else:
                cls_tol = 0

            logger.scalar_summary('cls_loss', new_cls_loss.item(),
                                  iteration + 1)
            logger.scalar_summary('cls_loss1', new_cls_losses[0].item(),
                                  iteration + 1)
            logger.scalar_summary('cls_loss12', new_cls_losses[1].item(),
                                  iteration + 1)
            logger.scalar_summary('cls_loss123', new_cls_losses[2].item(),
                                  iteration + 1)
            iteration += 1
            cls_iter += 1
            if (cls_iter % 10) == 0:
                print(
                    " [*] cls_epoch[%d], Iter %d || cls_iter %d || cls_loss: %.4f || Timer: %.4fsec"
                    % (cls_epoch, iteration, cls_iter, new_cls_loss.item(),
                       (t1 - t0)))
                # 在日志文件中记录每个epoch的精度和loss
                with open('./result_racnn/%s.txt' % 'cls', 'a') as acc_file:
                    acc_file.write(
                        " [*] cls_epoch[%d], Iter %d || cls_iter %d || cls_loss: %.4f || Timer: %.4fsec\n"
                        % (cls_epoch, iteration, cls_iter, new_cls_loss.item(),
                           (t1 - t0)))

        images, labels = next(batch_iterator)
        if args.cuda:
            images, labels = images.cuda(), labels.cuda()
        logits, _, _ = net(images)
        preds = []
        for i in range(len(labels)):
            pred = [logit[i][labels[i]] for logit in logits]
            preds.append(pred)
        new_apn_loss = pairwise_ranking_loss(preds)
        logger.scalar_summary('rank_loss', new_apn_loss.item(), iteration + 1)
        iteration += 1
        #cls_iter += 1
        print('cls validate class......')
        accuracy_list = validate(valloader, iteration, type='cls')

        #continue
        print(' [*] Swtich optimize parameters to APN')
        switch_step += 1
        while ((apn_tol < 10) and apn_iter % 500 != 0):
            if (not batch_iterator) or (iteration % epoch_size == 0):
                batch_iterator = iter(trainloader)

            if apn_iter % epoch_size == 0:
                apn_epoch += 1
                if apn_epoch in decay_steps:
                    apn_steps += 1
                    adjust_learning_rate(opt2, 0.1, apn_steps, args.lr)

            old_apn_loss = new_apn_loss

            images, labels = next(batch_iterator)
            # images, labels = Variable(images, requires_grad = True), Variable(labels)
            images.requires_grad_()
            if args.cuda:
                images, labels = images.cuda(), labels.cuda()

            t0 = time.time()
            logits, _, _ = net(images)

            opt2.zero_grad()
            preds = []
            for i in range(len(labels)):
                pred = [logit[i][labels[i]] for logit in logits]
                preds.append(pred)
            new_apn_loss = pairwise_ranking_loss(preds)
            new_apn_loss.backward()
            opt2.step()
            t1 = time.time()

            if (old_apn_loss - new_apn_loss)**2 < 1e-6:
                apn_tol += 1
            else:
                apn_tol = 0

            logger.scalar_summary('rank_loss', new_apn_loss.item(),
                                  iteration + 1)
            iteration += 1
            apn_iter += 1
            if (apn_iter % 20) == 0:
                print(
                    " [*] apn_epoch[%d], Iter %d || apn_iter %d || apn_loss: %.4f || Timer: %.4fsec"
                    % (apn_epoch, iteration, apn_iter, new_apn_loss.item(),
                       (t1 - t0)))
                # 在日志文件中记录每个epoch的精度和loss
                with open('./result_racnn/%s.txt' % 'APN', 'a') as acc_file:
                    acc_file.write(
                        " [*] apn_epoch[%d], Iter %d || apn_iter %d || apn_loss: %.4f || Timer: %.4fsec\n"
                        % (apn_epoch, iteration, apn_iter, new_apn_loss.item(),
                           (t1 - t0)))

        switch_step += 1

        images, labels = next(batch_iterator)
        if args.cuda:
            images, labels = images.cuda(), labels.cuda()
        new_cls_losses = multitask_loss(logits, labels)
        new_cls_loss = sum(new_cls_losses)
        logger.scalar_summary('cls_loss', new_cls_loss.item(), iteration + 1)
        iteration += 1
        cls_iter += 1
        apn_iter += 1

        print('apn validate apn......')

        accuracy_list = validate(valloader, iteration, type='apn')
        print("accuracy_list: ", accuracy_list)
        # save model
        # torch.save(net.state_dict(), 'racnn_1_2_3.pth.tar')
        number = 0
        for i, acc in enumerate(accuracy_list):
            accuracy_max = accuracy_list[0]

            if acc > accuracy_max:
                number = i
                # 在日志文件中记录每个epoch的精度和loss
        with open('./result_racnn/%s.txt' % 'accuracy', 'a') as acc_file:
            acc_file.write(
                " accuracy1 = %d || accuracy2 =  %d || accuracy3 = %d || accuracy_best_number = %d \n"
                % (accuracy_list[0], accuracy_list[1], accuracy_list[2],
                   accuracy_list[number]))
        state = net.state_dict()
        save_checkpoint(type='cls_and_rank', state=state)
Example #5
0
def train(trainset, trainloader, valloader):
    net.train()

    conf_loss = 0
    loc_loss = 0

    print(" [*] Loading dataset...")
    batch_iterator = None

    # trainset = CUB200_loader(os.getcwd() + '/data/CUB_200_2011', split = 'train')
    # trainloader = data.DataLoader(trainset, batch_size = 4,
    #         shuffle = True, collate_fn = trainset.CUB_collate, num_workers = 4)
    # testset = CUB200_loader(os.getcwd() + '/data/CUB_200_2011', split = 'test')
    # testloader = data.DataLoader(testset, batch_size = 4,
    #         shuffle = False, collate_fn = testset.CUB_collate, num_workers = 4)

    apn_iter, apn_epoch, apn_steps = pretrainAPN(trainset, trainloader)
    cls_iter, cls_epoch, cls_steps = 0, 0, 1
    switch_step = 0
    old_cls_loss, new_cls_loss = 2, 1
    old_apn_loss, new_apn_loss = 2, 1
    iteration = 0  # count the both of iteration
    epoch_size = len(trainset) // 4
    cls_tol = 0
    apn_tol = 0
    batch_iterator = iter(trainloader)

    while ((old_cls_loss - new_cls_loss)**2 > 1e-7) and (
        (old_apn_loss - new_apn_loss)**2 > 1e-7) and (iteration < 500000):
        # until the two type of losses no longer change
        print(' [*] Swtich optimize parameters to Class')
        while ((cls_tol < 10) and (cls_iter % 5000 != 0)):
            if (not batch_iterator) or (iteration % epoch_size == 0):
                batch_iterator = iter(trainloader)

            if cls_iter % epoch_size == 0:
                cls_epoch += 1
                if cls_epoch in decay_steps:
                    cls_steps += 1
                    adjust_learning_rate(opt1, 0.1, cls_steps, args.lr)

            old_cls_loss = new_cls_loss

            images, labels = next(batch_iterator)
            images, labels = Variable(images,
                                      requires_grad=True), Variable(labels)
            if args.cuda:
                images, labels = images.cuda(), labels.cuda()

            t0 = time.time()
            logits, _, _ = net(images)

            opt1.zero_grad()
            new_cls_losses = multitask_loss(logits, labels)
            new_cls_loss = sum(new_cls_losses)
            #new_cls_loss = new_cls_losses[0]
            new_cls_loss.backward()
            opt1.step()
            t1 = time.time()

            if (old_cls_loss - new_cls_loss)**2 < 1e-6:
                cls_tol += 1
            else:
                cls_tol = 0

            logger.scalar_summary('cls_loss', new_cls_loss.item(),
                                  iteration + 1)
            logger.scalar_summary('cls_loss1', new_cls_losses[0].item(),
                                  iteration + 1)
            logger.scalar_summary('cls_loss12', new_cls_losses[1].item(),
                                  iteration + 1)
            logger.scalar_summary('cls_loss123', new_cls_losses[2].item(),
                                  iteration + 1)
            iteration += 1
            cls_iter += 1
            if (cls_iter % 20) == 0:
                print(
                    " [*] cls_epoch[%d], Iter %d || cls_iter %d || cls_loss: %.4f || Timer: %.4fsec"
                    % (cls_epoch, iteration, cls_iter, new_cls_loss.item(),
                       (t1 - t0)))

        images, labels = next(batch_iterator)
        if args.cuda:
            images, labels = images.cuda(), labels.cuda()
        logits, _, _ = net(images)
        preds = []
        for i in range(len(labels)):
            pred = [logit[i][labels[i]] for logit in logits]
            preds.append(pred)
        new_apn_loss = pairwise_ranking_loss(preds)
        logger.scalar_summary('rank_loss', new_apn_loss.item(), iteration + 1)
        iteration += 1
        #cls_iter += 1
        validate(valloader, iteration)
        #continue
        print(' [*] Swtich optimize parameters to APN')
        switch_step += 1

        while ((apn_tol < 10) and apn_iter % 5000 != 0):
            if (not batch_iterator) or (iteration % epoch_size == 0):
                batch_iterator = iter(trainloader)

            if apn_iter % epoch_size == 0:
                apn_epoch += 1
                if apn_epoch in decay_steps:
                    apn_steps += 1
                    adjust_learning_rate(opt2, 0.1, apn_steps, args.lr)

            old_apn_loss = new_apn_loss

            images, labels = next(batch_iterator)
            images, labels = Variable(images,
                                      requires_grad=True), Variable(labels)
            if args.cuda:
                images, labels = images.cuda(), labels.cuda()

            t0 = time.time()
            logits, _, _ = net(images)

            opt2.zero_grad()
            preds = []
            for i in range(len(labels)):
                pred = [logit[i][labels[i]] for logit in logits]
                preds.append(pred)
            new_apn_loss = pairwise_ranking_loss(preds)
            new_apn_loss.backward()
            opt2.step()
            t1 = time.time()

            if (old_apn_loss - new_apn_loss)**2 < 1e-6:
                apn_tol += 1
            else:
                apn_tol = 0

            logger.scalar_summary('rank_loss', new_apn_loss.item(),
                                  iteration + 1)
            iteration += 1
            apn_iter += 1
            if (apn_iter % 20) == 0:
                print(
                    " [*] apn_epoch[%d], Iter %d || apn_iter %d || apn_loss: %.4f || Timer: %.4fsec"
                    % (apn_epoch, iteration, apn_iter, new_apn_loss.item(),
                       (t1 - t0)))

        switch_step += 1

        images, labels = next(batch_iterator)
        if args.cuda:
            images, labels = images.cuda(), labels.cuda()
        new_cls_losses = multitask_loss(logits, labels)
        new_cls_loss = sum(new_cls_losses)
        logger.scalar_summary('cls_loss', new_cls_loss.item(), iteration + 1)
        iteration += 1
        cls_iter += 1
        apn_iter += 1
        validate(valloader, iteration)
Example #6
0
def train_test(trainset, trainloader, valloader):

    start_epoch = 0
    total_epochs = 10000

    net.train()

    print(" [*] Loading dataset...")
    batch_iterator = None

    resume_apn = True

    # pretrain APN
    if resume_apn:
        checkpoint_path = './racnn_models/checkpoint_pre_APN.pth.tar'
        if os.path.isfile(checkpoint_path):
            print("=> loading checkpoint '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path)
            apn_iter = checkpoint['apn_iter']
            apn_epoch = checkpoint['apn_epoch']
            apn_steps = checkpoint['apn_steps']
            net.load_state_dict(checkpoint['state_dict'])
    else:
        apn_iter, apn_epoch, apn_steps = pretrainAPN(trainset, trainloader)

    # batch_iterator = iter(trainloader)

    for epoch in range(start_epoch, total_epochs):
        # train class
        print(' [*] Swtich optimize parameters to Class')
        train_cls(trainloader, epoch)
        accuracy_list = validate(valloader, epoch, type='cls')

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        # evaluate on validation set
        precision, avg_loss = validate(val_loader, model, criterion)

        # 在日志文件中记录每个epoch的精度和loss
        with open('./result/%s.txt' % file_name, 'a') as acc_file:
            acc_file.write('Epoch: %2d, Precision: %.8f, Loss: %.8f\n' %
                           (epoch, precision, avg_loss))

        # writer.add_scalar('Loss', avg_loss, epoch+1)
        # writer.add_scalar('Accuracy', precision, epoch+1)

        # 记录最高精度与最低loss,保存最新模型与最佳模型
        is_best = precision > best_precision
        is_lowest_loss = avg_loss < lowest_loss
        best_precision = max(precision, best_precision)
        lowest_loss = min(avg_loss, lowest_loss)
        state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'best_precision': best_precision,
            'lowest_loss': lowest_loss,
            'stage': stage,
            'lr': lr,
        }
        save_checkpoint(state, is_best, is_lowest_loss)

        # 判断是否进行下一个stage
        if (epoch + 1) in np.cumsum(stage_epochs)[:-1]:
            stage += 1
            optimizer = adjust_learning_rate()
            model.load_state_dict(
                torch.load('./model/%s/model_best.pth.tar' %
                           file_name)['state_dict'])
            print('Step into next stage')
            with open('./result/%s.txt' % file_name, 'a') as acc_file:
                acc_file.write(
                    '---------------Step into next stage----------------\n')

    conf_loss = 0
    loc_loss = 0

    print(" [*] Loading dataset...")
    batch_iterator = None

    resume_apn = True

    # pretrain APN
    if resume_apn:
        checkpoint_path = './racnn_models/checkpoint_pre_APN.pth.tar'
        if os.path.isfile(checkpoint_path):
            print("=> loading checkpoint '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path)
            apn_iter = checkpoint['apn_iter']
            apn_epoch = checkpoint['apn_epoch']
            apn_steps = checkpoint['apn_steps']
            net.load_state_dict(checkpoint['state_dict'])
    else:
        apn_iter, apn_epoch, apn_steps = pretrainAPN(trainset, trainloader)

    cls_iter, cls_epoch, cls_steps = 0, 0, 1
    switch_step = 0
    old_cls_loss, new_cls_loss = 2, 1
    old_apn_loss, new_apn_loss = 2, 1
    iteration = 0  # count the both of iteration
    epoch_size = len(trainset) // 4
    cls_tol = 0
    apn_tol = 0
    batch_iterator = iter(trainloader)

    # while ((old_cls_loss - new_cls_loss)**2 > 1e-7) and ((old_apn_loss - new_apn_loss)**2 > 1e-7) and (iteration < 500000):
    while ((old_cls_loss - new_cls_loss)**2 > 1e-7) and (
        (old_apn_loss - new_apn_loss)**2 > 1e-7) and (iteration < 50000):
        # until the two type of losses no longer change
        print(' [*] Swtich optimize parameters to Class')
        while ((cls_tol < 10) and (cls_iter % 500 != 0)):
            # print('cls_tol=%d, cls_iter=%d'%(cls_tol, cls_iter))
            if (not batch_iterator) or (iteration % epoch_size == 0):
                batch_iterator = iter(trainloader)

            if cls_iter % epoch_size == 0:
                cls_epoch += 1
                if cls_epoch in decay_steps:
                    cls_steps += 1
                    adjust_learning_rate(opt1, 0.1, cls_steps, args.lr)

            old_cls_loss = new_cls_loss

            images, labels = next(batch_iterator)
            # images, labels = Variable(images, requires_grad = True), Variable(labels)
            images.requires_grad_()
            if args.cuda:
                images, labels = images.cuda(), labels.cuda()

            t0 = time.time()
            logits, _, _ = net(images)

            opt1.zero_grad()
            new_cls_losses = multitask_loss(logits, labels)
            new_cls_loss = sum(new_cls_losses)
            # new_cls_loss = new_cls_losses[0]
            new_cls_loss.backward()
            opt1.step()
            t1 = time.time()

            if (old_cls_loss - new_cls_loss)**2 < 1e-6:
                cls_tol += 1
            else:
                cls_tol = 0

            logger.scalar_summary('cls_loss', new_cls_loss.item(),
                                  iteration + 1)
            logger.scalar_summary('cls_loss1', new_cls_losses[0].item(),
                                  iteration + 1)
            logger.scalar_summary('cls_loss12', new_cls_losses[1].item(),
                                  iteration + 1)
            logger.scalar_summary('cls_loss123', new_cls_losses[2].item(),
                                  iteration + 1)
            iteration += 1
            cls_iter += 1
            if (cls_iter % 10) == 0:
                print(
                    " [*] cls_epoch[%d], Iter %d || cls_iter %d || cls_loss: %.4f || Timer: %.4fsec"
                    % (cls_epoch, iteration, cls_iter, new_cls_loss.item(),
                       (t1 - t0)))
                # 在日志文件中记录每个epoch的精度和loss
                with open('./result_racnn/%s.txt' % 'cls', 'a') as acc_file:
                    acc_file.write(
                        " [*] cls_epoch[%d], Iter %d || cls_iter %d || cls_loss: %.4f || Timer: %.4fsec\n"
                        % (cls_epoch, iteration, cls_iter, new_cls_loss.item(),
                           (t1 - t0)))

        images, labels = next(batch_iterator)
        if args.cuda:
            images, labels = images.cuda(), labels.cuda()
        logits, _, _ = net(images)
        preds = []
        for i in range(len(labels)):
            pred = [logit[i][labels[i]] for logit in logits]
            preds.append(pred)
        new_apn_loss = pairwise_ranking_loss(preds)
        logger.scalar_summary('rank_loss', new_apn_loss.item(), iteration + 1)
        iteration += 1
        # cls_iter += 1
        print('cls validate class......')
        accuracy_list = validate(valloader, iteration, type='cls')

        # continue
        print(' [*] Swtich optimize parameters to APN')
        switch_step += 1
        while ((apn_tol < 10) and apn_iter % 500 != 0):
            if (not batch_iterator) or (iteration % epoch_size == 0):
                batch_iterator = iter(trainloader)

            if apn_iter % epoch_size == 0:
                apn_epoch += 1
                if apn_epoch in decay_steps:
                    apn_steps += 1
                    adjust_learning_rate(opt2, 0.1, apn_steps, args.lr)

            old_apn_loss = new_apn_loss

            images, labels = next(batch_iterator)
            # images, labels = Variable(images, requires_grad = True), Variable(labels)
            images.requires_grad_()
            if args.cuda:
                images, labels = images.cuda(), labels.cuda()

            t0 = time.time()
            logits, _, _ = net(images)

            opt2.zero_grad()
            preds = []
            for i in range(len(labels)):
                pred = [logit[i][labels[i]] for logit in logits]
                preds.append(pred)
            new_apn_loss = pairwise_ranking_loss(preds)
            new_apn_loss.backward()
            opt2.step()
            t1 = time.time()

            if (old_apn_loss - new_apn_loss)**2 < 1e-6:
                apn_tol += 1
            else:
                apn_tol = 0

            logger.scalar_summary('rank_loss', new_apn_loss.item(),
                                  iteration + 1)
            iteration += 1
            apn_iter += 1
            if (apn_iter % 20) == 0:
                print(
                    " [*] apn_epoch[%d], Iter %d || apn_iter %d || apn_loss: %.4f || Timer: %.4fsec"
                    % (apn_epoch, iteration, apn_iter, new_apn_loss.item(),
                       (t1 - t0)))
                # 在日志文件中记录每个epoch的精度和loss
                with open('./result_racnn/%s.txt' % 'APN', 'a') as acc_file:
                    acc_file.write(
                        " [*] apn_epoch[%d], Iter %d || apn_iter %d || apn_loss: %.4f || Timer: %.4fsec\n"
                        % (apn_epoch, iteration, apn_iter, new_apn_loss.item(),
                           (t1 - t0)))

        switch_step += 1

        images, labels = next(batch_iterator)
        if args.cuda:
            images, labels = images.cuda(), labels.cuda()
        new_cls_losses = multitask_loss(logits, labels)
        new_cls_loss = sum(new_cls_losses)
        logger.scalar_summary('cls_loss', new_cls_loss.item(), iteration + 1)
        iteration += 1
        cls_iter += 1
        apn_iter += 1

        print('apn validate apn......')

        accuracy_list = validate(valloader, iteration, type='apn')
        print("accuracy_list: ", accuracy_list)
        # save model
        # torch.save(net.state_dict(), 'racnn_1_2_3.pth.tar')
        number = 0
        for i, acc in enumerate(accuracy_list):
            accuracy_max = accuracy_list[0]

            if acc > accuracy_max:
                number = i
                # 在日志文件中记录每个epoch的精度和loss
        with open('./result_racnn/%s.txt' % 'accuracy', 'a') as acc_file:
            acc_file.write(
                " accuracy1 = %d || accuracy2 =  %d || accuracy3 = %d || accuracy_best_number = %d \n"
                % (accuracy_list[0], accuracy_list[1], accuracy_list[2],
                   accuracy_list[number]))
        state = net.state_dict()
        save_checkpoint(type='cls_and_rank', state=state)