def create_model(pretrained=True, architecture="resnet34", is_train=True): if architecture == "resnet18": if pretrained: model = torchvision.models.resnet18(pretrained=pretrained) model.fc = torch.nn.Linear(512 * 1, NUM_CATEGORY) else: model = torchvision.models.resnet18(pretrained=pretrained, num_classes=NUM_CATEGORY) model.avgpool = torch.nn.AdaptiveAvgPool2d(1) elif architecture == "resnet34": if pretrained: model = torchvision.models.resnet34(pretrained=pretrained) model.fc = torch.nn.Linear(512 * 1, NUM_CATEGORY) else: model = torchvision.models.resnet34(pretrained=pretrained, num_classes=NUM_CATEGORY) model.avgpool = torch.nn.AdaptiveAvgPool2d(1) elif architecture == "resnet50": if pretrained: model = torchvision.models.resnet50(pretrained=pretrained) model.fc = torch.nn.Linear(512 * 4, NUM_CATEGORY) else: model = torchvision.models.resnet50(pretrained=pretrained, num_classes=NUM_CATEGORY) model.avgpool = torch.nn.AdaptiveAvgPool2d(1) elif architecture == "mobilenetv2": model = MobileNetV2.MobileNetV2(n_class=NUM_CATEGORY, input_size=IMG_HEIGHT) elif architecture == "se_resnet50": if pretrained: model = great_networks.se_resnet50(pretrained=pretrained) model.last_linear = torch.nn.Linear(512 * 4, NUM_CATEGORY) else: model = great_networks.se_resnet50(num_classes=NUM_CATEGORY, pretrained=None) model.avg_pool = torch.nn.AdaptiveAvgPool2d(1) elif architecture == "se_resnext50": if pretrained: model = great_networks.se_resnext50_32x4d(pretrained=pretrained) model.last_linear = torch.nn.Linear(512 * 4, NUM_CATEGORY) else: model = great_networks.se_resnext50_32x4d(num_classes=NUM_CATEGORY, pretrained=None) model.avg_pool = torch.nn.AdaptiveAvgPool2d(1) elif architecture == "original": model = network.ResNet(network.BasicBlock, [3, 4, 6, 3], num_classes=NUM_CATEGORY) else: raise ValueError() model.to(DEVICE) if is_train: model.train() else: model.eval() return model
def train_init_irm(args): # prepare data dsets = {} dset_loaders = {} dsets["source"] = ImageList(open(args.source_list).readlines(), \ transform=image_train()) dset_loaders["source"] = DataLoader(dsets["source"], batch_size=args.batch_size, \ shuffle=True, num_workers=4, drop_last=True) dsets["target"] = ImageList(open(args.target_list).readlines(), \ transform=image_train()) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, \ shuffle=True, num_workers=4, drop_last=True) dsets["test"] = ImageList(open(args.target_list).readlines(), \ transform=image_test()) dset_loaders["test"] = DataLoader(dsets["test"], batch_size=2 * args.batch_size, \ shuffle=False, num_workers=4) #model model = network.ResNet(class_num=args.num_class, radius=args.radius, trainable_radius=args.trainable_radius).cuda() parameter_list = model.get_parameters() #pdb.set_trace() optimizer = torch.optim.SGD(parameter_list, lr=args.lr, momentum=0.9, weight_decay=0.005) gpus = args.gpu_id.split(',') if len(gpus) > 1: adv_net = nn.DataParallel(adv_net, device_ids=[int(i) for i in gpus]) model = nn.DataParallel(model, device_ids=[int(i) for i in gpus]) ## train len_train_source = len(dset_loaders["source"]) len_train_target = len(dset_loaders["target"]) best_acc = 0.0 best_model = copy.deepcopy(model) Cs_memory = torch.zeros(args.num_class, 256).cuda() Ct_memory = torch.zeros(args.num_class, 256).cuda() for i in range(args.max_iter): if i % args.test_interval == args.test_interval - 1: model.train(False) temp_acc = image_classification_test(dset_loaders, model) if temp_acc > best_acc: best_acc = temp_acc best_model = copy.deepcopy(model) log_str = "\niter: {:05d}, \t precision: {:.4f},\t best_acc:{:.4f}".format( i, temp_acc, best_acc) args.log_file.write(log_str) args.log_file.flush() print(log_str) if i % args.snapshot_interval == args.snapshot_interval - 1: if not os.path.exists('snapshot'): os.mkdir('snapshot') if not os.path.exists('snapshot/save'): os.mkdir('snapshot/save') torch.save(best_model, 'snapshot/save/initial_model.pk') model.train(True) if (args.lr_decay): optimizer = lr_schedule.inv_lr_scheduler(optimizer, i) if i % len_train_source == 0: iter_source = iter(dset_loaders["source"]) if i % len_train_target == 0: iter_target = iter(dset_loaders["target"]) inputs_source, labels_source = iter_source.next() inputs_target, labels_target = iter_target.next() inputs_source, inputs_target, labels_source = inputs_source.cuda( ), inputs_target.cuda(), labels_source.cuda() if (args.irm_type == 'batch'): scale_source = torch.tensor(1.).cuda().requires_grad_() scale_target = torch.tensor(1.).cuda().requires_grad_() elif (args.irm_type == 'sample'): scale_source = torch.ones(inputs_source.size(0), 1).cuda().requires_grad_() scale_target = torch.ones(inputs_target.size(0), 1).cuda().requires_grad_() if (args.irm_feature == 'last_hidden'): features_source, outputs_source = model.forward_mul( inputs_source, scale_source) features_target, outputs_target = model.forward_mul( inputs_target, scale_target) elif (args.irm_feature == 'logit'): features_source, outputs_source = model(inputs_source) features_target, outputs_target = model(inputs_target) outputs_source = outputs_source * scale_source outputs_target = outputs_target * scale_target features = torch.cat((features_source, features_target), dim=0) classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) if args.baseline == 'MSTN': lam = network.calc_coeff(i) elif args.baseline == 'DANN': lam = 0.0 pseu_labels_target = torch.argmax(outputs_target, dim=1) loss_sm, Cs_memory, Ct_memory = utils.SM(features_source, features_target, labels_source, pseu_labels_target, Cs_memory, Ct_memory) total_loss = classifier_loss + lam * loss_sm irm_loss = 0 if (i > args.irm_warmup_step): if ('MSTN' in args.init_method): irm_loss += sum( penalty_loss_scales(total_loss, [scale_source, scale_target])) else: source_irm_loss = sum( penalty_loss_scales(classifier_loss, [scale_source])) irm_loss += source_irm_loss if ('target' in args.init_method): classifier_loss_target = nn.CrossEntropyLoss()( outputs_target, pseu_labels_target.detach()) target_irm_loss = sum( penalty_loss_scales(classifier_loss_target, [scale_target])) irm_loss += target_irm_loss total_loss += args.irm_weight * irm_loss optimizer.zero_grad() total_loss.backward() optimizer.step() if (args.trainable_radius): print( 'step:{: d},\t,class_loss:{:.4f},\t,irm_loss:{:.4f},\tradius:{:.4f}' .format(i, classifier_loss.item(), float(irm_loss), float(model.radius))) else: print('step:{: d},\t,class_loss:{:.4f},\t,irm_loss:{:.4f}'.format( i, classifier_loss.item(), float(irm_loss))) Cs_memory.detach_() Ct_memory.detach_() return best_acc, best_model
def train_init(args): # prepare data dsets = {} dset_loaders = {} dsets["source"] = ImageList(open(args.source_list).readlines(), \ transform=image_train()) dset_loaders["source"] = DataLoader(dsets["source"], batch_size=args.batch_size, \ shuffle=True, num_workers=4, drop_last=True) dsets["target"] = ImageList(open(args.target_list).readlines(), \ transform=image_train()) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, \ shuffle=True, num_workers=4, drop_last=True) dsets["test"] = ImageList(open(args.target_list).readlines(), \ transform=image_test()) dset_loaders["test"] = DataLoader(dsets["test"], batch_size=2 * args.batch_size, \ shuffle=False, num_workers=4) #model model = network.ResNet(class_num=args.num_class, radius=args.radius, trainable_radius=args.trainable_radius).cuda() adv_net = network.AdversarialNetwork(in_feature=model.output_num(), hidden_size=1024).cuda() parameter_list = model.get_parameters() + adv_net.get_parameters() optimizer = torch.optim.SGD(parameter_list, lr=args.lr, momentum=0.9, weight_decay=0.005) gpus = args.gpu_id.split(',') if len(gpus) > 1: adv_net = nn.DataParallel(adv_net, device_ids=[int(i) for i in gpus]) model = nn.DataParallel(model, device_ids=[int(i) for i in gpus]) ## train len_train_source = len(dset_loaders["source"]) len_train_target = len(dset_loaders["target"]) best_acc = 0.0 best_model = copy.deepcopy(model) Cs_memory = torch.zeros(args.num_class, 256).cuda() Ct_memory = torch.zeros(args.num_class, 256).cuda() for i in range(args.max_iter): if i % args.test_interval == args.test_interval - 1: model.train(False) temp_acc = image_classification_test(dset_loaders, model) if temp_acc > best_acc: best_acc = temp_acc best_model = copy.deepcopy(model) log_str = "\niter: {:05d}, \t precision: {:.4f},\t best_acc:{:.4f}".format( i, temp_acc, best_acc) args.log_file.write(log_str) args.log_file.flush() print(log_str) if i % args.snapshot_interval == args.snapshot_interval - 1: if not os.path.exists('snapshot'): os.mkdir('snapshot') if not os.path.exists('snapshot/save'): os.mkdir('snapshot/save') torch.save(best_model, 'snapshot/save/initial_model.pk') model.train(True) adv_net.train(True) if (args.lr_decay): optimizer = lr_schedule.inv_lr_scheduler(optimizer, i) if i % len_train_source == 0: iter_source = iter(dset_loaders["source"]) if i % len_train_target == 0: iter_target = iter(dset_loaders["target"]) inputs_source, labels_source = iter_source.next() inputs_target, labels_target = iter_target.next() inputs_source, inputs_target, labels_source = inputs_source.cuda( ), inputs_target.cuda(), labels_source.cuda() features_source, outputs_source = model(inputs_source) features_target, outputs_target = model(inputs_target) features = torch.cat((features_source, features_target), dim=0) classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) adv_loss = utils.loss_adv(features, adv_net) if args.baseline == 'MSTN': lam = network.calc_coeff(i) elif args.baseline == 'DANN': lam = 0.0 pseu_labels_target = torch.argmax(outputs_target, dim=1) loss_sm, Cs_memory, Ct_memory = utils.SM(features_source, features_target, labels_source, pseu_labels_target, Cs_memory, Ct_memory) total_loss = classifier_loss + adv_loss + lam * loss_sm optimizer.zero_grad() total_loss.backward() optimizer.step() print('step:{: d},\t,class_loss:{:.4f},\t,adv_loss:{:.4f}'.format( i, classifier_loss.item(), adv_loss.item())) Cs_memory.detach_() Ct_memory.detach_() return best_acc, best_model
def train_irm_feat(args): # prepare data dsets = {} dset_loaders = {} dsets["source"] = ImageList(open(args.source_list).readlines(), \ transform=image_train()) dset_loaders["source"] = DataLoader(dsets["source"], batch_size=args.batch_size, \ shuffle=True, num_workers=4, drop_last=True) dsets["target"] = ImageList(open(args.save_path).readlines(), transform=image_train(), pseudo=True) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, \ shuffle=True, num_workers=4, drop_last=True) dsets["test"] = ImageList(open(args.target_list).readlines(), \ transform=image_test()) dset_loaders["test"] = DataLoader(dsets["test"], batch_size=2 * args.batch_size, \ shuffle=False, num_workers=4) #model model = network.ResNet(class_num=args.num_class, radius=args.radius_refine, trainable_radius=args.trainable_radius).cuda() parameter_classifier = model.get_parameters() optimizer_classifier = torch.optim.SGD(parameter_classifier, lr=args.lr_refine, momentum=0.9, weight_decay=0.005) gpus = args.gpu_id.split(',') if len(gpus) > 1: model = nn.DataParallel(model, device_ids=[int(i) for i in gpus]) ## train len_train_source = len(dset_loaders["source"]) len_train_target = len(dset_loaders["target"]) best_acc = 0.0 best_model = copy.deepcopy(model) for i in range(args.max_iter): if i % args.test_interval == args.test_interval - 1: model.train(False) temp_acc = image_classification_test(dset_loaders, model) if temp_acc > best_acc: best_acc = temp_acc best_model = copy.deepcopy(model) log_str = "\n iter: {:05d}, \t precision: {:.4f},\t best_acc:{:.4f}".format( i, temp_acc, best_acc) args.log_file.write(log_str) args.log_file.flush() print(log_str) if i % args.snapshot_interval == args.snapshot_interval - 1: if not os.path.exists('snapshot'): os.mkdir('snapshot') if not os.path.exists('snapshot/save'): os.mkdir('snapshot/save') torch.save(best_model, 'snapshot/save/best_model.pk') model.train(True) if (args.lr_decay_refine): optimizer_classifier = lr_schedule.inv_lr_scheduler( optimizer_classifier, i) if i % len_train_source == 0: iter_source = iter(dset_loaders["source"]) if i % len_train_target == 0: iter_target = iter(dset_loaders["target"]) inputs_source, labels_source = iter_source.next() inputs_target, pseudo_labels_target, weights = iter_target.next() inputs_source, labels_source = inputs_source.cuda( ), labels_source.cuda() inputs_target, pseudo_labels_target = inputs_target.cuda( ), pseudo_labels_target.cuda() weights = weights.type(torch.Tensor).cuda() scale_source = torch.tensor(1.).cuda().requires_grad_() scale_target = torch.tensor(1.).cuda().requires_grad_() features_source, outputs_source = model.forward_mul( inputs_source, scale_source) features_target, outputs_target = model.forward_mul( inputs_target, scale_target) features = torch.cat((features_source, features_target), dim=0) source_class_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) target_robust_loss = utils.robust_pseudo_loss(outputs_target, pseudo_labels_target, weights) classifier_loss = source_class_loss + target_robust_loss if args.baseline == 'MSTN': lam = network.calc_coeff(i, max_iter=2000) elif args.baseline == 'DANN': lam = 0.0 # pseu_labels_target = torch.argmax(outputs_target, dim=1) # loss_sm, Cs_memory, Ct_memory = utils.SM(features_source, features_target, labels_source, pseu_labels_target, # Cs_memory, Ct_memory) # feature_loss = classifier_loss + lam*loss_sm + lam*H # irm_loss = 0 # if('MSTN' in args.irm_feature): # if(i>args.irm_warmup_step): # irm_loss += sum(penalty_loss_scales(feature_loss, [scale_source, scale_target])) # else: source_irm_loss = penalty_loss_scale(source_class_loss, scale_source) target_irm_loss = penalty_loss_scale(target_robust_loss, scale_target) irm_loss = (source_irm_loss + target_irm_loss) feature_loss = classifier_loss + args.irm_weight * irm_loss optimizer_classifier.zero_grad() classifier_loss.backward() optimizer_classifier.step() print( 'step:{: d},\t,source_class_loss:{:.4f},\t,target_robust_loss:{:.4f}' ''.format(i, source_class_loss.item(), target_robust_loss.item())) #Cs_memory.detach_() #Ct_memory.detach_() return best_acc, best_model
from utils import * from graphics import Graphics from state import State from torch.autograd import Variable import network import torch current = State(None, None) net = network.ResNet() focusMoves = [] focus = 0 def MCTS(root): global net if root.Expand(): data = torch.FloatTensor(StateToImg(root.child[-1])).unsqueeze(0) delta = net(Variable(data)).data[0, 0] root.child[-1].v = delta delta *= -1 else: best = root.BestChild() if best == None: delta = -1 else: delta = -MCTS(best) root.v += delta root.n += 1 return delta
def __initialize_model(self): self.model = network.ResNet() if self.cuda: self.model = self.model.cuda()
def train(args): ## init logger logger = Logger(ckpt_path=args.ckpt_path, tsbd_path=args.vis_path) ## pre process train_transforms = prep.image_train(augmentation=args.augmentation) valid_transforms = prep.image_test() train_dset = ImageList(open(args.train_list).readlines(), datadir='', transform=train_transforms) valid_dset = ImageList(open(args.valid_list).readlines(), datadir='', transform=valid_transforms) train_loader = DataLoader(train_dset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=False) valid_loader = DataLoader(valid_dset, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=False) ## set the model net = None if args.net == 'MyOwn': ## model C net = network.ResNet(network.Bottleneck, [3, 4, 6, 3], weight_init=args.weight_init, use_bottleneck=args.bottleneck, num_classes=args.class_num, weight=args.weight) else: ## model A -> Resnet50 == pretrained ## model B -> Resnet50 == not pretrained net = network.ResNetFc(resnet_name=args.net, pretrained=args.pretrained, weight_init=args.weight_init, use_bottleneck=args.bottleneck, new_cls=True, class_num=args.class_num) net = net.cuda() parameter_list = net.get_parameters() ## set optimizer and learning scheduler if args.opt_type == 'SGD': optimizer = optim.SGD(parameter_list, lr=1.0, momentum=args.momentum, weight_decay=0.0005, nesterov=True) lr_param = {'lr': args.lr, "gamma": 0.001, "power": 0.75} lr_scheduler = lr_schedule.inv_lr_scheduler ## gpu gpus = args.gpu_id.split(',') if len(gpus) > 0: print('gpus: ', [int(i) for i in gpus]) net = nn.DataParallel(net, device_ids=[int(i) for i in gpus]) ## for save model model = {} model['net'] = net ## log logger.reset() ## progress bar total_epochs = 1000 total_progress_bar = tqdm.tqdm(desc='Train iter', total=total_epochs * len(train_loader)) ## begin train it = 0 for epoch in range(total_epochs): for img, label, path in train_loader: ## update log it += 1 logger.step(1) total_progress_bar.update(1) ## validate if it % args.test_interval == 1: ## validate acc = validate(model, valid_loader) ## utils logger.add_scalar('accuracy', acc * 100) logger.save_ckpt(state={'net': net.state_dict()}, cur_metric_val=acc) log_str = "iter: {:05d}, precision: {:.5f}".format(it, acc) args.log.write(log_str + '\n') args.log.flush() ## train the model net.train(True) optimizer = lr_scheduler(optimizer, it, **lr_param) optimizer.zero_grad() ## cuda img = img.cuda() label = label.cuda() feature, output = net(img) loss = nn.CrossEntropyLoss()(output, label) loss.backward() optimizer.step() ## vis logger.add_scalar('loss', loss)
def train_distill(teacher, args): # prepare data dsets = {} dset_loaders = {} dsets["source"] = ImageList(open(args.source_list).readlines(), \ transform=image_train()) dset_loaders["source"] = DataLoader(dsets["source"], batch_size=args.batch_size, \ shuffle=True, num_workers=2, drop_last=True) dsets["target"] = ImageList(open(args.target_list).readlines(), \ transform=image_train(), params=args) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, \ shuffle=True, num_workers=2, drop_last=True) dsets["test"] = ImageList(open(args.target_list).readlines(), \ transform=image_test()) dset_loaders["test"] = DataLoader(dsets["test"], batch_size=2 * args.batch_size, \ shuffle=False, num_workers=2) #model model = network.ResNet(class_num=args.num_class).cuda() adv_net = network.AdversarialNetwork(in_feature=model.output_num(),hidden_size=1024, max_iter=args.max_iter).cuda() parameter_list = model.get_parameters() + adv_net.get_parameters() optimizer = torch.optim.SGD(parameter_list,lr=args.lr,momentum=0.9,weight_decay=0.005) # model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0) gpus = args.gpu_id.split(',') if len(gpus) > 1: adv_net = nn.DataParallel(adv_net, device_ids=[int(i) for i in gpus]) model = nn.DataParallel(model, device_ids=[int(i) for i in gpus]) ## train len_train_source = len(dset_loaders["source"]) len_train_target = len(dset_loaders["target"]) best_acc = 0.0 best_model = copy.deepcopy(model) print_interval = (args.test_interval // 10) nt_cent = utils.NTXentLoss('cuda', args.batch_size, 0.2, True) Cs_memory = torch.zeros(args.num_class, 256).cuda() Ct_memory = torch.zeros(args.num_class, 256).cuda() max_batch = 100 queue_size = args.batch_size * max_batch queue_data = [torch.randn(queue_size, 256).cuda(), torch.randn(queue_size, args.num_class).cuda()] queue_data_w = [torch.randn(queue_size, 256).cuda(), torch.randn(queue_size, args.num_class).cuda()] # queue_data = [torch.randn(queue_size, 256).cuda(), torch.randn(queue_size, 256).cuda()] queue_labels = [torch.ones(queue_size).cuda() * (args.num_class+1), torch.ones(queue_size).cuda() * (args.num_class+1)] queue_ptr = torch.zeros(1, dtype=torch.long) queue_weight = np.power(np.linspace(.0, 1.0, max_batch), 3) queue_weight = np.repeat(queue_weight, args.batch_size) best_ema_acc = 0.0 for i in range(args.max_iter): if i % args.test_interval == args.test_interval - 1: model.train(False) temp_acc = image_classification_test(dset_loaders, model) if temp_acc > best_acc: best_acc = temp_acc best_model = copy.deepcopy(model) log_str = "\niter: {:05d}, \t precision: {:.4f},\t best_acc:{:.4f}".format(i, temp_acc, best_acc) args.log_file.write(log_str) args.log_file.flush() print(log_str) temp_acc = image_classification_test(dset_loaders, teacher) if temp_acc > best_ema_acc: best_ema_acc = temp_acc # best_model = copy.deepcopy(model) log_str = "\niter: {:05d}, \t precision: {:.4f},\t best_ema_acc:{:.4f}".format(i, temp_acc, best_ema_acc) args.log_file.write(log_str) args.log_file.flush() print(log_str) # if i % args.snapshot_interval == args.snapshot_interval -1: # if not os.path.exists(args.save_dir): # os.mkdir(args.save_dir) # torch.save(best_model,os.path.join(args.save_dir, 'initial_model.pk')) model.train(True) adv_net.train(True) teacher.train(False) optimizer = lr_schedule.inv_lr_scheduler(optimizer,i) if i % len_train_source == 0: iter_source = iter(dset_loaders["source"]) if i % len_train_target == 0: iter_target = iter(dset_loaders["target"]) inputs_source, labels_source = iter_source.next() inputs_target, _, inputs_target_mosaic_w, inputs_target_mosaic_s, labels_target = iter_target.next() inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda() inputs_target_mosaic_w, inputs_target_mosaic_s = inputs_target_mosaic_w.cuda(), inputs_target_mosaic_s.cuda() features_source, outputs_source = model(inputs_source) features_target, outputs_target = model(inputs_target) features = torch.cat((features_source, features_target), dim=0) with torch.no_grad(): features_target_teacher, outputs_target_teacher = teacher(inputs_target) adv_loss = utils.loss_adv(features,adv_net) H = torch.mean(utils.Entropy(F.softmax(outputs_target, dim=1))) if args.baseline == 'MSTN': lam = network.calc_coeff(i) elif args.baseline =='DANN': lam = 0.0 prob_max, pseu_labels_target = torch.max(F.softmax(outputs_target, dim=1), dim=1) loss_sm, Cs_memory, Ct_memory = utils.SM(features_source, features_target, labels_source, pseu_labels_target, Cs_memory, Ct_memory) # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) classifier_loss = 4*utils.cross_entropy_with_logits(outputs_target / 4.0, F.softmax(outputs_target_teacher / 4.0, dim=1)) + nn.CrossEntropyLoss()(outputs_source, labels_source) total_loss = classifier_loss + lam * loss_sm + adv_loss + network.calc_coeff((i-100), high=0.1, max_iter=100)*H prob_max, pseu_labels_target = torch.max(F.softmax(outputs_target, dim=1), dim=1) optimizer.zero_grad() total_loss.backward() optimizer.step() optimizer.zero_grad() mosaic_loss_target = torch.zeros(1) if i < args.max_iter // 5 * 2: alpha = 0.0 else: alpha = 0.5 with _disable_tracking_bn_stats(model): mosaic_features_target_w, mosaic_outputs_target_w = model(inputs_target_mosaic_w) mosaic_features_target_s, mosaic_outputs_target_s = model(inputs_target_mosaic_s) with torch.no_grad(): features_list_w = [mosaic_features_target_w, F.softmax(mosaic_outputs_target_w, dim=1)] features_target_, outputs_target_ = model(inputs_target) outputs_target = alpha * outputs_target_ + (1. - alpha) * outputs_target_teacher prob_max, pseu_labels_target = torch.max(F.softmax(outputs_target, dim=1), dim=1) features_list = [features_target_, F.softmax(outputs_target, dim=1)] labels_list = [pseu_labels_target, pseu_labels_target] utils.rightshift(queue_weight, args.batch_size) for j in range(len(features_list)): queue_data[j][queue_ptr:queue_ptr+args.batch_size, :] = features_list[j] queue_data_w[j][queue_ptr:queue_ptr+args.batch_size, :] = features_list_w[j] queue_labels[j][queue_ptr:queue_ptr+args.batch_size] = labels_list[j] pre_ptr = int(queue_ptr) ptr = ((i+1) % max_batch) * args.batch_size queue_ptr[0] = ptr mosaic_loss_target = (nt_cent(queue_data[1].detach(), F.softmax(mosaic_outputs_target_w, dim=1), queue_labels[1], pseu_labels_target.float(), queue_weight, pre_ptr, class_level=False) + 1.*nt_cent(queue_data_w[1].detach(), F.softmax(mosaic_outputs_target_s, dim=1), queue_labels[1], pseu_labels_target.float(), queue_weight, pre_ptr, class_level=False)) * network.calc_coeff(i, high=0.3, max_iter=50) mosaic_loss = mosaic_loss_target * 1.0 # mosaic_loss = utils.cross_entropy_with_logits(mosaic_outputs_target, F.softmax(outputs_target*1.5, dim=1)) * (network.calc_coeff(i, high=0.5, max_iter=2000)) # mosaic_loss += 0.4*(torch.abs(F.softmax(outputs_target, dim=1).detach() - F.softmax(mosaic_outputs_target, dim=1)).sum(1)).mean(0) mosaic_loss.backward() optimizer.step() if i % print_interval == 0: log_str = 'step:{: d},\t,class_loss:{:.4f},\t,adv_loss:{:.4f}\t,mosaic_loss:{:.4f}\t,mean_prob:{:.4f}'.format(i, classifier_loss.item(), adv_loss.item(), mosaic_loss_target.item(),prob_max.mean().item()) print(log_str) args.log_file.write('\n'+log_str) args.log_file.flush() Cs_memory.detach_() Ct_memory.detach_() return best_acc, best_model