Beispiel #1
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]

    # seperate the source and validation set
    cls_source_list, cls_validation_list = sep.split_set(
        data_config["source"]["list_path"],
        config["network"]["params"]["class_num"])
    source_list = sep.dimension_rd(cls_source_list)

    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(source_list, \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                shuffle=False, num_workers=4) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    best_acc = 0.0
    best_model = 0

    ## set base network
    net_config = config["network"]
    if config["load_module"] == "":
        base_network = net_config["name"](**net_config["params"])
        base_network = base_network.cuda()

        ## add additional network for some methods
        if config["loss"]["random"]:
            random_layer = network.RandomLayer(
                [base_network.output_num(), class_num],
                config["loss"]["random_dim"])
            ad_net = network.AdversarialNetwork(config["loss"]["random_dim"],
                                                1024)
        else:
            random_layer = None
            ad_net = network.AdversarialNetwork(
                base_network.output_num() * class_num, 1024)
        if config["loss"]["random"]:
            random_layer.cuda()
        ad_net = ad_net.cuda()
        parameter_list = base_network.get_parameters() + ad_net.get_parameters(
        )

        ## set optimizer
        optimizer_config = config["optimizer"]
        optimizer = optimizer_config["type"](parameter_list, \
                                             **(optimizer_config["optim_params"]))
        param_lr = []
        for param_group in optimizer.param_groups:
            param_lr.append(param_group["lr"])
        schedule_param = optimizer_config["lr_param"]
        lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

        gpus = config['gpu'].split(',')
        if len(gpus) > 1:
            ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
            base_network = nn.DataParallel(base_network,
                                           device_ids=[int(i) for i in gpus])

        ## train
        len_train_source = len(dset_loaders["source"])
        len_train_target = len(dset_loaders["target"])
        transfer_loss_value = classifier_loss_value = total_loss_value = 0.0

        for i in range(config["num_iterations"]):
            if i % config["test_interval"] == config["test_interval"] - 1:
                base_network.train(False)
                temp_acc = image_classification_test(dset_loaders, \
                                                     base_network, test_10crop=prep_config["test_10crop"])
                temp_model = nn.Sequential(base_network)
                if temp_acc > best_acc:
                    best_acc = temp_acc
                    best_model = temp_model
                log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
                f.write('Accuracy is:' + log_str + '\n')
                config["out_file"].write(log_str + "\n")
                config["out_file"].flush()
                print(log_str)
            if i % config["snapshot_interval"] == 0:
                torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                                                                 "iter_{:05d}_model.pth".format(i)))

            loss_params = config["loss"]
            ## train one iter
            base_network.train(True)
            ad_net.train(True)
            optimizer = lr_scheduler(optimizer, i, **schedule_param)
            optimizer.zero_grad()
            if i % len_train_source == 0:
                iter_source = iter(dset_loaders["source"])
            if i % len_train_target == 0:
                iter_target = iter(dset_loaders["target"])
            inputs_source, labels_source = iter_source.next()
            inputs_target, labels_target = iter_target.next()
            inputs_source, inputs_target, labels_source = inputs_source.cuda(
            ), inputs_target.cuda(), labels_source.cuda()
            features_source, outputs_source = base_network(inputs_source)
            features_target, outputs_target = base_network(inputs_target)
            features = torch.cat((features_source, features_target), dim=0)
            outputs = torch.cat((outputs_source, outputs_target), dim=0)
            softmax_out = nn.Softmax(dim=1)(outputs)
            if config['method'] == 'CDAN+E':
                entropy = loss.Entropy(softmax_out)
                transfer_loss = loss.CDAN([features, softmax_out], ad_net,
                                          entropy, network.calc_coeff(i),
                                          random_layer)
            elif config['method'] == 'CDAN':
                transfer_loss = loss.CDAN([features, softmax_out], ad_net,
                                          None, None, random_layer)
            elif config['method'] == 'DANN':
                transfer_loss = loss.DANN(features, ad_net)
            else:
                raise ValueError('Method cannot be recognized.')
            classifier_loss = nn.CrossEntropyLoss()(outputs_source,
                                                    labels_source)
            total_loss = loss_params[
                "trade_off"] * transfer_loss + classifier_loss
            total_loss.backward()
            optimizer.step()
        torch.save(best_model, osp.join(config["output_path"],
                                        "best_model.pth"))
    else:
        base_network = torch.load(config["load_module"])
        use_gpu = torch.cuda.is_available()
        if use_gpu:
            base_network = base_network.cuda()
        base_network.train(False)
        temp_acc = image_classification_test(
            dset_loaders, base_network, test_10crop=prep_config["test_10crop"])
        temp_model = nn.Sequential(base_network)
        if temp_acc > best_acc:
            best_acc = temp_acc
            best_model = temp_model
        log_str = "iter: {:05d}, precision: {:.5f}".format(
            config["test_interval"], temp_acc)
        config["out_file"].write(log_str + "\n")
        config["out_file"].flush()
        print(log_str)
        if config["val_method"] == "Source_Risk":
            cv_loss = source_risk.cross_validation_loss(
                base_network, base_network, cls_source_list,
                data_config["target"]["list_path"], cls_validation_list,
                class_num, prep_config["params"]["resize_size"],
                prep_config["params"]["crop_size"],
                data_config["target"]["batch_size"], use_gpu)
        elif config["val_method"] == "Dev_icml":
            cv_loss = dev_icml.cross_validation_loss(
                config["load_module"], config["load_module"], source_list,
                data_config["target"]["list_path"], cls_validation_list,
                class_num, prep_config["params"]["resize_size"],
                prep_config["params"]["crop_size"],
                data_config["target"]["batch_size"], use_gpu)
        elif config["val_method"] == "Dev":
            cv_loss = dev.cross_validation_loss(
                config["load_module"], config["load_module"], cls_source_list,
                data_config["target"]["list_path"], cls_validation_list,
                class_num, prep_config["params"]["resize_size"],
                prep_config["params"]["crop_size"],
                data_config["target"]["batch_size"], use_gpu)
        print(cv_loss)
        f.write(config["val_method"] + ' Validation is:' + str(cv_loss) + '\n')

    # base_network = net_config["name"](**net_config["params"])
    # base_network = base_network.cuda()
    #
    # ## add additional network for some methods
    # if config["loss"]["random"]:
    #     random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"])
    #     ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    # else:
    #     random_layer = None
    #     ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)
    # if config["loss"]["random"]:
    #     random_layer.cuda()
    # ad_net = ad_net.cuda()
    # parameter_list = base_network.get_parameters() + ad_net.get_parameters()
    #
    # ## set optimizer
    # optimizer_config = config["optimizer"]
    # optimizer = optimizer_config["type"](parameter_list, \
    #                 **(optimizer_config["optim_params"]))
    # param_lr = []
    # for param_group in optimizer.param_groups:
    #     param_lr.append(param_group["lr"])
    # schedule_param = optimizer_config["lr_param"]
    # lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]
    #
    # gpus = config['gpu'].split(',')
    # if len(gpus) > 1:
    #     ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
    #     base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])
    #
    #
    # ## train
    # len_train_source = len(dset_loaders["source"])
    # len_train_target = len(dset_loaders["target"])
    # transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    #
    # for i in range(config["num_iterations"]):
    #     if i % config["test_interval"] == config["test_interval"] - 1:
    #         base_network.train(False)
    #         temp_acc = image_classification_test(dset_loaders, \
    #             base_network, test_10crop=prep_config["test_10crop"])
    #         temp_model = nn.Sequential(base_network)
    #         if temp_acc > best_acc:
    #             best_acc = temp_acc
    #             best_model = temp_model
    #         log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
    #         config["out_file"].write(log_str+"\n")
    #         config["out_file"].flush()
    #         print(log_str)
    #     if i % config["snapshot_interval"] == 0:
    #         torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
    #             "iter_{:05d}_model.pth".format(i)))
    #
    #     loss_params = config["loss"]
    #     ## train one iter
    #     base_network.train(True)
    #     ad_net.train(True)
    #     optimizer = lr_scheduler(optimizer, i, **schedule_param)
    #     optimizer.zero_grad()
    #     if i % len_train_source == 0:
    #         iter_source = iter(dset_loaders["source"])
    #     if i % len_train_target == 0:
    #         iter_target = iter(dset_loaders["target"])
    #     inputs_source, labels_source = iter_source.next()
    #     inputs_target, labels_target = iter_target.next()
    #     inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
    #     features_source, outputs_source = base_network(inputs_source)
    #     features_target, outputs_target = base_network(inputs_target)
    #     features = torch.cat((features_source, features_target), dim=0)
    #     outputs = torch.cat((outputs_source, outputs_target), dim=0)
    #     softmax_out = nn.Softmax(dim=1)(outputs)
    #     if config['method'] == 'CDAN+E':
    #         entropy = loss.Entropy(softmax_out)
    #         transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer)
    #     elif config['method']  == 'CDAN':
    #         transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer)
    #     elif config['method']  == 'DANN':
    #         transfer_loss = loss.DANN(features, ad_net)
    #     else:
    #         raise ValueError('Method cannot be recognized.')
    #     classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
    #     total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
    #     total_loss.backward()
    #     optimizer.step()
    # torch.save(best_model, osp.join(config["output_path"], "best_model.pth"))
    return best_acc
Beispiel #2
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataroot',
                        required=True,
                        help='path to source dataset')
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batchSize',
                        type=int,
                        default=100,
                        help='input batch size')
    parser.add_argument(
        '--imageSize',
        type=int,
        default=32,
        help='the height / width of the input image to network')
    parser.add_argument('--nz',
                        type=int,
                        default=512,
                        help='size of the latent z vector')
    parser.add_argument(
        '--ngf',
        type=int,
        default=64,
        help='Number of filters to use in the generator network')
    parser.add_argument(
        '--ndf',
        type=int,
        default=64,
        help='Number of filters to use in the discriminator network')
    parser.add_argument('--gpu',
                        type=int,
                        default=1,
                        help='GPU to use, -1 for CPU training')
    parser.add_argument('--checkpoint_dir',
                        default='results/models',
                        help='folder to load model checkpoints from')
    parser.add_argument('--method',
                        default='GTA',
                        help='Method to evaluate| GTA, sourceonly')
    parser.add_argument(
        '--model_best',
        type=int,
        default=0,
        help=
        'Flag to specify whether to use the best validation model or last checkpoint| 1-model best, 0-current checkpoint'
    )
    parser.add_argument('--src_path',
                        type=str,
                        default='digits/server_svhn_list.txt',
                        help='path for source dataset txt file')
    parser.add_argument('--tar_path',
                        type=str,
                        default='digits/server_mnist_list.txt',
                        help='path for target dataset txt file')
    parser.add_argument('--val_method',
                        type=str,
                        default='Source_Risk',
                        choices=['Source_Risk', 'Dev_icml', 'Dev'])
    opt = parser.parse_args()

    # GPU/CPU flags
    cudnn.benchmark = True
    if torch.cuda.is_available() and opt.gpu == -1:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --gpu [gpu id]"
        )
    if opt.gpu >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu)
        use_gpu = True

    # Creating data loaders
    mean = np.array([0.44, 0.44, 0.44])
    std = np.array([0.19, 0.19, 0.19])

    if 'svhn' in opt.src_path and 'mnist' in opt.tar_path:
        test_adaptation = 's->m'
    elif 'usps' in opt.src_path and 'mnist' in opt.tar_path:
        test_adaptation = 'u->m'
    else:
        test_adaptation = 'm->u'

    if test_adaptation == 'u->m' or test_adaptation == 's->m':
        target_root = os.path.join(opt.dataroot, 'mnist/trainset')

        transform_target = transforms.Compose([
            transforms.Resize(opt.imageSize),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        target_test = dset.ImageFolder(root=target_root,
                                       transform=transform_target)
        nclasses = len(target_test.classes)
    elif test_adaptation == 'm->u':
        transform_usps = transforms.Compose([
            transforms.Resize(opt.imageSize),
            transforms.Grayscale(3),
            transforms.ToTensor(),
            transforms.Normalize((0.44, ), (0.19, ))
        ])
        target_test = usps.USPS(root=opt.dataroot,
                                train=True,
                                transform=transform_usps,
                                download=True)
        nclasses = 10
    # target_root = os.path.join(opt.dataroot, 'mnist/trainset')
    #
    # transform_target = transforms.Compose([transforms.Resize(opt.imageSize), transforms.ToTensor(), transforms.Normalize(mean,std)])
    # target_test = dset.ImageFolder(root=target_root, transform=transform_target)
    targetloader = torch.utils.data.DataLoader(target_test,
                                               batch_size=opt.batchSize,
                                               shuffle=False,
                                               num_workers=2)

    # Creating and loading models

    netF = models._netF(opt)
    netC = models._netC(opt, nclasses)

    if opt.method == 'GTA':
        if opt.model_best == 0:
            netF_path = os.path.join(opt.checkpoint_dir, 'netF.pth')
            netC_path = os.path.join(opt.checkpoint_dir, 'netC.pth')
        else:
            netF_path = os.path.join(opt.checkpoint_dir, 'model_best_netF.pth')
            netC_path = os.path.join(opt.checkpoint_dir, 'model_best_netC.pth')

    elif opt.method == 'sourceonly':
        if opt.model_best == 0:
            netF_path = os.path.join(opt.checkpoint_dir, 'netF_sourceonly.pth')
            netC_path = os.path.join(opt.checkpoint_dir, 'netC_sourceonly.pth')
        else:
            netF_path = os.path.join(opt.checkpoint_dir,
                                     'model_best_netF_sourceonly.pth')
            netC_path = os.path.join(opt.checkpoint_dir,
                                     'model_best_netC_sourceonly.pth')
    else:
        raise ValueError('method argument should be sourceonly or GTA')

    netF.load_state_dict(torch.load(netF_path))
    netC.load_state_dict(torch.load(netC_path))

    if opt.gpu >= 0:
        netF.cuda()
        netC.cuda()

    # Testing

    netF.eval()
    netC.eval()

    total = 0
    correct = 0

    for i, datas in enumerate(targetloader):
        inputs, labels = datas
        if opt.gpu >= 0:
            inputs, labels = inputs.cuda(), labels.cuda()
        inputv, labelv = Variable(inputs, volatile=True), Variable(labels)

        outC = netC(netF(inputv))
        _, predicted = torch.max(outC.data, 1)
        total += labels.size(0)
        correct += ((predicted == labels.cuda()).sum())

    test_acc = 100 * float(correct) / total
    print('Test Accuracy: %f %%' % (test_acc))

    cls_source_list, cls_validation_list = sep.split_set(
        opt.src_path, nclasses)
    source_list = sep.dimension_rd(cls_source_list)
    # outC = netC(netF(inputv)) 是算 classification的
    # outF = netF(inputv)) 是算 feature的
    # netF.load_state_dict(torch.load(netF_path)) 是加载网络的 方式
    # crop size 不用
    if opt.val_method == 'Source_Risk':
        cv_loss = source_risk.cross_validation_loss(
            netF_path, netC_path, cls_source_list, opt.tar_path,
            cls_validation_list, nclasses, opt.imageSize, 224, opt.batchSize,
            use_gpu, opt)
    elif opt.val_method == 'Dev_icml':
        cv_loss = dev_icml.cross_validation_loss(netF_path, netC_path,
                                                 source_list, opt.tar_path,
                                                 cls_validation_list, nclasses,
                                                 opt.imageSize, 224,
                                                 opt.batchSize, use_gpu, opt)
    elif opt.val_method == 'Dev':
        cv_loss = dev.cross_validation_loss(netF_path, netC_path,
                                            cls_source_list, opt.tar_path,
                                            cls_validation_list, nclasses,
                                            opt.imageSize, 224, opt.batchSize,
                                            use_gpu, opt)
    print(cv_loss)
Beispiel #3
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])

    ## set loss
    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]

    # seperate the source and validation set
    cls_source_list, cls_validation_list = sep.split_set(
        data_config["source"]["list_path"],
        config["network"]["params"]["class_num"])
    source_list = sep.dimension_rd(cls_source_list)

    dsets["source"] = ImageList(source_list, \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
            batch_size=data_config["source"]["batch_size"], \
            shuffle=True, num_workers=4)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
            batch_size=data_config["target"]["batch_size"], \
            shuffle=True, num_workers=4)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"+str(i)] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

            dsets["target"+str(i)] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

        dsets["target_test"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.bottleneck.parameters(), "lr":10}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
    else:
        parameter_list = [{"params": base_network.parameters(), "lr": 1}]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])
    if use_gpu:
        ad_net = ad_net.cuda()
    parameter_list.append({"params": ad_net.parameters(), "lr": 10})

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    best_model = 0
    for i in range(config["num_iterations"]):
        if (i + 1) % config["test_interval"] == 0:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"], \
                gpu=use_gpu)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            print(log_str)
        if (i + 1) % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        # if i % loss_params["update_iter"] == loss_params["update_iter"] - 1:
        #     base_network.train(False)
        #     target_fc8_out = image_classification_predict(dset_loaders, base_network, softmax_param=config["softmax_param"])
        #     class_weight = torch.mean(target_fc8_out, 0)
        #     class_weight = (class_weight / torch.mean(class_weight)).cuda().view(-1)
        #     class_criterion = nn.CrossEntropyLoss(weight = class_weight)

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        features, outputs = base_network(inputs)

        # #
        # if i % 100 == 0:
        #     check = dev.get_label_list(open(data_config["source"]["list_path"]).readlines(),
        #                                base_network,
        #                                prep_config["resize_size"],
        #                                prep_config["crop_size"],
        #                                data_config["target"]["batch_size"],
        #                                use_gpu)
        #     f = open("Class_result.txt", "a+")
        #     f.close()
        #     for cls in range(class_num):
        #         count = 0
        #         for j in check:
        #             if int(j.split(" ")[1].replace("\n", "")) == cls:
        #                 count = count + 1
        #         f = open("Source_result.txt", "a+")
        #         f.write("Source_Class: " + str(cls) + "\n" + "Number of images: " + str(count) + "\n")
        #         f.close()
        #
        #     check = dev.get_label_list(open(data_config["target"]["list_path"]).readlines(),
        #                                base_network,
        #                                prep_config["resize_size"],
        #                                prep_config["crop_size"],
        #                                data_config["target"]["batch_size"],
        #                                use_gpu)
        #     f = open("Class_result.txt", "a+")
        #     f.write("Iteration: " + str(i) + "\n")
        #     f.close()
        #     for cls in range(class_num):
        #         count = 0
        #         for j in check:
        #             if int(j.split(" ")[1].replace("\n", "")) == cls:
        #                 count = count + 1
        #         f = open("Class_result.txt", "a+")
        #         f.write("Target_Class: " + str(cls) + "\n" + "Number of images: " + str(count) + "\n")
        #         f.close()

        #
        # #
        # print("Training test:")
        # print(features)
        # print(features.shape)
        # print(outputs)
        # print(outputs.shape)

        softmax_out = nn.Softmax(dim=1)(outputs).detach()
        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        # label_numpy = labels_source.data.cpu().numpy()
        # for j in range(int(inputs.size(0) / 2)):
        #     weight_ad[j] = class_weight[int(label_numpy[j])]
        # weight_ad = weight_ad / torch.max(weight_ad[0:int(inputs.size(0)/2)])
        # for j in range(int(inputs.size(0) / 2), inputs.size(0)):
        #     weight_ad[j] = 1.0
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                           weight_ad, use_gpu)

        classifier_loss = class_criterion(
            outputs.narrow(0, 0, int(inputs.size(0) / 2)), labels_source)

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
    cv_loss = dev.cross_validation_loss(
        base_network, base_network, cls_source_list,
        data_config["target"]["list_path"], cls_validation_list, class_num,
        prep_config["resize_size"], prep_config["crop_size"],
        data_config["target"]["batch_size"], use_gpu)
    print(cv_loss)
Beispiel #4
0
def train(num_epoch, option, num_layer, test_load, cuda):
    criterion = nn.CrossEntropyLoss().cuda()
    if not test_load:
        for ep in range(num_epoch):
            G.train()
            F1.train()
            F2.train()
            for batch_idx, data in enumerate(dataset):
                if batch_idx * batch_size > 30000:
                    break
                if args.cuda:
                    data1 = data['S']
                    target1 = data['S_label']
                    data2 = data['T']
                    target2 = data['T_label']
                    data1, target1 = data1.cuda(), target1.cuda()
                    data2, target2 = data2.cuda(), target2.cuda()
                # when pretraining network source only
                eta = 1.0
                data = Variable(torch.cat((data1, data2), 0))
                target1 = Variable(target1)
                # Step A train all networks to minimize loss on source
                optimizer_g.zero_grad()
                optimizer_f.zero_grad()
                output = G(data)
                output1 = F1(output)
                output2 = F2(output)

                output_s1 = output1[:batch_size, :]
                output_s2 = output2[:batch_size, :]
                output_t1 = output1[batch_size:, :]
                output_t2 = output2[batch_size:, :]
                output_t1 = F.softmax(output_t1)
                output_t2 = F.softmax(output_t2)

                entropy_loss = -torch.mean(
                    torch.log(torch.mean(output_t1, 0) + 1e-6))
                entropy_loss -= torch.mean(
                    torch.log(torch.mean(output_t2, 0) + 1e-6))

                loss1 = criterion(output_s1, target1)
                loss2 = criterion(output_s2, target1)
                all_loss = loss1 + loss2 + 0.01 * entropy_loss
                all_loss.backward()
                optimizer_g.step()
                optimizer_f.step()

                # Step B train classifier to maximize discrepancy
                optimizer_g.zero_grad()
                optimizer_f.zero_grad()

                output = G(data)
                output1 = F1(output)
                output2 = F2(output)
                output_s1 = output1[:batch_size, :]
                output_s2 = output2[:batch_size, :]
                output_t1 = output1[batch_size:, :]
                output_t2 = output2[batch_size:, :]
                output_t1 = F.softmax(output_t1)
                output_t2 = F.softmax(output_t2)
                loss1 = criterion(output_s1, target1)
                loss2 = criterion(output_s2, target1)
                entropy_loss = -torch.mean(
                    torch.log(torch.mean(output_t1, 0) + 1e-6))
                entropy_loss -= torch.mean(
                    torch.log(torch.mean(output_t2, 0) + 1e-6))
                loss_dis = torch.mean(torch.abs(output_t1 - output_t2))
                F_loss = loss1 + loss2 - eta * loss_dis + 0.01 * entropy_loss
                F_loss.backward()
                optimizer_f.step()
                # Step C train genrator to minimize discrepancy
                for i in range(num_k):
                    optimizer_g.zero_grad()
                    output = G(data)
                    output1 = F1(output)
                    output2 = F2(output)

                    output_s1 = output1[:batch_size, :]
                    output_s2 = output2[:batch_size, :]
                    output_t1 = output1[batch_size:, :]
                    output_t2 = output2[batch_size:, :]

                    loss1 = criterion(output_s1, target1)
                    loss2 = criterion(output_s2, target1)
                    output_t1 = F.softmax(output_t1)
                    output_t2 = F.softmax(output_t2)
                    loss_dis = torch.mean(torch.abs(output_t1 - output_t2))
                    entropy_loss = -torch.mean(
                        torch.log(torch.mean(output_t1, 0) + 1e-6))
                    entropy_loss -= torch.mean(
                        torch.log(torch.mean(output_t2, 0) + 1e-6))

                    loss_dis.backward()
                    optimizer_g.step()
                if batch_idx % args.log_interval == 0:
                    print(
                        'Train Ep: {} [{}/{} ({:.0f}%)]\tLoss1: {:.6f}\tLoss2: {:.6f}\t Dis: {:.6f} Entropy: {:.6f}'
                        .format(ep, batch_idx * len(data), 70000,
                                100. * batch_idx / 70000, loss1.data[0],
                                loss2.data[0], loss_dis.data[0],
                                entropy_loss.data[0]))
                if batch_idx == 1 and ep > 1:
                    test(ep)
                    G.train()
                    F1.train()
                    F2.train()
    else:
        G_load = ResBase(option)
        F1_load = ResClassifier(num_layer=num_layer)
        F2_load = ResClassifier(num_layer=num_layer)

        F1_load.apply(weights_init)
        F2_load.apply(weights_init)
        G_path = args.load_network_path + 'G.pth'
        F1_path = args.load_network_path + 'F1.pth'
        F2_path = args.load_network_path + 'F2.pth'
        G_load.load_state_dict(torch.load(G_path))
        F1_load.load_state_dict(torch.load(F1_path))
        F2_load.load_state_dict(torch.load(F2_path))
        #
        # G_load = torch.load('whole_model_G.pth')
        # F1_load = torch.load('whole_model_F1.pth')
        # F2_load = torch.load('whole_model_F2.pth')
        if cuda:
            G_load.cuda()
            F1_load.cuda()
            F2_load.cuda()
        G_load.eval()
        F1_load.eval()
        F2_load.eval()
        test_loss = 0
        correct = 0
        correct2 = 0
        size = 0

        val = False
        for batch_idx, data in enumerate(dataset_test):
            if batch_idx * batch_size > 5000:
                break
            if args.cuda:
                data2 = data['T']
                target2 = data['T_label']
                if val:
                    data2 = data['S']
                    target2 = data['S_label']
                data2, target2 = data2.cuda(), target2.cuda()
            data1, target1 = Variable(data2, volatile=True), Variable(target2)
            output = G_load(data1)
            output1 = F1_load(output)
            output2 = F2_load(output)
            # print("Feature: {}\n Predict_value: {}".format(output, output2))
            test_loss += F.nll_loss(output1, target1).item()
            pred = output1.data.max(1)[
                1]  # get the index of the max log-probability
            correct += pred.eq(target1.data).cpu().sum()
            pred = output2.data.max(1)[
                1]  # get the index of the max log-probability
            k = target1.data.size()[0]
            correct2 += pred.eq(target1.data).cpu().sum()

            size += k
        test_loss = test_loss
        test_loss /= len(
            test_loader)  # loss function already averages over batch size
        print(
            '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%) ({:.0f}%)\n'
            .format(test_loss, correct, size, 100. * correct / size,
                    100. * correct2 / size))
        acc = 100. * correct / size
        f.write('Accuracy is:' + str(acc) + '%' + '\n')
        value = max(100. * correct / size, 100. * correct2 / size)
        print("Value: {}".format(value))
        if args.cuda:
            use_gpu = True
        else:
            use_gpu = False
        # configure network path

        if (100. * correct / size) > (100. * correct2 / size):
            predict_network_path = F1_path
        else:
            predict_network_path = F2_path

        feature_network_path = G_path

        # configure the datapath for target and test
        source_path = 'chn_training_list.txt'
        target_path = 'chn_validation_list.txt'
        cls_source_list, cls_validation_list = sep.split_set(
            source_path, class_num)
        source_list = sep.dimension_rd(cls_source_list)

        if args.validation_method == 'Source_Risk':
            cv_loss = source_risk.cross_validation_loss(
                args, feature_network_path, predict_network_path, num_layer,
                cls_source_list, target_path, cls_validation_list, class_num,
                256, 224, batch_size, use_gpu)
        elif args.validation_method == 'Dev_icml':
            cv_loss = dev_icml.cross_validation_loss(
                args, feature_network_path, predict_network_path, num_layer,
                source_list, target_path, cls_validation_list, class_num, 256,
                224, batch_size, use_gpu)
        else:
            cv_loss = dev.cross_validation_loss(args, feature_network_path,
                                                predict_network_path,
                                                num_layer, source_list,
                                                target_path,
                                                cls_validation_list, class_num,
                                                256, 224, batch_size, use_gpu)
        print(cv_loss)
        f.write(args.validation_method + ' Validation is:' + str(cv_loss) +
                '\n')