Пример #1
0
def train_student(net, teacher, optimizer, criterion, scheduler):
    net.train()
    pbar = tqdm(train_loader)
    for images, labels in pbar:
        images = Variable(images.to(device, dtype=torch.float32))
        labels = Variable(labels.to(device, dtype=torch.long))
        
        outputs_student, ints_student = net(images)
        outputs_teacher, ints_teacher = teacher(images)
        
        if opts.loss_type == 'both' or opts.loss_type == 'kd':
            loss = utils.distillation(outputs_student, outputs_teacher, labels, opts.temperature, opts.alpha)
        else:
            loss = criterion(outputs_student, labels)
        
        if opts.loss_type == 'both' or opts.loss_type == 'at' and opts.at_type != 'none':
         
             
            adjusted_beta = (opts.beta*3)/len(ints_student)    
            for i in range(len(ints_student)): 
                if ints_teacher[i].shape[2] != ints_student[i].shape[2]:
                    ints_teacher[i] = F.interpolate(ints_teacher[i], size=ints_student[i].shape[2:], mode='bilinear', align_corners=False)
                loss += adjusted_beta * utils.at_loss(ints_student[i], ints_teacher[i])
        
        preds = outputs_student.detach().max(dim=1)[1].cpu().numpy()
        targets = labels.cpu().numpy()
        metrics.update(targets, preds)
        score = metrics.get_results()
        pbar.set_postfix({"IoU": score["Mean IoU"]})
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
Пример #2
0
 def h(sample):
     inputs = utils.cast(sample[0], opt.dtype).detach()
     targets = utils.cast(sample[1], 'long')
     if opt.teacher_id != '':
         #loss_groups是什么?
         print('f = ', f)
         print('tensor inputs = ', inputs.shape)
         print('dict params = ', params.keys())
         print('sample = ', sample[2])
         print('opt.ngpu = ', range(opt.ngpu))
         y_s, y_t, loss_groups = utils.data_parallel(
             f, inputs, params, sample[2], range(opt.ngpu))
         print('y_s = ', y_s.shape)
         print('y_t = ', y_t.shape)
         print('loss_groups = ', loss_groups)
         ipdb.set_trace()
         loss_groups = [v.sum() for v in loss_groups]
         #计算meters_at,即at_losses注意力loss
         [m.add(v.item()) for m, v in zip(meters_at, loss_groups)]
         return utils.distillation(
             y_s, y_t, targets, opt.temperature,
             opt.alpha) + opt.beta * sum(loss_groups), y_s
     else:
         y = utils.data_parallel(f, inputs, params, sample[2],
                                 range(opt.ngpu))[0]
         return F.cross_entropy(y, targets), y
Пример #3
0
 def h(sample):
     #input 是输入样本
     #target是标签
     inputs = utils.cast(sample[0], opt.dtype).detach()
     targets = utils.cast(sample[1], 'long')
     #如果模型是学生模型
     #用给出的损失函数训练
     if opt.teacher_id != '':
         y_s, y_t, loss_groups = utils.data_parallel(
             f, inputs, params, sample[2], range(opt.ngpu))
         #取出总的loss
         loss_groups = [v.sum() for v in loss_groups]
         #总的损失?
         [m.add(v.item()) for m, v in zip(meters_at, loss_groups)]
         #第一部分是蒸馏#y_s:学生网络的输出#y_t:教师网络的输出#target:真实标签
         #第二部分是AD损失函数部分
         #第三部分是学生网络的输出
         #当是AT算法时,alpha等于0,第一部分。就剩的是学生网络和真实标签的交叉熵
         #当为KD算法时,beta等于0,就剩蒸馏损失函数,在这儿实现从1加到c
         return utils.distillation(y_s, y_t, targets, opt.temperature, opt.alpha) \
                 + opt.beta * sum(loss_groups), y_s
     #如果是教师网络
     #用标准交叉熵训练
     else:
         #y是网络的输出
         y = utils.data_parallel(f, inputs, params, sample[2],
                                 range(opt.ngpu))[0]
         return F.cross_entropy(y, targets), y
Пример #4
0
 def h(sample):
     inputs, targets, mode = sample
     inputs = inputs.cuda().detach()
     targets = targets.cuda().long().detach()
     y_s, y_t, loss_groups = utils.data_parallel(f, inputs, params, mode, range(opt.ngpu))
     loss_groups = [v.sum() for v in loss_groups]
     [m.add(v.item()) for m,v in zip(meters_at, loss_groups)]
     return utils.distillation(y_s, y_t, targets, opt.temperature, opt.alpha) \
             + opt.beta * sum(loss_groups), y_s
Пример #5
0
 def h(sample):
     inputs = Variable(sample[0].cuda())
     targets = Variable(sample[1].cuda().long())
     y_s, y_t, loss_groups = data_parallel(f, inputs, params, stats,
                                           sample[2], np.arange(opt.ngpu))
     loss_groups = [v.sum() for v in loss_groups]
     [m.add(v.data[0]) for m, v in zip(meters_at, loss_groups)]
     return distillation(y_s, y_t, targets, opt.temperature, opt.alpha) \
             + opt.beta * sum(loss_groups), y_s
Пример #6
0
 def h(sample):
     inputs = Variable(cast(sample[0], opt.dtype))
     targets = Variable(cast(sample[1], 'long'))
     if opt.teacher_id != '':
         if opt.gamma:
             ys, y_t_auto, y_t = data_parallel(f, inputs, params,
                                               stats, sample[2],
                                               np.arange(opt.ngpu))[:3]
             loss_l2 = torch.nn.MSELoss()
             T = 4
             loss_student = F.cross_entropy(ys, targets)
             loss_teacher = F.cross_entropy(y_t_auto, targets)
             loss_course = opt.beta * \
                 ((y_t_auto - ys) * (y_t_auto - ys)).sum() / opt.batchSize
             y_tech_temp = torch.autograd.Variable(y_t_auto.data,
                                                   requires_grad=False)
             log_kd = rocket_distillation(ys, y_t, targets, opt.temperature,
                                          opt.alpha)
             return rocket_distillation(ys, y_t, targets, opt.temperature, opt.alpha) \
                 + F.cross_entropy(y_t_auto, targets) + F.cross_entropy(ys, targets) + opt.beta * ((y_tech_temp - ys) * (
                     y_tech_temp - ys)).sum() / opt.batchSize, (ys, y_t_auto, loss_student, loss_teacher, loss_course, log_kd)
         else:
             y_s, y_t, loss_groups = data_parallel(f, inputs, params, stats,
                                                   sample[2],
                                                   np.arange(opt.ngpu))
             loss_groups = [v.sum() for v in loss_groups]
             [m.add(v.data[0]) for m, v in zip(meters_at, loss_groups)]
             return distillation(y_s, y_t, targets, opt.temperature, opt.alpha) \
                 + opt.beta * sum(loss_groups), y_s
     else:
         if opt.gamma:
             ys, y = data_parallel(f, inputs, params, stats, sample[2],
                                   np.arange(opt.ngpu))[:2]
             loss_l2 = torch.nn.MSELoss()
             T = 4
             loss_student = F.cross_entropy(ys, targets)
             loss_teacher = F.cross_entropy(y, targets)
             loss_course = opt.beta * \
                 ((y - ys) * (y - ys)).sum() / opt.batchSize
             if opt.grad_block:
                 y_course = torch.autograd.Variable(y.data,
                                                    requires_grad=False)
             else:
                 y_course = y
             return F.cross_entropy(y, targets) + F.cross_entropy(
                 ys, targets) + opt.beta * (
                     (y_course - ys) *
                     (y_course - ys)).sum() / opt.batchSize, (ys, y,
                                                              loss_student,
                                                              loss_teacher,
                                                              loss_course)
         else:
             y = data_parallel(f, inputs, params, stats, sample[2],
                               np.arange(opt.ngpu))[0]
             return F.cross_entropy(y, targets), y
Пример #7
0
 def h(sample):
     inputs = Variable(cast(sample[0], opt.dtype))
     targets = Variable(cast(sample[1], 'long'))
     if opt.teacher_id != '':
         y_s, y_t, loss_groups = data_parallel(f, inputs, params, stats, sample[2], np.arange(opt.ngpu))
         loss_groups = [v.sum() for v in loss_groups]
         [m.add(v.data[0]) for m,v in zip(meters_at, loss_groups)]
         return distillation(y_s, y_t, targets, opt.temperature, opt.alpha) \
                 + opt.beta * sum(loss_groups), y_s
     else:
         y = data_parallel(f, inputs, params, stats, sample[2], np.arange(opt.ngpu))[0]
         return F.cross_entropy(y, targets), y
Пример #8
0
 def h(sample):
     inputs, targets, mode = sample
     inputs = inputs.cuda().detach()
     targets = targets.cuda().long().detach()
     if opt.teacher_id != '':
         if opt.kt_method == "at":
             y_s, y_t, loss_groups = utils.data_parallel(f, inputs, params, mode, range(opt.ngpu))
             loss_groups = [v.sum() for v in loss_groups]
             [m.add(v.item()) for m,v in zip(meters_at, loss_groups)]
             return utils.distillation(y_s, y_t, targets, opt.temperature, opt.alpha) + opt.beta * sum(loss_groups), y_s
         elif opt.kt_method == "st":
             y_s, y_t, loss_groups = utils.data_parallel(f, inputs, params, sample[2], range(opt.ngpu))
             return torch.sqrt(torch.mean((y_s - y_t) ** 2)), y_s
     else:
         y = utils.data_parallel(f, inputs, params, mode, range(opt.ngpu))[0]
         return F.cross_entropy(y, targets), y
Пример #9
0
 def h(sample):
     inputs = utils.cast(sample[0], opt.dtype).detach()
     targets = utils.cast(sample[1], 'long')
     if opt.teacher_id != '':
         if opt.kt_method == "at":
             y_s, y_t, loss_groups = utils.data_parallel(
                 f, inputs, params, sample[2], range(opt.ngpu))
             loss_groups = [v.sum() for v in loss_groups]
             [m.add(v.item()) for m, v in zip(meters_at, loss_groups)]
             return utils.distillation(
                 y_s, y_t, targets, opt.temperature,
                 opt.alpha) + opt.beta * sum(loss_groups), y_s
         elif opt.kt_method == "st":
             y_s, y_t, loss_list = utils.data_parallel(
                 f, inputs, params, sample[2], range(opt.ngpu))
             loss_list = [v.sum() for v in loss_list]
             [m.add(v.item()) for m, v in zip(meters_st, loss_list)]
             fc_loss = torch.sqrt(torch.mean((y_s - y_t)**2))
             loss_list.append(fc_loss)
             return loss_list, y_s
     else:
         y = utils.data_parallel(f, inputs, params, sample[2],
                                 range(opt.ngpu))[0]
         return F.cross_entropy(y, targets), y