コード例 #1
0
ファイル: mixpul_cifar10.py プロジェクト: Stomach-ache/MixPUL
def validate(eval_loader, model, global_step, epoch, ema=False, testing=False):
    class_criterion = nn.CrossEntropyLoss(reduction='sum',
                                          ignore_index=NO_LABEL).to(device)
    import utils
    import time
    meters = utils.AverageMeterSet()

    # switch to evaluate mode
    model.eval()

    output = []

    end = time.time()
    for i, (input, target) in enumerate(eval_loader):
        meters.update('data_time', time.time() - end)

        with torch.no_grad():
            input_var = torch.autograd.Variable(input.to(device))
        with torch.no_grad():
            target_var = torch.autograd.Variable(target.to(device))

        minibatch_size = len(target_var)
        #labeled_minibatch_size = target_var.data.ne(NO_LABEL).sum().type(torch.cuda.FloatTensor)
        #assert labeled_minibatch_size > 0
        #meters.update('labeled_minibatch_size', labeled_minibatch_size)

        # compute output
        output1 = model(input_var.float())
        #print ("output1",output1)
        class_loss = class_criterion(output1,
                                     target_var.long()) / minibatch_size

        output = output + list(output1.data)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(output1.data, target_var.data, topk=(1, 2))
        meters.update('class_loss', class_loss.item(), minibatch_size)
        meters.update('top1', prec1[0], minibatch_size)
        meters.update('error1', 100.0 - prec1[0], minibatch_size)
        meters.update('top5', prec5[0], minibatch_size)
        meters.update('error5', 100.0 - prec5[0], minibatch_size)

        # measure elapsed time
        meters.update('batch_time', time.time() - end)
        end = time.time()

    print(' * Prec@1 {top1.avg:.3f}\tPrec@5 {top5.avg:.3f}\n'.format(
        top1=meters['top1'], top5=meters['top5']))

    val_class_loss_list.append(meters['class_loss'].avg)
    val_error_list.append(float(meters['error1'].avg))
    val_pre_list.append(meters['top1'].avg)
    return output, meters['top1'].avg
コード例 #2
0
ファイル: mixpul_cifar10.py プロジェクト: Stomach-ache/MixPUL
def train(trainloader, unlabelledloader, model, ema_model, optimizer, epoch):

    global global_step
    if args.consistency_type == 'mse':
        consistency_criterion = losses.softmax_mse_loss
    elif args.consistency_type == 'kl':
        consistency_criterion = losses.softmax_kl_loss
    else:
        assert False, args.consistency_type

    # switch to train mode
    model.train()
    ema_model.train()

    import utils
    import time
    meters = utils.AverageMeterSet()
    #class_criterion=nn.MSELoss().cuda()
    class_criterion = nn.CrossEntropyLoss().to(device)
    i = -1

    for (input, target), (u, _) in zip(cycle(trainloader), unlabelledloader):

        #print ("target",target[0:6])
        i = i + 1
        if input.shape[0] != u.shape[0]:
            bt_size = np.minimum(input.shape[0], u.shape[0])
            input = input[0:bt_size]
            target = target[0:bt_size]
            u = u[0:bt_size]

        if args.mixup_sup_alpha:
            if use_cuda:
                input, target, u = input.to(device), target.to(device), u.to(
                    device)
            input_var, target_var, u_var = Variable(input), Variable(
                target), Variable(u)

            #IF False:
            if args.mixup_hidden:
                ### model
                output_mixed_l, target_a_var, target_b_var, lam = model(
                    input_var,
                    target_var,
                    mixup_hidden=True,
                    mixup_alpha=args.mixup_sup_alpha,
                    layers_mix=args.num_mix_layer)
                lam = lam[0]
            else:
                mixed_input, target_a, target_b, lam = mixup_data_sup(
                    input, target, args.mixup_sup_alpha)
                # if use_cuda:
                #    mixed_input, target_a, target_b  = mixed_input.cuda(), target_a.cuda(), target_b.cuda()
                mixed_input_var, target_a_var, target_b_var = Variable(
                    mixed_input), Variable(target_a), Variable(target_b)
                ### model
                output_mixed_l = model(mixed_input_var.float())

            loss_func = mixup_criterion(target_a_var, target_b_var, lam)
            class_loss = loss_func(class_criterion, output_mixed_l)
            output = output_mixed_l

        else:
            input_var = torch.autograd.Variable(input.to(device))
            with torch.no_grad():
                u_var = torch.autograd.Variable(u.to(device))
            target_var = torch.autograd.Variable(target.to(device))
            ### model
            output = model(input_var.float())

            # sharpening
            #output = output**2 / sum([x**2 for x in output])

            #print ("output",output[0:6])
            #print ("target",target[0:6])
            class_loss = class_criterion(output,
                                         target_var.long()) / len(output)

        #print("class_loss",class_loss)
        meters.update('class_loss', class_loss.item())

        ### get ema loss. We use the actual samples(not the mixed up samples ) for calculating EMA loss
        minibatch_size = len(target_var)
        if args.pseudo_label == 'single':
            ema_logit_unlabeled = model(u_var.float())
            ema_logit_labeled = model(input_var.float())
        else:
            ema_logit_unlabeled = ema_model(u_var.float())
            ema_logit_labeled = ema_model(input_var.float())
        if args.mixup_sup_alpha:
            class_logit = model(input_var.float())
        else:
            class_logit = output
        cons_logit = model(u_var.float())
        #print ("cons_logit",cons_logit)

        ema_logit_unlabeled = Variable(ema_logit_unlabeled.detach().data,
                                       requires_grad=False)

        # class_loss = class_criterion(class_logit, target_var) / minibatch_size

        ema_class_loss = class_criterion(ema_logit_labeled,
                                         target_var.long())  # / minibatch_size
        meters.update('ema_class_loss', ema_class_loss.item())
        #print ("ema_class_loss",ema_class_loss)

        ### get the unsupervised mixup loss###
        if args.mixup_consistency:
            if args.mixup_hidden:
                # output_u = model(u_var)
                output_mixed_u, target_a_var, target_b_var, lam = model(
                    u_var.float(),
                    ema_logit_unlabeled,
                    mixup_hidden=True,
                    mixup_alpha=args.mixup_sup_alpha,
                    layers_mix=args.num_mix_layer)
                # ema_logit_unlabeled
                lam = lam[0]
                mixedup_target = lam * target_a_var + (1 - lam) * target_b_var
            else:
                # output_u = model(u_var)
                mixedup_x, mixedup_target, lam = mixup_data(
                    u_var, ema_logit_unlabeled, args.mixup_usup_alpha)
                # mixedup_x, mixedup_target, lam = mixup_data(u_var, output_u, args.mixup_usup_alpha)
                output_mixed_u = model(mixedup_x.float())
            mixup_consistency_loss = consistency_criterion(
                output_mixed_u, mixedup_target
            ) / minibatch_size  # criterion_u(F.log_softmax(output_mixed_u,1), F.softmax(mixedup_target,1))
            meters.update('mixup_cons_loss', mixup_consistency_loss.item())
            if epoch < args.consistency_rampup_starts:
                mixup_consistency_weight = 0.0
            else:
                mixup_consistency_weight = get_current_consistency_weight(
                    args.mixup_consistency, epoch, i, len(unlabelledloader))
            meters.update('mixup_cons_weight', mixup_consistency_weight)
            mixup_consistency_loss = mixup_consistency_weight * mixup_consistency_loss
        else:
            mixup_consistency_loss = 0
            meters.update('mixup_cons_loss', 0)

        # labeled_minibatch_size = target_var.data.ne(NO_LABEL).sum().type(torch.cuda.FloatTensor)
        # assert labeled_minibatch_size > 0

        #p_score, _ = output.topk(1, 1, True, True)
        #u_score, _ = cons_logit.topk(1, 1, True, True)
        p_score = output
        u_score = cons_logit
        gamma = 0.2
        pairwise_ranking_loss = max(
            0,
            u_score.view(-1).mean() - p_score.view(-1).mean() - gamma)

        #loss = mixup_consistency_loss
        #print(class_loss)
        #loss = pairwise_ranking_loss - 1 * mixup_consistency_loss
        #loss = 0 * pairwise_ranking_loss + 1 * mixup_consistency_loss
        loss = class_loss + 1 * mixup_consistency_loss
        #loss = class_loss + 1 * mixup_consistency_loss + 0 * pairwise_ranking_loss
        #print ('pairwise ranking loss: ', pairwise_ranking_loss)

        #print (class_loss)
        #print (mixup_consistency_loss)
        #print ("loss",loss)

        meters.update('loss', loss.item())

        prec1, prec5 = accuracy(class_logit.data, target_var.data, topk=(1, 2))

        #print ("prec1",prec1[0])
        #print ("prec5",prec5[0])
        meters.update('top1', prec1[0], minibatch_size)
        meters.update('error1', 100. - prec1[0], minibatch_size)
        meters.update('top5', prec5[0], minibatch_size)
        meters.update('error5', 100. - prec5[0], minibatch_size)

        ema_prec1, ema_prec5 = accuracy(ema_logit_labeled.data,
                                        target_var.data,
                                        topk=(1, 2))
        meters.update('ema_top1', ema_prec1[0], minibatch_size)
        meters.update('ema_error1', 100. - ema_prec1[0], minibatch_size)
        meters.update('ema_top5', ema_prec5[0], minibatch_size)
        meters.update('ema_error5', 100. - ema_prec5[0], minibatch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1
        update_ema_variables(model, ema_model, args.ema_decay, global_step)

        # measure elapsed time
        #meters.update('batch_time', time.time() - end)
        #end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Class {meters[class_loss]:.4f}\t'
                  'Mixup Cons {meters[mixup_cons_loss]:.4f}\t'
                  'Prec@1 {meters[top1]:.3f}\t'
                  'Prec@5 {meters[top5]:.3f}'.format(epoch,
                                                     i,
                                                     len(unlabelledloader),
                                                     meters=meters))
            # print ('lr:',optimizer.param_groups[0]['lr'])
    train_class_loss_list.append(meters['class_loss'].avg)
    train_error_list.append(float(meters['error1'].avg))
    train_pre_list.append(meters['top1'].avg)
コード例 #3
0
ファイル: mixpul_cifar10.py プロジェクト: Stomach-ache/MixPUL
def pre_train(trainloader, model, optimizer, epoch):
    import utils

    # switch to train mode
    model.train()

    meters = utils.AverageMeterSet()
    class_criterion = nn.CrossEntropyLoss().to(device)
    i = -1
    for (input, target) in trainloader:

        #if False:
        if args.mixup_sup_alpha:
            if use_cuda:
                input, target = input.to(device), target.to(device)
            input_var, target_var = Variable(input), Variable(target)

            if args.mixup_hidden:
                ### model
                output_mixed_l, target_a_var, target_b_var, lam = model(
                    input_var,
                    target_var,
                    mixup_hidden=True,
                    mixup_alpha=args.mixup_sup_alpha,
                    layers_mix=args.num_mix_layer)
                lam = lam[0]
            else:
                mixed_input, target_a, target_b, lam = mixup_data_sup(
                    input, target, args.mixup_sup_alpha)
                # if use_cuda:
                #    mixed_input, target_a, target_b  = mixed_input.cuda(), target_a.cuda(), target_b.cuda()
                mixed_input_var, target_a_var, target_b_var = Variable(
                    mixed_input), Variable(target_a), Variable(target_b)
                ### model
                output_mixed_l = model(mixed_input_var)

            loss_func = mixup_criterion(target_a_var, target_b_var, lam)
            class_loss = loss_func(class_criterion, output_mixed_l)

        else:
            input_var = torch.autograd.Variable(input.to(device))

            target_var = torch.autograd.Variable(target.to(device))

            #    print (input_var.shape)
            #    print (type(input_var))
            output = model(input_var.float())

            class_loss = class_criterion(output, target_var.long())

        #print("class_loss",class_loss)
        meters.update('class_loss', class_loss.item())
        loss = class_loss
        #print ("loss",loss)

        ### get ema loss. We use the actual samples(not the mixed up samples ) for calculating EMA loss
        minibatch_size = len(target_var)

        if args.mixup_sup_alpha:
            class_logit = model(input_var)
        else:
            class_logit = output

        meters.update('loss', loss.item())

        prec1, prec5 = accuracy(class_logit.data, target_var.data, topk=(1, 2))

        meters.update('top1', prec1[0], minibatch_size)
        meters.update('error1', 100.0 - prec1[0], minibatch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_class_loss_list.append(meters['class_loss'].avg)
    train_error_list.append(float(meters['error1'].avg))
    train_pre_list.append(meters['top1'].avg)