Esempio n. 1
0
def create_model(pretrained=True, architecture="resnet34", is_train=True):

    if architecture == "resnet18":
        if pretrained:
            model = torchvision.models.resnet18(pretrained=pretrained)
            model.fc = torch.nn.Linear(512 * 1, NUM_CATEGORY)
        else:
            model = torchvision.models.resnet18(pretrained=pretrained, num_classes=NUM_CATEGORY)
        model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
    elif architecture == "resnet34":
        if pretrained:
            model = torchvision.models.resnet34(pretrained=pretrained)
            model.fc = torch.nn.Linear(512 * 1, NUM_CATEGORY)
        else:
            model = torchvision.models.resnet34(pretrained=pretrained, num_classes=NUM_CATEGORY)
        model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
    elif architecture == "resnet50":
        if pretrained:
            model = torchvision.models.resnet50(pretrained=pretrained)
            model.fc = torch.nn.Linear(512 * 4, NUM_CATEGORY)
        else:
            model = torchvision.models.resnet50(pretrained=pretrained, num_classes=NUM_CATEGORY)
        model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
    elif architecture == "mobilenetv2":
        model = MobileNetV2.MobileNetV2(n_class=NUM_CATEGORY, input_size=IMG_HEIGHT)
    elif architecture == "se_resnet50":
        if pretrained:
            model = great_networks.se_resnet50(pretrained=pretrained)
            model.last_linear = torch.nn.Linear(512 * 4, NUM_CATEGORY)
        else:
            model = great_networks.se_resnet50(num_classes=NUM_CATEGORY, pretrained=None)
        model.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
    elif architecture == "se_resnext50":
        if pretrained:
            model = great_networks.se_resnext50_32x4d(pretrained=pretrained)
            model.last_linear = torch.nn.Linear(512 * 4, NUM_CATEGORY)
        else:
            model = great_networks.se_resnext50_32x4d(num_classes=NUM_CATEGORY, pretrained=None)
        model.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
    elif architecture == "original":
        model = network.ResNet(network.BasicBlock, [3, 4, 6, 3], num_classes=NUM_CATEGORY)
    else:
        raise ValueError()

    model.to(DEVICE)

    if is_train:
        model.train()
    else:
        model.eval()

    return model
Esempio n. 2
0
def train_init_irm(args):
    # prepare data
    dsets = {}
    dset_loaders = {}
    dsets["source"] = ImageList(open(args.source_list).readlines(), \
                                transform=image_train())
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=args.batch_size, \
                                        shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(args.target_list).readlines(), \
                                transform=image_train())
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, \
                                        shuffle=True, num_workers=4, drop_last=True)

    dsets["test"] = ImageList(open(args.target_list).readlines(), \
                              transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=2 * args.batch_size, \
                                      shuffle=False, num_workers=4)

    #model
    model = network.ResNet(class_num=args.num_class,
                           radius=args.radius,
                           trainable_radius=args.trainable_radius).cuda()
    parameter_list = model.get_parameters()
    #pdb.set_trace()
    optimizer = torch.optim.SGD(parameter_list,
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=0.005)

    gpus = args.gpu_id.split(',')
    if len(gpus) > 1:
        adv_net = nn.DataParallel(adv_net, device_ids=[int(i) for i in gpus])
        model = nn.DataParallel(model, device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    best_acc = 0.0
    best_model = copy.deepcopy(model)

    Cs_memory = torch.zeros(args.num_class, 256).cuda()
    Ct_memory = torch.zeros(args.num_class, 256).cuda()

    for i in range(args.max_iter):
        if i % args.test_interval == args.test_interval - 1:
            model.train(False)
            temp_acc = image_classification_test(dset_loaders, model)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = copy.deepcopy(model)
            log_str = "\niter: {:05d}, \t precision: {:.4f},\t best_acc:{:.4f}".format(
                i, temp_acc, best_acc)
            args.log_file.write(log_str)
            args.log_file.flush()
            print(log_str)
        if i % args.snapshot_interval == args.snapshot_interval - 1:
            if not os.path.exists('snapshot'):
                os.mkdir('snapshot')
            if not os.path.exists('snapshot/save'):
                os.mkdir('snapshot/save')
            torch.save(best_model, 'snapshot/save/initial_model.pk')

        model.train(True)
        if (args.lr_decay):
            optimizer = lr_schedule.inv_lr_scheduler(optimizer, i)

        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()

        if (args.irm_type == 'batch'):
            scale_source = torch.tensor(1.).cuda().requires_grad_()
            scale_target = torch.tensor(1.).cuda().requires_grad_()
        elif (args.irm_type == 'sample'):
            scale_source = torch.ones(inputs_source.size(0),
                                      1).cuda().requires_grad_()
            scale_target = torch.ones(inputs_target.size(0),
                                      1).cuda().requires_grad_()

        if (args.irm_feature == 'last_hidden'):
            features_source, outputs_source = model.forward_mul(
                inputs_source, scale_source)
            features_target, outputs_target = model.forward_mul(
                inputs_target, scale_target)
        elif (args.irm_feature == 'logit'):
            features_source, outputs_source = model(inputs_source)
            features_target, outputs_target = model(inputs_target)
            outputs_source = outputs_source * scale_source
            outputs_target = outputs_target * scale_target

        features = torch.cat((features_source, features_target), dim=0)

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)

        if args.baseline == 'MSTN':
            lam = network.calc_coeff(i)
        elif args.baseline == 'DANN':
            lam = 0.0
        pseu_labels_target = torch.argmax(outputs_target, dim=1)

        loss_sm, Cs_memory, Ct_memory = utils.SM(features_source,
                                                 features_target,
                                                 labels_source,
                                                 pseu_labels_target, Cs_memory,
                                                 Ct_memory)
        total_loss = classifier_loss + lam * loss_sm

        irm_loss = 0
        if (i > args.irm_warmup_step):
            if ('MSTN' in args.init_method):
                irm_loss += sum(
                    penalty_loss_scales(total_loss,
                                        [scale_source, scale_target]))
            else:
                source_irm_loss = sum(
                    penalty_loss_scales(classifier_loss, [scale_source]))
                irm_loss += source_irm_loss

            if ('target' in args.init_method):
                classifier_loss_target = nn.CrossEntropyLoss()(
                    outputs_target, pseu_labels_target.detach())
                target_irm_loss = sum(
                    penalty_loss_scales(classifier_loss_target,
                                        [scale_target]))
                irm_loss += target_irm_loss

        total_loss += args.irm_weight * irm_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if (args.trainable_radius):
            print(
                'step:{: d},\t,class_loss:{:.4f},\t,irm_loss:{:.4f},\tradius:{:.4f}'
                .format(i, classifier_loss.item(), float(irm_loss),
                        float(model.radius)))

        else:
            print('step:{: d},\t,class_loss:{:.4f},\t,irm_loss:{:.4f}'.format(
                i, classifier_loss.item(), float(irm_loss)))
        Cs_memory.detach_()
        Ct_memory.detach_()

    return best_acc, best_model
Esempio n. 3
0
def train_init(args):
    # prepare data
    dsets = {}
    dset_loaders = {}
    dsets["source"] = ImageList(open(args.source_list).readlines(), \
                                transform=image_train())
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=args.batch_size, \
                                        shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(args.target_list).readlines(), \
                                transform=image_train())
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, \
                                        shuffle=True, num_workers=4, drop_last=True)

    dsets["test"] = ImageList(open(args.target_list).readlines(), \
                              transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=2 * args.batch_size, \
                                      shuffle=False, num_workers=4)

    #model
    model = network.ResNet(class_num=args.num_class,
                           radius=args.radius,
                           trainable_radius=args.trainable_radius).cuda()
    adv_net = network.AdversarialNetwork(in_feature=model.output_num(),
                                         hidden_size=1024).cuda()
    parameter_list = model.get_parameters() + adv_net.get_parameters()
    optimizer = torch.optim.SGD(parameter_list,
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=0.005)

    gpus = args.gpu_id.split(',')
    if len(gpus) > 1:
        adv_net = nn.DataParallel(adv_net, device_ids=[int(i) for i in gpus])
        model = nn.DataParallel(model, device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    best_acc = 0.0
    best_model = copy.deepcopy(model)

    Cs_memory = torch.zeros(args.num_class, 256).cuda()
    Ct_memory = torch.zeros(args.num_class, 256).cuda()

    for i in range(args.max_iter):
        if i % args.test_interval == args.test_interval - 1:
            model.train(False)
            temp_acc = image_classification_test(dset_loaders, model)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = copy.deepcopy(model)
            log_str = "\niter: {:05d}, \t precision: {:.4f},\t best_acc:{:.4f}".format(
                i, temp_acc, best_acc)
            args.log_file.write(log_str)
            args.log_file.flush()
            print(log_str)
        if i % args.snapshot_interval == args.snapshot_interval - 1:
            if not os.path.exists('snapshot'):
                os.mkdir('snapshot')
            if not os.path.exists('snapshot/save'):
                os.mkdir('snapshot/save')
            torch.save(best_model, 'snapshot/save/initial_model.pk')

        model.train(True)
        adv_net.train(True)
        if (args.lr_decay):
            optimizer = lr_schedule.inv_lr_scheduler(optimizer, i)

        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = model(inputs_source)
        features_target, outputs_target = model(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        adv_loss = utils.loss_adv(features, adv_net)

        if args.baseline == 'MSTN':
            lam = network.calc_coeff(i)
        elif args.baseline == 'DANN':
            lam = 0.0
        pseu_labels_target = torch.argmax(outputs_target, dim=1)
        loss_sm, Cs_memory, Ct_memory = utils.SM(features_source,
                                                 features_target,
                                                 labels_source,
                                                 pseu_labels_target, Cs_memory,
                                                 Ct_memory)
        total_loss = classifier_loss + adv_loss + lam * loss_sm
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        print('step:{: d},\t,class_loss:{:.4f},\t,adv_loss:{:.4f}'.format(
            i, classifier_loss.item(), adv_loss.item()))
        Cs_memory.detach_()
        Ct_memory.detach_()

    return best_acc, best_model
Esempio n. 4
0
def train_irm_feat(args):
    # prepare data
    dsets = {}
    dset_loaders = {}
    dsets["source"] = ImageList(open(args.source_list).readlines(), \
                                transform=image_train())
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=args.batch_size, \
                                        shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(args.save_path).readlines(),
                                transform=image_train(),
                                pseudo=True)
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, \
                                        shuffle=True, num_workers=4, drop_last=True)

    dsets["test"] = ImageList(open(args.target_list).readlines(), \
                              transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=2 * args.batch_size, \
                                      shuffle=False, num_workers=4)

    #model
    model = network.ResNet(class_num=args.num_class,
                           radius=args.radius_refine,
                           trainable_radius=args.trainable_radius).cuda()
    parameter_classifier = model.get_parameters()
    optimizer_classifier = torch.optim.SGD(parameter_classifier,
                                           lr=args.lr_refine,
                                           momentum=0.9,
                                           weight_decay=0.005)

    gpus = args.gpu_id.split(',')
    if len(gpus) > 1:
        model = nn.DataParallel(model, device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    best_acc = 0.0
    best_model = copy.deepcopy(model)

    for i in range(args.max_iter):
        if i % args.test_interval == args.test_interval - 1:
            model.train(False)
            temp_acc = image_classification_test(dset_loaders, model)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = copy.deepcopy(model)
            log_str = "\n iter: {:05d}, \t precision: {:.4f},\t best_acc:{:.4f}".format(
                i, temp_acc, best_acc)
            args.log_file.write(log_str)
            args.log_file.flush()
            print(log_str)
        if i % args.snapshot_interval == args.snapshot_interval - 1:
            if not os.path.exists('snapshot'):
                os.mkdir('snapshot')
            if not os.path.exists('snapshot/save'):
                os.mkdir('snapshot/save')
            torch.save(best_model, 'snapshot/save/best_model.pk')

        model.train(True)
        if (args.lr_decay_refine):
            optimizer_classifier = lr_schedule.inv_lr_scheduler(
                optimizer_classifier, i)

        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, pseudo_labels_target, weights = iter_target.next()
        inputs_source, labels_source = inputs_source.cuda(
        ), labels_source.cuda()
        inputs_target, pseudo_labels_target = inputs_target.cuda(
        ), pseudo_labels_target.cuda()
        weights = weights.type(torch.Tensor).cuda()

        scale_source = torch.tensor(1.).cuda().requires_grad_()
        scale_target = torch.tensor(1.).cuda().requires_grad_()

        features_source, outputs_source = model.forward_mul(
            inputs_source, scale_source)
        features_target, outputs_target = model.forward_mul(
            inputs_target, scale_target)

        features = torch.cat((features_source, features_target), dim=0)

        source_class_loss = nn.CrossEntropyLoss()(outputs_source,
                                                  labels_source)
        target_robust_loss = utils.robust_pseudo_loss(outputs_target,
                                                      pseudo_labels_target,
                                                      weights)
        classifier_loss = source_class_loss + target_robust_loss

        if args.baseline == 'MSTN':
            lam = network.calc_coeff(i, max_iter=2000)
        elif args.baseline == 'DANN':
            lam = 0.0
        # pseu_labels_target = torch.argmax(outputs_target, dim=1)
        # loss_sm, Cs_memory, Ct_memory = utils.SM(features_source, features_target, labels_source, pseu_labels_target,
        #                                         Cs_memory, Ct_memory)
        # feature_loss = classifier_loss + lam*loss_sm + lam*H

        # irm_loss = 0
        # if('MSTN' in args.irm_feature):
        #     if(i>args.irm_warmup_step):
        #         irm_loss += sum(penalty_loss_scales(feature_loss, [scale_source, scale_target]))
        # else:
        source_irm_loss = penalty_loss_scale(source_class_loss, scale_source)
        target_irm_loss = penalty_loss_scale(target_robust_loss, scale_target)
        irm_loss = (source_irm_loss + target_irm_loss)

        feature_loss = classifier_loss + args.irm_weight * irm_loss

        optimizer_classifier.zero_grad()
        classifier_loss.backward()
        optimizer_classifier.step()

        print(
            'step:{: d},\t,source_class_loss:{:.4f},\t,target_robust_loss:{:.4f}'
            ''.format(i, source_class_loss.item(), target_robust_loss.item()))

        #Cs_memory.detach_()
        #Ct_memory.detach_()

    return best_acc, best_model
Esempio n. 5
0
from utils import *
from graphics import Graphics
from state import State
from torch.autograd import Variable
import network
import torch

current = State(None, None)
net = network.ResNet()
focusMoves = []
focus = 0


def MCTS(root):
    global net

    if root.Expand():
        data = torch.FloatTensor(StateToImg(root.child[-1])).unsqueeze(0)
        delta = net(Variable(data)).data[0, 0]
        root.child[-1].v = delta
        delta *= -1
    else:
        best = root.BestChild()
        if best == None:
            delta = -1
        else:
            delta = -MCTS(best)
    root.v += delta
    root.n += 1
    return delta
Esempio n. 6
0
 def __initialize_model(self):
     self.model = network.ResNet()
     if self.cuda:
         self.model = self.model.cuda()
Esempio n. 7
0
def train(args):

    ## init logger
    logger = Logger(ckpt_path=args.ckpt_path, tsbd_path=args.vis_path)

    ## pre process
    train_transforms = prep.image_train(augmentation=args.augmentation)
    valid_transforms = prep.image_test()

    train_dset = ImageList(open(args.train_list).readlines(),
                           datadir='',
                           transform=train_transforms)
    valid_dset = ImageList(open(args.valid_list).readlines(),
                           datadir='',
                           transform=valid_transforms)

    train_loader = DataLoader(train_dset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=4,
                              drop_last=False)
    valid_loader = DataLoader(valid_dset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=4,
                              drop_last=False)

    ## set the model
    net = None
    if args.net == 'MyOwn':
        ## model C
        net = network.ResNet(network.Bottleneck, [3, 4, 6, 3],
                             weight_init=args.weight_init,
                             use_bottleneck=args.bottleneck,
                             num_classes=args.class_num,
                             weight=args.weight)
    else:
        ## model A -> Resnet50 == pretrained
        ## model B -> Resnet50 == not pretrained
        net = network.ResNetFc(resnet_name=args.net,
                               pretrained=args.pretrained,
                               weight_init=args.weight_init,
                               use_bottleneck=args.bottleneck,
                               new_cls=True,
                               class_num=args.class_num)

    net = net.cuda()
    parameter_list = net.get_parameters()

    ## set optimizer and learning scheduler
    if args.opt_type == 'SGD':
        optimizer = optim.SGD(parameter_list,
                              lr=1.0,
                              momentum=args.momentum,
                              weight_decay=0.0005,
                              nesterov=True)
    lr_param = {'lr': args.lr, "gamma": 0.001, "power": 0.75}
    lr_scheduler = lr_schedule.inv_lr_scheduler

    ## gpu
    gpus = args.gpu_id.split(',')
    if len(gpus) > 0:
        print('gpus: ', [int(i) for i in gpus])
        net = nn.DataParallel(net, device_ids=[int(i) for i in gpus])

    ## for save model
    model = {}
    model['net'] = net

    ## log
    logger.reset()

    ## progress bar
    total_epochs = 1000
    total_progress_bar = tqdm.tqdm(desc='Train iter',
                                   total=total_epochs * len(train_loader))

    ## begin train
    it = 0
    for epoch in range(total_epochs):
        for img, label, path in train_loader:
            ## update log
            it += 1
            logger.step(1)
            total_progress_bar.update(1)

            ## validate
            if it % args.test_interval == 1:
                ## validate
                acc = validate(model, valid_loader)

                ## utils
                logger.add_scalar('accuracy', acc * 100)
                logger.save_ckpt(state={'net': net.state_dict()},
                                 cur_metric_val=acc)
                log_str = "iter: {:05d}, precision: {:.5f}".format(it, acc)
                args.log.write(log_str + '\n')
                args.log.flush()

            ## train the model
            net.train(True)
            optimizer = lr_scheduler(optimizer, it, **lr_param)
            optimizer.zero_grad()

            ## cuda
            img = img.cuda()
            label = label.cuda()

            feature, output = net(img)

            loss = nn.CrossEntropyLoss()(output, label)

            loss.backward()
            optimizer.step()

            ## vis
            logger.add_scalar('loss', loss)
Esempio n. 8
0
def train_distill(teacher, args):
    # prepare data
    dsets = {}
    dset_loaders = {}
    dsets["source"] = ImageList(open(args.source_list).readlines(), \
                                transform=image_train())
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=args.batch_size, \
                                        shuffle=True, num_workers=2, drop_last=True)
    dsets["target"] = ImageList(open(args.target_list).readlines(), \
                                transform=image_train(),  params=args)
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, \
                                        shuffle=True, num_workers=2, drop_last=True)

    dsets["test"] = ImageList(open(args.target_list).readlines(), \
                              transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=2 * args.batch_size, \
                                      shuffle=False, num_workers=2)

    #model
    model = network.ResNet(class_num=args.num_class).cuda()
    adv_net = network.AdversarialNetwork(in_feature=model.output_num(),hidden_size=1024, max_iter=args.max_iter).cuda()
    parameter_list = model.get_parameters() + adv_net.get_parameters()
    optimizer = torch.optim.SGD(parameter_list,lr=args.lr,momentum=0.9,weight_decay=0.005)
    # model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

    gpus = args.gpu_id.split(',')
    if len(gpus) > 1:
        adv_net = nn.DataParallel(adv_net, device_ids=[int(i) for i in gpus])
        model = nn.DataParallel(model, device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    best_acc = 0.0
    best_model = copy.deepcopy(model)
    print_interval = (args.test_interval // 10)
    nt_cent = utils.NTXentLoss('cuda', args.batch_size, 0.2, True)
    
    Cs_memory = torch.zeros(args.num_class, 256).cuda()
    Ct_memory = torch.zeros(args.num_class, 256).cuda()

    max_batch = 100
    queue_size = args.batch_size * max_batch
    queue_data = [torch.randn(queue_size, 256).cuda(), torch.randn(queue_size, args.num_class).cuda()]
    queue_data_w = [torch.randn(queue_size, 256).cuda(), torch.randn(queue_size, args.num_class).cuda()]
    # queue_data = [torch.randn(queue_size, 256).cuda(), torch.randn(queue_size, 256).cuda()]

    queue_labels = [torch.ones(queue_size).cuda() * (args.num_class+1), torch.ones(queue_size).cuda() * (args.num_class+1)]
    queue_ptr = torch.zeros(1, dtype=torch.long)

    queue_weight = np.power(np.linspace(.0, 1.0, max_batch), 3)

    queue_weight = np.repeat(queue_weight, args.batch_size)


    best_ema_acc = 0.0
    for i in range(args.max_iter):
        if i % args.test_interval == args.test_interval - 1:
            model.train(False)
            temp_acc = image_classification_test(dset_loaders, model)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = copy.deepcopy(model)
            log_str = "\niter: {:05d}, \t precision: {:.4f},\t best_acc:{:.4f}".format(i, temp_acc, best_acc)
            args.log_file.write(log_str)
            args.log_file.flush()
            print(log_str)

            
            temp_acc = image_classification_test(dset_loaders, teacher)
            if temp_acc > best_ema_acc:
                best_ema_acc = temp_acc
                # best_model = copy.deepcopy(model)
            log_str = "\niter: {:05d}, \t precision: {:.4f},\t best_ema_acc:{:.4f}".format(i, temp_acc, best_ema_acc)
            args.log_file.write(log_str)
            args.log_file.flush()
            print(log_str)
        # if i % args.snapshot_interval == args.snapshot_interval -1:
        #     if not os.path.exists(args.save_dir):
        #         os.mkdir(args.save_dir)
        #     torch.save(best_model,os.path.join(args.save_dir, 'initial_model.pk'))

        model.train(True)
        adv_net.train(True)
        teacher.train(False)
        optimizer = lr_schedule.inv_lr_scheduler(optimizer,i)

        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, _, inputs_target_mosaic_w, inputs_target_mosaic_s, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        inputs_target_mosaic_w, inputs_target_mosaic_s =  inputs_target_mosaic_w.cuda(), inputs_target_mosaic_s.cuda()
        features_source, outputs_source = model(inputs_source)
        features_target, outputs_target = model(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        with torch.no_grad():
            features_target_teacher, outputs_target_teacher = teacher(inputs_target)

        adv_loss = utils.loss_adv(features,adv_net)

        H = torch.mean(utils.Entropy(F.softmax(outputs_target, dim=1)))

        if args.baseline == 'MSTN':
            lam = network.calc_coeff(i)
        elif args.baseline =='DANN':
            lam = 0.0
        prob_max, pseu_labels_target = torch.max(F.softmax(outputs_target, dim=1), dim=1)
        loss_sm, Cs_memory, Ct_memory = utils.SM(features_source, features_target, labels_source, pseu_labels_target,
                                                Cs_memory, Ct_memory)

        # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        classifier_loss = 4*utils.cross_entropy_with_logits(outputs_target / 4.0, F.softmax(outputs_target_teacher / 4.0, dim=1)) + nn.CrossEntropyLoss()(outputs_source, labels_source)
        total_loss = classifier_loss + lam * loss_sm + adv_loss + network.calc_coeff((i-100), high=0.1, max_iter=100)*H

        prob_max, pseu_labels_target = torch.max(F.softmax(outputs_target, dim=1), dim=1)
    
        optimizer.zero_grad()

        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        mosaic_loss_target = torch.zeros(1)

        if i < args.max_iter // 5 * 2:
            alpha = 0.0
        else:
            alpha = 0.5
        with _disable_tracking_bn_stats(model):
            mosaic_features_target_w, mosaic_outputs_target_w = model(inputs_target_mosaic_w)
            mosaic_features_target_s, mosaic_outputs_target_s = model(inputs_target_mosaic_s)
            with torch.no_grad():
                
                features_list_w = [mosaic_features_target_w, F.softmax(mosaic_outputs_target_w, dim=1)]

                features_target_, outputs_target_ = model(inputs_target)
                outputs_target = alpha * outputs_target_ + (1. - alpha) * outputs_target_teacher

                prob_max, pseu_labels_target = torch.max(F.softmax(outputs_target, dim=1), dim=1)
                features_list = [features_target_, F.softmax(outputs_target, dim=1)]
                labels_list = [pseu_labels_target, pseu_labels_target]

                utils.rightshift(queue_weight, args.batch_size)
                for j in range(len(features_list)):
                    queue_data[j][queue_ptr:queue_ptr+args.batch_size, :] = features_list[j]
                    queue_data_w[j][queue_ptr:queue_ptr+args.batch_size, :] = features_list_w[j]
                    queue_labels[j][queue_ptr:queue_ptr+args.batch_size] = labels_list[j]
                pre_ptr = int(queue_ptr)
                ptr = ((i+1) % max_batch) * args.batch_size
                queue_ptr[0] = ptr


            mosaic_loss_target = (nt_cent(queue_data[1].detach(), F.softmax(mosaic_outputs_target_w, dim=1), queue_labels[1], 
                pseu_labels_target.float(), queue_weight, pre_ptr, class_level=False) +
                                    1.*nt_cent(queue_data_w[1].detach(), F.softmax(mosaic_outputs_target_s, dim=1), queue_labels[1], 
                pseu_labels_target.float(), queue_weight, pre_ptr, class_level=False)) * network.calc_coeff(i, high=0.3, max_iter=50)
            mosaic_loss = mosaic_loss_target * 1.0

            # mosaic_loss = utils.cross_entropy_with_logits(mosaic_outputs_target, F.softmax(outputs_target*1.5, dim=1)) * (network.calc_coeff(i, high=0.5, max_iter=2000))
            # mosaic_loss += 0.4*(torch.abs(F.softmax(outputs_target, dim=1).detach() - F.softmax(mosaic_outputs_target, dim=1)).sum(1)).mean(0)

            mosaic_loss.backward()
            optimizer.step()


        if i % print_interval == 0:
            log_str = 'step:{: d},\t,class_loss:{:.4f},\t,adv_loss:{:.4f}\t,mosaic_loss:{:.4f}\t,mean_prob:{:.4f}'.format(i, classifier_loss.item(),
                                                        adv_loss.item(), mosaic_loss_target.item(),prob_max.mean().item())
            print(log_str)
            args.log_file.write('\n'+log_str)
            args.log_file.flush()

        Cs_memory.detach_()
        Ct_memory.detach_()

    return best_acc, best_model