示例#1
0
def test(epoch):

    snet.eval()
    if args.model == 'VID':
        VID_NET1.eval()
        VID_NET2.eval()
    elif args.model == 'OFD':
        OFD_NET1.eval()
        OFD_NET2.eval()
    elif args.model == 'AFD':
        AFD_NET1.eval()
        AFD_NET2.eval()
    else:
        pass
    PrivateTest_loss = 0
    t_prediction = 0
    conf_mat = np.zeros((NUM_CLASSES, NUM_CLASSES))

    for batch_idx, (img, target) in enumerate(PrivateTestloader):
        t = time.time()
        test_bs, ncrops, c, h, w = np.shape(img)
        img = img.view(-1, c, h, w)
        if args.cuda:
            img = img.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        img, target = Variable(img), Variable(target)

        with torch.no_grad():
            rb1_s, rb2_s, rb3_s, mimic_s, out_s = snet(img)

        outputs_avg = out_s.view(test_bs, ncrops, -1).mean(1)

        loss = Cls_crit(outputs_avg, target)
        t_prediction += (time.time() - t)
        PrivateTest_loss += loss.item()

        conf_mat += losses.confusion_matrix(outputs_avg, target, NUM_CLASSES)
        acc = sum([conf_mat[i, i]
                   for i in range(conf_mat.shape[0])]) / conf_mat.sum()
        precision = [
            conf_mat[i, i] / (conf_mat[i].sum() + 1e-10)
            for i in range(conf_mat.shape[0])
        ]
        mAP = sum(precision) / len(precision)

        recall = [
            conf_mat[i, i] / (conf_mat[:, i].sum() + 1e-10)
            for i in range(conf_mat.shape[0])
        ]
        precision = np.array(precision)
        recall = np.array(recall)
        f1 = 2 * precision * recall / (precision + recall + 1e-10)
        F1_score = f1.mean()

        #utils.progress_bar(batch_idx, len(PrivateTestloader), 'Loss: %.3f | Acc: %.3f%% | mAP: %.3f%% | F1: %.3f%%'
        #% (PrivateTest_loss / (batch_idx + 1), 100.*acc, 100.* mAP, 100.* F1_score))

    return PrivateTest_loss / (batch_idx +
                               1), 100. * acc, 100. * mAP, 100 * F1_score
示例#2
0
def PrivateTest(epoch):
    global PrivateTest_acc
    global best_PrivateTest_acc
    global best_PrivateTest_acc_epoch
    global total_prediction_fps
    global total_prediction_n
    net.eval()
    PrivateTest_loss = 0
    t_prediction = 0
    conf_mat = np.zeros((NUM_CLASSES, NUM_CLASSES))

    for batch_idx, (inputs, targets) in enumerate(PrivateTestloader):
        t = time.time()
        test_bs, ncrops, c, h, w = np.shape(inputs)
        inputs = inputs.view(-1, c, h, w)
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs), Variable(targets)
        _, _, _, _, outputs = net(inputs)
        outputs_avg = outputs.view(test_bs, ncrops,
                                   -1).mean(1)  # avg over crops
        t_prediction += (time.time() - t)

        loss = criterion(outputs_avg, targets)
        PrivateTest_loss += loss.item()

        conf_mat += losses.confusion_matrix(outputs_avg, targets, NUM_CLASSES)
        acc = sum([conf_mat[i, i]
                   for i in range(conf_mat.shape[0])]) / conf_mat.sum()
        precision = [
            conf_mat[i, i] / (conf_mat[i].sum() + 1e-10)
            for i in range(conf_mat.shape[0])
        ]
        mAP = sum(precision) / len(precision)

        recall = [
            conf_mat[i, i] / (conf_mat[:, i].sum() + 1e-10)
            for i in range(conf_mat.shape[0])
        ]
        precision = np.array(precision)
        recall = np.array(recall)
        f1 = 2 * precision * recall / (precision + recall + 1e-10)
        F1_score = f1.mean()

        #utils.progress_bar(batch_idx, len(PrivateTestloader), 'Loss: %.3f | Acc: %.3f%% | mAP: %.3f%% | F1: %.3f%%'

#% (PrivateTest_loss / (batch_idx + 1), 100.*acc, 100.* mAP, 100.* F1_score))
    total_prediction_fps = total_prediction_fps + (
        1 / (t_prediction / len(PrivateTestloader)))
    total_prediction_n = total_prediction_n + 1
    print('Prediction time: %.2f' % t_prediction + ', Average : %.5f/image' %
          (t_prediction / len(PrivateTestloader)) + ', Speed : %.2fFPS' %
          (1 / (t_prediction / len(PrivateTestloader))))

    # Save checkpoint.
    return PrivateTest_loss / (batch_idx +
                               1), 100. * acc, 100. * mAP, 100 * F1_score
示例#3
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    snet.train()
    decoder.train()
    train_loss = 0
    train_cls_loss = 0

    conf_mat = np.zeros((NUM_CLASSES, NUM_CLASSES))
    conf_mat_a = np.zeros((NUM_CLASSES, NUM_CLASSES))
    conf_mat_b = np.zeros((NUM_CLASSES, NUM_CLASSES))

    if epoch > learning_rate_decay_start and learning_rate_decay_start >= 0:
        frac = (epoch - learning_rate_decay_start) // learning_rate_decay_every
        decay_factor = learning_rate_decay_rate**frac
        current_lr = args.lr * decay_factor
        utils.set_lr(optimizer, current_lr)  # set the decayed rate
    else:
        current_lr = args.lr
    print('learning_rate: %s' % str(current_lr))

    for batch_idx, (img_teacher, img_student,
                    target) in enumerate(trainloader):

        if args.cuda:
            img_teacher = img_teacher.cuda(non_blocking=True)
            img_student = img_student.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        optimizer.zero_grad()

        if args.augmentation:
            img_teacher, teacher_target_a, teacher_target_b, teacher_lam = mixup_data(
                img_teacher, target, 0.6)
            img_teacher, teacher_target_a, teacher_target_b = map(
                Variable, (img_teacher, teacher_target_a, teacher_target_b))

            img_student, student_target_a, student_target_b, student_lam = mixup_data(
                img_student, target, 0.6)
            img_student, student_target_a, student_target_b = map(
                Variable, (img_student, student_target_a, student_target_b))
        else:
            img_teacher, img_student, target = Variable(img_teacher), Variable(
                img_student), Variable(target)

        rb1_s, rb2_s, rb3_s, mimic_s, out_s = snet(img_student)
        rb1_t, rb2_t, rb3_t, mimic_t, out_t = tnet(img_teacher)

        if args.augmentation:
            cls_loss = mixup_criterion(Cls_crit, out_s, student_target_a,
                                       student_target_b, student_lam)
        else:
            cls_loss = Cls_crit(out_s, target)

        kd_loss = KD_T_crit(out_t, out_s)

        if args.distillation == 'KD':
            loss = 0.2 * cls_loss + 0.8 * kd_loss
        elif args.distillation == 'DE':
            new_rb1_s = decoder(rb1_s)
            decoder_loss = losses.styleLoss(img_teacher, new_rb1_s.cuda(),
                                            MSE_crit)
            loss = 0.2 * cls_loss + 0.8 * kd_loss + 0.1 * decoder_loss
        elif args.distillation == 'AS':
            rb2_loss = losses.Absdiff_Similarity(rb2_t, rb2_s).cuda()
            loss = 0.2 * cls_loss + 0.8 * kd_loss + 0.9 * rb2_loss
        elif args.distillation == 'DEAS':
            new_rb1_s = decoder(rb1_s)
            decoder_loss = losses.styleLoss(img_teacher, new_rb1_s.cuda(),
                                            MSE_crit)
            rb2_loss = losses.Absdiff_Similarity(rb2_t, rb2_s).cuda()
            loss = 0.2 * cls_loss + 0.8 * kd_loss + 0.1 * decoder_loss + 0.9 * rb2_loss
        elif args.distillation == 'SSDEAS':
            new_rb1_s = decoder(rb1_s)
            decoder_loss = losses.styleLoss(img_teacher, new_rb1_s.cuda(),
                                            MSE_crit)
            rb2_loss = losses.Absdiff_Similarity(rb2_t, rb2_s).cuda()
            loss = 0 * cls_loss + 0 * kd_loss + 0.1 * decoder_loss + 0.9 * rb2_loss
        else:
            raise Exception('Invalid distillation name...')

        loss.backward()
        utils.clip_gradient(optimizer, 0.1)
        optimizer.step()
        train_loss += loss.item()
        train_cls_loss += cls_loss.item()

        if args.augmentation:
            conf_mat_a += losses.confusion_matrix(out_s, student_target_a,
                                                  NUM_CLASSES)
            acc_a = sum([conf_mat_a[i, i] for i in range(conf_mat_a.shape[0])
                         ]) / conf_mat_a.sum()
            precision_a = np.array([
                conf_mat_a[i, i] / (conf_mat_a[i].sum() + 1e-10)
                for i in range(conf_mat_a.shape[0])
            ])
            recall_a = np.array([
                conf_mat_a[i, i] / (conf_mat_a[:, i].sum() + 1e-10)
                for i in range(conf_mat_a.shape[0])
            ])
            mAP_a = sum(precision_a) / len(precision_a)
            F1_score_a = (2 * precision_a * recall_a /
                          (precision_a + recall_a + 1e-10)).mean()

            conf_mat_b += losses.confusion_matrix(out_s, student_target_b,
                                                  NUM_CLASSES)
            acc_b = sum([conf_mat_b[i, i] for i in range(conf_mat_b.shape[0])
                         ]) / conf_mat_b.sum()
            precision_b = np.array([
                conf_mat_b[i, i] / (conf_mat_b[i].sum() + 1e-10)
                for i in range(conf_mat_b.shape[0])
            ])
            recall_b = np.array([
                conf_mat_b[i, i] / (conf_mat_b[:, i].sum() + 1e-10)
                for i in range(conf_mat_b.shape[0])
            ])
            mAP_b = sum(precision_b) / len(precision_b)
            F1_score_b = (2 * precision_b * recall_b /
                          (precision_b + recall_b + 1e-10)).mean()

            acc = student_lam * acc_a + (1 - student_lam) * acc_b
            mAP = student_lam * mAP_a + (1 - student_lam) * mAP_b
            F1_score = student_lam * F1_score_a + (1 -
                                                   student_lam) * F1_score_b

        else:
            conf_mat += losses.confusion_matrix(out_s, target, NUM_CLASSES)
            acc = sum([conf_mat[i, i]
                       for i in range(conf_mat.shape[0])]) / conf_mat.sum()
            precision = [
                conf_mat[i, i] / (conf_mat[i].sum() + 1e-10)
                for i in range(conf_mat.shape[0])
            ]
            mAP = sum(precision) / len(precision)

            recall = [
                conf_mat[i, i] / (conf_mat[:, i].sum() + 1e-10)
                for i in range(conf_mat.shape[0])
            ]
            precision = np.array(precision)
            recall = np.array(recall)
            f1 = 2 * precision * recall / (precision + recall + 1e-10)
            F1_score = f1.mean()

        #utils.progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% | mAP: %.3f%% | F1: %.3f%%'
        #% (train_loss/(batch_idx+1), 100.*acc, 100.* mAP, 100.* F1_score))

    return train_cls_loss / (batch_idx +
                             1), 100. * acc, 100. * mAP, 100 * F1_score
示例#4
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0

    conf_mat = np.zeros((NUM_CLASSES, NUM_CLASSES))
    conf_mat_a = np.zeros((NUM_CLASSES, NUM_CLASSES))
    conf_mat_b = np.zeros((NUM_CLASSES, NUM_CLASSES))

    if epoch > learning_rate_decay_start and learning_rate_decay_start >= 0:
        frac = (epoch - learning_rate_decay_start) // learning_rate_decay_every
        decay_factor = learning_rate_decay_rate**frac
        current_lr = args.lr * decay_factor
        utils.set_lr(optimizer, current_lr)  # set the decayed rate
    else:
        current_lr = args.lr
    print('learning_rate: %s' % str(current_lr))

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()

        if args.augmentation:
            inputs, targets_a, targets_b, lam = mixup_data(
                inputs, targets, 0.6)
            inputs, targets_a, targets_b = map(Variable,
                                               (inputs, targets_a, targets_b))
        else:
            inputs, targets = Variable(inputs), Variable(targets)

        _, _, _, _, outputs = net(inputs)

        if args.augmentation:
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b,
                                   lam)
        else:
            loss = criterion(outputs, targets)

        loss.backward()
        utils.clip_gradient(optimizer, 0.1)
        optimizer.step()
        train_loss += loss.item()

        if args.augmentation:
            conf_mat_a += losses.confusion_matrix(outputs, targets_a,
                                                  NUM_CLASSES)
            acc_a = sum([conf_mat_a[i, i] for i in range(conf_mat_a.shape[0])
                         ]) / conf_mat_a.sum()
            precision_a = np.array([
                conf_mat_a[i, i] / (conf_mat_a[i].sum() + 1e-10)
                for i in range(conf_mat_a.shape[0])
            ])
            recall_a = np.array([
                conf_mat_a[i, i] / (conf_mat_a[:, i].sum() + 1e-10)
                for i in range(conf_mat_a.shape[0])
            ])
            mAP_a = sum(precision_a) / len(precision_a)
            F1_score_a = (2 * precision_a * recall_a /
                          (precision_a + recall_a + 1e-10)).mean()

            conf_mat_b += losses.confusion_matrix(outputs, targets_b,
                                                  NUM_CLASSES)
            acc_b = sum([conf_mat_b[i, i] for i in range(conf_mat_b.shape[0])
                         ]) / conf_mat_b.sum()
            precision_b = np.array([
                conf_mat_b[i, i] / (conf_mat_b[i].sum() + 1e-10)
                for i in range(conf_mat_b.shape[0])
            ])
            recall_b = np.array([
                conf_mat_b[i, i] / (conf_mat_b[:, i].sum() + 1e-10)
                for i in range(conf_mat_b.shape[0])
            ])
            mAP_b = sum(precision_b) / len(precision_b)
            F1_score_b = (2 * precision_b * recall_b /
                          (precision_b + recall_b + 1e-10)).mean()

            acc = lam * acc_a + (1 - lam) * acc_b
            mAP = lam * mAP_a + (1 - lam) * mAP_b
            F1_score = lam * F1_score_a + (1 - lam) * F1_score_b

        else:
            conf_mat += losses.confusion_matrix(outputs, targets, NUM_CLASSES)
            acc = sum([conf_mat[i, i]
                       for i in range(conf_mat.shape[0])]) / conf_mat.sum()
            precision = [
                conf_mat[i, i] / (conf_mat[i].sum() + 1e-10)
                for i in range(conf_mat.shape[0])
            ]
            mAP = sum(precision) / len(precision)

            recall = [
                conf_mat[i, i] / (conf_mat[:, i].sum() + 1e-10)
                for i in range(conf_mat.shape[0])
            ]
            precision = np.array(precision)
            recall = np.array(recall)
            f1 = 2 * precision * recall / (precision + recall + 1e-10)
            F1_score = f1.mean()

        #utils.progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% | mAP: %.3f%% | F1: %.3f%%'
        #% (train_loss/(batch_idx+1), 100.*acc, 100.* mAP, 100.* F1_score))

    return train_loss / (batch_idx + 1), 100. * acc, 100. * mAP, 100 * F1_score
示例#5
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    snet.train()
    if args.model == 'VID':
        VID_NET1.train()
        VID_NET2.train()
    elif args.model == 'OFD':
        OFD_NET1.train()
        OFD_NET2.train()
    elif args.model == 'AFD':
        AFD_NET1.train()
        AFD_NET2.train()
    else:
        pass
    train_loss = 0
    train_cls_loss = 0

    conf_mat = np.zeros((NUM_CLASSES, NUM_CLASSES))
    conf_mat_a = np.zeros((NUM_CLASSES, NUM_CLASSES))
    conf_mat_b = np.zeros((NUM_CLASSES, NUM_CLASSES))

    if epoch > learning_rate_decay_start and learning_rate_decay_start >= 0:
        frac = (epoch - learning_rate_decay_start) // learning_rate_decay_every
        decay_factor = learning_rate_decay_rate**frac
        current_lr = args.lr * decay_factor
        utils.set_lr(optimizer, current_lr)  # set the decayed rate
    else:
        current_lr = args.lr
    print('learning_rate: %s' % str(current_lr))

    for batch_idx, (img_teacher, img_student,
                    target) in enumerate(trainloader):

        if args.cuda:
            img_teacher = img_teacher.cuda(non_blocking=True)
            img_student = img_student.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        optimizer.zero_grad()

        if args.augmentation:
            img_teacher, teacher_target_a, teacher_target_b, teacher_lam = mixup_data(
                img_teacher, target, 0.6)
            img_teacher, teacher_target_a, teacher_target_b = map(
                Variable, (img_teacher, teacher_target_a, teacher_target_b))

            img_student, student_target_a, student_target_b, student_lam = mixup_data(
                img_student, target, 0.6)
            img_student, student_target_a, student_target_b = map(
                Variable, (img_student, student_target_a, student_target_b))
        else:
            img_teacher, img_student, target = Variable(img_teacher), Variable(
                img_student), Variable(target)

        rb1_s, rb2_s, rb3_s, mimic_s, out_s = snet(img_student)
        rb1_t, rb2_t, rb3_t, mimic_t, out_t = tnet(img_teacher)

        if args.augmentation:
            cls_loss = mixup_criterion(Cls_crit, out_s, student_target_a,
                                       student_target_b, student_lam)
        else:
            cls_loss = Cls_crit(out_s, target)

        kd_loss = KD_T_crit(out_t, out_s)

        if args.model == 'Fitnet':
            #FITNETS: Hints for Thin Deep Nets
            if args.stage == 'Block1':
                Fitnet1_loss = other.Fitnet(rb1_t, rb1_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * Fitnet1_loss
            elif args.stage == 'Block2':
                Fitnet2_loss = other.Fitnet(rb2_t, rb2_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * Fitnet2_loss
            else:
                Fitnet1_loss = other.Fitnet(rb1_t, rb1_s).cuda()
                Fitnet2_loss = other.Fitnet(rb2_t, rb2_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * Fitnet1_loss + args.delta * Fitnet2_loss

        elif args.model == 'AT':  # An activation-based attention transfer with the sum of absolute values raised to the power of 2.
            #Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer
            if args.stage == 'Block1':
                AT1_loss = other.AT(rb1_t, rb1_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * AT1_loss
            elif args.stage == 'Block2':
                AT2_loss = other.AT(rb2_t, rb2_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * AT2_loss
            else:
                AT1_loss = other.AT(rb1_t, rb1_s).cuda()
                AT2_loss = other.AT(rb2_t, rb2_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * AT1_loss + args.delta * AT2_loss

        elif args.model == 'NST':  # NST (poly)
            #Like What You Like: Knowledge Distill via Neuron Selectivity Transfer
            if args.stage == 'Block1':
                NST1_loss = other.NST(rb1_t, rb1_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * NST1_loss
            elif args.stage == 'Block2':
                NST2_loss = other.NST(rb2_t, rb2_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * NST2_loss
            else:
                NST1_loss = other.NST(rb1_t, rb1_s).cuda()
                NST2_loss = other.NST(rb2_t, rb2_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * NST1_loss + args.delta * NST2_loss

        elif args.model == 'PKT':  # PKT
            #Learning Deep Representations with Probabilistic Knowledge Transfer
            if args.stage == 'Block1':
                PKT1_loss = other.PKT(rb1_t, rb1_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * PKT1_loss
            elif args.stage == 'Block2':
                PKT2_loss = other.PKT(rb2_t, rb2_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * PKT2_loss
            else:
                PKT1_loss = other.PKT(rb1_t, rb1_s).cuda()
                PKT2_loss = other.PKT(rb2_t, rb2_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * PKT1_loss + args.delta * PKT2_loss

        elif args.model == 'AB':  # AB
            #Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
            if args.stage == 'Block1':
                AB1_loss = other.AB(rb1_t, rb1_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * AB1_loss
            elif args.stage == 'Block2':
                AB2_loss = other.AB(rb2_t, rb2_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * AB2_loss
            else:
                AB1_loss = other.AB(rb1_t, rb1_s).cuda()
                AB2_loss = other.AB(rb2_t, rb2_s).cuda()
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * AB1_loss + args.delta * AB2_loss

        elif args.model == 'CCKD':  #
            #Correlation Congruence for Knowledge Distillation
            if args.stage == 'Block1':
                CCKD1_loss = other.CCKD().cuda()(rb1_t, rb1_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * CCKD1_loss
            elif args.stage == 'Block2':
                CCKD2_loss = other.CCKD().cuda()(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * CCKD2_loss
            else:
                CCKD1_loss = other.CCKD().cuda()(rb1_t, rb1_s)
                CCKD2_loss = other.CCKD().cuda()(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * CCKD1_loss + args.delta * CCKD2_loss

        elif args.model == 'RKD':  # RKD-DA
            #Relational Knowledge Disitllation
            if args.stage == 'Block1':
                RKD1_loss = other.RKD().cuda()(rb1_t, rb1_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * RKD1_loss
            elif args.stage == 'Block2':
                RKD2_loss = other.RKD().cuda()(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * RKD2_loss
            else:
                RKD1_loss = other.RKD().cuda()(rb1_t, rb1_s)
                RKD2_loss = other.RKD().cuda()(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * RKD1_loss + args.delta * RKD2_loss

        elif args.model == 'SP':  # SP
            #Similarity-Preserving Knowledge Distillation
            if args.stage == 'Block1':
                SP1_loss = other.SP().cuda()(rb1_t, rb1_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * SP1_loss
            elif args.stage == 'Block2':
                SP2_loss = other.SP().cuda()(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * SP2_loss
            else:
                SP1_loss = other.SP().cuda()(rb1_t, rb1_s)
                SP2_loss = other.SP().cuda()(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * SP1_loss + args.delta * SP2_loss

        elif args.model == 'VID':  # VID-I
            #Variational Information Distillation for Knowledge Transfer
            if args.stage == 'Block1':
                VID1_loss = VID_NET1(rb1_t, rb1_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * VID1_loss
            elif args.stage == 'Block2':
                VID2_loss = VID_NET2(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * VID2_loss
            else:
                VID1_loss = VID_NET1(rb1_t, rb1_s)
                VID2_loss = VID_NET2(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * VID1_loss + args.delta * VID2_loss

        elif args.model == 'OFD':  # OFD
            #A Comprehensive Overhaul of Feature Distillation
            if args.stage == 'Block1':
                OFD1_loss = OFD_NET1(rb1_t, rb1_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * OFD1_loss
            elif args.stage == 'Block2':
                OFD2_loss = OFD_NET2(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * OFD2_loss
            else:
                OFD1_loss = OFD_NET1.cuda()(rb1_t, rb1_s)
                OFD2_loss = OFD_NET2(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * OFD1_loss + args.delta * OFD2_loss

        elif args.model == 'AFDS':  #
            #Pay Attention to Features, Transfer Learn Faster CNNs
            if args.stage == 'Block1':
                AFD1_loss = AFD_NET1(rb1_t, rb1_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * AFD1_loss
            elif args.stage == 'Block2':
                AFD2_loss = AFD_NET2(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * AFD2_loss
            else:
                AFD1_loss = AFD_NET1(rb1_t, rb1_s)
                AFD2_loss = AFD_NET2(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * AFD1_loss + args.delta * AFD2_loss

        elif args.model == 'FT':  #
            #Paraphrasing Complex Network: Network Compression via Factor Transfer
            if args.stage == 'Block1':
                FT1_loss = other.FT().cuda()(rb1_t, rb1_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * FT1_loss
            elif args.stage == 'Block2':
                FT2_loss = other.FT().cuda()(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.delta * FT2_loss
            else:
                FT1_loss = other.FT().cuda()(rb1_t, rb1_s)
                FT2_loss = other.FT().cuda()(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * FT1_loss + args.delta * FT2_loss

        elif args.model == 'CD':  # CD+GKD+CE
            #Channel Distillation: Channel-Wise Attention for Knowledge Distillation
            if args.stage == 'Block1':
                kd_loss_v2 = other.KDLossv2(args.T).cuda()(out_t, out_s,
                                                           target)
                CD1_loss = other.CD().cuda()(rb1_t, rb1_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss_v2 + args.gamma * CD1_loss
            elif args.stage == 'Block2':
                kd_loss_v2 = other.KDLossv2(args.T).cuda()(out_t, out_s,
                                                           target)
                CD2_loss = other.CD().cuda()(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss_v2 + args.delta * CD2_loss
            else:
                kd_loss_v2 = other.KDLossv2(args.T).cuda()(out_t, out_s,
                                                           target)
                CD1_loss = other.CD().cuda()(rb1_t, rb1_s)
                CD2_loss = other.CD().cuda()(rb2_t, rb2_s)
                loss = args.alpha * cls_loss + args.beta * kd_loss_v2 + args.gamma * CD1_loss + args.delta * CD2_loss

        elif args.model == 'FAKD':  # DS+TS+SA
            #FAKD: Feature-Affinity Based Knowledge Distillation for Efficient Image Super-Resolution
            if args.stage == 'Block1':
                FAKD_DT_loss = other.FAKD_DT().cuda()(out_t, out_s, target,
                                                      NUM_CLASSES)
                FAKD_SA1_loss = other.FAKD_SA().cuda()(rb1_t, rb1_s)
                loss = args.alpha * FAKD_DT_loss + args.gamma * FAKD_SA1_loss  # No T
            elif args.stage == 'Block2':
                FAKD_DT_loss = other.FAKD_DT().cuda()(out_t, out_s, target,
                                                      NUM_CLASSES)
                FAKD_SA2_loss = other.FAKD_SA().cuda()(rb2_t, rb2_s)
                loss = args.alpha * FAKD_DT_loss + args.gamma * FAKD_SA2_loss
            else:
                FAKD_DT_loss = other.FAKD_DT().cuda()(out_t, out_s, target,
                                                      NUM_CLASSES)
                FAKD_SA1_loss = other.FAKD_SA().cuda()(rb1_t, rb1_s)
                FAKD_SA2_loss = other.FAKD_SA().cuda()(rb2_t, rb2_s)
                loss = args.alpha * FAKD_DT_loss + args.gamma * FAKD_SA1_loss + args.delta * FAKD_SA2_loss

        elif args.model == 'VKD':  #
            #Robust Re-Identification by Multiple Views Knowledge Distillation
            if args.stage == 'Block1':
                VKD_Similarity1_loss = other.VKD_SimilarityDistillationLoss(
                ).cuda()(rb1_t, rb1_s)
                VKD_OnlineTriplet1_loss = other.VKD_OnlineTripletLoss().cuda()(
                    rb1_s, target)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * VKD_Similarity1_loss \
                                             + args.delta * VKD_OnlineTriplet1_loss
            elif args.stage == 'Block2':
                VKD_Similarity2_loss = other.VKD_SimilarityDistillationLoss(
                ).cuda()(rb2_t, rb2_s)
                VKD_OnlineTriplet2_loss = other.VKD_OnlineTripletLoss().cuda()(
                    rb2_s, target)
                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * VKD_Similarity2_loss \
                                             + args.delta * VKD_OnlineTriplet2_loss
            else:
                VKD_Similarity1_loss = other.VKD_SimilarityDistillationLoss(
                ).cuda()(rb1_t, rb1_s)
                VKD_OnlineTriplet1_loss = other.VKD_OnlineTripletLoss().cuda()(
                    rb1_s, target)

                VKD_Similarity2_loss = other.VKD_SimilarityDistillationLoss(
                ).cuda()(rb2_t, rb2_s)
                VKD_OnlineTriplet2_loss = other.VKD_OnlineTripletLoss().cuda()(
                    rb2_s, target)

                loss = args.alpha * cls_loss + args.beta * kd_loss + args.gamma * VKD_Similarity1_loss \
                           + args.delta * VKD_OnlineTriplet1_loss  + args.gamma * VKD_Similarity2_loss \
                                             + args.delta * VKD_OnlineTriplet2_loss

        elif args.model == 'RAD':  # RAD:  Resolution-Adapted Distillation
            # Efficient Low-Resolution Face Recognition via Bridge Distillation
            distance = mimic_t - mimic_s
            RAD_loss = torch.pow(distance, 2).sum(dim=(0, 1), keepdim=False)
            loss = RAD_loss + cls_loss
        else:
            raise Exception('Invalid model name...')

        loss.backward()
        utils.clip_gradient(optimizer, 0.1)
        optimizer.step()
        train_loss += loss.item()
        train_cls_loss += cls_loss.item()

        if args.augmentation:
            conf_mat_a += losses.confusion_matrix(out_s, student_target_a,
                                                  NUM_CLASSES)
            acc_a = sum([conf_mat_a[i, i] for i in range(conf_mat_a.shape[0])
                         ]) / conf_mat_a.sum()
            precision_a = np.array([
                conf_mat_a[i, i] / (conf_mat_a[i].sum() + 1e-10)
                for i in range(conf_mat_a.shape[0])
            ])
            recall_a = np.array([
                conf_mat_a[i, i] / (conf_mat_a[:, i].sum() + 1e-10)
                for i in range(conf_mat_a.shape[0])
            ])
            mAP_a = sum(precision_a) / len(precision_a)
            F1_score_a = (2 * precision_a * recall_a /
                          (precision_a + recall_a + 1e-10)).mean()

            conf_mat_b += losses.confusion_matrix(out_s, student_target_b,
                                                  NUM_CLASSES)
            acc_b = sum([conf_mat_b[i, i] for i in range(conf_mat_b.shape[0])
                         ]) / conf_mat_b.sum()
            precision_b = np.array([
                conf_mat_b[i, i] / (conf_mat_b[i].sum() + 1e-10)
                for i in range(conf_mat_b.shape[0])
            ])
            recall_b = np.array([
                conf_mat_b[i, i] / (conf_mat_b[:, i].sum() + 1e-10)
                for i in range(conf_mat_b.shape[0])
            ])
            mAP_b = sum(precision_b) / len(precision_b)
            F1_score_b = (2 * precision_b * recall_b /
                          (precision_b + recall_b + 1e-10)).mean()

            acc = student_lam * acc_a + (1 - student_lam) * acc_b
            mAP = student_lam * mAP_a + (1 - student_lam) * mAP_b
            F1_score = student_lam * F1_score_a + (1 -
                                                   student_lam) * F1_score_b

        else:
            conf_mat += losses.confusion_matrix(out_s, target, NUM_CLASSES)
            acc = sum([conf_mat[i, i]
                       for i in range(conf_mat.shape[0])]) / conf_mat.sum()
            precision = [
                conf_mat[i, i] / (conf_mat[i].sum() + 1e-10)
                for i in range(conf_mat.shape[0])
            ]
            mAP = sum(precision) / len(precision)

            recall = [
                conf_mat[i, i] / (conf_mat[:, i].sum() + 1e-10)
                for i in range(conf_mat.shape[0])
            ]
            precision = np.array(precision)
            recall = np.array(recall)
            f1 = 2 * precision * recall / (precision + recall + 1e-10)
            F1_score = f1.mean()

        #utils.progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% | mAP: %.3f%% | F1: %.3f%%'
        #% (train_loss/(batch_idx+1), 100.*acc, 100.* mAP, 100.* F1_score))

    return train_cls_loss / (batch_idx +
                             1), 100. * acc, 100. * mAP, 100 * F1_score