Esempio n. 1
0
def main():

    global args
    best_prec1, best_epoch = 0.0, 0
    
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.data.startswith('cifar'):
        IM_SIZE = 32
    else:
        IM_SIZE = 224
    
    print(args.arch)    
    model = getattr(models, args.arch)(args)
    args.num_exits = len(model.classifier)
    global n_flops

    n_flops, n_params = measure_model(model, IM_SIZE, IM_SIZE)
    
    torch.save(n_flops, os.path.join(args.save, 'flops.pth'))
    del(model)
    
    print(args)
    with open('{}/args.txt'.format(args.save), 'w') as f:
        print(args, file=f)

    model = getattr(models, args.arch)(args)
    model = torch.nn.DataParallel(model.cuda())
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.resume:
        checkpoint = load_checkpoint(args)
        if checkpoint is not None:
            args.start_epoch = checkpoint['epoch'] + 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    train_loader, val_loader, test_loader = get_dataloaders(args)

    if args.evalmode is not None:
        state_dict = torch.load(args.evaluate_from)['state_dict']
        model.load_state_dict(state_dict)

        if args.evalmode == 'anytime':
            validate(test_loader, model, criterion)
        elif args.evalmode == 'dynamic':
            dynamic_evaluate(model, test_loader, val_loader, args)
        else:
            validate(test_loader, model, criterion)
            dynamic_evaluate(model, test_loader, val_loader, args)
        return

    scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_prec1'
              '\tval_prec1\ttrain_prec5\tval_prec5']

    for epoch in range(args.start_epoch, args.epochs):

        train_loss, train_prec1, train_prec5, lr = train(train_loader, model, criterion, optimizer, epoch)

        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)

        scores.append(('{}\t{:.3f}' + '\t{:.4f}' * 6)
                      .format(epoch, lr, train_loss, val_loss,
                              train_prec1, val_prec1, train_prec5, val_prec5))

        is_best = val_prec1 > best_prec1
        if is_best or (epoch == 299):
            best_prec1 = val_prec1
            best_epoch = epoch
            print('Best var_prec1 {}'.format(best_prec1))

        model_filename = 'checkpoint_%03d.pth.tar' % epoch
        save_checkpoint({
            'epoch': epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, args, is_best, model_filename, scores)

        model_path = '%s/save_models/checkpoint_%03d.pth.tar' % (args.save, epoch-1)
        if os.path.exists(model_path):
            os.remove(model_path)

    print('Best val_prec1: {:.4f} at epoch {}'.format(best_prec1, best_epoch))

    ### Test the final model
    print('********** Final prediction results **********')
    validate(test_loader, model, criterion)

    return
Esempio n. 2
0
def main(args):
    #######################################################################################
    ##   注释:
    ##   载入模型
    #######################################################################################
    best_prec1, best_epoch = 0.0, 0
    model = getattr(models, args.arch)(args)
    n_flops, n_params = measure_model(model, IM_SIZE, IM_SIZE)
    torch.save(n_flops, os.path.join(args.save, 'flops.pth'))
    del (model)
    model = getattr(models, args.arch)(args)

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    #######################################################################################
    ##   注释:
    ##   载入criterion
    #######################################################################################
    criterion = nn.CrossEntropyLoss().cuda()

    #######################################################################################
    ##   注释:
    ##   载入optimizer
    #######################################################################################
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    #######################################################################################
    ##   注释:
    ##   接 中断的训练
    #######################################################################################
    if args.resume:
        checkpoint = load_checkpoint(args)
        if checkpoint is not None:
            args.start_epoch = checkpoint['epoch'] + 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    #######################################################################################
    ##   注释:
    ##   导入数据集
    #######################################################################################
    train_loader, val_loader, test_loader = get_dataloaders(args)

    #######################################################################################
    ##   注释:
    ##   选择推理模式 imagenet--dynamic
    #######################################################################################
    if args.evalmode is not None:
        state_dict = torch.load(args.evaluate_from)['state_dict']
        model.load_state_dict(state_dict)

        if args.evalmode == 'anytime':
            validate(test_loader, model, criterion)
        else:
            dynamic_evaluate(model, test_loader, val_loader, args)
        return

    scores = [
        'epoch\tlr\ttrain_loss\tval_loss\ttrain_prec1'
        '\tval_prec1\ttrain_prec5\tval_prec5'
    ]

    for epoch in range(args.start_epoch, args.epochs):

        train_loss, train_prec1, train_prec5, lr = train(
            train_loader, model, criterion, optimizer, epoch)

        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)

        scores.append(
            ('{}\t{:.3f}' + '\t{:.4f}' * 6).format(epoch, lr, train_loss,
                                                   val_loss, train_prec1,
                                                   val_prec1, train_prec5,
                                                   val_prec5))

        is_best = val_prec1 > best_prec1
        if is_best:
            best_prec1 = val_prec1
            best_epoch = epoch
            print('Best var_prec1 {}'.format(best_prec1))

        model_filename = 'checkpoint_%03d.pth.tar' % epoch
        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, args, is_best, model_filename, scores)

    print('Best val_prec1: {:.4f} at epoch {}'.format(best_prec1, best_epoch))

    ### Test the final model

    print('********** Final prediction results **********')
    validate(test_loader, model, criterion)

    return
Esempio n. 3
0
def main():

    global args
    best_prec1, best_epoch = 0.0, 0

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.data.startswith('cifar'):
        IM_SIZE = 32
    else:
        IM_SIZE = 224

    model = getattr(models, args.arch)(args)
    # 根据模型结构计算flops params
    n_flops, n_params = measure_model(model, IM_SIZE, IM_SIZE)
    # 存储下模型的flops
    torch.save(n_flops, os.path.join(args.save, 'flops.pth'))
    del (model)

    model = getattr(models, args.arch)(args)

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # 断点继续
    if args.resume:
        checkpoint = load_checkpoint(args)
        if checkpoint is not None:
            args.start_epoch = checkpoint['epoch'] + 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    train_loader, val_loader, test_loader = get_dataloaders(args)

    # evalmode两种 分别为anytime和dynamic
    if args.evalmode is not None:
        # args.evaluate_from为模型存储路径
        state_dict = torch.load(args.evaluate_from)['state_dict']
        model.load_state_dict(state_dict)
        # 不同的处理方式
        if args.evalmode == 'anytime':
            validate(test_loader, model, criterion)
        else:
            dynamic_evaluate(model, test_loader, val_loader, args)
        return
    # 训练看这里
    scores = [
        'epoch\tlr\ttrain_loss\tval_loss\ttrain_prec1'
        '\tval_prec1\ttrain_prec5\tval_prec5'
    ]

    for epoch in range(args.start_epoch, args.epochs):

        train_loss, train_prec1, train_prec5, lr = train(
            train_loader, model, criterion, optimizer, epoch)

        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)

        scores.append(
            ('{}\t{:.3f}' + '\t{:.4f}' * 6).format(epoch, lr, train_loss,
                                                   val_loss, train_prec1,
                                                   val_prec1, train_prec5,
                                                   val_prec5))

        is_best = val_prec1 > best_prec1
        if is_best:
            best_prec1 = val_prec1
            best_epoch = epoch
            print('Best var_prec1 {}'.format(best_prec1))

        model_filename = 'checkpoint_%03d.pth.tar' % epoch
        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, args, is_best, model_filename, scores)

    print('Best val_prec1: {:.4f} at epoch {}'.format(best_prec1, best_epoch))

    ### Test the final model

    print('********** Final prediction results **********')
    validate(test_loader, model, criterion)

    return
Esempio n. 4
0
    args.scale_list = '1-2-3-4'

    args.reduction = 0.5

    args.grFactor = list(map(int, args.grFactor.split('-')))
    args.bnFactor = list(map(int, args.bnFactor.split('-')))
    args.scale_list = list(map(int, args.scale_list.split('-')))
    args.nScales = len(args.grFactor)
    # print(args.grFactor)
    if args.use_valid:
        args.splits = ['train', 'val', 'test']
    else:
        args.splits = ['train', 'val']

    if args.data == 'cifar10':
        args.num_classes = 10
    elif args.data == 'cifar100':
        args.num_classes = 100
    else:
        args.num_classes = 1000

    inp_c = torch.rand(16, 3, 224, 224)

    model = MSDNet(args)
    # output = model(inp_c)
    # oup = net_head(inp_c)
    # print(len(oup))

    n_flops, n_params = measure_model(model, 224, 224)
    # net = _BlockNormal(num_layers = 4, nIn = 64, growth_rate = 24, reduction_rate = 0.5, trans_down = True)
Esempio n. 5
0
def main():

    global args
    best_prec1, best_epoch = 0.0, 0

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.data.startswith('cifar'):
        IM_SIZE = 32
    else:
        IM_SIZE = 224

    model = getattr(models, args.arch)(args)
    global n_flops

    n_flops, n_params = measure_model(model, IM_SIZE, IM_SIZE)

    torch.save(n_flops, os.path.join(args.save, 'flops.pth'))
    del (model)

    with open('{}/args.txt'.format(args.save), 'w') as f:
        print(args, file=f)  #打开文件,写入

    pred_model = getattr(models, args.arch)(args).cuda()
    pred_model = torch.nn.DataParallel(pred_model)
    state_dict = torch.load(args.pretrained)['state_dict']
    pred_model.load_state_dict(state_dict)
    print("=> loaded pretrained checkpoint '{}'".format(args.pretrained))

    model = getattr(models, args.arch)(args).cuda()
    data = torch.randn(2, 3, IM_SIZE, IM_SIZE).cuda()
    model.eval()
    with torch.no_grad():
        _, feat = model(data)

    trainable_list = nn.ModuleList([])
    trainable_list.append(model)
    pad_reg = nn.ModuleList([
        AD(feat[j].size(1), feat[-1].size(1)).cuda()
        for j in range(args.nBlocks)
    ])
    trainable_list.append(pad_reg)

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(trainable_list.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
    else:
        model = torch.nn.DataParallel(model)

    if args.resume:
        checkpoint = load_checkpoint(args)
        if checkpoint is not None:
            args.start_epoch = checkpoint['epoch'] + 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    train_loader, val_loader, test_loader = get_dataloaders(args)

    if args.evalmode is not None:
        state_dict = torch.load(args.evaluate_from)['state_dict']
        model.load_state_dict(state_dict)

        if args.evalmode == 'anytime':
            validate(test_loader, model, criterion)
        else:
            dynamic_evaluate(model, test_loader, val_loader, args)
        return

    scores = [
        'epoch\tlr\ttrain_loss\tval_loss\ttrain_prec1'
        '\tval_prec1\ttrain_prec5\tval_prec5'
    ]

    for epoch in range(args.start_epoch, args.epochs):

        train_loss, train_prec1, train_prec5, lr = train(
            train_loader, model, pred_model, criterion, optimizer, epoch,
            pad_reg)

        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)

        scores.append(
            ('{}\t{:.3f}' + '\t{:.4f}' * 6).format(epoch, lr, train_loss,
                                                   val_loss, train_prec1,
                                                   val_prec1, train_prec5,
                                                   val_prec5))

        is_best = val_prec1 > best_prec1
        if is_best:
            best_prec1 = val_prec1
            best_epoch = epoch
            print('Best var_prec1 {}'.format(best_prec1))

        model_filename = 'checkpoint_%03d.pth.tar' % epoch
        save_checkpoint({              #save_checkpoint(state, args, is_best, filename, result)
            'epoch': epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, args, is_best, model_filename, scores)

        model_path = '%s/save_models/checkpoint_%03d.pth.tar' % (args.save,
                                                                 epoch - 1)
        if os.path.exists(model_path):
            os.remove(model_path)

    print('Best val_prec1: {:.4f} at epoch {}'.format(best_prec1, best_epoch))

    ### Test the final model

    print('********** {} Final prediction results **********'.format(
        time.strftime('%m%d_%H:%M:%S')))
    validate(test_loader, model, criterion)

    return
Esempio n. 6
0
def train(config):
    # set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    # prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"] * 2
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(
        data_config["source"]["list_path"]).readlines(),
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"],
                                        batch_size=int(train_bs / 2),
                                        shuffle=True,
                                        num_workers=4,
                                        drop_last=True)
    dsets["target"] = ImageList(open(
        data_config["target"]["list_path"]).readlines(),
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=int(train_bs / 2),
                                        shuffle=True,
                                        num_workers=4,
                                        drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [
                ImageList(open(data_config["test"]["list_path"]).readlines(),
                          transform=prep_dict["test"][i]) for i in range(10)
            ]
            dset_loaders["test"] = [
                DataLoader(dset,
                           batch_size=test_bs,
                           shuffle=False,
                           num_workers=4) for dset in dsets['test']
            ]
            dsets["validation"] = [
                ImageList(open(data_config["target"]["list_path"]).readlines(),
                          transform=prep_dict["test"][i]) for i in range(10)
            ]
            dset_loaders["validation"] = [
                DataLoader(dset,
                           batch_size=test_bs,
                           shuffle=False,
                           num_workers=4) for dset in dsets['validation']
            ]
    else:
        dsets["test"] = ImageList(open(
            data_config["test"]["list_path"]).readlines(),
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"],
                                          batch_size=test_bs,
                                          shuffle=False,
                                          num_workers=4)
        dsets["validation"] = ImageList(open(
            data_config["target"]["list_path"]).readlines(),
                                        transform=prep_dict["test"])
        dset_loaders["validation"] = DataLoader(dsets["validation"],
                                                batch_size=test_bs,
                                                shuffle=False,
                                                num_workers=4)

    prep_dict["detection"] = prep.image_test(**config["prep"]['params'])
    dsets["detection"] = ImageList(open(
        data_config["target"]["list_path"]).readlines(),
                                   transform=prep_dict["detection"])
    dset_loaders["detection"] = DataLoader(dsets["detection"],
                                           batch_size=test_bs,
                                           shuffle=False,
                                           num_workers=0)

    # set base network
    net_config = config["network"]

    base_network = msdnet.MSDNet(net_config)
    if net_config["pattern"] == "budget":
        IM_SIZE = 224
        n_flops, n_params = measure_model(base_network, IM_SIZE, IM_SIZE)
        torch.save(n_flops, os.path.join(config["output_path"], 'flops.pth'))
    base_network = base_network.cuda()
    state_dict = torch.load(net_config["preTrained"])['state_dict']
    state_dict_adapt = {}
    for key in state_dict.keys():
        if key[:17] == "module.classifier":
            pass
        else:
            state_dict_adapt[key[7:]] = state_dict[key]

    # set base_network
    base_network.load_state_dict(state_dict_adapt, strict=False)
    # set classifier
    classifier = network.GroupClassifiers(
        nblocks=base_network.get_nBlocks(),
        num_classes=config["network"]["params"]["class_num"],
        channel=base_network.output_num())
    classifier = classifier.cuda()

    # set adversrial
    ad_net = network.GroupAdversarialNetworks(
        nblocks=base_network.get_nBlocks(), channel=base_network.output_num())
    ad_net = ad_net.cuda()

    parameter_list = base_network.get_parameters() + classifier.get_parameters(
    ) + ad_net.get_parameters()

    # set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list,
                                         **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    # set self-training
    st_select = 0
    st_config = config["self-training"]
    has_pseudo = False

    # train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    source_iter = 0
    pseudo_iter = 0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            temp_acc = image_classification_test(
                dset_loaders,
                base_network,
                classifier,
                test_10crop=prep_config["test_10crop"])
            if config['dataset'] == 'visda':
                log_str = "iter: {:05d}".format(i)
                correct = [0.0 for i in range(base_network.get_nBlocks())]
                allsample = [0.0 for i in range(base_network.get_nBlocks())]
                meanclass = [0.0 for i in range(base_network.get_nBlocks())]
                for j in range(base_network.get_nBlocks()):
                    log_str += "\n {} classifier:".format(j)
                    for k in visda_classes:
                        correct[j] += temp_acc[j][k][0]
                        allsample[j] += temp_acc[j][k][1]
                        meanclass[
                            j] += 100. * temp_acc[j][k][0] / temp_acc[j][k][1]
                        log_str += '\t{}: [{}/{}] ({:.6f}%)'.format(
                            k, temp_acc[j][k][0], temp_acc[j][k][1],
                            100. * temp_acc[j][k][0] / temp_acc[j][k][1])
                log_str += "\nall: "
                for j in range(base_network.get_nBlocks()):
                    log_str += "{:02d}-pre:{:.05f}".format(
                        j, 100. * correct[j] / allsample[j])
                log_str += "\ncls: "
                for j in range(base_network.get_nBlocks()):
                    log_str += "{:02d}-pre:{:.05f}".format(
                        j, meanclass[j] /
                        config["network"]["params"]["class_num"])
                config["out_file"].write(log_str + "\n")
                config["out_file"].flush()
            else:
                log_str = "iter: {:05d}".format(i)
                for j in range(base_network.get_nBlocks()):
                    log_str += " {:02d}-pre:{:.05f}".format(j, temp_acc[j])
                config["out_file"].write(log_str + "\n")
            print(log_str)

        if (i + 1) % config["snapshot_interval"] == 0 and config["save_model"]:
            torch.save(
                base_network,
                osp.join(config["output_path"],
                         "iter_{:05d}_model.pth.tar".format(i)))

        # dynamic evaluation
        if (i + 1) % 3000 == 0 and net_config["pattern"] == "budget":
            torch.multiprocessing.set_sharing_strategy('file_system')
            dynamic_evaluate(base_network, classifier, dset_loaders["test"],
                             dset_loaders["validation"], config,
                             'target-validation@' + str(i))
            torch.multiprocessing.set_sharing_strategy('file_descriptor')

        if (source_iter % len_train_source == 0) and i >= st_config["start"]:
            st_select += 1
            has_pseudo = True
            pseudo_path, len_train_pseudo, correct = selfDetection(
                dset_loaders, base_network, classifier, config, i)
            log_str = "size: {} correct:{}".format(len_train_pseudo, correct)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)
            # set batch
            source_batchsize = int(
                int(train_bs / 2) * len_train_source /
                (len_train_source + len_train_pseudo))
            if source_batchsize == int(train_bs / 2):
                source_batchsize -= 1
            if source_batchsize < int(int(train_bs / 2) / 2):
                source_batchsize = int(int(train_bs / 2) / 2)
            pseudo_batchsize = int(train_bs / 2) - source_batchsize

            dsets["source"] = ImageList(open(
                data_config["source"]["list_path"]).readlines(),
                                        transform=prep_dict["source"])
            dset_loaders["source"] = DataLoader(dsets["source"],
                                                batch_size=source_batchsize,
                                                shuffle=True,
                                                num_workers=4,
                                                drop_last=True)
            dsets["target"] = ImageList(open(
                data_config["target"]["list_path"]).readlines(),
                                        transform=prep_dict["target"])
            dset_loaders["target"] = DataLoader(dsets["target"],
                                                batch_size=source_batchsize,
                                                shuffle=True,
                                                num_workers=4,
                                                drop_last=True)

            dsets["pseudo"] = ImageList(open(pseudo_path).readlines(),
                                        transform=prep_dict["target"])
            dset_loaders["pseudo"] = DataLoader(dsets["pseudo"],
                                                batch_size=pseudo_batchsize,
                                                shuffle=True,
                                                num_workers=4,
                                                drop_last=True)
            dsets["source_pseudo"] = ImageList(open(
                data_config["source"]["list_path"]).readlines(),
                                               transform=prep_dict["source"])
            dset_loaders["source_pseudo"] = DataLoader(
                dsets["source_pseudo"],
                batch_size=pseudo_batchsize,
                shuffle=True,
                num_workers=4,
                drop_last=True)

            len_train_source = len(dset_loaders["source"])
            len_train_target = len(dset_loaders["target"])
            len_train_pseudo = len(dset_loaders["pseudo"])
            len_train_source_pseudo = len(dset_loaders["source_pseudo"])

            source_iter = 0
            pseudo_iter = 0

            iter_source = iter(dset_loaders["source"])
            iter_target = iter(dset_loaders["target"])
            iter_pseudo = iter(dset_loaders["pseudo"])
            iter_source_pseudo = iter(dset_loaders["source_pseudo"])
            # set self-training oprimizer
            if st_config["is_lr"] and st_select == 0:
                param_lr = []
                optimizer = optimizer_config["type"](
                    parameter_list, **(st_config["optimizer"]["optim_params"]))
                for param_group in optimizer.param_groups:
                    param_lr.append(param_group["lr"])
                schedule_param = st_config["optimizer"]["lr_param"]
                lr_scheduler = lr_schedule.schedule_dict[st_config["optimizer"]
                                                         ["lr_type"]]
        loss_params = config["loss"]
        # train one iter
        base_network.train(True)
        classifier.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()

        if has_pseudo:
            if source_iter % len_train_source == 0:
                iter_source = iter(dset_loaders["source"])
            if source_iter % len_train_target == 0:
                iter_target = iter(dset_loaders["target"])
            if pseudo_iter % len_train_pseudo == 0:
                iter_pseudo = iter(dset_loaders["pseudo"])
            if pseudo_iter % len_train_source_pseudo == 0:
                iter_source_pseudo = iter(dset_loaders["source_pseudo"])

            source_iter += 1
            pseudo_iter += 1

            inputs_source, labels_source, _, _ = iter_source.next()
            inputs_target, _, _, _ = iter_target.next()

            inputs_pseudo, labels_pseudo, _, randcls = iter_pseudo.next()
            inputs_source_pseudo, _, _, _ = iter_source_pseudo.next()

            inputs_source = torch.cat((inputs_source, inputs_pseudo), dim=0)
            labels_source = torch.cat((labels_source, labels_pseudo), dim=0)

            inputs_target = torch.cat((inputs_target, inputs_source_pseudo),
                                      dim=0)
        else:
            if i % len_train_source == 0:
                iter_source = iter(dset_loaders["source"])
            if i % len_train_target == 0:
                iter_target = iter(dset_loaders["target"])

            inputs_source, labels_source, _, _ = iter_source.next()
            inputs_target, _, _, _ = iter_target.next()

        inputs_source, inputs_target, labels_source = inputs_source.cuda(), \
            inputs_target.cuda(), \
            labels_source.cuda(),

        features_source = base_network(inputs_source)
        outputs_source = classifier(features_source)
        domain_source = ad_net(features_source)
        features_target = base_network(inputs_target)
        domain_target = ad_net(features_target)

        classifier_loss = 0.0
        transfer_loss = 0.0
        if has_pseudo and st_config["is_weight"]:
            mean_pseudo = nn.Softmax(dim=1)(
                outputs_source[0][source_batchsize:].detach())
            for j in range(base_network.get_nBlocks()):
                if j != 0:
                    mean_pseudo += nn.Softmax(dim=1)(
                        outputs_source[j][source_batchsize:].detach())
            mean_pseudo = mean_pseudo / base_network.get_nBlocks()
        for j in range(base_network.get_nBlocks()):
            if has_pseudo and st_config["is_weight"]:
                source_mask = torch.FloatTensor([1.] * source_batchsize)
                pseudo_mask = [1 if k == j else 0 for k in randcls]
                pseudo_mask = torch.tensor(pseudo_mask).float()
                mask = torch.cat((source_mask, pseudo_mask), dim=0).cuda()
                classifier_loss += loss.Weighted_loss(outputs_source[j],
                                                      labels_source, mask)
            else:
                classifier_loss += nn.CrossEntropyLoss()(outputs_source[j],
                                                         labels_source)
            domain = torch.cat((domain_source[j], domain_target[j]), dim=0)
            batch_size = domain.size(0) // 2
            if has_pseudo:
                dc_target = torch.from_numpy(
                    np.array([[1]] * source_batchsize +
                             [[0]] * pseudo_batchsize +
                             [[0]] * source_batchsize +
                             [[1]] * pseudo_batchsize)).float().cuda()
            else:
                dc_target = torch.from_numpy(
                    np.array([[1]] * batch_size +
                             [[0]] * batch_size)).float().cuda()
            transfer_loss += nn.BCELoss()(domain, dc_target)

        total_loss = classifier_loss + loss_params["trade_off"] * transfer_loss
        total_loss.backward()
        optimizer.step()