Ejemplo n.º 1
0
def validate(val_loader, model, criterion):
    batch_time = AverageMeter(0)
    losses = AverageMeter(0)
    top1 = AverageMeter(0)
    top5 = AverageMeter(0)
    # switch to evaluate mode
    model.train(mode=False)

    logger = logging.getLogger('global_logger')
    end = time.time()
    eval_target = []
    eval_output = []
    eval_uk = []
    for i, (input, target) in enumerate(val_loader):
        input = Variable(input, volatile=True).cuda()
        target =Variable(target, volatile=True).cuda()

        # compute output
        output, output1 = model(input)

        known_ind = target!=args.num_classes
        # measure accuracy and record loss
        softmax_output = F.softmax(output, dim=1)
        #loss for known class
        loss = criterion(output[known_ind], target[known_ind])

        #losses.update(loss.item())
        eval_target.append(target.cpu().data.numpy())
        eval_output.append(softmax_output.cpu().data.numpy())
        eval_uk.append(output1.cpu().data.numpy())

        prec1, prec5 = accuracy_2(softmax_output.data,
                                  output1.data,
                                  target, 0.5, topk=(1, 5))

        top1.update(prec1.item())
        top5.update(prec5.item())
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            logger.info('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1, top5=top5))
    eval_target = np.concatenate(eval_target, axis=0)
    eval_output = np.concatenate(eval_output, axis=0)
    eval_uk = np.concatenate(eval_uk, axis=0)
    evaluator = utils.PredictionEvaluator_2(eval_target, args.num_classes)
    mean_aug_class_acc, aug_cls_acc = evaluator.evaluate(eval_output, eval_uk, 0.5)
    best_acc = 0
    for i in range(10):
        t_clss_acc, t_aug_cls_acc = evaluator.evaluate(eval_output, eval_uk, i*0.1)
        best_acc = max(best_acc, t_clss_acc)
        print("epslion {:.2f}, mean_aug_class_acc {}, aug_cls_acc {}".format(i*0.1, t_clss_acc, t_aug_cls_acc))

    logger.info(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))
    model.train(mode=True)


    return losses.avg, top1.avg, top5.avg, mean_aug_class_acc, aug_cls_acc, best_acc*100
Ejemplo n.º 2
0
def train(train_source_loader, train_target_loader, val_loader,
          student_model, criterion, student_optimizer,
           lr_scheduler, start_iter, tb_logger,
          teacher_model=None,
          teacher_optimizer=None):

    global best_prec1

    batch_time = AverageMeter(10)
    data_time = AverageMeter(10)
    losses = AverageMeter(10)
    top1 = AverageMeter(10)
    top5 = AverageMeter(10)
    losses_bal = AverageMeter(10)
    confs_mask_count = AverageMeter(10)
    losses_aug = AverageMeter(10)
    losses_entropy = AverageMeter(10)
    losses_cls_uk = AverageMeter(10)
    losses_aug_uk = AverageMeter(10)
    losses_bal_uk = AverageMeter(10)
    student_model.train()
    # switch to train mode


    logger = logging.getLogger('global_logger')
    criterion_bce = nn.BCELoss()
    criterion_uk =  nn.BCEWithLogitsLoss()
    end = time.time()
    eval_output = []
    eval_target = []
    eval_uk = []
    for i, (batch_source, batch_target) in enumerate(zip(train_source_loader, train_target_loader)):
        input_source, label_source = batch_source

        input_target, input_target1, label_target = batch_target
        curr_step = start_iter + i

        lr_scheduler.step(curr_step)
        current_lr = lr_scheduler.get_lr()[0]
        # measure data loading time
        data_time.update(time.time() - end)


        label_source = Variable(label_source).cuda(async=True)
        input_source = Variable(input_source).cuda()

        input_target = Variable(input_target).cuda(async=True)
        input_target1 = Variable(input_target1).cuda(async=True)
        # compute output for source data
        source_output, source_output2 = student_model(input_source)

        # measure accuracy and record loss
        softmax_source_output = F.softmax(source_output, dim=1)

        #loss for known class
        known_ind = label_source != args.num_classes
        if args.double_softmax:
            loss_cls = criterion(softmax_source_output[known_ind], label_source[known_ind])
        else:
            loss_cls = criterion(source_output[known_ind], label_source[known_ind])


        loss = loss_cls
        uk_label = label_source.clone()
        uk_label[uk_label!=args.num_classes]=0
        uk_label[uk_label==args.num_classes]=1
        uk_label = uk_label.float().unsqueeze(1)
        loss_uk = criterion_uk(source_output2, uk_label)


        loss_entropy = torch.mean(
            torch.mul(softmax_source_output[label_source==args.num_classes],
                    torch.log(softmax_source_output[label_source==args.num_classes])))
        loss += args.lambda_uk * loss_uk
        loss += args.lambda_entropy * loss_entropy
        #loss for unknown class
        #integrate loss_cls and loss_entropy
        #compute accuracy




        # for target data

        stu_out, stu_out2 = student_model(input_target)
        tea_out, tea_out2 = teacher_model(input_target1)

        loss_aug, conf_mask, loss_cls_bal = \
            utils.new_compute_aug_loss_enp(stu_out, tea_out, args.aug_thresh,
                                           tea_out2,
                                           args.cls_balance, args)
        conf_mask_count = torch.sum(conf_mask) / args.batch_size
        loss_aug = torch.mean(loss_aug)
        loss += args.lambda_aug * loss_aug
        loss += args.cls_balance * args.lambda_aug * loss_cls_bal

        student_optimizer.zero_grad()
        loss.backward()
        student_optimizer.step()
        teacher_optimizer.step()


        eval_output.append(softmax_source_output.cpu().data.numpy())
        eval_target.append(label_source.cpu().data.numpy())
        eval_uk.append(source_output2.cpu().data.numpy())
        prec1, prec5 = accuracy_2(softmax_source_output.data, source_output2.data, label_source, 0.5, topk=(1, 5))




        losses.update(loss_cls.item())
        top1.update(prec1.item())
        top5.update(prec5.item())
        # compute gradient and do SGD step
        losses_cls_uk.update(loss_uk.item())
        losses_entropy.update(loss_entropy.item())
        losses_aug.update(loss_aug.item())
        losses_bal.update(loss_cls.item())


        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()


        # measure elapsed time

        if curr_step % args.print_freq == 0 :
            tb_logger.add_scalar('loss_train', losses.avg, curr_step)
            tb_logger.add_scalar('acc1_train', top1.avg, curr_step)
            tb_logger.add_scalar('acc5_train', top5.avg, curr_step)
            tb_logger.add_scalar('lr', current_lr, curr_step)
            print(args.exp_name)
            logger.info('Iter: [{0}/{1}]\t'
                        'Time: {batch_time.avg:.3f}\t'
                        'Data: {data_time.avg:.3f}\t'
                        'loss: {loss.avg:.4f}\t'
                        'loss_uk: {loss_uk.avg:.4f}\t'
                        'loss_aug: {loss_aug.avg:.4f}\t'
                        'loss_bal: {loss_bal.avg:.4f}\t'
                        'loss_entropy: {loss_entropy.avg:.4f}\t'
                        'Prec@1: {top1.avg:.3f}\t'
                        'Prec@5: {top5.avg:.3f}\t'
                        'lr: {lr:.6f}'.format(
                   curr_step, len(train_source_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses,
                   loss_uk=losses_cls_uk,
                   top1=top1, top5=top5,
                   loss_aug=losses_aug,
                   loss_bal=losses_bal,
                   loss_entropy=losses_entropy,
                   lr=current_lr))

        if (curr_step+1)%args.val_freq == 0 :

            eval_target = np.concatenate(eval_target, axis=0)
            eval_output = np.concatenate(eval_output, axis=0)
            eval_uk = np.concatenate(eval_uk, axis=0)
            evaluator = utils.PredictionEvaluator_2(eval_target, args.num_classes)
            mean_aug_class_acc, aug_cls_acc = evaluator.evaluate(eval_output, eval_uk, 0.5)
            eval_target = []
            eval_output = []
            eval_uk = []
            logger.info("mean_cls_acc: {} cls {}".format(mean_aug_class_acc, aug_cls_acc))

            for cls in range(args.num_classes):
                tb_logger.add_scalar('acc_'+args.class_name[cls], aug_cls_acc[cls], curr_step)


            val_loss, prec1, prec5, mean_aug_class_acc, aug_cls_acc, best_acc  = validate(val_loader, teacher_model, criterion)
            if not tb_logger is None:
                tb_logger.add_scalar('loss_val', val_loss, curr_step)
                tb_logger.add_scalar('acc1_val', prec1, curr_step)
                tb_logger.add_scalar('acc5_val', prec5, curr_step)
                tb_logger.add_scalar('best_acc_val', best_acc, curr_step)
            logger.info("evaluate step mean:{} class:{} best_acc {}".format(mean_aug_class_acc, aug_cls_acc, best_acc))


            # remember best prec@1 and save checkpoint
            is_best = best_acc > best_prec1
            best_prec1 = max(best_acc, best_prec1)
            logger.info("best val prec1 {}".format(best_prec1))
            save_checkpoint({
                'step': curr_step,
                'arch': args.arch,
                'state_dict': teacher_model.state_dict(),
                'best_prec1': best_prec1,
                'student_optimizer' : student_optimizer.state_dict(),
            }, is_best, args.save_path+'/ckpt' )