Ejemplo n.º 1
0
def data_load(args): 
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_src = open(args.s_dset_path).readlines()

    txt_tar = []
    for i in range(len(args.t_dset_path)):
        tmp = open(args.t_dset_path[i]).readlines()
        txt_tar.extend(tmp)
    txt_test = txt_tar.copy()

    if args.trte == 'val':
        dsize = len(txt_src)
        tr_size = int(0.9*dsize)
        print(dsize, tr_size, dsize - tr_size)
        tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
    else:
        tr_txt = txt_src
        te_txt = txt_src   

    dsets['source_tr'] = ImageList(tr_txt, transform=image_train())
    dset_loaders['source_tr'] = DataLoader(dsets['source_tr'], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets['source_te'] = ImageList(te_txt, transform=image_test())
    dset_loaders['source_te'] = DataLoader(dsets['source_te'], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets['target'] = ImageList_idx(txt_tar, transform=image_train())
    dset_loaders['target'] = DataLoader(dsets['target'], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets['test'] = ImageList(txt_test, transform=image_test())
    dset_loaders['test'] = DataLoader(dsets['test'], batch_size=train_bs*2, shuffle=False, num_workers=args.worker, drop_last=False)

    return dset_loaders
Ejemplo n.º 2
0
def get_mnist(train):
    """Get MNIST dataset loader."""
    # image pre-processing
    pre_process = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=params.dataset_mean, std=params.dataset_std)
    ])

    if train:
        source_list = '/data/mnist/mnist_train.txt'
        mnist_data_loader = torch.utils.data.DataLoader(
            ImageList(open(source_list).readlines(),
                      transform=pre_process,
                      mode='L'),
            batch_size=params.batch_size,
            shuffle=True,
            num_workers=0,
            drop_last=True)
    else:
        test_list = '/data/mnist/mnist_test.txt'
        mnist_data_loader = torch.utils.data.DataLoader(
            ImageList(open(test_list).readlines(),
                      transform=pre_process,
                      mode='L'),
            batch_size=params.batch_size,
            shuffle=True,
            num_workers=0,
            drop_last=True)
    return mnist_data_loader
Ejemplo n.º 3
0
def data_load(args, txt_src, txt_tgt):
    train_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256, 256)),
        torchvision.transforms.RandomCrop(224),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    test_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256, 256)),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    dsets = {}
    dsets["source"] = ImageList(txt_src, transform=train_transform)
    dsets["target"] = ImageList_twice(txt_tgt, transform=[train_transform, train_transform])

    txt_test = open(args.test_dset_path).readlines()
    dsets["test"] = ImageList(txt_test, transform=test_transform)

    dset_loaders = {}
    dset_loaders["source"] = torch.utils.data.DataLoader(dsets["source"], batch_size=args.batch_size,
        shuffle=True, num_workers=args.worker, drop_last=True)
    dset_loaders["target"] = torch.utils.data.DataLoader(dsets["target"], batch_size=args.batch_size,
        shuffle=True, num_workers=args.worker, drop_last=True)
    dset_loaders["test"] = torch.utils.data.DataLoader(dsets["test"], batch_size=args.batch_size*3,
        shuffle=False, num_workers=args.worker, drop_last=False)

    return dset_loaders
def get_dsets_loader_imagenet(mode='test', data_batch_size=16):
    prep_dict = {}
    dsets = {}

    dset_loaders = {}
    config = {}

    prep_dict["train_set1"] = trans_train_resize_imagenet()
    prep_dict["database"] = trans_train_resize_imagenet()
    prep_dict["test"] = trans_train_resize_imagenet()

    config["data"] = {"database": {"list_path": "../data/imagenet/database.txt", "batch_size": data_batch_size}, \
                      "test": {"list_path": "../data/imagenet/test.txt", "batch_size": data_batch_size}}
    data_config = config["data"]

    dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                              transform=prep_dict["test"])
    dsets["database"] = ImageList(open(data_config["database"]["list_path"]).readlines(), \
                                  transform=prep_dict["database"])

    dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                                batch_size=data_config["test"]["batch_size"], \
                                                shuffle=False, num_workers=16)

    dset_loaders["database"] = util_data.DataLoader(dsets["database"], \
                                                    batch_size=data_config["database"]["batch_size"], \
                                                    shuffle=False, num_workers=16)
    return dsets, dset_loaders
Ejemplo n.º 5
0
def data_load(args):
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_src = open(args.s_dset_path).readlines()
    txt_tar = open(args.t_dset_path).readlines()
    txt_test = open(args.test_dset_path).readlines()

    if args.trte == "val":
        dsize = len(txt_src)
        tr_size = int(0.7*dsize)
        print(dsize, tr_size, dsize - tr_size)
        tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
    else:
        tr_txt = txt_src
        te_txt = txt_src

    dsets["source_tr"] = ImageList(tr_txt, transform=image_train())
    dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets["source_te"] = ImageList(te_txt, transform=image_test())
    dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets["target"] = ImageList_idx(txt_tar, transform=image_train())
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)

    dsets["test"] = ImageList(txt_test, transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False)

    return dset_loaders
Ejemplo n.º 6
0
def predict(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    if prep_config["test_10crop"]:
        prep_dict["database"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["database"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
               
    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["database"+str(i)] = ImageList(open(data_config["database"]["list_path"]).readlines(), \
                                transform=prep_dict["database"]["val"+str(i)])
            dset_loaders["database"+str(i)] = util_data.DataLoader(dsets["database"+str(i)], \
                                batch_size=data_config["database"]["batch_size"], \
                                shuffle=False, num_workers=4)
            dsets["test"+str(i)] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

    else:
        dsets["database"] = ImageList(open(data_config["database"]["list_path"]).readlines(), \
                                transform=prep_dict["database"])
        dset_loaders["database"] = util_data.DataLoader(dsets["database"], \
                                batch_size=data_config["database"]["batch_size"], \
                                shuffle=False, num_workers=4)
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)
    ## set base network
    base_network = torch.load(config["snapshot_path"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    database_codes, database_labels = code_predict(dset_loaders, base_network, "database", test_10crop=prep_config["test_10crop"], gpu=use_gpu)
    test_codes, test_labels = code_predict(dset_loaders, base_network, "test", test_10crop=prep_config["test_10crop"], gpu=use_gpu)

    return {"database_code":database_codes.numpy(), "database_labels":database_labels.numpy(), \
            "test_code":test_codes.numpy(), "test_labels":test_labels.numpy()}
Ejemplo n.º 7
0
def run(config,path):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = 34
    test_bs = 34
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)
    dset_loaders["test1"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=False, num_workers=1, drop_last=False)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=0) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=0)

    n_class = config["network"]["params"]["class_num"]
    # for p in [499,999,1499,1999,2499,2999]:
    #     PATH = path + '/'+ str(p) + '_model.pth.tar'
    #
    #     model = load_model(PATH)
    #     base_network = model.cuda()
    #     fun1(dset_loaders, base_network, n_class)

    PATH = path+'/2999_model.pth.tar'

    model = load_model(PATH)
    base_network = model.cuda()
    # homo_cl = train_homo_cl(dset_loaders, base_network)

    fun1(dset_loaders, base_network, n_class, dsets)
Ejemplo n.º 8
0
def data_load(args): 
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_src = open(args.s_dset_path).readlines()
    txt_test = open(args.test_dset_path).readlines()

    if not args.da == 'uda':
        label_map_s = {}
        for i in range(len(args.src_classes)):
            label_map_s[args.src_classes[i]] = i
        
        new_src = []
        for i in range(len(txt_src)):
            rec = txt_src[i]
            reci = rec.strip().split(' ')
            if int(reci[1]) in args.src_classes:
                line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'   
                new_src.append(line)
        txt_src = new_src.copy()

        new_tar = []
        for i in range(len(txt_test)):
            rec = txt_test[i]
            reci = rec.strip().split(' ')
            if int(reci[1]) in args.tar_classes:
                if int(reci[1]) in args.src_classes:
                    line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'   
                    new_tar.append(line)
                else:
                    line = reci[0] + ' ' + str(len(label_map_s)) + '\n'   
                    new_tar.append(line)
        txt_test = new_tar.copy()

    if args.trte == "val":
        dsize = len(txt_src)
        tr_size = int(0.9*dsize)
        # print(dsize, tr_size, dsize - tr_size)
        tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
    else:
        dsize = len(txt_src)
        tr_size = int(0.9*dsize)
        _, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
        tr_txt = txt_src

    dsets["source_tr"] = ImageList(tr_txt, transform=image_train())
    dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets["source_te"] = ImageList(te_txt, transform=image_test())
    dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets["test"] = ImageList(txt_test, transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=True, num_workers=args.worker, drop_last=False)

    return dset_loaders
Ejemplo n.º 9
0
def data_load(args):
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_src = open(args.s_dset_path).readlines()
    txt_tar = open(args.t_dset_path).readlines()
    txt_test = open(args.test_dset_path).readlines()

    count = np.zeros(args.class_num)
    tr_txt = []
    te_txt = []
    for i in range(len(txt_src)):
        line = txt_src[i]
        reci = line.strip().split(' ')
        if count[int(reci[1])] < 3:
            count[int(reci[1])] += 1
            te_txt.append(line)
        else:
            tr_txt.append(line)

    dsets["source_tr"] = ImageList(tr_txt, transform=image_train())
    dset_loaders["source_tr"] = DataLoader(dsets["source_tr"],
                                           batch_size=train_bs,
                                           shuffle=True,
                                           num_workers=args.worker,
                                           drop_last=False)
    dsets["source_te"] = ImageList(te_txt, transform=image_test())
    dset_loaders["source_te"] = DataLoader(dsets["source_te"],
                                           batch_size=train_bs,
                                           shuffle=True,
                                           num_workers=args.worker,
                                           drop_last=False)
    dsets["target"] = ImageList_idx(txt_tar, transform=image_train())
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=train_bs,
                                        shuffle=True,
                                        num_workers=args.worker,
                                        drop_last=False)
    dsets["target_te"] = ImageList(txt_tar, transform=image_test())
    dset_loaders["target_te"] = DataLoader(dsets["target_te"],
                                           batch_size=train_bs,
                                           shuffle=False,
                                           num_workers=args.worker,
                                           drop_last=False)
    dsets["test"] = ImageList(txt_test, transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"],
                                      batch_size=train_bs * 2,
                                      shuffle=False,
                                      num_workers=args.worker,
                                      drop_last=False)

    return dset_loaders
Ejemplo n.º 10
0
def get_train_dsets(config):
    # prepare dataset for hashnet training
    import os
    import torchvision.datasets as dset
    from data_list import ImageList
    ## set pre-process
    prep_dict = {}
    job_dataset = config["dataset"]
    prep_dict["train_set1"] = get_trans(job_dataset=job_dataset, mode='test')
    prep_dict["train_set2"] = get_trans(job_dataset=job_dataset, mode='test')

    dsets = {}
    data_config = config["data"]

    if 'cifar10' in config["dataset"]:
        # cifar10 for 32*32, cifar10resize for 224*224
        if 'cifar100' in job_dataset:
            root = '../data/cifar100'
            if not os.path.exists(root):
                os.mkdir(root)
            dsets['train_set1'] = dset.CIFAR100(root=root, train=True, transform=prep_dict["train_set1"], download=True)
            dsets['train_set2'] = dset.CIFAR100(root=root, train=True, transform=prep_dict["train_set2"], download=True)
        else:
            root = '../data/cifar10'
            if not os.path.exists(root):
                os.mkdir(root)
            dsets['train_set1'] = dset.CIFAR10(root=root, train=True, transform=prep_dict["train_set1"], download=True)
            dsets['train_set2'] = dset.CIFAR10(root=root, train=True, transform=prep_dict["train_set2"], download=True)
    elif config["dataset"] == 'mnist':
        root = '../data/mnist'
        if not os.path.exists(root):
            os.mkdir(root)
        dsets["train_set1"] = dset.MNIST(root=root, train=True, transform=prep_dict["train_set1"], download=True)
        dsets["train_set2"] = dset.MNIST(root=root, train=True, transform=prep_dict["train_set2"], download=True)
    elif config["dataset"] == 'fashion_mnist':
        root = '../data/fashion_mnist'
        if not os.path.exists(root):
            os.mkdir(root)
        dsets["train_set1"] = dset.FashionMNIST(root=root, train=True, transform=prep_dict["train_set1"],
                                                download=True)
        dsets["train_set2"] = dset.FashionMNIST(root=root, train=True, transform=prep_dict["train_set2"],
                                                download=True)
    else:
        dsets['train_set1'] = ImageList(open(data_config["train_set1"]["list_path"]).readlines(), \
                                        transform=prep_dict["train_set1"])
        dsets['train_set2'] = ImageList(open(data_config["train_set2"]["list_path"]).readlines(), \
                                        transform=prep_dict["train_set2"])
    return dsets
Ejemplo n.º 11
0
def data_load(args):
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_tar = open(args.t_dset_path).readlines()
    txt_test = open(args.test_dset_path).readlines()

    dsets["target"] = ImageList(txt_tar, transform=image_train(), cfg=args, balance_sample=False)
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker,
                                        drop_last=False)
    dsets["test"] = ImageList(txt_test, transform=image_test(), cfg=args, balance_sample=False)
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs * 3, shuffle=False, num_workers=args.worker,
                                      drop_last=False)

    return dset_loaders, txt_tar
Ejemplo n.º 12
0
def get_svhn(train):

    # image pre-processing
    # pre_process = transforms.Compose([transforms.ToTensor(),
    #                                   transforms.Normalize(
    #                                       mean=params.dataset_mean,
    #                                       std=params.dataset_std)])

    # # dataset and data loader
    # svhn_dataset = svhn(root=params.data_root,
    #                     train=train,
    #                     transform=pre_process,
    #                     download=True)

    # svhn_data_loader = torch.utils.data.DataLoader(
    #     dataset=svhn_dataset,
    #     batch_size=params.batch_size,
    #     shuffle=True)

    # return svhn_data_loader

    pre_process = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=params.dataset_mean, std=params.dataset_std)
    ])

    if train:
        source_list = '/data/svhn/svhn_train.txt'
        svhn_data_loader = torch.utils.data.DataLoader(
            ImageList(open(source_list).readlines(),
                      transform=pre_process,
                      mode='L'),
            batch_size=params.batch_size,
            shuffle=True,
            num_workers=0,
            drop_last=True)
    else:
        test_list = '/data/svhn/svhn_test.txt'
        svhn_data_loader = torch.utils.data.DataLoader(
            ImageList(open(test_list).readlines(),
                      transform=pre_process,
                      mode='L'),
            batch_size=params.batch_size,
            shuffle=True,
            num_workers=0,
            drop_last=True)
    return svhn_data_loader
Ejemplo n.º 13
0
def get_label_list(target_list, predict_network_path, feature_network_path,
                   class_num, resize_size, crop_size, batch_size, use_gpu,
                   opt):
    # done with debugging, works fine
    """
    Return the target list with pesudolabel
    :param target_list: list conatinging all target file path and a wrong label
    :param predict_network: network to perdict label for target image
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    label_list = []
    # net_config = predict_network_name
    # predict_network = net_config["name"](**net_config["params"])
    # if use_gpu:
    #     predict_network = predict_network.cuda()
    netF = models._netF(opt)
    netC = models._netC(opt, class_num)
    netF.load_state_dict(torch.load(feature_network_path))
    netC.load_state_dict(torch.load(predict_network_path))
    if use_gpu:
        netF.cuda()
        netC.cuda()

    mean = np.array([0.44, 0.44, 0.44])
    std = np.array([0.19, 0.19, 0.19])
    transform_target = transforms.Compose([
        transforms.Resize(resize_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    dsets_tar = ImageList(target_list, transform=transform_target)
    dset_loaders_tar = util_data.DataLoader(dsets_tar,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)
    len_train_target = len(dset_loaders_tar)
    iter_target = iter(dset_loaders_tar)
    count = 0
    for i in range(len_train_target):
        input_tar, label_tar = iter_target.next()
        if use_gpu:
            input_tar, label_tar = Variable(input_tar).cuda(), Variable(
                label_tar).cuda()
        else:
            input_tar, label_tar = Variable(input_tar), Variable(label_tar)
        predict_score = netC(netF(input_tar))
        # _, predict_score = predict_network(input_tar)
        _, predict_label = torch.max(predict_score, 1)
        for num in range(len(predict_label.cpu())):
            label_list.append(target_list[count][:-2])
            label_list[count] = label_list[count] + str(
                predict_label[num].cpu().numpy()) + "\n"
            count += 1
    return label_list
Ejemplo n.º 14
0
def data_load(args):
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_src = open(args.s_dset_path).readlines()
    txt_test = open(args.test_dset_path).readlines()

    if args.trte == "val":
        dsize = len(txt_src)
        tr_size = int(0.9 * dsize)
        tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])

    elif args.trte == "stratified":
        cls_dict = getClassDict(txt_src)
        val_sample_cls_dict = getSampleDict(cls_dict, 0.1)
        te_txt = []
        for k in val_sample_cls_dict.keys():
            te_txt.extend(val_sample_cls_dict[k])
        tr_txt = list(set(txt_src) - set(te_txt))

    else:
        dsize = len(txt_src)
        tr_size = int(0.8 * dsize)
        _, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
        tr_txt = txt_src

    # training set
    if args.source_balanced:
        # balanced sampler of source train
        dsets["source_tr"] = ImageList(tr_txt, transform=image_train(), cfg=args, balance_sample=True)
    else:
        dsets["source_tr"] = ImageList(tr_txt, transform=image_train(), cfg=args, balance_sample=False)
    dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True,
                                           num_workers=args.worker, drop_last=False)
    # validation set
    dsets["source_te"] = ImageList(te_txt, transform=image_test(), cfg=args, balance_sample=False)
    dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True,
                                           num_workers=args.worker, drop_last=False)
    # test set
    dsets["test"] = ImageList(txt_test, transform=image_test(), cfg=args, balance_sample=False)
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs * 2, shuffle=True, num_workers=args.worker,
                                      drop_last=False)

    return dset_loaders
Ejemplo n.º 15
0
def run(config, model):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = 34
    test_bs = 34
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=False)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=False)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=0) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=0)

    n_class = config["network"]["params"]["class_num"]

    model.train(False)
    temp_acc = image_classification_test(dset_loaders, model, n_class=n_class)
    log_str = "precision: {:.5f}".format(temp_acc)
    print(log_str)
Ejemplo n.º 16
0
def get_label_list(args, target_list, feature_network_path,
                   predict_network_path, num_layer, resize_size, crop_size,
                   batch_size, use_gpu):
    """
    Return the target list with pesudolabel
    :param target_list: list conatinging all target file path and a wrong label
    :param predict_network: network to perdict label for target image
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    option = 'resnet' + args.resnet
    G = ResBase(option)
    F1 = ResClassifier(num_layer=num_layer)

    G.load_state_dict(torch.load(feature_network_path))
    F1.load_state_dict(torch.load(predict_network_path))
    if use_gpu:
        G.cuda()
        F1.cuda()
    G.eval()
    F1.eval()

    label_list = []
    dsets_tar = ImageList(target_list,
                          transform=prep.image_train(resize_size=resize_size,
                                                     crop_size=crop_size))
    dset_loaders_tar = util_data.DataLoader(dsets_tar,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=4)
    len_train_target = len(dset_loaders_tar)
    iter_target = iter(dset_loaders_tar)
    count = 0
    for i in range(len_train_target):
        input_tar, label_tar = iter_target.next()
        if use_gpu:
            input_tar, label_tar = Variable(input_tar).cuda(), Variable(
                label_tar).cuda()
        else:
            input_tar, label_tar = Variable(input_tar), Variable(label_tar)
        tar_feature = G(input_tar)
        predict_score = F1(tar_feature)
        _, pre_lab = torch.max(predict_score, 1)
        predict_label = pre_lab.detach()
        for num in range(len(predict_label.cpu())):
            if target_list[count][-3] == ' ':
                ind = -2
            else:
                ind = -3
            label_list.append(target_list[count][:ind])
            label_list[count] = label_list[count] + str(
                predict_label[num].cpu().numpy()) + "\n"
            count += 1
    return label_list
Ejemplo n.º 17
0
def get_dsets(job_dataset):
    # get the dsets for extract code
    # or can be used to softmax training with the dsets['train']

    import torchvision.datasets as dset
    dsets = {}
    trans_test = get_trans(job_dataset=job_dataset, mode='test')
    trans_train = get_trans(job_dataset=job_dataset, mode='train')
    if job_dataset == 'mnist':
        root = '../data/mnist'

        dsets['test'] = dset.MNIST(root=root, train=False, transform=trans_test, download=True)
        dsets['database'] = dset.MNIST(root=root, train=True, transform=trans_test, download=True)
        dsets['train'] = dset.MNIST(root=root, train=True, transform=trans_train, download=True)
    if 'cifar10' in job_dataset:
        if 'cifar100' in job_dataset:
            root = '../data/cifar100'
            dsets['test'] = dset.CIFAR100(root=root, train=False, transform=trans_test, download=True)
            dsets['database'] = dset.CIFAR100(root=root, train=True, transform=trans_test, download=True)
            dsets['train'] = dset.CIFAR100(root=root, train=True, transform=trans_train, download=True)
        else:
            root = '../data/cifar10'
            dsets['test'] = dset.CIFAR10(root=root, train=False, transform=trans_test, download=True)
            dsets['database'] = dset.CIFAR10(root=root, train=True, transform=trans_test, download=True)
            dsets['train'] = dset.CIFAR10(root=root, train=True, transform=trans_train, download=True)

    if job_dataset == 'fashion_mnist':
        root = '../data/fashion_mnist'
        dsets['test'] = dset.FashionMNIST(root=root, train=False, transform=trans_test, download=True)
        dsets['database'] = dset.FashionMNIST(root=root, train=True, transform=trans_test, download=True)
        dsets['train'] = dset.FashionMNIST(root=root, train=True, transform=trans_train, download=True)

    if 'imagenet' in job_dataset or 'nus_wide' in job_dataset or 'places365' in job_dataset:
        # load data file from path
        from publicVariables import data_list_path
        dsets["test"] = ImageList(open(data_list_path[job_dataset]["test"]).readlines(), \
                                  transform=trans_test)
        dsets["database"] = ImageList(open(data_list_path[job_dataset]["database"]).readlines(), \
                                      transform=trans_test)
        dsets["train"] = ImageList(open(data_list_path[job_dataset]["database"]).readlines(), \
                                      transform=trans_train)

    return dsets
Ejemplo n.º 18
0
def data_load(args):
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_tar = open(args.t_dset_path).readlines()
    txt_test = open(args.test_dset_path).readlines()

    if not args.da == 'uda':
        label_map_s = {}
        for i in range(len(args.src_classes)):
            label_map_s[args.src_classes[i]] = i

        new_tar = []
        for i in range(len(txt_tar)):
            rec = txt_tar[i]
            reci = rec.strip().split(' ')
            if int(reci[1]) in args.tar_classes:
                if int(reci[1]) in args.src_classes:
                    line = reci[0] + ' ' + str(label_map_s[int(
                        reci[1])]) + '\n'
                    new_tar.append(line)
                else:
                    line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
                    new_tar.append(line)
        txt_tar = new_tar.copy()
        txt_test = txt_tar.copy()

    dsets["target"] = ImageList(txt_tar, transform=image_test())
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=train_bs,
                                        shuffle=True,
                                        num_workers=args.worker,
                                        drop_last=False)
    dsets["test"] = ImageList(txt_test, transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"],
                                      batch_size=train_bs * 3,
                                      shuffle=False,
                                      num_workers=args.worker,
                                      drop_last=False)

    return dset_loaders
Ejemplo n.º 19
0
def save(config, model, save_name):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = 34
    test_bs = 34
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=False)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=False)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=0) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=0)

    n_class = config["network"]["params"]["class_num"]

    model.train(False)
    save_feature(dset_loaders, model, './snapshot/model/', save_name)
Ejemplo n.º 20
0
def load_ori_dsets(job_dataset):
    # Load the original dsets that have no transforms
    dsets = {}

    if job_dataset == 'mnist':
        import torchvision.datasets as dset
        root = '../../../../data/mnist'
        dsets["test"] = dset.MNIST(root=root, train=False, download=True)
        dsets["database"] = dset.MNIST(root=root, train=True, download=True)
        dsets["train"] = dsets["database"]
    if job_dataset == 'cifar10':
        import torchvision.datasets as dset
        root = '../../../../data/cifar10/pytorch_path'
        dsets['test'] = dset.CIFAR10(root=root, train=False, download=True)
        dsets['database'] = dset.CIFAR10(root=root, train=True, download=True)
        dsets["train"] = dsets["database"]
    if job_dataset == 'fashion_mnist':
        import torchvision.datasets as dset
        root = '../../../../data/mnist'
        dsets["test"] = dset.FashionMNIST(root=root,
                                          train=False,
                                          download=True)
        dsets["database"] = dset.FashionMNIST(root=root,
                                              train=True,
                                              download=True)
        dsets["train"] = dsets["database"]
    if 'imagenet' in job_dataset:
        dsets["train"] = ImageList(
            open('../data/imagenet/train.txt').readlines())
        dsets["test"] = ImageList(
            open('../data/imagenet/test.txt').readlines())
        dsets["database"] = ImageList(
            open('../data/imagenet/database.txt').readlines())
    if job_dataset == 'places365':
        import torchvision.datasets as dset
        dsets["database"] = ImageList(
            open("../data/places365_standard/database.txt").readlines())
        dsets['test'] = ImageList(
            open("../data/places365_standard/val.txt").readlines())

    return dsets
Ejemplo n.º 21
0
def get_dataloader(image_dir):

    img_size = 299
    mean = [0.5] * 3
    std = [0.5] * 3
    normalize_transform = Normalize(mean, std)

    infer = open(os.path.join(image_dir, "infer_dir.txt")).readlines()
    inferset = ImageList(infer, normalize_transform, img_size)

    dataloader = DataLoader(inferset, batch_size=1)

    return dataloader
Ejemplo n.º 22
0
    def get_dataloader(args):
        image_dir = args.image_dir

        train = open(os.path.join(image_dir, "train_dir.txt")).readlines()
        val = open(os.path.join(image_dir, "val_dir.txt")).readlines()
        test = open(os.path.join(image_dir, "test_dir.txt")).readlines()

        img_size = 299
        mean = [0.5] * 3
        std = [0.5] * 3
        normalize_transform = Normalize(mean, std)

        trainset = ImageList(train, normalize_transform, img_size)
        valset = ImageList(val, normalize_transform, img_size)
        testset = ImageList(test, normalize_transform, img_size)

        dataloader = dict()
        for split in ["train", "val", "test"]:
            dataloader[split] = DataLoader(eval(split + "set"),
                                           batch_size=1,
                                           shuffle=True)

        return dataloader
Ejemplo n.º 23
0
def perform_operation(file_path):
    torch.no_grad()
    e_net.eval()
    a_net.eval()
    s_net.eval()
    fusion.eval()

    imDataset = ImageList(crop_size=args.IM_SIZE, path=file_path, img_path=args.img_path, NUM_CLASS=args.NUM_CLASS,
              phase='test', transform=prep.image_test(crop_size=args.IM_SIZE),
              target_transform=prep.land_transform(img_size=args.IM_SIZE))
    imDataLoader = torch.utils.data.DataLoader(imDataset, batch_size=args.Test_BATCH, num_workers=0)

    pbar = tqdm(total=len(imDataLoader))
    for batch_Idx, data in enumerate(imDataLoader):

        datablob, datalb, pos_para = data
        datablob = torch.autograd.Variable(datablob).cuda()
        y_lb = torch.autograd.Variable(datalb).view(datalb.size(0), -1).cuda()
        pos_para = torch.autograd.Variable(pos_para).cuda()

        pred_global = e_net(datablob)
        feat_data = e_net.predict_BN(datablob)
        pred_att_map, pred_conf = a_net(feat_data)
        slice_feat_data = prep_model_input(pred_att_map, pos_para)
        pred_local = s_net(slice_feat_data)
        cls_pred = fusion(pred_global + pred_local)

        cls_pred = cls_pred.data.cpu().float()
        y_lb = y_lb.data.cpu().float()

        if batch_Idx == 0:
            all_output = cls_pred
            all_label = y_lb
        else:
            all_output = torch.cat((all_output, cls_pred), 0)
            all_label = torch.cat((all_label, y_lb), 0)
        pbar.update()

    pbar.close()
    all_acc_scr = get_acc(all_output, all_label)
    all_f1_score = get_f1(all_output, all_label)

    print('f1 score: ', str(all_f1_score.numpy().tolist()))
    print('average f1 score: ', str(all_f1_score.mean().numpy().tolist()))
    print('acc score: ', str(all_acc_scr.numpy().tolist()))
    print('average acc score: ', str(all_acc_scr.mean().numpy().tolist()))
Ejemplo n.º 24
0
def data_load(args): 
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_src = open(args.s_dset_path).readlines()

    txt_tar = []
    for i in range(len(args.t_dset_path)):
        tmp = open(args.t_dset_path[i]).readlines()
        txt_tar.extend(tmp)
    txt_test = txt_tar.copy()

    dsets["target"] = ImageList_idx(txt_tar, transform=image_train())
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets["test"] = ImageList(txt_test, transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=False, num_workers=args.worker, drop_last=False)

    return dset_loaders
Ejemplo n.º 25
0
def get_label_list(target_list, predict_network_name, resize_size, crop_size,
                   batch_size, use_gpu):
    # done with debugging, works fine
    """
    Return the target list with pesudolabel
    :param target_list: list conatinging all target file path and a wrong label
    :param predict_network: network to perdict label for target image
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    label_list = []
    net_config = predict_network_name
    predict_network = net_config["name"](**net_config["params"])
    if use_gpu:
        predict_network = predict_network.cuda()

    dsets_tar = ImageList(target_list,
                          transform=prep.image_train(resize_size=resize_size,
                                                     crop_size=crop_size))
    dset_loaders_tar = util_data.DataLoader(dsets_tar,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)
    len_train_target = len(dset_loaders_tar)
    iter_target = iter(dset_loaders_tar)
    count = 0
    for i in range(len_train_target):
        input_tar, label_tar = iter_target.next()
        if use_gpu:
            input_tar, label_tar = Variable(input_tar).cuda(), Variable(
                label_tar).cuda()
        else:
            input_tar, label_tar = Variable(input_tar), Variable(label_tar)
        _, predict_score = predict_network(input_tar)
        _, predict_label = torch.max(predict_score, 1)
        for num in range(len(predict_label.cpu())):
            label_list.append(target_list[count][:-2])
            label_list[count] = label_list[count] + str(
                predict_label[num].cpu().numpy()) + "\n"
            count += 1
    return label_list
Ejemplo n.º 26
0
Archivo: dev.py Proyecto: iriszc/CDAN
def get_label_list(target_list, predict_network, resize_size, crop_size,
                   batch_size, use_gpu):
    """
    Return the target list with pesudolabel
    :param target_list: list conatinging all target file path and a wrong label
    :param predict_network: network to perdict label for target image
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    label_list = []
    dsets_tar = ImageList(target_list,
                          transform=prep.image_train(resize_size=resize_size,
                                                     crop_size=crop_size))
    dset_loaders_tar = util_data.DataLoader(dsets_tar,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=4)
    len_train_target = len(dset_loaders_tar)
    iter_target = iter(dset_loaders_tar)
    count = 0
    for i in range(len_train_target):
        input_tar, label_tar = iter_target.next()
        if use_gpu:
            input_tar, label_tar = Variable(input_tar).cuda(), Variable(
                label_tar).cuda()
        else:
            input_tar, label_tar = Variable(input_tar), Variable(label_tar)
        _, predict_score = predict_network(input_tar)
        _, predict_label = torch.max(predict_score, 1)
        for num in range(len(predict_label.cpu())):
            if target_list[count][-3] == ' ':
                ind = -2
            else:
                ind = -3
            label_list.append(target_list[count][:ind])
            label_list[count] = label_list[count] + str(
                predict_label[num].cpu().numpy()) + "\n"
            print(label_list[count])
            count += 1
    return label_list
Ejemplo n.º 27
0
def get_label_list(target_list, predict_network, resize_size, crop_size, batch_size):
    """
    Return the target list with pesudolabel
    :param target_list: list conatinging all target file path and a wrong label
    :param predict_network: network to perdict label for target image
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    label_list = []
    dsets_tar = ImageList(target_list, transform=prep.image_train(resize_size=resize_size, crop_size=crop_size))
    dset_loaders_tar = util_data.DataLoader(dsets_tar, batch_size=batch_size, shuffle=True, num_workers=4)
    len_train_target = len(dset_loaders_tar)
    iter_target = iter(dset_loaders_tar)
    for i in range(len_train_target):
        input_tar, label_tar = iter_target.next()
        predict_score = predict_network(input_tar)[1]
        label = np.argsort(-predict_score)[0]
        label_list.append(target_list[i][:-2])
        label_list[i] = label_list[i] + str(label) + "\n"
    return label_list
Ejemplo n.º 28
0
def main(config):
    ## set loss criterion
    use_gpu = torch.cuda.is_available()
    au_weight = torch.from_numpy(np.loadtxt(config.train_path_prefix + '_weight.txt'))
    if use_gpu:
        au_weight = au_weight.float().cuda()
    else:
        au_weight = au_weight.float()

    ## prepare data
    dsets = {}
    dset_loaders = {}

    dsets['train'] = ImageList(crop_size=config.crop_size, path=config.train_path_prefix,
                                       transform=prep.image_train(crop_size=config.crop_size),
                                       target_transform=prep.land_transform(img_size=config.crop_size,
                                                                            flip_reflect=np.loadtxt(
                                                                                config.flip_reflect)))

    dset_loaders['train'] = util_data.DataLoader(dsets['train'], batch_size=config.train_batch_size,
                                                 shuffle=True, num_workers=config.num_workers)

    dsets['test'] = ImageList(crop_size=config.crop_size, path=config.test_path_prefix, phase='test',
                                       transform=prep.image_test(crop_size=config.crop_size),
                                       target_transform=prep.land_transform(img_size=config.crop_size,
                                                                            flip_reflect=np.loadtxt(
                                                                                config.flip_reflect))
                                       )

    dset_loaders['test'] = util_data.DataLoader(dsets['test'], batch_size=config.eval_batch_size,
                                                shuffle=False, num_workers=config.num_workers)

    ## set network modules
    region_learning = network.network_dict[config.region_learning](input_dim=3, unit_dim = config.unit_dim)
    align_net = network.network_dict[config.align_net](crop_size=config.crop_size, map_size=config.map_size,
                                                           au_num=config.au_num, land_num=config.land_num,
                                                           input_dim=config.unit_dim*8, fill_coeff=config.fill_coeff)
    local_attention_refine = network.network_dict[config.local_attention_refine](au_num=config.au_num, unit_dim=config.unit_dim)
    local_au_net = network.network_dict[config.local_au_net](au_num=config.au_num, input_dim=config.unit_dim*8,
                                                                                     unit_dim=config.unit_dim)
    global_au_feat = network.network_dict[config.global_au_feat](input_dim=config.unit_dim*8,
                                                                                     unit_dim=config.unit_dim)
    au_net = network.network_dict[config.au_net](au_num=config.au_num, input_dim = 12000, unit_dim = config.unit_dim)


    if config.start_epoch > 0:
        print('resuming model from epoch %d' %(config.start_epoch))
        region_learning.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/region_learning_' + str(config.start_epoch) + '.pth'))
        align_net.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/align_net_' + str(config.start_epoch) + '.pth'))
        local_attention_refine.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/local_attention_refine_' + str(config.start_epoch) + '.pth'))
        local_au_net.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/local_au_net_' + str(config.start_epoch) + '.pth'))
        global_au_feat.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/global_au_feat_' + str(config.start_epoch) + '.pth'))
        au_net.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/au_net_' + str(config.start_epoch) + '.pth'))

    if use_gpu:
        region_learning = region_learning.cuda()
        align_net = align_net.cuda()
        local_attention_refine = local_attention_refine.cuda()
        local_au_net = local_au_net.cuda()
        global_au_feat = global_au_feat.cuda()
        au_net = au_net.cuda()

    print(region_learning)
    print(align_net)
    print(local_attention_refine)
    print(local_au_net)
    print(global_au_feat)
    print(au_net)

    ## collect parameters
    region_learning_parameter_list = [{'params': filter(lambda p: p.requires_grad, region_learning.parameters()), 'lr': 1}]
    align_net_parameter_list = [
        {'params': filter(lambda p: p.requires_grad, align_net.parameters()), 'lr': 1}]
    local_attention_refine_parameter_list = [
        {'params': filter(lambda p: p.requires_grad, local_attention_refine.parameters()), 'lr': 1}]
    local_au_net_parameter_list = [
        {'params': filter(lambda p: p.requires_grad, local_au_net.parameters()), 'lr': 1}]
    global_au_feat_parameter_list = [
        {'params': filter(lambda p: p.requires_grad, global_au_feat.parameters()), 'lr': 1}]
    au_net_parameter_list = [
        {'params': filter(lambda p: p.requires_grad, au_net.parameters()), 'lr': 1}]

    ## set optimizer
    optimizer = optim_dict[config.optimizer_type](itertools.chain(region_learning_parameter_list, align_net_parameter_list,
                                                                  local_attention_refine_parameter_list,
                                                                  local_au_net_parameter_list,
                                                                  global_au_feat_parameter_list,
                                                                  au_net_parameter_list),
                                                  lr=1.0, momentum=config.momentum, weight_decay=config.weight_decay,
                                                  nesterov=config.use_nesterov)
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group['lr'])

    lr_scheduler = lr_schedule.schedule_dict[config.lr_type]

    if not os.path.exists(config.write_path_prefix + config.run_name):
        os.makedirs(config.write_path_prefix + config.run_name)
    if not os.path.exists(config.write_res_prefix + config.run_name):
        os.makedirs(config.write_res_prefix + config.run_name)

    res_file = open(
        config.write_res_prefix + config.run_name + '/AU_pred_' + str(config.start_epoch) + '.txt', 'w')

    ## train
    count = 0

    for epoch in range(config.start_epoch, config.n_epochs + 1):
        if epoch > config.start_epoch:
            print('taking snapshot ...')
            torch.save(region_learning.state_dict(),
                       config.write_path_prefix + config.run_name + '/region_learning_' + str(epoch) + '.pth')
            torch.save(align_net.state_dict(),
                       config.write_path_prefix + config.run_name + '/align_net_' + str(epoch) + '.pth')
            torch.save(local_attention_refine.state_dict(),
                       config.write_path_prefix + config.run_name + '/local_attention_refine_' + str(epoch) + '.pth')
            torch.save(local_au_net.state_dict(),
                       config.write_path_prefix + config.run_name + '/local_au_net_' + str(epoch) + '.pth')
            torch.save(global_au_feat.state_dict(),
                       config.write_path_prefix + config.run_name + '/global_au_feat_' + str(epoch) + '.pth')
            torch.save(au_net.state_dict(),
                       config.write_path_prefix + config.run_name + '/au_net_' + str(epoch) + '.pth')

        # eval in the train
        if epoch > config.start_epoch:
            print('testing ...')
            region_learning.train(False)
            align_net.train(False)
            local_attention_refine.train(False)
            local_au_net.train(False)
            global_au_feat.train(False)
            au_net.train(False)

            local_f1score_arr, local_acc_arr, f1score_arr, acc_arr, mean_error, failure_rate = AU_detection_evalv2(
                dset_loaders['test'], region_learning, align_net, local_attention_refine,
                local_au_net, global_au_feat, au_net, use_gpu=use_gpu)
            print('epoch =%d, local f1 score mean=%f, local accuracy mean=%f, '
                  'f1 score mean=%f, accuracy mean=%f, mean error=%f, failure rate=%f' % (epoch, local_f1score_arr.mean(),
                                local_acc_arr.mean(), f1score_arr.mean(),
                                acc_arr.mean(), mean_error, failure_rate))
            print('%d\t%f\t%f\t%f\t%f\t%f\t%f' % (epoch, local_f1score_arr.mean(),
                                                local_acc_arr.mean(), f1score_arr.mean(),
                                                acc_arr.mean(), mean_error, failure_rate), file=res_file)

            region_learning.train(True)
            align_net.train(True)
            local_attention_refine.train(True)
            local_au_net.train(True)
            global_au_feat.train(True)
            au_net.train(True)

        if epoch >= config.n_epochs:
            break

        for i, batch in enumerate(dset_loaders['train']):
            if i % config.display == 0 and count > 0:
                print('[epoch = %d][iter = %d][total_loss = %f][loss_au_softmax = %f][loss_au_dice = %f]'
                      '[loss_local_au_softmax = %f][loss_local_au_dice = %f]'
                      '[loss_land = %f]' % (epoch, i,
                    total_loss.data.cpu().numpy(), loss_au_softmax.data.cpu().numpy(), loss_au_dice.data.cpu().numpy(),
                    loss_local_au_softmax.data.cpu().numpy(), loss_local_au_dice.data.cpu().numpy(), loss_land.data.cpu().numpy()))
                print('learning rate = %f %f %f %f %f %f' % (optimizer.param_groups[0]['lr'],
                                                          optimizer.param_groups[1]['lr'],
                                                          optimizer.param_groups[2]['lr'],
                                                          optimizer.param_groups[3]['lr'],
                                                          optimizer.param_groups[4]['lr'],
                                                          optimizer.param_groups[5]['lr']))
                print('the number of training iterations is %d' % (count))

            input, land, biocular, au = batch

            if use_gpu:
                input, land, biocular, au = input.cuda(), land.float().cuda(), \
                                            biocular.float().cuda(), au.long().cuda()
            else:
                au = au.long()

            optimizer = lr_scheduler(param_lr, optimizer, epoch, config.gamma, config.stepsize, config.init_lr)
            optimizer.zero_grad()

            region_feat = region_learning(input)
            align_feat, align_output, aus_map = align_net(region_feat)
            if use_gpu:
                aus_map = aus_map.cuda()
            output_aus_map = local_attention_refine(aus_map.detach())
            local_au_out_feat, local_aus_output = local_au_net(region_feat, output_aus_map)
            global_au_out_feat = global_au_feat(region_feat)
            concat_au_feat = torch.cat((align_feat, global_au_out_feat, local_au_out_feat.detach()), 1)
            aus_output = au_net(concat_au_feat)

            loss_au_softmax = au_softmax_loss(aus_output, au, weight=au_weight)
            loss_au_dice = au_dice_loss(aus_output, au, weight=au_weight)
            loss_au = loss_au_softmax + loss_au_dice

            loss_local_au_softmax = au_softmax_loss(local_aus_output, au, weight=au_weight)
            loss_local_au_dice = au_dice_loss(local_aus_output, au, weight=au_weight)
            loss_local_au = loss_local_au_softmax + loss_local_au_dice

            loss_land = landmark_loss(align_output, land, biocular)

            total_loss = config.lambda_au * (loss_au + loss_local_au) + \
                         config.lambda_land * loss_land

            total_loss.backward()
            optimizer.step()

            count = count + 1

    res_file.close()
Ejemplo n.º 29
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)

    #     if prep_config["test_10crop"]:
    #         for i in range(10):
    #             dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
    #                                 transform=prep_dict["test"][i]) for i in range(10)]
    #             dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
    #                                 shuffle=False, num_workers=4) for dset in dsets['test']]
    #     else:
    #         dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
    #                                 transform=prep_dict["test"])
    #         dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
    #                                 shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer(
            [base_network.output_num(), class_num],
            config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(
            base_network.output_num() * class_num, 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network,
                                       device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        #         if i % config["test_interval"] == config["test_interval"] - 1:
        #             base_network.train(False)
        #             temp_acc = image_classification_test(dset_loaders, \
        #                 base_network, test_10crop=prep_config["test_10crop"])
        #             temp_model = nn.Sequential(base_network)
        #             if temp_acc > best_acc:
        #                 best_acc = temp_acc
        #                 best_model = temp_model
        #             log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
        #             config["out_file"].write(log_str+"\n")
        #             config["out_file"].flush()
        #             print(log_str)
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        loss_params = config["loss"]
        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)
        if config['method'] == 'CDAN+E':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy,
                                      network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None,
                                      None, random_layer)
        elif config['method'] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
        else:
            raise ValueError('Method cannot be recognized.')
        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        if i % 10 == 0:
            print('iter: ', i, 'classifier_loss: ', classifier_loss.data,
                  'transfer_loss: ', transfer_loss.data)
        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
    torch.save(best_model, osp.join(config["output_path"],
                                    "best_model.pth.tar"))
    return best_acc
Ejemplo n.º 30
0
from data_list import ImageList
import pre_process as prep
from resnet import make_resnet50_base

model_path = 'resnet50-base.pth'
cnn = make_resnet50_base()
cnn.load_state_dict(torch.load(model_path))
for param in cnn.parameters():
    param.requires_grad = False
cnn.cuda()
cnn.eval()
infile = open('S_features.csv', 'a')
TEST_LIST = 'WEB_3D3_2.txt'
prep_test = prep.image_test(resize_size=256, crop_size=224)
dsets_test = ImageList(open(TEST_LIST).readlines(),
                       shape=(256, 256),
                       transform=prep_test,
                       train=False)
loaders_test = util_data.DataLoader(dsets_test,
                                    batch_size=512,
                                    shuffle=False,
                                    num_workers=16,
                                    pin_memory=True)
for batch_id, batch in enumerate(loaders_test, 1):
    print(str(batch_id))
    data, label = batch
    data = data.cuda()
    features = cnn(data)
    now = features.cpu().data.numpy()
    np.savetxt(infile, now, delimiter=',')
    torch.cuda.empty_cache()