def PI_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) w = 0 for epoch in range(args.epochs): loss_record = AverageMeter() model.train() exp_lr_scheduler.step() w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)): x, x_bar = x.to(device), x_bar.to(device) _, feat = model(x) _, feat_bar = model(x_bar) prob = feat2prob(feat, model.center) prob_bar = feat2prob(feat_bar, model.center) sharp_loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) consistency_loss = F.mse_loss(prob, prob_bar) loss = sharp_loss + w * consistency_loss loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eva_loader, args, epoch) if epoch % args.update_interval==0: print('updating target ...') args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) for epoch in range(args.epochs): loss_record = AverageMeter() model.train() w = args.rampup_coefficient * ramps.sigmoid_rampup( epoch, args.rampup_length) for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) _, feat_g = model(g_x.to(device)) prob = feat2prob(feat, model.center) prob_g = feat2prob(feat_g, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) mse_loss = F.mse_loss(prob, prob_g) loss = loss + w * mse_loss optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def TE_train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) w = 0 alpha = 0.6 ntrain = len(train_loader.dataset) Z = torch.zeros(ntrain, args.n_clusters).float().to(device) # intermediate values z_ema = torch.zeros(ntrain, args.n_clusters).float().to(device) # temporal outputs z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(device) # current outputs for epoch in range(args.epochs): loss_record = AverageMeter() model.train() w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) for batch_idx, (x, _, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) prob = feat2prob(feat, model.center) z_epoch[idx, :] = prob prob_bar = Variable(z_ema[idx, :], requires_grad=False) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) mse_loss = F.mse_loss(prob, prob_bar) loss=loss+w*mse_loss optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) Z = alpha * Z + (1. - alpha) * z_epoch z_ema = Z * (1. / (1. - alpha ** (epoch + 1))) print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) if epoch % args.update_interval==0: args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def train(model, train_loader, unlabeled_eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) criterion1 = nn.CrossEntropyLoss() criterion2 = BCE() for epoch in range(args.epochs): loss_record = AverageMeter() model.train() exp_lr_scheduler.step() w = args.rampup_coefficient * ramps.sigmoid_rampup( epoch, args.rampup_length) for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)): x, x_bar, label = x.to(device), x_bar.to(device), label.to(device) output1, output2, feat = model(x) output1_bar, output2_bar, _ = model(x_bar) prob1, prob1_bar, prob2, prob2_bar = F.softmax( output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax( output2, dim=1), F.softmax(output2_bar, dim=1) mask_lb = idx < train_loader.labeled_length rank_feat = (feat[~mask_lb]).detach() rank_idx = torch.argsort(rank_feat, dim=1, descending=True) rank_idx1, rank_idx2 = PairEnum(rank_idx) rank_idx1, rank_idx2 = rank_idx1[:, :args. topk], rank_idx2[:, :args.topk] rank_idx1, _ = torch.sort(rank_idx1, dim=1) rank_idx2, _ = torch.sort(rank_idx2, dim=1) rank_diff = rank_idx1 - rank_idx2 rank_diff = torch.sum(torch.abs(rank_diff), dim=1) target_ulb = torch.ones_like(rank_diff).float().to(device) target_ulb[rank_diff > 0] = -1 prob1_ulb, _ = PairEnum(prob2[~mask_lb]) _, prob2_ulb = PairEnum(prob2_bar[~mask_lb]) loss_ce = criterion1(output1[mask_lb], label[mask_lb]) loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb) consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss( prob2, prob2_bar) loss = loss_ce + loss_bce + w * consistency_loss loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) print('test on unlabeled classes') args.head = 'head2' test(model, unlabeled_eval_loader, args)
def train(model, model_ema, train_loader, labeled_eval_loader, unlabeled_eval_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) criterion1 = nn.CrossEntropyLoss() criterion2 = BCE() for epoch in range(args.epochs): loss_record = AverageMeter() model.train() model_ema.train() exp_lr_scheduler.step() w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)): x, x_bar, label = x.to(device), x_bar.to(device), label.to(device) output1, output2, feat = model(x) output1_bar, output2_bar, _ = model(x_bar) with torch.no_grad(): output1_ema, output2_ema, feat_ema = model_ema(x) output1_bar_ema, output2_bar_ema, _ = model_ema(x_bar) prob1, prob1_bar, prob2, prob2_bar = F.softmax(output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax(output2, dim=1), F.softmax(output2_bar, dim=1) prob1_ema, prob1_bar_ema, prob2_ema, prob2_bar_ema = F.softmax(output1_ema, dim=1), F.softmax(output1_bar_ema, dim=1), F.softmax(output2_ema, dim=1), F.softmax(output2_bar_ema, dim=1) mask_lb = label<args.num_labeled_classes loss_ce = criterion1(output1[mask_lb], label[mask_lb]) loss_bce = rank_bce(criterion2,feat,mask_lb,prob2,prob2_bar) consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(prob2, prob2_bar) consistency_loss_ema = F.mse_loss(prob1, prob1_bar_ema) + F.mse_loss(prob2, prob2_bar_ema) loss = loss_ce + loss_bce + w * consistency_loss + w * consistency_loss_ema #+ smooth_loss(feat,mask_lb) #+ MCR(feat, idx) loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() _update_ema_variables(model, model_ema, 0.99, epoch * len(train_loader) + batch_idx) print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) print('test on labeled classes') args.head = 'head1' test(model, labeled_eval_loader, args) print('test on unlabeled classes') args.head='head2' test(model, unlabeled_eval_loader, args) test(model_ema, unlabeled_eval_loader, args)
def TE_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) w = 0 alpha = 0.6 ntrain = len(train_loader.dataset) Z = torch.zeros(ntrain, args.n_clusters).float().to(device) # intermediate values z_ema = torch.zeros(ntrain, args.n_clusters).float().to(device) # temporal outputs z_epoch = torch.zeros(ntrain, args.n_clusters).float().to( device) # current outputs for epoch in range(args.epochs): loss_record = AverageMeter() model.train() exp_lr_scheduler.step() w = args.rampup_coefficient * ramps.sigmoid_rampup( epoch, args.rampup_length) for batch_idx, (x, label, idx) in enumerate(tqdm(train_loader)): x = x.to(device) feat = model(x) prob = feat2prob(feat, model.center) z_epoch[idx, :] = prob prob_bar = Variable(z_ema[idx, :], requires_grad=False) sharp_loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) consistency_loss = F.mse_loss(prob, prob_bar) loss = sharp_loss + w * consistency_loss loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() Z = alpha * Z + (1. - alpha) * z_epoch z_ema = Z * (1. / (1. - alpha**(epoch + 1))) print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) _, _, _, probs = test(model, eva_loader, args) if epoch % args.update_interval == 0: print('updating target ...') args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def TEP_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) w = 0 alpha = 0.6 ntrain = len(train_loader.dataset) Z = torch.zeros(ntrain, args.n_clusters).float().to(device) # intermediate values z_ema = torch.zeros(ntrain, args.n_clusters).float().to(device) # temporal outputs z_epoch = torch.zeros(ntrain, args.n_clusters).float().to( device) # current outputs for epoch in range(args.epochs): loss_record = AverageMeter() acc_record = AverageMeter() model.train() w = args.rampup_coefficient * ramps.sigmoid_rampup( epoch, args.rampup_length) for batch_idx, (x, label, idx) in enumerate(tqdm(train_loader)): x = x.to(device) feat = model(x) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) _, _, _, probs = test(model, eva_loader, args, epoch) z_epoch = probs.float().to(device) Z = alpha * Z + (1. - alpha) * z_epoch z_ema = Z * (1. / (1. - alpha**(epoch + 1))) if epoch % args.update_interval == 0: print('updating target ...') args.p_targets = target_distribution(z_ema).float().to(device) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def get_current_consistency_weight(epoch): # Consistency ramp-up from https://arxiv.org/abs/1610.02242 return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)
def train(args, snapshot_path): base_lr = args.base_lr num_classes = args.num_classes batch_size = args.batch_size max_iterations = args.max_iterations def create_model(ema=False): # Network definition model = net_factory(net_type=args.model, in_chns=1, class_num=num_classes) if ema: for param in model.parameters(): param.detach_() return model model = create_model() ema_model = create_model(ema=True) def worker_init_fn(worker_id): random.seed(args.seed + worker_id) db_train = BaseDataSets(base_dir=args.root_path, split="train", num=None, transform=transforms.Compose( [RandomGenerator(args.patch_size)])) db_val = BaseDataSets(base_dir=args.root_path, split="val") total_slices = len(db_train) labeled_slice = patients_to_slices(args.root_path, args.labeled_num) print("Total silices is: {}, labeled slices is: {}".format( total_slices, labeled_slice)) labeled_idxs = list(range(0, labeled_slice)) unlabeled_idxs = list(range(labeled_slice, total_slices)) batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size - args.labeled_bs) trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) model.train() valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1) optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) ce_loss = CrossEntropyLoss() dice_loss = losses.DiceLoss(num_classes) writer = SummaryWriter(snapshot_path + '/log') logging.info("{} iterations per epoch".format(len(trainloader))) iter_num = 0 max_epoch = max_iterations // len(trainloader) + 1 best_performance = 0.0 iterator = tqdm(range(max_epoch), ncols=70) for epoch_num in iterator: for i_batch, sampled_batch in enumerate(trainloader): volume_batch, label_batch = sampled_batch['image'], sampled_batch[ 'label'] volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() unlabeled_volume_batch = volume_batch[args.labeled_bs:] noise = torch.clamp( torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2) ema_inputs = unlabeled_volume_batch + noise outputs = model(volume_batch) outputs_soft = torch.softmax(outputs, dim=1) with torch.no_grad(): ema_output = ema_model(ema_inputs) T = 8 _, _, w, h = unlabeled_volume_batch.shape volume_batch_r = unlabeled_volume_batch.repeat(2, 1, 1, 1) stride = volume_batch_r.shape[0] // 2 preds = torch.zeros([stride * T, num_classes, w, h]).cuda() for i in range(T // 2): ema_inputs = volume_batch_r + \ torch.clamp(torch.randn_like( volume_batch_r) * 0.1, -0.2, 0.2) with torch.no_grad(): preds[2 * stride * i:2 * stride * (i + 1)] = ema_model(ema_inputs) preds = F.softmax(preds, dim=1) preds = preds.reshape(T, stride, num_classes, w, h) preds = torch.mean(preds, dim=0) uncertainty = -1.0 * \ torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) loss_ce = ce_loss(outputs[:args.labeled_bs], label_batch[:args.labeled_bs][:].long()) loss_dice = dice_loss(outputs_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1)) supervised_loss = 0.5 * (loss_dice + loss_ce) consistency_weight = get_current_consistency_weight(iter_num // 150) consistency_dist = losses.softmax_mse_loss( outputs[args.labeled_bs:], ema_output) # (batch, 2, 112,112,80) threshold = (0.75 + 0.25 * ramps.sigmoid_rampup( iter_num, max_iterations)) * np.log(2) mask = (uncertainty < threshold).float() consistency_loss = torch.sum( mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16) loss = supervised_loss + consistency_weight * consistency_loss optimizer.zero_grad() loss.backward() optimizer.step() update_ema_variables(model, ema_model, args.ema_decay, iter_num) lr_ = base_lr * (1.0 - iter_num / max_iterations)**0.9 for param_group in optimizer.param_groups: param_group['lr'] = lr_ iter_num = iter_num + 1 writer.add_scalar('info/lr', lr_, iter_num) writer.add_scalar('info/total_loss', loss, iter_num) writer.add_scalar('info/loss_ce', loss_ce, iter_num) writer.add_scalar('info/loss_dice', loss_dice, iter_num) writer.add_scalar('info/consistency_loss', consistency_loss, iter_num) writer.add_scalar('info/consistency_weight', consistency_weight, iter_num) logging.info( 'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' % (iter_num, loss.item(), loss_ce.item(), loss_dice.item())) if iter_num % 20 == 0: image = volume_batch[1, 0:1, :, :] writer.add_image('train/Image', image, iter_num) outputs = torch.argmax(torch.softmax(outputs, dim=1), dim=1, keepdim=True) writer.add_image('train/Prediction', outputs[1, ...] * 50, iter_num) labs = label_batch[1, ...].unsqueeze(0) * 50 writer.add_image('train/GroundTruth', labs, iter_num) if iter_num > 0 and iter_num % 200 == 0: model.eval() metric_list = 0.0 for i_batch, sampled_batch in enumerate(valloader): metric_i = test_single_volume(sampled_batch["image"], sampled_batch["label"], model, classes=num_classes) metric_list += np.array(metric_i) metric_list = metric_list / len(db_val) for class_i in range(num_classes - 1): writer.add_scalar('info/val_{}_dice'.format(class_i + 1), metric_list[class_i, 0], iter_num) writer.add_scalar('info/val_{}_hd95'.format(class_i + 1), metric_list[class_i, 1], iter_num) performance = np.mean(metric_list, axis=0)[0] mean_hd95 = np.mean(metric_list, axis=0)[1] writer.add_scalar('info/val_mean_dice', performance, iter_num) writer.add_scalar('info/val_mean_hd95', mean_hd95, iter_num) if performance > best_performance: best_performance = performance save_mode_path = os.path.join( snapshot_path, 'iter_{}_dice_{}.pth'.format( iter_num, round(best_performance, 4))) save_best = os.path.join( snapshot_path, '{}_best_model.pth'.format(args.model)) torch.save(model.state_dict(), save_mode_path) torch.save(model.state_dict(), save_best) logging.info('iteration %d : mean_dice : %f mean_hd95 : %f' % (iter_num, performance, mean_hd95)) model.train() if iter_num % 3000 == 0: save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') torch.save(model.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) if iter_num >= max_iterations: break if iter_num >= max_iterations: iterator.close() break writer.close() return "Training Finished!"
def train(train_loader, model, optimizer, epoch, ema_model=None, weak_mask=None, strong_mask=None): """ One epoch of a Mean Teacher model :param train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch. Should return 3 values: teacher input, student input, labels :param model: torch.Module, model to be trained, should return a weak and strong prediction :param optimizer: torch.Module, optimizer used to train the model :param epoch: int, the current epoch of training :param ema_model: torch.Module, student model, should return a weak and strong prediction :param weak_mask: mask the batch to get only the weak labeled data (used to calculate the loss) :param strong_mask: mask the batch to get only the strong labeled data (used to calcultate the loss) """ class_criterion = nn.BCELoss() ################################################## class_criterion1 = nn.BCELoss(reduction='none') ################################################## consistency_criterion = nn.MSELoss() # [class_criterion, consistency_criterion] = to_cuda_if_available( # [class_criterion, consistency_criterion]) [class_criterion, class_criterion1, consistency_criterion] = to_cuda_if_available( [class_criterion, class_criterion1, consistency_criterion]) meters = AverageMeterSet() LOG.debug("Nb batches: {}".format(len(train_loader))) start = time.time() rampup_length = len(train_loader) * cfg.n_epoch // 2 print("Train\n") # LOG.info("Weak[k] -> Weak[k]") # LOG.info("Weak[k] -> strong[k]") # print(weak_mask.start) # print(strong_mask.start) # exit() count = 0 check_cus_weak = 0 difficulty_loss = 0 loss_w = 1 LOG.info("loss paramater:{}".format(loss_w)) for i, (batch_input, ema_batch_input, target) in enumerate(train_loader): # print(batch_input.shape) # print(ema_batch_input.shape) # exit() global_step = epoch * len(train_loader) + i if global_step < rampup_length: rampup_value = ramps.sigmoid_rampup(global_step, rampup_length) else: rampup_value = 1.0 # Todo check if this improves the performance # adjust_learning_rate(optimizer, rampup_value, rampdown_value) meters.update('lr', optimizer.param_groups[0]['lr']) [batch_input, ema_batch_input, target] = to_cuda_if_available([batch_input, ema_batch_input, target]) LOG.debug("batch_input:{}".format(batch_input.mean())) # print(batch_input) # exit() # Outputs ################################################## # strong_pred_ema, weak_pred_ema = ema_model(ema_batch_input) strong_pred_ema, weak_pred_ema, sof_ema = ema_model(ema_batch_input) sof_ema = sof_ema.detach() ################################################## strong_pred_ema = strong_pred_ema.detach() weak_pred_ema = weak_pred_ema.detach() ################################################## # strong_pred, weak_pred = model(batch_input) strong_pred, weak_pred, sof = model(batch_input) ################################################## ################################################## # custom_ema_loss = Custom_BCE_Loss(ema_batch_input, class_criterion1) if difficulty_loss == 0: LOG.info("############### Deffine Difficulty Loss ###############") difficulty_loss = 1 custom_ema_loss = Custom_BCE_Loss_difficulty(ema_batch_input, class_criterion1, paramater=loss_w) custom_ema_loss.initialize(strong_pred_ema, sof_ema) # custom_loss = Custom_BCE_Loss(batch_input, class_criterion1) custom_loss = Custom_BCE_Loss_difficulty(batch_input, class_criterion1, paramater=loss_w) custom_loss.initialize(strong_pred, sof) ################################################## # print(strong_pred.shape) # print(strong_pred) # print(weak_pred.shape) # print(weak_pred) # exit() loss = None # Weak BCE Loss # Take the max in the time axis # torch.set_printoptions(threshold=10000) # print(target[-10]) # # print(target.max(-2)) # # print(target.max(-2)[0]) # print(target.max(-1)[0][-10]) # exit() target_weak = target.max(-2)[0] if weak_mask is not None: weak_class_loss = class_criterion(weak_pred[weak_mask], target_weak[weak_mask]) ema_class_loss = class_criterion(weak_pred_ema[weak_mask], target_weak[weak_mask]) print( "noraml_weak:", class_criterion(weak_pred[weak_mask], target_weak[weak_mask])) ################################################## custom_weak_class_loss = custom_loss.weak(target_weak, weak_mask) custom_ema_class_loss = custom_ema_loss.weak( target_weak, weak_mask) print("custom_weak:", custom_weak_class_loss) ################################################## count += 1 check_cus_weak += custom_weak_class_loss # print(custom_weak_class_loss.item()) if i == 0: LOG.debug("target: {}".format(target.mean(-2))) LOG.debug("Target_weak: {}".format(target_weak)) LOG.debug("Target_weak mask: {}".format( target_weak[weak_mask])) LOG.debug(custom_weak_class_loss) ### LOG.debug("rampup_value: {}".format(rampup_value)) meters.update('weak_class_loss', custom_weak_class_loss.item()) ### meters.update('Weak EMA loss', custom_ema_class_loss.item()) ### # loss = weak_class_loss loss = custom_weak_class_loss #################################################################################### # weak_class_loss = class_criterion(strong_pred[weak_mask], target[weak_mask]) # ema_class_loss = class_criterion(strong_pred_ema[weak_mask], target[weak_mask]) # # if i == 0: # # LOG.debug("target: {}".format(target.mean(-2))) # # LOG.debug("Target_weak: {}".format(target)) # # LOG.debug("Target_weak mask: {}".format(target[weak_mask])) # # LOG.debug(weak_class_loss) # # LOG.debug("rampup_value: {}".format(rampup_value)) # meters.update('weak_class_loss', weak_class_loss.item()) # meters.update('Weak EMA loss', ema_class_loss.item()) # loss = weak_class_loss #################################################################################### # Strong BCE loss if strong_mask is not None: strong_class_loss = class_criterion(strong_pred[strong_mask], target[strong_mask]) # meters.update('Strong loss', strong_class_loss.item()) strong_ema_class_loss = class_criterion( strong_pred_ema[strong_mask], target[strong_mask]) # meters.update('Strong EMA loss', strong_ema_class_loss.item()) print( "normal_strong:", class_criterion(strong_pred[strong_mask], target[strong_mask])) ################################################## custom_strong_class_loss = custom_loss.strong(target, strong_mask) meters.update('Strong loss', custom_strong_class_loss.item()) custom_strong_ema_class_loss = custom_ema_loss.strong( target, strong_mask) meters.update('Strong EMA loss', custom_strong_ema_class_loss.item()) print("custom_strong:", custom_strong_class_loss) ################################################## if loss is not None: # loss += strong_class_loss loss += custom_strong_class_loss else: # loss = strong_class_loss loss = custom_strong_class_loss # print("check_weak:", class_criterion1(weak_pred[weak_mask], target_weak[weak_mask]).mean()) # print("check_strong:", class_criterion1(strong_pred[strong_mask], target[strong_mask]).mean()) # print("\n") # exit() # Teacher-student consistency cost if ema_model is not None: consistency_cost = cfg.max_consistency_cost * rampup_value meters.update('Consistency weight', consistency_cost) # Take consistency about strong predictions (all data) consistency_loss_strong = consistency_cost * consistency_criterion( strong_pred, strong_pred_ema) meters.update('Consistency strong', consistency_loss_strong.item()) if loss is not None: loss += consistency_loss_strong else: loss = consistency_loss_strong meters.update('Consistency weight', consistency_cost) # Take consistency about weak predictions (all data) consistency_loss_weak = consistency_cost * consistency_criterion( weak_pred, weak_pred_ema) meters.update('Consistency weak', consistency_loss_weak.item()) if loss is not None: loss += consistency_loss_weak else: loss = consistency_loss_weak assert not (np.isnan(loss.item()) or loss.item() > 1e5), 'Loss explosion: {}'.format( loss.item()) assert not loss.item() < 0, 'Loss problem, cannot be negative' meters.update('Loss', loss.item()) # compute gradient and do optimizer step optimizer.zero_grad() loss.backward() optimizer.step() global_step += 1 if ema_model is not None: update_ema_variables(model, ema_model, 0.999, global_step) epoch_time = time.time() - start LOG.info('Epoch: {}\t' 'Time {:.2f}\t' '{meters}'.format(epoch, epoch_time, meters=meters)) print("\ncheck_cus_weak:\n", check_cus_weak / count)
def get_current_consistency_weight(epoch): return sigmoid_rampup(epoch, 10)
def train(args, snapshot_path): base_lr = args.base_lr train_data_path = args.root_path batch_size = args.batch_size max_iterations = args.max_iterations num_classes = 2 def create_model(ema=False): # Network definition net = net_factory_3d(net_type=args.model, in_chns=1, class_num=num_classes) model = net.cuda() if ema: for param in model.parameters(): param.detach_() return model model = create_model() ema_model = create_model(ema=True) db_train = BraTS2019(base_dir=train_data_path, split='train', num=None, transform=transforms.Compose([ RandomRotFlip(), RandomCrop(args.patch_size), ToTensor(), ])) def worker_init_fn(worker_id): random.seed(args.seed + worker_id) labeled_idxs = list(range(0, args.labeled_num)) unlabeled_idxs = list(range(args.labeled_num, 250)) batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size - args.labeled_bs) trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) model.train() ema_model.train() optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) ce_loss = CrossEntropyLoss() dice_loss = losses.DiceLoss(2) writer = SummaryWriter(snapshot_path + '/log') logging.info("{} iterations per epoch".format(len(trainloader))) iter_num = 0 max_epoch = max_iterations // len(trainloader) + 1 best_performance = 0.0 iterator = tqdm(range(max_epoch), ncols=70) for epoch_num in iterator: for i_batch, sampled_batch in enumerate(trainloader): volume_batch, label_batch = sampled_batch['image'], sampled_batch[ 'label'] volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() unlabeled_volume_batch = volume_batch[args.labeled_bs:] noise = torch.clamp( torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2) ema_inputs = unlabeled_volume_batch + noise outputs = model(volume_batch) outputs_soft = torch.softmax(outputs, dim=1) with torch.no_grad(): ema_output = ema_model(ema_inputs) T = 8 _, _, d, w, h = unlabeled_volume_batch.shape volume_batch_r = unlabeled_volume_batch.repeat(2, 1, 1, 1, 1) stride = volume_batch_r.shape[0] // 2 preds = torch.zeros([stride * T, 2, d, w, h]).cuda() for i in range(T // 2): ema_inputs = volume_batch_r + \ torch.clamp(torch.randn_like( volume_batch_r) * 0.1, -0.2, 0.2) with torch.no_grad(): preds[2 * stride * i:2 * stride * (i + 1)] = ema_model(ema_inputs) preds = torch.softmax(preds, dim=1) preds = preds.reshape(T, stride, 2, d, w, h) preds = torch.mean(preds, dim=0) uncertainty = -1.0 * \ torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) loss_ce = ce_loss(outputs[:args.labeled_bs], label_batch[:args.labeled_bs][:]) loss_dice = dice_loss(outputs_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1)) supervised_loss = 0.5 * (loss_dice + loss_ce) consistency_weight = get_current_consistency_weight(iter_num // 150) consistency_dist = losses.softmax_mse_loss( outputs[args.labeled_bs:], ema_output) # (batch, 2, 112,112,80) threshold = (0.75 + 0.25 * ramps.sigmoid_rampup( iter_num, max_iterations)) * np.log(2) mask = (uncertainty < threshold).float() consistency_loss = torch.sum( mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16) loss = supervised_loss + consistency_weight * consistency_loss optimizer.zero_grad() loss.backward() optimizer.step() update_ema_variables(model, ema_model, args.ema_decay, iter_num) lr_ = base_lr * (1.0 - iter_num / max_iterations)**0.9 for param_group in optimizer.param_groups: param_group['lr'] = lr_ iter_num = iter_num + 1 writer.add_scalar('info/lr', lr_, iter_num) writer.add_scalar('info/total_loss', loss, iter_num) writer.add_scalar('info/loss_ce', loss_ce, iter_num) writer.add_scalar('info/loss_dice', loss_dice, iter_num) writer.add_scalar('info/consistency_loss', consistency_loss, iter_num) writer.add_scalar('info/consistency_weight', consistency_weight, iter_num) logging.info( 'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' % (iter_num, loss.item(), loss_ce.item(), loss_dice.item())) writer.add_scalar('loss/loss', loss, iter_num) if iter_num % 20 == 0: image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=True) writer.add_image('train/Image', grid_image, iter_num) image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=False) writer.add_image('train/Predicted_label', grid_image, iter_num) image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute( 3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=False) writer.add_image('train/Groundtruth_label', grid_image, iter_num) if iter_num > 0 and iter_num % 200 == 0: model.eval() avg_metric = test_all_case(model, args.root_path, test_list="val.txt", num_classes=2, patch_size=args.patch_size, stride_xy=64, stride_z=64) if avg_metric[:, 0].mean() > best_performance: best_performance = avg_metric[:, 0].mean() save_mode_path = os.path.join( snapshot_path, 'iter_{}_dice_{}.pth'.format( iter_num, round(best_performance, 4))) save_best = os.path.join( snapshot_path, '{}_best_model.pth'.format(args.model)) torch.save(model.state_dict(), save_mode_path) torch.save(model.state_dict(), save_best) writer.add_scalar('info/val_dice_score', avg_metric[0, 0], iter_num) writer.add_scalar('info/val_hd95', avg_metric[0, 1], iter_num) logging.info('iteration %d : dice_score : %f hd95 : %f' % (iter_num, avg_metric[0, 0].mean(), avg_metric[0, 1].mean())) model.train() if iter_num % 3000 == 0: save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') torch.save(model.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) if iter_num >= max_iterations: break if iter_num >= max_iterations: iterator.close() break writer.close() return "Training Finished!"
preds[2 * stride * i:2 * stride * (i + 1)] = ema_model(ema_inputs) preds = F.softmax(preds, dim=1) preds = preds.reshape(T, stride, 2, 112, 112, 80) preds = torch.mean(preds, dim=0) #(batch, 2, 112,112,80) uncertainty = -1.0*torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) #(batch, 1, 112,112,80) ## calculate the loss loss_seg = F.cross_entropy(outputs[:labeled_bs], label_batch[:labeled_bs]) outputs_soft = F.softmax(outputs, dim=1) loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 1, :, :, :], label_batch[:labeled_bs] == 1) supervised_loss = 0.5*(loss_seg+loss_seg_dice) consistency_weight = get_current_consistency_weight(iter_num//150) consistency_dist = consistency_criterion(outputs[labeled_bs:], ema_output) #(batch, 2, 112,112,80) threshold = (0.75+0.25*ramps.sigmoid_rampup(iter_num, max_iterations))*np.log(2) mask = (uncertainty<threshold).float() consistency_dist = torch.sum(mask*consistency_dist)/(2*torch.sum(mask)+1e-16) consistency_loss = consistency_weight * consistency_dist loss = supervised_loss + consistency_loss optimizer.zero_grad() loss.backward() optimizer.step() update_ema_variables(model, ema_model, args.ema_decay, iter_num) iter_num = iter_num + 1 writer.add_scalar('uncertainty/mean', uncertainty[0,0].mean(), iter_num) writer.add_scalar('uncertainty/max', uncertainty[0,0].max(), iter_num) writer.add_scalar('uncertainty/min', uncertainty[0,0].min(), iter_num) writer.add_scalar('uncertainty/mask_per', torch.sum(mask)/mask.numel(), iter_num)
else: model_path_pretrain = os.path.join(model_directory, model_name_triplet, "epoch_" + str(f_args.epochs)) print("path of model : " + model_path_pretrain) create_folder(os.path.join(model_directory, model_name_triplet)) batch_size_classif = cfg.batch_size_classif # Hard coded because no semi_hard in this version semi_hard_embed = None semi_hard_input = None if not os.path.exists(model_path_pretrain) or cfg.recompute_embedding: margin = triplet_margin for epoch in range(f_args.epochs): t_start_epoch = time.time() if cfg.rampup_margin_length is not None: margin = sigmoid_rampup(epoch, cfg.rampup_margin_length) * triplet_margin model_triplet.train() model_triplet, loss_mean_triplet, ratio_used = train_triplet_epoch(triplet_loader, # triplet_loader, model_triplet, optimizer, semi_hard_input, semi_hard_embed, pit=pit, margin=margin, swap=swap, acc_grad=segment ) model_triplet.eval() loss_mean_triplet = np.mean(loss_mean_triplet) embed_dir = "stored_data/embeddings" embed_dir = os.path.join(embed_dir, model_name_triplet, "embeddings") create_folder(embed_dir)
labeled_outputs_soft[:, i, :, :, :], label_batch == i) loss_seg_dice += loss_mid print('dice score (1-dice_loss): {:.3f}'.format(1 - loss_mid)) supervised_loss = (loss_seg + loss_seg_dice) / 2.0 # supervised_loss = loss_seg_dice # if epoch_num==20 and i_batch==5: # import pdb # pdb.set_trace() ## calculate the loss ( only for unlabeled samples ) consistency_weight = args.consistency # get_current_consistency_weight(epoch_num/20) consistency_dist = consistency_criterion( outputs[:, 1:, :, :, :], ema_outputs[:, 1:, :, :, :]) #(batch, num_classes, 112,112,80) threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(epoch_num, 20)) * np.log( num_classes) #N分类问题的最大不确定度是log(N) mask = (uncertainty < threshold).int() consistency_dist = torch.sum( mask * consistency_dist) / (torch.sum(mask) + 1e-16) consistency_loss = consistency_weight * consistency_dist loss = supervised_loss + consistency_loss # pytorch模型训练的三板斧 # 一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程 optimizer.zero_grad() #把模型中参数的梯度设为0 loss.backward() optimizer.step() update_ema_variables(model, ema_model, args.ema_decay, epoch_num) iter_num = iter_num + 1
def train(cfg, train_loader, model, optimizer, epoch, ema_model=None, weak_mask=None, strong_mask=None): """ One epoch of a Mean Teacher model :param train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch. Should return 3 values: teacher input, student input, labels :param model: torch.Module, model to be trained, should return a weak and strong prediction :param optimizer: torch.Module, optimizer used to train the model :param epoch: int, the current epoch of training :param ema_model: torch.Module, student model, should return a weak and strong prediction :param weak_mask: mask the batch to get only the weak labeled data (used to calculate the loss) :param strong_mask: mask the batch to get only the strong labeled data (used to calcultate the loss) """ class_criterion = nn.BCELoss() consistency_criterion_strong = nn.MSELoss() lds_criterion = LDSLoss(xi=cfg.vat_xi, eps=cfg.vat_eps, n_power_iter=cfg.vat_n_power_iter) [class_criterion, consistency_criterion_strong, lds_criterion] = to_cuda_if_available( [class_criterion, consistency_criterion_strong, lds_criterion]) meters = AverageMeterSet() LOG.debug("Nb batches: {}".format(len(train_loader))) start = time.time() rampup_length = len(train_loader) * cfg.n_epoch // 2 for i, (batch_input, ema_batch_input, target) in enumerate(train_loader): global_step = epoch * len(train_loader) + i if global_step < rampup_length: rampup_value = ramps.sigmoid_rampup(global_step, rampup_length) else: rampup_value = 1.0 # Todo check if this improves the performance # adjust_learning_rate(optimizer, rampup_value, rampdown_value) meters.update('lr', optimizer.param_groups[0]['lr']) [batch_input, ema_batch_input, target] = to_cuda_if_available([batch_input, ema_batch_input, target]) LOG.debug(batch_input.mean()) # Outputs strong_pred_ema, weak_pred_ema = ema_model(ema_batch_input) strong_pred_ema = strong_pred_ema.detach() weak_pred_ema = weak_pred_ema.detach() strong_pred, weak_pred = model(batch_input) loss = None # Weak BCE Loss # Take the max in axis 2 (assumed to be time) if len(target.shape) > 2: target_weak = target.max(-2)[0] else: target_weak = target if weak_mask is not None: weak_class_loss = class_criterion(weak_pred[weak_mask], target_weak[weak_mask]) ema_class_loss = class_criterion(weak_pred_ema[weak_mask], target_weak[weak_mask]) if i == 0: LOG.debug("target: {}".format(target.mean(-2))) LOG.debug("Target_weak: {}".format(target_weak)) LOG.debug("Target_weak mask: {}".format( target_weak[weak_mask])) LOG.debug(weak_class_loss) LOG.debug("rampup_value: {}".format(rampup_value)) meters.update('weak_class_loss', weak_class_loss.item()) meters.update('Weak EMA loss', ema_class_loss.item()) loss = weak_class_loss # Strong BCE loss if strong_mask is not None: strong_class_loss = class_criterion(strong_pred[strong_mask], target[strong_mask]) meters.update('Strong loss', strong_class_loss.item()) strong_ema_class_loss = class_criterion( strong_pred_ema[strong_mask], target[strong_mask]) meters.update('Strong EMA loss', strong_ema_class_loss.item()) if loss is not None: loss += strong_class_loss else: loss = strong_class_loss # Teacher-student consistency cost if ema_model is not None: consistency_cost = cfg.max_consistency_cost * rampup_value meters.update('Consistency weight', consistency_cost) # Take only the consistence with weak and unlabel consistency_loss_strong = consistency_cost * consistency_criterion_strong( strong_pred, strong_pred_ema) meters.update('Consistency strong', consistency_loss_strong.item()) if loss is not None: loss += consistency_loss_strong else: loss = consistency_loss_strong meters.update('Consistency weight', consistency_cost) # Take only the consistence with weak and unlabel consistency_loss_weak = consistency_cost * consistency_criterion_strong( weak_pred, weak_pred_ema) meters.update('Consistency weak', consistency_loss_weak.item()) if loss is not None: loss += consistency_loss_weak else: loss = consistency_loss_weak # LDS loss if cfg.vat_enabled: lds_loss = cfg.vat_coeff * lds_criterion(model, batch_input, weak_pred) LOG.info('loss: {:.3f}, lds loss: {:.3f}'.format( loss, cfg.vat_coeff * lds_loss.detach().cpu().numpy())) loss += lds_loss else: if i % 25 == 0: LOG.info('loss: {:.3f}'.format(loss)) assert not (np.isnan(loss.item()) or loss.item() > 1e5), 'Loss explosion: {}'.format( loss.item()) assert not loss.item() < 0, 'Loss problem, cannot be negative' meters.update('Loss', loss.item()) # compute gradient and do optimizer step optimizer.zero_grad() loss.backward() optimizer.step() global_step += 1 if ema_model is not None: update_ema_variables(model, ema_model, 0.999, global_step) epoch_time = time.time() - start LOG.info('Epoch: {}\t' 'Time {:.2f}\t' '{meters}'.format(epoch, epoch_time, meters=meters))
label_batch == i) loss_seg_dice += loss_mid print('dice score (1-dice_loss): {:.3f}'.format(1 - loss_mid)) # import pdb # pdb.set_trace() # print('dicetotal:{:.3f}'.format( loss_seg_dice)) #loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 1, :, :, :], label_batch[:labeled_bs] == 1) supervised_loss = 0.5 * (loss_seg + loss_seg_dice) # only for unlabeled samples consistency_weight = get_current_consistency_weight(iter_num // 150) consistency_dist = consistency_criterion( unlabeled_outputs, ema_output) #(batch, num_classes, 112,112,80) threshold = (0.75 + 0.25 * ramps.sigmoid_rampup( iter_num, max_iterations)) * np.sqrt(3) #N分类问题的最大不确定度是sqrt(N) mask = (uncertainty < threshold).float() # print("consistency_dist:",consistency_dist.item()) asd = np.prod(list(mask.shape)) #print("mask:",np.sum(mask.item())/asd ) consistency_dist = torch.sum( mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16) consistency_loss = consistency_weight * consistency_dist loss = supervised_loss + consistency_loss # pytorch模型训练的三板斧 # 一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程 optimizer.zero_grad() #把模型中参数的梯度设为0 loss.backward() optimizer.step() update_ema_variables(model, ema_model, args.ema_decay, iter_num)
def get_current_consistency_weight(epoch): return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)
def train(model, model_ema, memorybank, labeled_eval_loader_train, unlabeled_eval_loader_test, unlabeled_eval_loader_train, args): labeled_train_loader = CIFAR10Loader_iter(root=args.dataset_root, batch_size=args.batch_size // 2, split='train', aug='twice', shuffle=True, target_list=range( args.num_labeled_classes)) optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) criterion1 = nn.CrossEntropyLoss() criterion2 = BCE() criterion3 = CrossEntropyLabelSmooth( num_classes=args.num_unlabeled_classes) for epoch in range(args.epochs): loss_record = AverageMeter() model.train() model_ema.train() exp_lr_scheduler.step() w = args.rampup_coefficient * ramps.sigmoid_rampup( epoch, args.rampup_length) iters = 400 if epoch % 5 == 0: args.head = 'head2' feats, feats_mb, _ = test(model_ema, unlabeled_eval_loader_train, args) feats = F.normalize(torch.cat(feats, dim=0), dim=1) feats_mb = F.normalize(torch.cat(feats_mb, dim=0), dim=1) cluster = faiss.Kmeans(512, 5, niter=300, verbose=True, gpu=True) moving_avg_features = feats.numpy() cluster.train(moving_avg_features) _, labels_ = cluster.index.search(moving_avg_features, 1) labels = labels_ + 5 target_label = labels.reshape(-1).tolist() # centers=faiss.vector_to_array(cluster.centroids).reshape(5, 512) centers = cluster.centroids # Memory bank by zkc # if epoch == 0: memorybank.features = torch.cat((F.normalize(torch.tensor(centers).cuda(), dim=1), feats), dim=0).cuda() # memorybank.labels = torch.cat((torch.arange(args.num_unlabeled_classes), torch.Tensor(target_label).long())).cuda() if epoch == 0: memorybank.features = feats_mb.cuda() memorybank.labels = torch.Tensor( labels_.reshape(-1).tolist()).long().cuda() model.memory.prototypes[args.num_labeled_classes:] = F.normalize( torch.tensor(centers).cuda(), dim=1) model_ema.memory.prototypes[args. num_labeled_classes:] = F.normalize( torch.tensor(centers).cuda(), dim=1) feats, _, labels = test(model_ema, labeled_eval_loader_train, args) feats = F.normalize(torch.cat(feats, dim=0), dim=1) centers = torch.zeros(args.num_labeled_classes, 512) for i in range(args.num_labeled_classes): idx = torch.where(torch.tensor(labels) == i)[0] centers[i] = torch.mean(feats[idx], 0) model.memory.prototypes[:args.num_labeled_classes] = torch.tensor( centers).cuda() model_ema.memory.prototypes[:args. num_labeled_classes] = torch.tensor( centers).cuda() unlabeled_train_loader = CIFAR10Loader_iter( root=args.dataset_root, batch_size=args.batch_size // 2, split='train', aug='twice', shuffle=True, target_list=range(args.num_labeled_classes, num_classes), new_labels=target_label) # model.head2.weight.data.copy_( # torch.from_numpy(F.normalize(target_centers, axis=1)).float().cuda()) # labeled_train_loader.new_epoch() # unlabeled_train_loader.new_epoch() # for batch_idx,_ in enumerate(range(iters)): # ((x_l, x_bar_l), label_l, idx) = labeled_train_loader.next() # ((x_u, x_bar_u), label_u, idx) = unlabeled_train_loader.next() for batch_idx, (((x_l, x_bar_l), label_l, idx_l), ((x_u, x_bar_u), label_u, idx_u)) in enumerate( zip(labeled_train_loader, unlabeled_train_loader)): x = torch.cat([x_l, x_u], dim=0) x_bar = torch.cat([x_bar_l, x_bar_u], dim=0) label = torch.cat([label_l, label_u], dim=0) idx = torch.cat([idx_l, idx_u], dim=0) x, x_bar, label = x.to(device), x_bar.to(device), label.to(device) output1, output2, feat, feat_mb = model(x) output1_bar, output2_bar, _, _ = model(x_bar) with torch.no_grad(): output1_ema, output2_ema, feat_ema, feat_mb_ema = model_ema(x) output1_bar_ema, output2_bar_ema, _, _ = model_ema(x_bar) prob1, prob1_bar, prob2, prob2_bar = F.softmax( output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax( output2, dim=1), F.softmax(output2_bar, dim=1) prob1_ema, prob1_bar_ema, prob2_ema, prob2_bar_ema = F.softmax( output1_ema, dim=1), F.softmax(output1_bar_ema, dim=1), F.softmax( output2_ema, dim=1), F.softmax(output2_bar_ema, dim=1) mask_lb = label < args.num_labeled_classes loss_ce_label = criterion1(output1[mask_lb], label[mask_lb]) loss_ce_unlabel = criterion1(output2[~mask_lb], label[~mask_lb]) # torch.tensor(0)# loss_in_unlabel = torch.tensor( 0 ) #memorybank(feat_mb[~mask_lb], feat_mb_ema[~mask_lb], label[~mask_lb], idx[~mask_lb]) loss_ce = loss_ce_label + loss_ce_unlabel loss_bce = rank_bce(criterion2, feat, mask_lb, prob2, prob2_bar) # torch.tensor(0)# consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss( prob2, prob2_bar) consistency_loss_ema = F.mse_loss( prob1, prob1_bar_ema) + F.mse_loss(prob2, prob2_bar_ema) loss = loss_ce + loss_bce + w * consistency_loss + w * consistency_loss_ema + loss_in_unlabel loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() _update_ema_variables(model, model_ema, 0.99, epoch * iters + batch_idx) if batch_idx % 200 == 0: print( 'Train Epoch: {}, iter {}/{} unl-CE Loss: {:.4f}, unl-instance Loss: {:.4f}, l-CE Loss: {:.4f}, BCE Loss: {:.4f}, CL Loss: {:.4f}, Avg Loss: {:.4f}' .format(epoch, batch_idx, 400, loss_ce_unlabel.item(), loss_in_unlabel.item(), loss_ce_label.item(), loss_bce.item(), consistency_loss.item(), loss_record.avg)) print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) # print('test on labeled classes') # args.head = 'head1' # test(model, labeled_eval_loader_test, args) # print('test on unlabeled classes') args.head = 'head2' # test(model, unlabeled_eval_loader_train, args) test(model_ema, unlabeled_eval_loader_test, args)