Exemplo n.º 1
0
def compute_precise_bn_stats(model, loader):
    """Computes precise BN stats on training data."""
    # Compute the number of minibatches to use
    num_iter = int(cfg.BN.NUM_SAMPLES_PRECISE / loader.batch_size /
                   cfg.NUM_GPUS)
    num_iter = min(num_iter, len(loader))
    # Retrieve the BN layers
    bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
    # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
    running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
    running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
    # Remember momentum values
    momentums = [bn.momentum for bn in bns]
    # Set momentum to 1.0 to compute BN stats that only reflect the current batch
    for bn in bns:
        bn.momentum = 1.0
    # Average the BN stats for each BN layer over the batches
    for inputs, _labels in itertools.islice(loader, num_iter):
        model(inputs.cuda())
        for i, bn in enumerate(bns):
            running_means[i] += bn.running_mean / num_iter
            running_vars[i] += bn.running_var / num_iter
    # Sync BN stats across GPUs (no reduction if 1 GPU used)
    running_means = dist.scaled_all_reduce(running_means)
    running_vars = dist.scaled_all_reduce(running_vars)
    # Set BN stats and restore original momentum values
    for i, bn in enumerate(bns):
        bn.running_mean = running_means[i]
        bn.running_var = running_vars[i]
        bn.momentum = momentums[i]
Exemplo n.º 2
0
def test_epoch(test_loader, model, test_meter, cur_epoch):
    """Evaluates the model on the test set."""

    # Enable eval mode
    model.eval()
    test_meter.iter_tic()

    for cur_iter, (inputs, labels) in enumerate(test_loader):
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Compute the predictions
        preds = model(inputs)
        # Compute the errors
        top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
        # Combine the errors across the GPUs
        if cfg.NUM_GPUS > 1:
            top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
        # Copy the errors from GPU to CPU (sync point)
        top1_err, top5_err = top1_err.item(), top5_err.item()
        test_meter.iter_toc()
        # Update and log stats
        test_meter.update_stats(top1_err, top5_err,
                                inputs.size(0) * cfg.NUM_GPUS)
        test_meter.log_iter_stats(cur_epoch, cur_iter)
        test_meter.iter_tic()

    # Log epoch stats
    test_meter.log_epoch_stats(cur_epoch)
    test_meter.reset()
Exemplo n.º 3
0
def test_epoch_semi(test_loader, model, test_meter, cur_epoch):
    """Evaluates the model on the test set."""
    # Enable eval mode
    model.eval()
    test_meter.iter_tic()
    total_ce_loss_1=0.0
    total_ce_loss_k=0.0
    total_samples=0
    for cur_iter, (inputs, labels) in enumerate(test_loader):
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Compute the predictions
        preds = model(inputs)

        # Compute normed CE error
        total_samples+=inputs.shape[0]
        probs = torch.softmax(preds, dim=1)
        _, lbs_u_guess = torch.max(probs, dim=1)
        criteria_u = nn.CrossEntropyLoss(reduction='none').cuda()

        normed_logits_1=F.normalize(preds,p=2,dim=1)
        loss_CE=(criteria_u(torch.softmax(normed_logits_1,dim=1),lbs_u_guess)).mean()
        total_ce_loss_1+=loss_CE.item()*inputs.shape[0]

        normed_logits_k=F.normalize(preds,p=2,dim=1)*10
        #print(torch.norm(normed_logits_k,dim=1))
        loss_CE=(criteria_u(torch.softmax(normed_logits_k,dim=1),lbs_u_guess)).mean()
        total_ce_loss_k+=loss_CE.item()*inputs.shape[0]

        # Compute the errors
        if cfg.TASK == "col":
            preds = preds.permute(0, 2, 3, 1)
            preds = preds.reshape(-1, preds.size(3))
            labels = labels.reshape(-1)
            mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
        else:
            mb_size = inputs.size(0) * cfg.NUM_GPUS
        if cfg.TASK == "seg":
            # top1_err is in fact inter; top5_err is in fact union
            top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
        else:
            ks = [1, min(5, cfg.MODEL.NUM_CLASSES)]  # rot only has 4 classes
            top1_err, top5_err = meters.topk_errors(preds, labels, ks)
        # Combine the errors across the GPUs  (no reduction if 1 GPU used)
        top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
        # Copy the errors from GPU to CPU (sync point)
        if cfg.TASK == "seg":
            top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
        else:
            top1_err, top5_err = top1_err.item(), top5_err.item()
        test_meter.iter_toc()
        # Update and log stats
        test_meter.update_stats(top1_err, top5_err, mb_size)
        test_meter.log_iter_stats(cur_epoch, cur_iter)
        test_meter.iter_tic()
    # Log epoch stats
    result=test_meter.get_epoch_stats(cur_epoch)
    test_meter.log_epoch_stats(cur_epoch)
    test_meter.reset()
    return result,[total_ce_loss_1/total_samples,total_ce_loss_k/total_samples]
Exemplo n.º 4
0
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter,
                cur_epoch):
    """Performs one epoch of training."""
    # Shuffle the data
    loader.shuffle(train_loader, cur_epoch)
    # Update the learning rate
    lr = optim.get_epoch_lr(cur_epoch)
    optim.set_lr(optimizer, lr)
    # Enable training mode
    model.train()
    train_meter.reset()
    train_meter.iter_tic()
    for cur_iter, (inputs, labels) in enumerate(train_loader):
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Perform the forward pass
        preds = model(inputs)
        # Compute the loss
        loss = loss_fun(preds, labels)
        # Perform the backward pass
        optimizer.zero_grad()
        loss.backward()
        # Update the parameters
        optimizer.step()
        # Compute the errors
        top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
        # Combine the stats across the GPUs (no reduction if 1 GPU used)
        loss, top1_err, top5_err = dist.scaled_all_reduce(
            [loss, top1_err, top5_err])
        # Copy the stats from GPU to CPU (sync point)
        loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item(
        )
        train_meter.iter_toc()
        # Update and log stats
        mb_size = inputs.size(0) * cfg.NUM_GPUS
        train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
        train_meter.log_iter_stats(cur_epoch, cur_iter)
        train_meter.iter_tic()
    # Log epoch stats
    train_meter.log_epoch_stats(cur_epoch)
    print(f'{cfg.OUT_DIR}')

    if not hasattr(cfg, 'search_epoch'):
        stats = train_meter.get_epoch_stats(cur_epoch)
        stats = {k: v for k, v in stats.items() if isinstance(v, (int, float))}
        summary_dict2txtfig(stats,
                            prefix='train',
                            step=cur_epoch,
                            textlogger=textlogger,
                            save_fig_sec=60)
Exemplo n.º 5
0
def train_epoch(loader, model, ema, loss_fun, optimizer, scaler, meter,
                cur_epoch):
    """Performs one epoch of training."""
    # Shuffle the data
    data_loader.shuffle(loader, cur_epoch)
    # Update the learning rate
    lr = optim.get_epoch_lr(cur_epoch)
    optim.set_lr(optimizer, lr)
    # Enable training mode
    model.train()
    ema.train()
    meter.reset()
    meter.iter_tic()
    for cur_iter, (inputs, labels) in enumerate(loader):
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Convert labels to smoothed one-hot vector
        labels_one_hot = net.smooth_one_hot_labels(labels)
        # Apply mixup to the batch (no effect if mixup alpha is 0)
        inputs, labels_one_hot, labels = net.mixup(inputs, labels_one_hot)
        # Perform the forward pass and compute the loss
        with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION):
            preds = model(inputs)
            loss = loss_fun(preds, labels_one_hot)
        # Perform the backward pass and update the parameters
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # Update ema weights
        net.update_model_ema(model, ema, cur_epoch, cur_iter)
        # Compute the errors
        top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
        # Combine the stats across the GPUs (no reduction if 1 GPU used)
        loss, top1_err, top5_err = dist.scaled_all_reduce(
            [loss, top1_err, top5_err])
        # Copy the stats from GPU to CPU (sync point)
        loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item(
        )
        meter.iter_toc()
        # Update and log stats
        mb_size = inputs.size(0) * cfg.NUM_GPUS
        meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
        meter.log_iter_stats(cur_epoch, cur_iter)
        meter.iter_tic()
    # Log epoch stats
    meter.log_epoch_stats(cur_epoch)
Exemplo n.º 6
0
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter,
                cur_epoch):
    """Performs one epoch of training."""

    # Shuffle the data
    loader.shuffle(train_loader, cur_epoch)
    # Update the learning rate
    lr = optim.get_epoch_lr(cur_epoch)
    optim.set_lr(optimizer, lr)
    # Enable training mode
    model.train()
    train_meter.iter_tic()

    for cur_iter, (inputs, labels) in enumerate(train_loader):
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Perform the forward pass
        preds = model(inputs)
        # Compute the loss
        loss = loss_fun(preds, labels)
        # Perform the backward pass
        optimizer.zero_grad()
        loss.backward()
        # Update the parameters
        optimizer.step()
        # Compute the errors
        top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
        # Combine the stats across the GPUs
        if cfg.NUM_GPUS > 1:
            loss, top1_err, top5_err = dist.scaled_all_reduce(
                [loss, top1_err, top5_err])
        # Copy the stats from GPU to CPU (sync point)
        loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item(
        )
        train_meter.iter_toc()
        # Update and log stats
        train_meter.update_stats(top1_err, top5_err, loss, lr,
                                 inputs.size(0) * cfg.NUM_GPUS)
        train_meter.log_iter_stats(cur_epoch, cur_iter)
        train_meter.iter_tic()

    # Log epoch stats
    train_meter.log_epoch_stats(cur_epoch)
    train_meter.reset()
Exemplo n.º 7
0
def test_epoch(test_loader, model, test_meter, cur_epoch):
    """Evaluates the model on the test set."""
    # Enable eval mode
    model.eval()
    test_meter.iter_tic()
    for cur_iter, (inputs, labels) in enumerate(test_loader):
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Compute the predictions
        preds = model(inputs)
        # Compute the errors
        if cfg.TASK == "col":
            preds = preds.permute(0, 2, 3, 1)
            preds = preds.reshape(-1, preds.size(3))
            labels = labels.reshape(-1)
            mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
        else:
            mb_size = inputs.size(0) * cfg.NUM_GPUS
        if cfg.TASK == "seg":
            # top1_err is in fact inter; top5_err is in fact union
            top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
        else:
            ks = [1, min(5, cfg.MODEL.NUM_CLASSES)]  # rot only has 4 classes
            top1_err, top5_err = meters.topk_errors(preds, labels, ks)
        # Combine the errors across the GPUs  (no reduction if 1 GPU used)
        top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
        # Copy the errors from GPU to CPU (sync point)
        if cfg.TASK == "seg":
            top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
        else:
            top1_err, top5_err = top1_err.item(), top5_err.item()
        test_meter.iter_toc()
        # Update and log stats
        test_meter.update_stats(top1_err, top5_err, mb_size)
        test_meter.log_iter_stats(cur_epoch, cur_iter)
        test_meter.iter_tic()
    # Log epoch stats
    result=test_meter.get_epoch_stats(cur_epoch)
    test_meter.log_epoch_stats(cur_epoch)
    test_meter.reset()
    return result
Exemplo n.º 8
0
def test_epoch(test_loader, model, test_meter, cur_epoch):
    """Evaluates the model on the test set."""
    # Enable eval mode
    model.eval()
    test_meter.reset()
    test_meter.iter_tic()
    for cur_iter, (inputs, labels) in enumerate(test_loader):
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Compute the predictions
        preds = model(inputs)
        # Compute the errors
        top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
        # Combine the errors across the GPUs  (no reduction if 1 GPU used)
        top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
        # Copy the errors from GPU to CPU (sync point)
        top1_err, top5_err = top1_err.item(), top5_err.item()
        test_meter.iter_toc()
        # Update and log stats
        test_meter.update_stats(top1_err, top5_err,
                                inputs.size(0) * cfg.NUM_GPUS)
        test_meter.log_iter_stats(cur_epoch, cur_iter)
        test_meter.iter_tic()
    # Log epoch stats
    test_meter.log_epoch_stats(cur_epoch)

    stats = test_meter.get_epoch_stats(cur_epoch)
    if not hasattr(cfg, 'search_epoch'):
        stats = {k: v for k, v in stats.items() if isinstance(v, (int, float))}
        summary_dict2txtfig(stats,
                            prefix='test',
                            step=cur_epoch,
                            textlogger=textlogger,
                            save_fig_sec=60)

    return stats
Exemplo n.º 9
0
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
    """Performs one epoch of training."""
    # Update drop path prob for NAS
    if cfg.MODEL.TYPE == "nas":
        m = model.module if cfg.NUM_GPUS > 1 else model
        m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH)
    # Shuffle the data
    loader.shuffle(train_loader, cur_epoch)
    # Update the learning rate per epoch
    if not cfg.OPTIM.ITER_LR:
        lr = optim.get_epoch_lr(cur_epoch)
        optim.set_lr(optimizer, lr)
    # Enable training mode
    model.train()
    train_meter.iter_tic()
    for cur_iter, (inputs, labels) in enumerate(train_loader):
        # Update the learning rate per iter
        if cfg.OPTIM.ITER_LR:
            lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader))
            optim.set_lr(optimizer, lr)
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Perform the forward pass
        preds = model(inputs)
        # Compute the loss
        if isinstance(preds, tuple):
            loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
            preds = preds[0]
        else:
            loss = loss_fun(preds, labels)
        # Perform the backward pass
        optimizer.zero_grad()
        loss.backward()
        # Update the parameters
        optimizer.step()
        # Compute the errors
        if cfg.TASK == "col":
            preds = preds.permute(0, 2, 3, 1)
            preds = preds.reshape(-1, preds.size(3))
            labels = labels.reshape(-1)
            mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
        else:
            mb_size = inputs.size(0) * cfg.NUM_GPUS
        if cfg.TASK == "seg":
            # top1_err is in fact inter; top5_err is in fact union
            top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
        else:
            ks = [1, min(5, cfg.MODEL.NUM_CLASSES)]  # rot only has 4 classes
            top1_err, top5_err = meters.topk_errors(preds, labels, ks)
        # Combine the stats across the GPUs (no reduction if 1 GPU used)
        loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
        # Copy the stats from GPU to CPU (sync point)
        loss = loss.item()
        if cfg.TASK == "seg":
            top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
        else:
            top1_err, top5_err = top1_err.item(), top5_err.item()
        train_meter.iter_toc()
        # Update and log stats
        train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
        train_meter.log_iter_stats(cur_epoch, cur_iter)
        train_meter.iter_tic()
    # Log epoch stats
    train_meter.log_epoch_stats(cur_epoch)
    train_meter.reset()
Exemplo n.º 10
0
def search_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
    """Performs one epoch of differentiable architecture search."""
    m = model.module if cfg.NUM_GPUS > 1 else model
    # Shuffle the data
    loader.shuffle(train_loader[0], cur_epoch)
    loader.shuffle(train_loader[1], cur_epoch)
    # Update the learning rate per epoch
    if not cfg.OPTIM.ITER_LR:
        lr = optim.get_epoch_lr(cur_epoch)
        optim.set_lr(optimizer[0], lr)
    # Enable training mode
    model.train()
    train_meter.iter_tic()
    trainB_iter = iter(train_loader[1])
    for cur_iter, (inputs, labels) in enumerate(train_loader[0]):
        # Update the learning rate per iter
        if cfg.OPTIM.ITER_LR:
            lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader[0]))
            optim.set_lr(optimizer[0], lr)
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Update architecture
        if cur_epoch + cur_iter / len(train_loader[0]) >= cfg.OPTIM.ARCH_EPOCH:
            try:
                inputsB, labelsB = next(trainB_iter)
            except StopIteration:
                trainB_iter = iter(train_loader[1])
                inputsB, labelsB = next(trainB_iter)
            inputsB, labelsB = inputsB.cuda(), labelsB.cuda(non_blocking=True)
            optimizer[1].zero_grad()
            loss = m._loss(inputsB, labelsB)
            loss.backward()
            optimizer[1].step()
        # Perform the forward pass
        preds = model(inputs)
        # Compute the loss
        loss = loss_fun(preds, labels)
        # Perform the backward pass
        optimizer[0].zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
        # Update the parameters
        optimizer[0].step()
        # Compute the errors
        if cfg.TASK == "col":
            preds = preds.permute(0, 2, 3, 1)
            preds = preds.reshape(-1, preds.size(3))
            labels = labels.reshape(-1)
            mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
        else:
            mb_size = inputs.size(0) * cfg.NUM_GPUS
        if cfg.TASK == "seg":
            # top1_err is in fact inter; top5_err is in fact union
            top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
        else:
            ks = [1, min(5, cfg.MODEL.NUM_CLASSES)]  # rot only has 4 classes
            top1_err, top5_err = meters.topk_errors(preds, labels, ks)
        # Combine the stats across the GPUs (no reduction if 1 GPU used)
        loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
        # Copy the stats from GPU to CPU (sync point)
        loss = loss.item()
        if cfg.TASK == "seg":
            top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
        else:
            top1_err, top5_err = top1_err.item(), top5_err.item()
        train_meter.iter_toc()
        # Update and log stats
        train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
        train_meter.log_iter_stats(cur_epoch, cur_iter)
        train_meter.iter_tic()
    # Log epoch stats
    train_meter.log_epoch_stats(cur_epoch)
    train_meter.reset()
    # Log genotype
    genotype = m.genotype()
    logger.info("genotype = %s", genotype)
    logger.info(F.softmax(m.net_.alphas_normal, dim=-1))
    logger.info(F.softmax(m.net_.alphas_reduce, dim=-1))
Exemplo n.º 11
0
def train_epoch_pseudo(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
    """Performs one epoch of Semi-supervised training."""
    # Update drop path prob for NAS
    if cfg.MODEL.TYPE == "nas":
        m = model.module if cfg.NUM_GPUS > 1 else model
        m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH)
    # Shuffle the data
    
    # Update the learning rate per epoch
    if not cfg.OPTIM.ITER_LR:
        lr = optim.get_epoch_lr(cur_epoch)
        optim.set_lr(optimizer, lr)
    # Enable training mode
    model.train()
    train_meter.iter_tic()
    max_iter=max(len(train_loader[1]),len(train_loader[0]))
    loader.shuffle(train_loader[0], cur_epoch)
    loader.shuffle(train_loader[1], cur_epoch)
    label_iter = iter(train_loader[0])
    unlabel_iter=iter(train_loader[1])
    for cur_iter in range(max_iter):    
        try:
            #print(next(label_iter))
            label_im,_,labels = next(label_iter)
        except:
            loader.shuffle(train_loader[0], cur_epoch)
            label_iter = iter(train_loader[0])
            label_im,_,labels = next(label_iter)
        try:
            unlabel_im1,unlabel_im2,_ = next(unlabel_iter)
        except:
            loader.shuffle(train_loader[1], cur_epoch)
            unlabel_iter = iter(train_loader[1])
            unlabel_im1,unlabel_im2,_ = next(unlabel_iter)
        # Update the learning rate per iter
        if cfg.OPTIM.ITER_LR:
            lr = optim.get_epoch_lr(cur_epoch + cur_iter / max_iter)
            optim.set_lr(optimizer, lr)
        # Transfer the data to the current GPU device
        label_im, labels = label_im.cuda(), labels.cuda(non_blocking=True)
        unlabel_im1, unlabel_im2 = unlabel_im1.cuda(), unlabel_im2.cuda()
        imgs=torch.cat([label_im,unlabel_im1,unlabel_im2],dim=0)
        logits = model(imgs)
        logits_label=logits[:len(labels)]
        logits_unlabel1,logits_unlabel2=torch.split(logits[len(labels):],unlabel_im1.shape[0])

        # with torch.no_grad():
        #     probs = torch.softmax(logits_label, dim=1)
        #     scores, lbs_guess = torch.max(probs, dim=1)
        # print(lbs_guess,labels)
        loss_label=loss_fun(logits_label,labels)
        
        #print(logits.shape,logits_label.shape,logits_unlabel1.shape,logits_unlabel2.shape)

        with torch.no_grad():
            probs = torch.softmax(logits_unlabel1, dim=1)
            scores, lbs_u_guess = torch.max(probs, dim=1)
            mask = scores.ge(cfg.TRAIN.PSD_THRESHOLD).float()
        criteria_u = nn.CrossEntropyLoss(reduction='none').cuda()
        if cfg.TASK=='psd':
            loss_unlabel=(criteria_u(logits_unlabel1,lbs_u_guess)*mask).mean()
        elif cfg.TASK=='fix':
            loss_unlabel=(criteria_u(logits_unlabel2,lbs_u_guess)*mask).mean()
        else:
            loss_unlabel=0
        
        loss=loss_label+loss_unlabel

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute the errors
        mb_size = label_im.size(0) * cfg.NUM_GPUS
        ks = [1, min(5, cfg.MODEL.NUM_CLASSES)]  # rot only has 4 classes
        top1_err, top5_err = meters.topk_errors(logits_label, labels, ks)
        # Combine the stats across the GPUs (no reduction if 1 GPU used)
        loss=loss_label.item()
        loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
        
        top1_err, top5_err = top1_err.item(), top5_err.item()
        train_meter.iter_toc() 
        # Update and log stats
        train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
        train_meter.log_iter_stats(cur_epoch, cur_iter)
        train_meter.iter_tic()
    # Log epoch stats
    train_meter.log_epoch_stats(cur_epoch)
    train_meter.reset()