Пример #1
0
def cross_validation_loss(args, feature_network_path, predict_network_path,
                          num_layer, src_list, target_path, val_list,
                          class_num, resize_size, crop_size, batch_size,
                          use_gpu):
    """
    Main function for computing the CV loss
    :param feature_network:
    :param predict_network:
    :param src_cls_list:
    :param target_path:
    :param val_cls_list:
    :param class_num:
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    option = 'resnet' + args.resnet
    G = ResBase(option)
    F1 = ResClassifier(num_layer=num_layer)

    G.load_state_dict(torch.load(feature_network_path))
    F1.load_state_dict(torch.load(predict_network_path))
    if use_gpu:
        G.cuda()
        F1.cuda()
    G.eval()
    F1.eval()

    val_list = seperate_data.dimension_rd(val_list)

    tar_list = open(target_path).readlines()
    cross_val_loss = 0

    prep_dict = prep.image_train(resize_size=resize_size, crop_size=crop_size)
    # load different class's image

    dsets_src = ImageList(src_list, transform=prep_dict)
    dset_loaders_src = util_data.DataLoader(dsets_src,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)

    # prepare source feature

    iter_src = iter(dset_loaders_src)
    src_input, src_labels = iter_src.next()
    if use_gpu:
        src_input, src_labels = Variable(src_input).cuda(), Variable(
            src_labels).cuda()
    else:
        src_input, src_labels = Variable(src_input), Variable(src_labels)
    feature_val = G(src_input)
    src_feature_de = feature_val.detach().detach().cpu().numpy()

    for _ in range(len(dset_loaders_src) - 1):
        src_input, src_labels = iter_src.next()
        if use_gpu:
            src_input, src_labels = Variable(src_input).cuda(), Variable(
                src_labels).cuda()
        else:
            src_input, src_labels = Variable(src_input), Variable(src_labels)
        src_feature_new = G(src_input)
        src_feature_new_de = src_feature_new.detach().cpu().numpy()
        src_feature_de = np.append(src_feature_de, src_feature_new_de, axis=0)
    print("Created Source feature: {}".format(src_feature_de.shape))

    # prepare target feature

    dsets_tar = ImageList(tar_list, transform=prep_dict)
    dset_loaders_tar = util_data.DataLoader(dsets_tar,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)
    iter_tar = iter(dset_loaders_tar)
    tar_input, _ = iter_tar.next()
    if use_gpu:
        tar_input, _ = Variable(tar_input).cuda(), Variable(_).cuda()
    else:
        src_input, _ = Variable(tar_input), Variable(_)
    tar_feature = G(tar_input)
    tar_feature_de = tar_feature.detach().cpu().numpy()
    for _ in range(len(dset_loaders_tar) - 1):
        tar_input, _ = iter_tar.next()
        if use_gpu:
            tar_input, _ = Variable(tar_input).cuda(), Variable(_).cuda()
        else:
            src_input, _ = Variable(tar_input), Variable(_)
        tar_feature_new = G(tar_input)
        tar_feature_new_de = tar_feature_new.detach().cpu().numpy()
        tar_feature_de = np.append(tar_feature_de, tar_feature_new_de, axis=0)
    print("Created Target feature: {}".format(tar_feature_de.shape))

    # prepare validation feature

    dsets_val = ImageList(val_list, transform=prep_dict)
    dset_loaders_val = util_data.DataLoader(dsets_val,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)
    iter_val = iter(dset_loaders_val)
    val_input, val_labels = iter_val.next()
    if use_gpu:
        val_input, val_labels = Variable(val_input).cuda(), Variable(
            val_labels).cuda()
    else:
        val_input, val_labels = Variable(val_input), Variable(val_labels)
    val_feature = G(val_input)
    pred_label = F1(val_feature)
    val_feature_de = val_feature.detach().cpu().numpy()
    w = pred_label[0].shape[0]
    error = np.zeros(1)
    error[0] = predict_loss(val_labels[0].item(),
                            pred_label[0].reshape(1, w)).item()
    error = error.reshape(1, 1)
    print("Before the final")
    print(pred_label.shape)

    for num_image in range(1, len(pred_label)):
        new_error = np.zeros(1)
        single_pred_label = pred_label[num_image]
        w = single_pred_label.shape[0]
        single_val_label = val_labels[num_image]
        new_error[0] = predict_loss(single_val_label.item(),
                                    single_pred_label.reshape(1, w)).item()
        new_error = new_error.reshape(1, 1)
        error = np.append(error, new_error, axis=0)
    for _ in range(len(dset_loaders_val) - 1):
        val_input, val_labels = iter_val.next()
        if use_gpu:
            val_input, val_labels = Variable(val_input).cuda(), Variable(
                val_labels).cuda()
        else:
            val_input, val_labels = Variable(val_input), Variable(val_labels)
        val_feature_new = G(val_input)
        val_feature_new_de = val_feature_new.detach().cpu().numpy()
        val_feature_de = np.append(val_feature_de, val_feature_new_de, axis=0)
        pred_label = F1(val_feature_new)
        for num_image in range(len(pred_label)):
            new_error = np.zeros(1)
            single_pred_label = pred_label[num_image]
            w = single_pred_label.shape[0]
            single_val_label = val_labels[num_image]
            new_error[0] = predict_loss(single_val_label.item(),
                                        single_pred_label.reshape(1,
                                                                  w)).item()
            new_error = new_error.reshape(1, 1)
            error = np.append(error, new_error, axis=0)

    print("Created Validation error shape: {}".format(error.shape))
    print("Created Validation feature: {}".format(val_feature_de.shape))

    if not os.path.exists(args.save.split("/")[0] + "/feature_np/"):
        os.makedirs(args.save.split("/")[0] + "/feature_np/")

    np.save(
        args.save.split("/")[0] + "/feature_np/" + "src_feature_de.npy",
        src_feature_de)
    np.save(
        args.save.split("/")[0] + "/feature_np/" + "tar_feature_de.npy",
        tar_feature_de)
    np.save(
        args.save.split("/")[0] + "/feature_np/" + "val_feature_de.npy",
        val_feature_de)
    src_feature_path = args.save.split(
        "/")[0] + "/feature_np/" + "_" + "src_feature_de.npy"
    tar_feature_path = args.save.split(
        "/")[0] + "/feature_np/" + "_" + "tar_feature_de.npy"
    val_feature_path = args.save.split(
        "/")[0] + "/feature_np/" + "_" + "val_feature_de.npy"
    weight = get_weight(src_feature_path, tar_feature_path, val_feature_path)
    cross_val_loss = cross_val_loss + get_dev_risk(weight, error)

    return cross_val_loss
Пример #2
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]

    # seperate the source and validation set
    cls_source_list, cls_validation_list = sep.split_set(
        data_config["source"]["list_path"],
        config["network"]["params"]["class_num"])
    source_list = sep.dimension_rd(cls_source_list)

    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(source_list, \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
            shuffle=True, num_workers=4, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                shuffle=False, num_workers=4) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    best_acc = 0.0
    best_model = 0

    ## set base network
    net_config = config["network"]
    if config["load_module"] == "":
        base_network = net_config["name"](**net_config["params"])
        base_network = base_network.cuda()

        ## add additional network for some methods
        if config["loss"]["random"]:
            random_layer = network.RandomLayer(
                [base_network.output_num(), class_num],
                config["loss"]["random_dim"])
            ad_net = network.AdversarialNetwork(config["loss"]["random_dim"],
                                                1024)
        else:
            random_layer = None
            ad_net = network.AdversarialNetwork(
                base_network.output_num() * class_num, 1024)
        if config["loss"]["random"]:
            random_layer.cuda()
        ad_net = ad_net.cuda()
        parameter_list = base_network.get_parameters() + ad_net.get_parameters(
        )

        ## set optimizer
        optimizer_config = config["optimizer"]
        optimizer = optimizer_config["type"](parameter_list, \
                                             **(optimizer_config["optim_params"]))
        param_lr = []
        for param_group in optimizer.param_groups:
            param_lr.append(param_group["lr"])
        schedule_param = optimizer_config["lr_param"]
        lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

        gpus = config['gpu'].split(',')
        if len(gpus) > 1:
            ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
            base_network = nn.DataParallel(base_network,
                                           device_ids=[int(i) for i in gpus])

        ## train
        len_train_source = len(dset_loaders["source"])
        len_train_target = len(dset_loaders["target"])
        transfer_loss_value = classifier_loss_value = total_loss_value = 0.0

        for i in range(config["num_iterations"]):
            if i % config["test_interval"] == config["test_interval"] - 1:
                base_network.train(False)
                temp_acc = image_classification_test(dset_loaders, \
                                                     base_network, test_10crop=prep_config["test_10crop"])
                temp_model = nn.Sequential(base_network)
                if temp_acc > best_acc:
                    best_acc = temp_acc
                    best_model = temp_model
                log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
                f.write('Accuracy is:' + log_str + '\n')
                config["out_file"].write(log_str + "\n")
                config["out_file"].flush()
                print(log_str)
            if i % config["snapshot_interval"] == 0:
                torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                                                                 "iter_{:05d}_model.pth".format(i)))

            loss_params = config["loss"]
            ## train one iter
            base_network.train(True)
            ad_net.train(True)
            optimizer = lr_scheduler(optimizer, i, **schedule_param)
            optimizer.zero_grad()
            if i % len_train_source == 0:
                iter_source = iter(dset_loaders["source"])
            if i % len_train_target == 0:
                iter_target = iter(dset_loaders["target"])
            inputs_source, labels_source = iter_source.next()
            inputs_target, labels_target = iter_target.next()
            inputs_source, inputs_target, labels_source = inputs_source.cuda(
            ), inputs_target.cuda(), labels_source.cuda()
            features_source, outputs_source = base_network(inputs_source)
            features_target, outputs_target = base_network(inputs_target)
            features = torch.cat((features_source, features_target), dim=0)
            outputs = torch.cat((outputs_source, outputs_target), dim=0)
            softmax_out = nn.Softmax(dim=1)(outputs)
            if config['method'] == 'CDAN+E':
                entropy = loss.Entropy(softmax_out)
                transfer_loss = loss.CDAN([features, softmax_out], ad_net,
                                          entropy, network.calc_coeff(i),
                                          random_layer)
            elif config['method'] == 'CDAN':
                transfer_loss = loss.CDAN([features, softmax_out], ad_net,
                                          None, None, random_layer)
            elif config['method'] == 'DANN':
                transfer_loss = loss.DANN(features, ad_net)
            else:
                raise ValueError('Method cannot be recognized.')
            classifier_loss = nn.CrossEntropyLoss()(outputs_source,
                                                    labels_source)
            total_loss = loss_params[
                "trade_off"] * transfer_loss + classifier_loss
            total_loss.backward()
            optimizer.step()
        torch.save(best_model, osp.join(config["output_path"],
                                        "best_model.pth"))
    else:
        base_network = torch.load(config["load_module"])
        use_gpu = torch.cuda.is_available()
        if use_gpu:
            base_network = base_network.cuda()
        base_network.train(False)
        temp_acc = image_classification_test(
            dset_loaders, base_network, test_10crop=prep_config["test_10crop"])
        temp_model = nn.Sequential(base_network)
        if temp_acc > best_acc:
            best_acc = temp_acc
            best_model = temp_model
        log_str = "iter: {:05d}, precision: {:.5f}".format(
            config["test_interval"], temp_acc)
        config["out_file"].write(log_str + "\n")
        config["out_file"].flush()
        print(log_str)
        if config["val_method"] == "Source_Risk":
            cv_loss = source_risk.cross_validation_loss(
                base_network, base_network, cls_source_list,
                data_config["target"]["list_path"], cls_validation_list,
                class_num, prep_config["params"]["resize_size"],
                prep_config["params"]["crop_size"],
                data_config["target"]["batch_size"], use_gpu)
        elif config["val_method"] == "Dev_icml":
            cv_loss = dev_icml.cross_validation_loss(
                config["load_module"], config["load_module"], source_list,
                data_config["target"]["list_path"], cls_validation_list,
                class_num, prep_config["params"]["resize_size"],
                prep_config["params"]["crop_size"],
                data_config["target"]["batch_size"], use_gpu)
        elif config["val_method"] == "Dev":
            cv_loss = dev.cross_validation_loss(
                config["load_module"], config["load_module"], cls_source_list,
                data_config["target"]["list_path"], cls_validation_list,
                class_num, prep_config["params"]["resize_size"],
                prep_config["params"]["crop_size"],
                data_config["target"]["batch_size"], use_gpu)
        print(cv_loss)
        f.write(config["val_method"] + ' Validation is:' + str(cv_loss) + '\n')

    # base_network = net_config["name"](**net_config["params"])
    # base_network = base_network.cuda()
    #
    # ## add additional network for some methods
    # if config["loss"]["random"]:
    #     random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"])
    #     ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    # else:
    #     random_layer = None
    #     ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)
    # if config["loss"]["random"]:
    #     random_layer.cuda()
    # ad_net = ad_net.cuda()
    # parameter_list = base_network.get_parameters() + ad_net.get_parameters()
    #
    # ## set optimizer
    # optimizer_config = config["optimizer"]
    # optimizer = optimizer_config["type"](parameter_list, \
    #                 **(optimizer_config["optim_params"]))
    # param_lr = []
    # for param_group in optimizer.param_groups:
    #     param_lr.append(param_group["lr"])
    # schedule_param = optimizer_config["lr_param"]
    # lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]
    #
    # gpus = config['gpu'].split(',')
    # if len(gpus) > 1:
    #     ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
    #     base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])
    #
    #
    # ## train
    # len_train_source = len(dset_loaders["source"])
    # len_train_target = len(dset_loaders["target"])
    # transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    #
    # for i in range(config["num_iterations"]):
    #     if i % config["test_interval"] == config["test_interval"] - 1:
    #         base_network.train(False)
    #         temp_acc = image_classification_test(dset_loaders, \
    #             base_network, test_10crop=prep_config["test_10crop"])
    #         temp_model = nn.Sequential(base_network)
    #         if temp_acc > best_acc:
    #             best_acc = temp_acc
    #             best_model = temp_model
    #         log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
    #         config["out_file"].write(log_str+"\n")
    #         config["out_file"].flush()
    #         print(log_str)
    #     if i % config["snapshot_interval"] == 0:
    #         torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
    #             "iter_{:05d}_model.pth".format(i)))
    #
    #     loss_params = config["loss"]
    #     ## train one iter
    #     base_network.train(True)
    #     ad_net.train(True)
    #     optimizer = lr_scheduler(optimizer, i, **schedule_param)
    #     optimizer.zero_grad()
    #     if i % len_train_source == 0:
    #         iter_source = iter(dset_loaders["source"])
    #     if i % len_train_target == 0:
    #         iter_target = iter(dset_loaders["target"])
    #     inputs_source, labels_source = iter_source.next()
    #     inputs_target, labels_target = iter_target.next()
    #     inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
    #     features_source, outputs_source = base_network(inputs_source)
    #     features_target, outputs_target = base_network(inputs_target)
    #     features = torch.cat((features_source, features_target), dim=0)
    #     outputs = torch.cat((outputs_source, outputs_target), dim=0)
    #     softmax_out = nn.Softmax(dim=1)(outputs)
    #     if config['method'] == 'CDAN+E':
    #         entropy = loss.Entropy(softmax_out)
    #         transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer)
    #     elif config['method']  == 'CDAN':
    #         transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer)
    #     elif config['method']  == 'DANN':
    #         transfer_loss = loss.DANN(features, ad_net)
    #     else:
    #         raise ValueError('Method cannot be recognized.')
    #     classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
    #     total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
    #     total_loss.backward()
    #     optimizer.step()
    # torch.save(best_model, osp.join(config["output_path"], "best_model.pth"))
    return best_acc
Пример #3
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])

    ## set loss
    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    loss_params = config["loss"]

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]

    # seperate the source and validation set
    cls_source_list, cls_validation_list = sep.split_set(
        data_config["source"]["list_path"],
        config["network"]["params"]["class_num"])
    source_list = sep.dimension_rd(cls_source_list)

    dsets["source"] = ImageList(source_list, \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
            batch_size=data_config["source"]["batch_size"], \
            shuffle=True, num_workers=4)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
            batch_size=data_config["target"]["batch_size"], \
            shuffle=True, num_workers=4)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"+str(i)] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

            dsets["target"+str(i)] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

        dsets["target_test"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=4)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])

    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.bottleneck.parameters(), "lr":10}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr":1}, \
                            {"params":base_network.fc.parameters(), "lr":10}]
    else:
        parameter_list = [{"params": base_network.parameters(), "lr": 1}]

    ## add additional network for some methods
    class_weight = torch.from_numpy(np.array([1.0] * class_num))
    if use_gpu:
        class_weight = class_weight.cuda()
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(
        high_value=config["high"])
    if use_gpu:
        ad_net = ad_net.cuda()
    parameter_list.append({"params": ad_net.parameters(), "lr": 10})

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    best_model = 0
    for i in range(config["num_iterations"]):
        if (i + 1) % config["test_interval"] == 0:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"], \
                gpu=use_gpu)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            print(log_str)
        if (i + 1) % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                "iter_{:05d}_model.pth.tar".format(i)))

        # if i % loss_params["update_iter"] == loss_params["update_iter"] - 1:
        #     base_network.train(False)
        #     target_fc8_out = image_classification_predict(dset_loaders, base_network, softmax_param=config["softmax_param"])
        #     class_weight = torch.mean(target_fc8_out, 0)
        #     class_weight = (class_weight / torch.mean(class_weight)).cuda().view(-1)
        #     class_criterion = nn.CrossEntropyLoss(weight = class_weight)

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)

        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        features, outputs = base_network(inputs)

        # #
        # if i % 100 == 0:
        #     check = dev.get_label_list(open(data_config["source"]["list_path"]).readlines(),
        #                                base_network,
        #                                prep_config["resize_size"],
        #                                prep_config["crop_size"],
        #                                data_config["target"]["batch_size"],
        #                                use_gpu)
        #     f = open("Class_result.txt", "a+")
        #     f.close()
        #     for cls in range(class_num):
        #         count = 0
        #         for j in check:
        #             if int(j.split(" ")[1].replace("\n", "")) == cls:
        #                 count = count + 1
        #         f = open("Source_result.txt", "a+")
        #         f.write("Source_Class: " + str(cls) + "\n" + "Number of images: " + str(count) + "\n")
        #         f.close()
        #
        #     check = dev.get_label_list(open(data_config["target"]["list_path"]).readlines(),
        #                                base_network,
        #                                prep_config["resize_size"],
        #                                prep_config["crop_size"],
        #                                data_config["target"]["batch_size"],
        #                                use_gpu)
        #     f = open("Class_result.txt", "a+")
        #     f.write("Iteration: " + str(i) + "\n")
        #     f.close()
        #     for cls in range(class_num):
        #         count = 0
        #         for j in check:
        #             if int(j.split(" ")[1].replace("\n", "")) == cls:
        #                 count = count + 1
        #         f = open("Class_result.txt", "a+")
        #         f.write("Target_Class: " + str(cls) + "\n" + "Number of images: " + str(count) + "\n")
        #         f.close()

        #
        # #
        # print("Training test:")
        # print(features)
        # print(features.shape)
        # print(outputs)
        # print(outputs.shape)

        softmax_out = nn.Softmax(dim=1)(outputs).detach()
        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        # label_numpy = labels_source.data.cpu().numpy()
        # for j in range(int(inputs.size(0) / 2)):
        #     weight_ad[j] = class_weight[int(label_numpy[j])]
        # weight_ad = weight_ad / torch.max(weight_ad[0:int(inputs.size(0)/2)])
        # for j in range(int(inputs.size(0) / 2), inputs.size(0)):
        #     weight_ad[j] = 1.0
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                           weight_ad, use_gpu)

        classifier_loss = class_criterion(
            outputs.narrow(0, 0, int(inputs.size(0) / 2)), labels_source)

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss
        total_loss.backward()
        optimizer.step()
    cv_loss = dev.cross_validation_loss(
        base_network, base_network, cls_source_list,
        data_config["target"]["list_path"], cls_validation_list, class_num,
        prep_config["resize_size"], prep_config["crop_size"],
        data_config["target"]["batch_size"], use_gpu)
    print(cv_loss)
Пример #4
0
def cross_validation_loss(args, feature_network_path, predict_network_path,
                          num_layer, src_list, target_path, val_list,
                          class_num, resize_size, crop_size, batch_size,
                          use_gpu):
    """
    Main function for computing the CV loss
    :param feature_network:
    :param predict_network:
    :param src_cls_list:
    :param target_path:
    :param val_cls_list:
    :param class_num:
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    option = 'resnet' + args.resnet
    G = ResBase(option)
    F1 = ResClassifier(num_layer=num_layer)

    G.load_state_dict(torch.load(feature_network_path))
    F1.load_state_dict(torch.load(predict_network_path))
    if use_gpu:
        G.cuda()
        F1.cuda()
    G.eval()
    F1.eval()

    val_list = seperate_data.dimension_rd(val_list)

    tar_list = open(target_path).readlines()
    cross_val_loss = 0

    prep_dict = prep.image_train(resize_size=resize_size, crop_size=crop_size)
    # load different class's image

    dsets_src = ImageList(src_list, transform=prep_dict)
    dset_loaders_src = util_data.DataLoader(dsets_src,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)
    # dsets_val = ImageList(val_list, transform=prep_dict)
    # dset_loaders_val = util_data.DataLoader(dsets_val, batch_size=batch_size, shuffle=True, num_workers=4)
    # dsets_tar = ImageList(tar_list, transform=prep_dict)
    # dset_loaders_tar = util_data.DataLoader(dsets_tar, batch_size=batch_size, shuffle=True, num_workers=4)

    # prepare source feature
    iter_src = iter(dset_loaders_src)
    src_input, src_labels = iter_src.next()
    if use_gpu:
        src_input, src_labels = Variable(src_input).cuda(), Variable(
            src_labels).cuda()
    else:
        src_input, src_labels = Variable(src_input), Variable(src_labels)
    # src_feature, _ = feature_network(src_input)

    feature_val = G(src_input)
    src_feature_de = feature_val.detach().detach().cpu().numpy()

    # src_feature_de = src_feature.detach().cpu().numpy()
    for _ in range(len(dset_loaders_src) - 1):
        src_input, src_labels = iter_src.next()
        if use_gpu:
            src_input, src_labels = Variable(src_input).cuda(), Variable(
                src_labels).cuda()
        else:
            src_input, src_labels = Variable(src_input), Variable(src_labels)
        # src_feature_new, _ = feature_network(src_input)
        # print("Src_input: {}".format(src_input))
        # print("Src_shape: {}".format(src_input.shape))
        src_feature_new = G(src_input)
        src_feature_new_de = src_feature_new.detach().cpu().numpy()
        # src_feature_new_de = src_feature_new.detach().cpu().numpy()
        src_feature_de = np.append(src_feature_de, src_feature_new_de, axis=0)
        # src_feature = torch.cat((src_feature, src_feature_new), 0)
    print("Pass Source")

    # prepare target feature

    dsets_tar = ImageList(tar_list, transform=prep_dict)
    dset_loaders_tar = util_data.DataLoader(dsets_tar,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)

    iter_tar = iter(dset_loaders_tar)
    tar_input, _ = iter_tar.next()
    if use_gpu:
        tar_input, _ = Variable(tar_input).cuda(), Variable(_).cuda()
    else:
        src_input, _ = Variable(tar_input), Variable(_)
    # tar_feature, _ = feature_network(tar_input)
    tar_feature = G(tar_input)
    tar_feature_de = tar_feature.detach().cpu().numpy()
    # tar_feature_de = tar_feature.detach().cpu().numpy()
    for _ in range(len(dset_loaders_tar) - 1):
        tar_input, _ = iter_tar.next()
        if use_gpu:
            tar_input, _ = Variable(tar_input).cuda(), Variable(_).cuda()
        else:
            src_input, _ = Variable(tar_input), Variable(_)
        # tar_feature_new, _ = feature_network(tar_input)
        tar_feature_new = G(tar_input)
        tar_feature_new_de = tar_feature_new.detach().cpu().numpy()
        # tar_feature_new_de = tar_feature_new.detach().cpu().numpy()
        tar_feature_de = np.append(tar_feature_de, tar_feature_new_de, axis=0)
        # tar_feature = torch.cat((tar_feature, tar_feature_new), 0)

    print("Pass Target")
    # prepare validation feature and predicted label for validation

    dsets_val = ImageList(val_list, transform=prep_dict)
    dset_loaders_val = util_data.DataLoader(dsets_val,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)

    iter_val = iter(dset_loaders_val)
    val_input, val_labels = iter_val.next()
    if use_gpu:
        val_input, val_labels = Variable(val_input).cuda(), Variable(
            val_labels).cuda()
    else:
        val_input, val_labels = Variable(val_input), Variable(val_labels)
    # val_feature, _ = feature_network(val_input)
    # _, pred_label = predict_network(val_input)
    val_feature = G(val_input)
    pred_label = F1(val_feature)
    val_feature_de = val_feature.detach().cpu().numpy()
    # val_feature_de = val_feature.detach().cpu().numpy()

    w = pred_label[0].shape[0]
    error = np.zeros(1)

    error[0] = predict_loss(val_labels[0].item(),
                            pred_label[0].reshape(1, w)).item()
    error = error.reshape(1, 1)
    print("Before the final")
    print(pred_label.shape)
    print(len(val_feature_de))
    for num_image in range(1, len(pred_label)):
        new_error = np.zeros(1)
        single_pred_label = pred_label[num_image]
        w = single_pred_label.shape[0]
        single_val_label = val_labels[num_image]
        new_error[0] = predict_loss(single_val_label.item(),
                                    single_pred_label.reshape(1, w)).item()
        new_error = new_error.reshape(1, 1)
        error = np.append(error, new_error, axis=0)

    for _ in range(len(dset_loaders_val) - 1):
        val_input, val_labels = iter_val.next()
        if use_gpu:
            val_input, val_labels = Variable(val_input).cuda(), Variable(
                val_labels).cuda()
        else:
            val_input, val_labels = Variable(val_input), Variable(val_labels)
        # val_feature_new, _ = feature_network(val_input)
        val_feature_new = G(val_input)

        val_feature_new_de = val_feature_new.detach().cpu().numpy()
        # val_feature_new_de = val_feature_new.detach().cpu().numpy()
        val_feature_de = np.append(val_feature_de, val_feature_new_de, axis=0)
        # val_feature = torch.cat((val_feature, val_feature_new), 0)
        # _, pred_label = predict_network(val_input)

        pred_label = F1(val_feature_new)
        for num_image in range(len(pred_label)):
            new_error = np.zeros(1)
            single_pred_label = pred_label[num_image]
            w = single_pred_label.shape[0]
            single_val_label = val_labels[num_image]
            new_error[0] = predict_loss(single_val_label.item(),
                                        single_pred_label.reshape(1,
                                                                  w)).item()
            new_error = new_error.reshape(1, 1)
            error = np.append(error, new_error, axis=0)
    print("Pass validation")
    #     print("Insides the for loop")
    #     print(len(error))
    #     print(len(val_feature_de))
    #
    # print("Input for scrore calculation: ")
    # print(len(error))
    # print(len(val_feature_de))
    print(src_feature_de)
    print(tar_feature_de)
    print(val_feature_de)
    weight = get_weight(src_feature_de, tar_feature_de, val_feature_de)
    print(weight)
    print(error)
    cross_val_loss = cross_val_loss + get_dev_risk(weight, error)
    print(cross_val_loss)

    return cross_val_loss
Пример #5
0
def cross_validation_loss(feature_network_path, predict_network_path, src_list,
                          target_path, val_list, class_num, resize_size,
                          crop_size, batch_size, use_gpu, opt):
    """
    Main function for computing the CV loss
    :param feature_network:
    :param predict_network:
    :param src_cls_list:
    :param target_path:
    :param val_cls_list:
    :param class_num:
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    netF = models._netF(opt)
    netC = models._netC(opt, class_num)
    netF.load_state_dict(torch.load(feature_network_path))
    netC.load_state_dict(torch.load(predict_network_path))
    if use_gpu:
        netF.cuda()
        netC.cuda()

    val_list = seperate_data.dimension_rd(val_list)

    tar_list = open(target_path).readlines()
    cross_val_loss = 0

    mean = np.array([0.44, 0.44, 0.44])
    std = np.array([0.19, 0.19, 0.19])
    transform_target = transforms.Compose([
        transforms.Resize(resize_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    # prep_dict = prep.image_train(resize_size=resize_size, crop_size=crop_size)
    # load different class's image

    dsets_src = ImageList(src_list, transform=transform_target)
    dset_loaders_src = util_data.DataLoader(dsets_src,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=2)
    dsets_val = ImageList(val_list, transform=transform_target)
    dset_loaders_val = util_data.DataLoader(dsets_val,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=2)
    dsets_tar = ImageList(tar_list, transform=transform_target)
    dset_loaders_tar = util_data.DataLoader(dsets_tar,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=2)

    # prepare source feature
    iter_src = iter(dset_loaders_src)
    src_input, src_labels = iter_src.next()
    if use_gpu:
        src_input, src_labels = Variable(src_input).cuda(), Variable(
            src_labels).cuda()
    else:
        src_input, src_labels = Variable(src_input), Variable(src_labels)
    # src_feature, _ = feature_network(src_input)
    src_feature = netF(src_input)
    src_feature_de = src_feature.detach().cpu().numpy()
    for _ in range(len(dset_loaders_src) - 1):
        src_input, src_labels = iter_src.next()
        if use_gpu:
            src_input, src_labels = Variable(src_input).cuda(), Variable(
                src_labels).cuda()
        else:
            src_input, src_labels = Variable(src_input), Variable(src_labels)
        # src_feature_new, _ = feature_network(src_input)
        src_feature_new = netF(src_input)
        src_feature_new_de = src_feature_new.detach().cpu().numpy()
        src_feature_de = np.append(src_feature_de, src_feature_new_de, axis=0)
        # src_feature = torch.cat((src_feature, src_feature_new), 0)

    # prepare target feature
    iter_tar = iter(dset_loaders_tar)
    tar_input, _ = iter_tar.next()
    if use_gpu:
        tar_input, _ = Variable(tar_input).cuda(), Variable(_).cuda()
    else:
        src_input, _ = Variable(tar_input), Variable(_)
    # tar_feature, _ = feature_network(tar_input)
    tar_feature = netF(tar_input)
    tar_feature_de = tar_feature.detach().cpu().numpy()
    k = 0
    for _ in range(len(dset_loaders_tar) - 1):
        k = k + 1
        print(k)
        tar_input, _ = iter_tar.next()
        if use_gpu:
            tar_input, _ = Variable(tar_input).cuda(), Variable(_).cuda()
        else:
            src_input, _ = Variable(tar_input), Variable(_)
        # tar_feature_new, _ = feature_network(tar_input)
        tar_feature_new = netF(tar_input)
        tar_feature_new_de = tar_feature_new.detach().cpu().numpy()
        tar_feature_de = np.append(tar_feature_de, tar_feature_new_de, axis=0)
        # tar_feature = torch.cat((tar_feature, tar_feature_new), 0)

    # prepare validation feature and predicted label for validation
    iter_val = iter(dset_loaders_val)
    val_input, val_labels = iter_val.next()
    if use_gpu:
        val_input, val_labels = Variable(val_input).cuda(), Variable(
            val_labels).cuda()
    else:
        val_input, val_labels = Variable(val_input), Variable(val_labels)
    # val_feature, _ = feature_network(val_input)
    # _, pred_label = predict_network(val_input)
    val_feature = netF(val_input)
    pred_label = netC(netF(val_input))
    val_feature_de = val_feature.detach().cpu().numpy()

    w = pred_label[0].shape[0]
    error = np.zeros(1)

    error[0] = predict_loss(val_labels[0].item(),
                            pred_label[0].reshape(1, w)).item()
    error = error.reshape(1, 1)
    print("Before the final")
    print(pred_label.shape)
    print(len(val_feature_de))
    for num_image in range(1, len(pred_label)):
        single_pred_label = pred_label[num_image]
        w = single_pred_label.shape[0]
        single_val_label = val_labels[num_image]
        error = np.append(error, [[
            predict_loss(single_val_label.item(),
                         single_pred_label.reshape(1, w)).item()
        ]],
                          axis=0)

    for _ in range(len(dset_loaders_val) - 1):
        val_input, val_labels = iter_val.next()
        if use_gpu:
            val_input, val_labels = Variable(val_input).cuda(), Variable(
                val_labels).cuda()
        else:
            val_input, val_labels = Variable(val_input), Variable(val_labels)
        # val_feature_new, _ = feature_network(val_input)
        val_feature_new = netF(val_input)
        val_feature_new_de = val_feature_new.detach().cpu().numpy()
        val_feature_de = np.append(val_feature_de, val_feature_new_de, axis=0)
        # val_feature = torch.cat((val_feature, val_feature_new), 0)
        # _, pred_label = predict_network(val_input)
        pred_label = netC(netF(val_input))
        for num_image in range(len(pred_label)):
            single_pred_label = pred_label[num_image]
            w = single_pred_label.shape[0]
            single_val_label = val_labels[num_image]
            error = np.append(error, [[
                predict_loss(single_val_label.item(),
                             single_pred_label.reshape(1, w)).item()
            ]],
                              axis=0)
    #     print("Insides the for loop")
    #     print(len(error))
    #     print(len(val_feature_de))
    #
    # print("Input for scrore calculation: ")
    # print(len(error))
    # print(len(val_feature_de))
    weight = get_weight(src_feature_de, tar_feature_de, val_feature_de)
    cross_val_loss = cross_val_loss + get_dev_risk(weight, error)

    return cross_val_loss
Пример #6
0
def cross_validation_loss(args, feature_network_path, predict_network_path, num_layer, src_list, target_path, val_list, class_num,
                          resize_size, crop_size, batch_size, use_gpu):
    """
    Main function for computing the CV loss
    :param feature_network:
    :param predict_network:
    :param src_cls_list:
    :param target_path:
    :param val_cls_list:
    :param class_num:
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    target_list_no_label = open(target_path).readlines()
    cross_val_loss = 0
    save_path = args.save
    # load network
    option = 'resnet' + args.resnet
    G = ResBase(option)
    F1 = ResClassifier(num_layer=num_layer)
    G.load_state_dict(torch.load(feature_network_path))
    F1.load_state_dict(torch.load(predict_network_path))
    if use_gpu:
        G.cuda()
        F1.cuda()
    G.eval()
    F1.eval()
    print("Loaded network")
    # add pesudolabel for target data
    print("Sperating target data")
    tar_list = []
    dsets_tar = ImageList(target_list_no_label,
                          transform=prep.image_train(resize_size=resize_size, crop_size=crop_size))
    dset_loaders_tar = util_data.DataLoader(dsets_tar, batch_size=batch_size, shuffle=False, num_workers=4)
    len_train_target = len(dset_loaders_tar)
    iter_target = iter(dset_loaders_tar)
    count = 0
    for i in range(len_train_target):
        input_tar, label_tar = iter_target.next()
        if use_gpu:
            input_tar, label_tar = Variable(input_tar).cuda(), Variable(label_tar).cuda()
        else:
            input_tar, label_tar = Variable(input_tar), Variable(label_tar)
        tar_feature = G(input_tar)
        predict_score = F1(tar_feature)
        _, pre_lab = torch.max(predict_score, 1)
        predict_label = pre_lab.detach()
        for num in range(len(predict_label.cpu())):
            if target_list_no_label[count][-3] == ' ':
                ind = -2
            else:
                ind = -3
            tar_list.append(target_list_no_label[count][:ind])
            tar_list[count] = tar_list[count] + str(predict_label[num].cpu().numpy()) + "\n"
            count += 1
    val_list = seperate_data.dimension_rd(val_list)
    print("Seperated")
    # load the dataloader for whole data
    prep_dict = prep.image_train(resize_size=resize_size, crop_size=crop_size)
    dsets_src = ImageList(src_list, transform=prep_dict)
    dset_loaders_src = util_data.DataLoader(dsets_src, batch_size=batch_size, shuffle=True, num_workers=4)
    dsets_val = ImageList(val_list, transform=prep_dict)
    dset_loaders_val = util_data.DataLoader(dsets_val, batch_size=batch_size, shuffle=True, num_workers=4)
    dsets_tar = ImageList(tar_list, transform=prep_dict)
    dset_loaders_tar = util_data.DataLoader(dsets_tar, batch_size=batch_size, shuffle=True, num_workers=4)

    # iterate through different classes
    for cls in range(class_num):
        # prepare source feature
        count_src = 0
        src_feature_de = np.array([])
        iter_src = iter(dset_loaders_src)
        while src_feature_de.size == 0:
            src_input, src_labels = iter_src.next()
            for i in range(len(src_labels)):
                if src_labels[i].item() == cls:
                    a, b, c = src_input[i].shape
                    if use_gpu:
                        src_pre_input = Variable(src_input[i]).cuda()
                    else:
                        src_pre_input = Variable(src_input[i])
                    src_input_final = src_pre_input.reshape(1, a, b, c)
                    if src_feature_de.size == 0:
                        src_feature = G(src_input_final)
                        src_feature_de = src_feature.detach().detach().cpu().numpy()
                    else:
                        src_feature_new = G(src_input_final)
                        src_feature_new_de = src_feature_new.detach().cpu().numpy()
                        src_feature_de = np.append(src_feature_de, src_feature_new_de, axis=0)
            count_src = count_src + 1
        for _ in range(len(dset_loaders_src) - count_src):
            src_input, src_labels = iter_src.next()
            for i in range(len(src_labels)):
                if src_labels[i].item() == cls:
                    a, b, c = src_input[i].shape
                    if use_gpu:
                        src_pre_input = Variable(src_input[i]).cuda()
                    else:
                        src_pre_input = Variable(src_input[i])
                    src_input_final = src_pre_input.reshape(1, a, b, c)
                    src_feature_new = G(src_input_final)
                    src_feature_new_de = src_feature_new.detach().cpu().numpy()
                    src_feature_de = np.append(src_feature_de, src_feature_new_de, axis=0)
        print("Pass Source for Class{}".format(cls + 1))
        print("Created feature: {}".format(src_feature_de.shape))

        # prepare target fature
        count_tar = 0
        tar_feature_de = np.array([])
        iter_tar = iter(dset_loaders_tar)
        while tar_feature_de.size == 0:
            tar_input, tar_labels = iter_tar.next()
            for i in range(len(tar_labels)):
                if tar_labels[i].item() == cls:
                    a, b, c = tar_input[i].shape
                    if use_gpu:
                        tar_pre_input = Variable(tar_input[i]).cuda()
                    else:
                        tar_pre_input = Variable(tar_input[i])
                    tar_input_final = tar_pre_input.reshape(1, a, b, c)
                    tar_feature = G(tar_input_final)
                    if tar_feature_de.size == 0:
                        tar_feature_de = tar_feature.detach().cpu().numpy()
                    else:
                        tar_feature_new_de = tar_feature.detach().cpu().numpy()
                        tar_feature_de = np.append(tar_feature_de, tar_feature_new_de, axis=0)
            count_tar = count_tar + 1
        for _ in range(len(dset_loaders_tar) - count_tar):
            tar_input, tar_labels = iter_tar.next()
            for i in range(len(tar_labels)):
                if tar_labels[i].item() == cls:
                    a, b, c = tar_input[i].shape
                    if use_gpu:
                        tar_pre_input = Variable(tar_input[i]).cuda()
                    else:
                        tar_pre_input = Variable(tar_input[i])
                    tar_input_final = tar_pre_input.reshape(1, a, b, c)
                    tar_feature_new = G(tar_input_final)
                    tar_feature_new_de = tar_feature_new.detach().cpu().numpy()
                    tar_feature_de = np.append(tar_feature_de, tar_feature_new_de, axis=0)
        print("Pass Target for Class: {}".format(cls + 1))
        print("Created feature: {}".format(tar_feature_de.shape))
        # prepare validation feature and errors

        count_val = 0
        val_feature_de = np.array([])
        iter_val = iter(dset_loaders_val)
        while val_feature_de.size == 0:
            val_input, val_labels = iter_val.next()
            for i in range(len(val_labels)):
                if val_labels[i].item() == cls:
                    a, b, c = val_input[i].shape
                    if use_gpu:
                        val_pre_input, val_labels_final = Variable(val_input[i]).cuda(), Variable(val_labels[i]).cuda()
                    else:
                        val_pre_input, val_labels_final = Variable(val_input[i]), Variable(val_labels[i])
                    val_input_final = val_pre_input.reshape(1, a, b, c)
                    val_feature = G(val_input_final)
                    pred_label = F1(val_feature)
                    w = pred_label.shape[1]
                    if val_feature_de.size == 0:
                        # feature and error
                        val_feature_de = val_feature.detach().cpu().numpy()
                        error = np.zeros(1)
                        error[0] = predict_loss(val_labels_final.item(), pred_label.reshape(1, w)).item()
                        error = error.reshape(1, 1)
                        print(error)
                    else:
                        # feature and error
                        val_feature_new_de = val_feature.detach().cpu().numpy()
                        val_feature_de = np.append(val_feature_de, val_feature_new_de, axis=0)
                        new_error = np.zeros(1)
                        new_error = new_error.reshape(1, 1)
                        new_error[0] = predict_loss(val_labels_final.item(), pred_label.reshape(1, w)).item()
                        error = np.append(error, new_error, axis=0)
            count_val = count_val + 1
        for _ in range(len(dset_loaders_val) - count_val):
            val_input, val_labels = iter_val.next()
            for i in range(len(val_labels)):
                if val_labels[i].item() == cls:
                    a, b, c = val_input[i].shape
                    if use_gpu:
                        val_pre_input, val_labels_final = Variable(val_input[i]).cuda(), Variable(val_labels[i]).cuda()
                    else:
                        val_pre_input, val_labels_final = Variable(val_input[i]), Variable(val_labels[i])
                    val_input_final = val_pre_input.reshape(1, a, b, c)
                    val_feature = G(val_input_final)
                    pred_label = F1(val_feature)
                    w = pred_label.shape[1]
                    val_feature_new_de = val_feature.detach().cpu().numpy()
                    val_feature_de = np.append(val_feature_de, val_feature_new_de, axis=0)
                    new_error = np.zeros(1)
                    new_error = new_error.reshape(1, 1)
                    new_error[0] = predict_loss(val_labels_final.item(), pred_label.reshape(1, w)).item()
                    error = np.append(error, new_error, axis=0)
        print("Pass Validation for Class: {}".format(cls + 1))
        print("Created error shape: {}".format(error.shape))
        print("Created feature: {}".format(val_feature_de.shape))
        # calculating the weight and the score for each class

        if not os.path.exists(args.save.split("/")[0] + "/feature_np/"):
            os.makedirs(args.save.split("/")[0] + "/feature_np/")

        np.save(args.save.split("/")[0]+ "/feature_np/" + str(cls) + "_" + "src_feature_de.npy", src_feature_de)
        np.save(args.save.split("/")[0]+ "/feature_np/" + str(cls) + "_" + "tar_feature_de.npy", tar_feature_de)
        np.save(args.save.split("/")[0]+ "/feature_np/" + str(cls) + "_" + "val_feature_de.npy", val_feature_de)
        src_feature_path = args.save.split("/")[0]+ "/feature_np/" + str(cls) + "_" + "src_feature_de.npy"
        tar_feature_path = args.save.split("/")[0]+ "/feature_np/" + str(cls) + "_" + "tar_feature_de.npy"
        val_feature_path = args.save.split("/")[0]+ "/feature_np/" + str(cls) + "_" + "val_feature_de.npy"
        weight = get_weight(src_feature_path, tar_feature_path, val_feature_path)
        cross_val_loss = cross_val_loss + get_dev_risk(weight, error) / class_num

    return cross_val_loss
Пример #7
0
def cross_validation_loss(args, feature_network_path, predict_network_path,
                          num_layer, src_cls_list, target_path, val_cls_list,
                          class_num, resize_size, crop_size, batch_size,
                          use_gpu):
    """
    Main function for computing the CV loss
    :param feature_network:
    :param predict_network:
    :param src_cls_list:
    :param target_path:
    :param val_cls_list:
    :param class_num:
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    option = 'resnet' + args.resnet
    G = ResBase(option)
    F1 = ResClassifier(num_layer=num_layer)
    G.load_state_dict(torch.load(feature_network_path))
    F1.load_state_dict(torch.load(predict_network_path))

    if use_gpu:
        G.cuda()
        F1.cuda()
    G.eval()
    F1.eval()

    val_cls_list = seperate_data.dimension_rd(val_cls_list)
    prep_dict_val = prep.image_train(resize_size=resize_size,
                                     crop_size=crop_size)
    # load different class's image
    dsets_val = ImageList(val_cls_list, transform=prep_dict_val)
    dset_loaders_val = util_data.DataLoader(dsets_val,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)

    # prepare validation feature and predicted label for validation
    iter_val = iter(dset_loaders_val)
    val_input, val_labels = iter_val.next()
    if use_gpu:
        val_input, val_labels = Variable(val_input).cuda(), Variable(
            val_labels).cuda()
    else:
        val_input, val_labels = Variable(val_input), Variable(val_labels)

    feature_val = G(val_input)
    pred_label = F1(feature_val)
    w = pred_label[0].shape[0]
    error = np.zeros(1)
    error[0] = predict_loss(val_labels[0], pred_label[0].view(1, w))[0]
    print("Error: {}".format(error[0]))
    error = error.reshape(1, 1)
    print(error)
    for num_image in range(1, len(pred_label)):
        new_error = np.zeros(1)
        single_pred_label = pred_label[num_image]
        w = single_pred_label.shape[0]
        single_val_label = val_labels[num_image]
        new_error[0] = predict_loss(single_val_label,
                                    single_pred_label.view(1, w))[0]
        new_error = new_error.reshape(1, 1)
        error = np.append(error, new_error, axis=0)

    for _ in range(len(iter_val) - 1):
        val_input, val_labels = iter_val.next()
        if use_gpu:
            val_input, val_labels = Variable(val_input).cuda(), Variable(
                val_labels).cuda()
        else:
            val_input, val_labels = Variable(val_input), Variable(val_labels)
        feature_val = G(val_input)
        pred_label = F1(feature_val)
        # _, pred_label = predict_network(val_input)
        for num_image in range(len(pred_label)):
            new_error = np.zeros(1)

            single_pred_label = pred_label[num_image]
            w = single_pred_label.shape[0]
            single_val_label = val_labels[num_image]

            new_error[0] = predict_loss(single_val_label,
                                        single_pred_label.view(1, w))[0]
            new_error = new_error.reshape(1, 1)

            error = np.append(error, new_error, axis=0)
    print("Error: {}".format(error))
    cross_val_loss = error.sum()
    return cross_val_loss
Пример #8
0
def cross_validation_loss(feature_network, predict_network, src_cls_list,
                          target_path, val_cls_list, class_num, resize_size,
                          crop_size, batch_size, use_gpu):
    """
    Main function for computing the CV loss
    :param feature_network:
    :param predict_network:
    :param src_cls_list:
    :param target_path:
    :param val_cls_list:
    :param class_num:
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    val_cls_list = seperate_data.dimension_rd(val_cls_list)
    prep_dict_val = prep.image_train(resize_size=resize_size,
                                     crop_size=crop_size)
    # load different class's image
    dsets_val = ImageList(val_cls_list, transform=prep_dict_val)
    dset_loaders_val = util_data.DataLoader(dsets_val,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)

    # prepare validation feature and predicted label for validation
    iter_val = iter(dset_loaders_val)
    val_input, val_labels = iter_val.next()
    if use_gpu:
        val_input, val_labels = Variable(val_input).cuda(), Variable(
            val_labels).cuda()
    else:
        val_input, val_labels = Variable(val_input), Variable(val_labels)

    _, pred_label = predict_network(val_input)

    w = pred_label[0].shape[0]

    error = np.zeros(1)
    error[0] = predict_loss(val_labels[0].item(),
                            pred_label[0].reshape(1, w)).item()
    error = error.reshape(1, 1)
    for num_image in range(1, len(pred_label)):
        single_pred_label = pred_label[num_image]
        w = single_pred_label.shape[0]
        single_val_label = val_labels[num_image]
        error = np.append(error, [[
            predict_loss(single_val_label.item(),
                         single_pred_label.reshape(1, w)).item()
        ]],
                          axis=0)

    for _ in range(len(iter_val) - 1):
        val_input, val_labels = iter_val.next()
        if use_gpu:
            val_input, val_labels = Variable(val_input).cuda(), Variable(
                val_labels).cuda()
        else:
            val_input, val_labels = Variable(val_input), Variable(val_labels)
        _, pred_label = predict_network(val_input)
        for num_image in range(len(pred_label)):
            single_pred_label = pred_label[num_image]
            w = single_pred_label.shape[0]
            single_val_label = val_labels[num_image]
            error = np.append(error, [[
                predict_loss(single_val_label.item(),
                             single_pred_label.reshape(1, w)).item()
            ]],
                              axis=0)

    cross_val_loss = error.sum()
    # for cls in range(class_num):
    #
    #     dsets_val = ImageList(val_cls_list[cls], transform=prep_dict_val)
    #     dset_loaders_val = util_data.DataLoader(dsets_val, batch_size=batch_size, shuffle=True, num_workers=4)
    #
    #     # prepare validation feature and predicted label for validation
    #     iter_val = iter(dset_loaders_val)
    #     val_input, val_labels = iter_val.next()
    #     if use_gpu:
    #         val_input, val_labels = Variable(val_input).cuda(), Variable(val_labels).cuda()
    #     else:
    #         val_input, val_labels = Variable(val_input), Variable(val_labels)
    #     val_feature, _ = feature_network(val_input)
    #     _, pred_label = predict_network(val_input)
    #     w, h = pred_label.shape
    #     error = np.zeros(1)
    #     error[0] = predict_loss(cls, pred_label.reshape(1, w*h)).numpy()
    #     error = error.reshape(1,1)
    #     for _ in range(len(val_cls_list[cls]) - 1):
    #         val_input, val_labels = iter_val.next()
    #         # val_feature1 = feature_network(val_input)
    #         val_feature_new, _ = feature_network(val_input)
    #         val_feature = np.append(val_feature, val_feature_new, axis=0)
    #         error = np.append(error, [[predict_loss(cls, predict_network(val_input)[1]).numpy()]], axis=0)
    #
    #     print('The class is {}\n'.format(cls))
    #
    #     cross_val_loss = cross_val_loss + error.sum()
    return cross_val_loss
Пример #9
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataroot',
                        required=True,
                        help='path to source dataset')
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batchSize',
                        type=int,
                        default=100,
                        help='input batch size')
    parser.add_argument(
        '--imageSize',
        type=int,
        default=32,
        help='the height / width of the input image to network')
    parser.add_argument('--nz',
                        type=int,
                        default=512,
                        help='size of the latent z vector')
    parser.add_argument(
        '--ngf',
        type=int,
        default=64,
        help='Number of filters to use in the generator network')
    parser.add_argument(
        '--ndf',
        type=int,
        default=64,
        help='Number of filters to use in the discriminator network')
    parser.add_argument('--gpu',
                        type=int,
                        default=1,
                        help='GPU to use, -1 for CPU training')
    parser.add_argument('--checkpoint_dir',
                        default='results/models',
                        help='folder to load model checkpoints from')
    parser.add_argument('--method',
                        default='GTA',
                        help='Method to evaluate| GTA, sourceonly')
    parser.add_argument(
        '--model_best',
        type=int,
        default=0,
        help=
        'Flag to specify whether to use the best validation model or last checkpoint| 1-model best, 0-current checkpoint'
    )
    parser.add_argument('--src_path',
                        type=str,
                        default='digits/server_svhn_list.txt',
                        help='path for source dataset txt file')
    parser.add_argument('--tar_path',
                        type=str,
                        default='digits/server_mnist_list.txt',
                        help='path for target dataset txt file')
    parser.add_argument('--val_method',
                        type=str,
                        default='Source_Risk',
                        choices=['Source_Risk', 'Dev_icml', 'Dev'])
    opt = parser.parse_args()

    # GPU/CPU flags
    cudnn.benchmark = True
    if torch.cuda.is_available() and opt.gpu == -1:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --gpu [gpu id]"
        )
    if opt.gpu >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu)
        use_gpu = True

    # Creating data loaders
    mean = np.array([0.44, 0.44, 0.44])
    std = np.array([0.19, 0.19, 0.19])

    if 'svhn' in opt.src_path and 'mnist' in opt.tar_path:
        test_adaptation = 's->m'
    elif 'usps' in opt.src_path and 'mnist' in opt.tar_path:
        test_adaptation = 'u->m'
    else:
        test_adaptation = 'm->u'

    if test_adaptation == 'u->m' or test_adaptation == 's->m':
        target_root = os.path.join(opt.dataroot, 'mnist/trainset')

        transform_target = transforms.Compose([
            transforms.Resize(opt.imageSize),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        target_test = dset.ImageFolder(root=target_root,
                                       transform=transform_target)
        nclasses = len(target_test.classes)
    elif test_adaptation == 'm->u':
        transform_usps = transforms.Compose([
            transforms.Resize(opt.imageSize),
            transforms.Grayscale(3),
            transforms.ToTensor(),
            transforms.Normalize((0.44, ), (0.19, ))
        ])
        target_test = usps.USPS(root=opt.dataroot,
                                train=True,
                                transform=transform_usps,
                                download=True)
        nclasses = 10
    # target_root = os.path.join(opt.dataroot, 'mnist/trainset')
    #
    # transform_target = transforms.Compose([transforms.Resize(opt.imageSize), transforms.ToTensor(), transforms.Normalize(mean,std)])
    # target_test = dset.ImageFolder(root=target_root, transform=transform_target)
    targetloader = torch.utils.data.DataLoader(target_test,
                                               batch_size=opt.batchSize,
                                               shuffle=False,
                                               num_workers=2)

    # Creating and loading models

    netF = models._netF(opt)
    netC = models._netC(opt, nclasses)

    if opt.method == 'GTA':
        if opt.model_best == 0:
            netF_path = os.path.join(opt.checkpoint_dir, 'netF.pth')
            netC_path = os.path.join(opt.checkpoint_dir, 'netC.pth')
        else:
            netF_path = os.path.join(opt.checkpoint_dir, 'model_best_netF.pth')
            netC_path = os.path.join(opt.checkpoint_dir, 'model_best_netC.pth')

    elif opt.method == 'sourceonly':
        if opt.model_best == 0:
            netF_path = os.path.join(opt.checkpoint_dir, 'netF_sourceonly.pth')
            netC_path = os.path.join(opt.checkpoint_dir, 'netC_sourceonly.pth')
        else:
            netF_path = os.path.join(opt.checkpoint_dir,
                                     'model_best_netF_sourceonly.pth')
            netC_path = os.path.join(opt.checkpoint_dir,
                                     'model_best_netC_sourceonly.pth')
    else:
        raise ValueError('method argument should be sourceonly or GTA')

    netF.load_state_dict(torch.load(netF_path))
    netC.load_state_dict(torch.load(netC_path))

    if opt.gpu >= 0:
        netF.cuda()
        netC.cuda()

    # Testing

    netF.eval()
    netC.eval()

    total = 0
    correct = 0

    for i, datas in enumerate(targetloader):
        inputs, labels = datas
        if opt.gpu >= 0:
            inputs, labels = inputs.cuda(), labels.cuda()
        inputv, labelv = Variable(inputs, volatile=True), Variable(labels)

        outC = netC(netF(inputv))
        _, predicted = torch.max(outC.data, 1)
        total += labels.size(0)
        correct += ((predicted == labels.cuda()).sum())

    test_acc = 100 * float(correct) / total
    print('Test Accuracy: %f %%' % (test_acc))

    cls_source_list, cls_validation_list = sep.split_set(
        opt.src_path, nclasses)
    source_list = sep.dimension_rd(cls_source_list)
    # outC = netC(netF(inputv)) 是算 classification的
    # outF = netF(inputv)) 是算 feature的
    # netF.load_state_dict(torch.load(netF_path)) 是加载网络的 方式
    # crop size 不用
    if opt.val_method == 'Source_Risk':
        cv_loss = source_risk.cross_validation_loss(
            netF_path, netC_path, cls_source_list, opt.tar_path,
            cls_validation_list, nclasses, opt.imageSize, 224, opt.batchSize,
            use_gpu, opt)
    elif opt.val_method == 'Dev_icml':
        cv_loss = dev_icml.cross_validation_loss(netF_path, netC_path,
                                                 source_list, opt.tar_path,
                                                 cls_validation_list, nclasses,
                                                 opt.imageSize, 224,
                                                 opt.batchSize, use_gpu, opt)
    elif opt.val_method == 'Dev':
        cv_loss = dev.cross_validation_loss(netF_path, netC_path,
                                            cls_source_list, opt.tar_path,
                                            cls_validation_list, nclasses,
                                            opt.imageSize, 224, opt.batchSize,
                                            use_gpu, opt)
    print(cv_loss)
Пример #10
0
def cross_validation_loss(feature_network, predict_network, src_list,
                          target_path, val_list, class_num, resize_size,
                          crop_size, batch_size, use_gpu):
    """
    Main function for computing the CV loss
    :param feature_network:
    :param predict_network:
    :param src_cls_list:
    :param target_path:
    :param val_cls_list:
    :param class_num:
    :param resize_size:
    :param crop_size:
    :param batch_size:
    :return:
    """
    val_list = seperate_data.dimension_rd(val_list)

    tar_list = open(target_path).readlines()
    cross_val_loss = 0

    prep_dict = prep.image_train(resize_size=resize_size, crop_size=crop_size)
    # load different class's image

    dsets_src = ImageList(src_list, transform=prep_dict)
    dset_loaders_src = util_data.DataLoader(dsets_src,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)
    dsets_val = ImageList(val_list, transform=prep_dict)
    dset_loaders_val = util_data.DataLoader(dsets_val,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)
    dsets_tar = ImageList(tar_list, transform=prep_dict)
    dset_loaders_tar = util_data.DataLoader(dsets_tar,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=4)

    # prepare source feature
    iter_src = iter(dset_loaders_src)
    src_input, src_labels = iter_src.next()
    if use_gpu:
        src_input, src_labels = Variable(src_input).cuda(), Variable(
            src_labels).cuda()
    else:
        src_input, src_labels = Variable(src_input), Variable(src_labels)
    src_feature, _ = feature_network(src_input)
    src_feature_de = src_feature.detach().cpu().numpy()
    for _ in range(len(dset_loaders_src) - 1):
        src_input, src_labels = iter_src.next()
        if use_gpu:
            src_input, src_labels = Variable(src_input).cuda(), Variable(
                src_labels).cuda()
        else:
            src_input, src_labels = Variable(src_input), Variable(src_labels)
        src_feature_new, _ = feature_network(src_input)
        src_feature_new_de = src_feature_new.detach().cpu().numpy()
        src_feature_de = np.append(src_feature_de, src_feature_new_de, axis=0)
        # src_feature = torch.cat((src_feature, src_feature_new), 0)

    # prepare target feature
    iter_tar = iter(dset_loaders_tar)
    tar_input, _ = iter_tar.next()
    if use_gpu:
        tar_input, _ = Variable(tar_input).cuda(), Variable(_).cuda()
    else:
        src_input, _ = Variable(tar_input), Variable(_)
    tar_feature, _ = feature_network(tar_input)
    tar_feature_de = tar_feature.detach().cpu().numpy()
    for _ in range(len(dset_loaders_tar) - 1):
        tar_input, _ = iter_tar.next()
        if use_gpu:
            tar_input, _ = Variable(tar_input).cuda(), Variable(_).cuda()
        else:
            src_input, _ = Variable(tar_input), Variable(_)
        tar_feature_new, _ = feature_network(tar_input)
        tar_feature_new_de = src_feature_new.detach().cpu().numpy()
        tar_feature_de = np.append(tar_feature_de, tar_feature_new_de, axis=0)
        # tar_feature = torch.cat((tar_feature, tar_feature_new), 0)

    # prepare validation feature and predicted label for validation
    iter_val = iter(dset_loaders_val)
    val_input, val_labels = iter_val.next()
    if use_gpu:
        val_input, val_labels = Variable(val_input).cuda(), Variable(
            val_labels).cuda()
    else:
        val_input, val_labels = Variable(val_input), Variable(val_labels)
    val_feature, _ = feature_network(val_input)
    _, pred_label = predict_network(val_input)
    val_feature_de = val_feature.detach().cpu().numpy()

    w = pred_label[0].shape[0]
    error = np.zeros(1)

    error[0] = predict_loss(val_labels[0].item(),
                            pred_label[0].reshape(1, w)).item()
    error = error.reshape(1, 1)
    print("Before the final")
    print(pred_label.shape)
    print(len(val_feature_de))
    for num_image in range(1, len(pred_label)):
        single_pred_label = pred_label[num_image]
        w = single_pred_label.shape[0]
        single_val_label = val_labels[num_image]
        error = np.append(error, [[
            predict_loss(single_val_label.item(),
                         single_pred_label.reshape(1, w)).item()
        ]],
                          axis=0)

    for _ in range(len(dset_loaders_val) - 1):
        val_input, val_labels = iter_val.next()
        if use_gpu:
            val_input, val_labels = Variable(val_input).cuda(), Variable(
                val_labels).cuda()
        else:
            val_input, val_labels = Variable(val_input), Variable(val_labels)
        val_feature_new, _ = feature_network(val_input)

        val_feature_new_de = val_feature_new.detach().cpu().numpy()
        val_feature_de = np.append(val_feature_de, val_feature_new_de, axis=0)
        # val_feature = torch.cat((val_feature, val_feature_new), 0)
        _, pred_label = predict_network(val_input)
        for num_image in range(len(pred_label)):
            single_pred_label = pred_label[num_image]
            w = single_pred_label.shape[0]
            single_val_label = val_labels[num_image]
            error = np.append(error, [[
                predict_loss(single_val_label.item(),
                             single_pred_label.reshape(1, w)).item()
            ]],
                              axis=0)
    #     print("Insides the for loop")
    #     print(len(error))
    #     print(len(val_feature_de))
    #
    # print("Input for scrore calculation: ")
    # print(len(error))
    # print(len(val_feature_de))
    weight = get_weight(src_feature_de, tar_feature_de, val_feature_de)
    cross_val_loss = cross_val_loss + get_dev_risk(weight, error)

    return cross_val_loss
Пример #11
0
def train(num_epoch, option, num_layer, test_load, cuda):
    criterion = nn.CrossEntropyLoss().cuda()
    if not test_load:
        for ep in range(num_epoch):
            G.train()
            F1.train()
            F2.train()
            for batch_idx, data in enumerate(dataset):
                if batch_idx * batch_size > 30000:
                    break
                if args.cuda:
                    data1 = data['S']
                    target1 = data['S_label']
                    data2 = data['T']
                    target2 = data['T_label']
                    data1, target1 = data1.cuda(), target1.cuda()
                    data2, target2 = data2.cuda(), target2.cuda()
                # when pretraining network source only
                eta = 1.0
                data = Variable(torch.cat((data1, data2), 0))
                target1 = Variable(target1)
                # Step A train all networks to minimize loss on source
                optimizer_g.zero_grad()
                optimizer_f.zero_grad()
                output = G(data)
                output1 = F1(output)
                output2 = F2(output)

                output_s1 = output1[:batch_size, :]
                output_s2 = output2[:batch_size, :]
                output_t1 = output1[batch_size:, :]
                output_t2 = output2[batch_size:, :]
                output_t1 = F.softmax(output_t1)
                output_t2 = F.softmax(output_t2)

                entropy_loss = -torch.mean(
                    torch.log(torch.mean(output_t1, 0) + 1e-6))
                entropy_loss -= torch.mean(
                    torch.log(torch.mean(output_t2, 0) + 1e-6))

                loss1 = criterion(output_s1, target1)
                loss2 = criterion(output_s2, target1)
                all_loss = loss1 + loss2 + 0.01 * entropy_loss
                all_loss.backward()
                optimizer_g.step()
                optimizer_f.step()

                # Step B train classifier to maximize discrepancy
                optimizer_g.zero_grad()
                optimizer_f.zero_grad()

                output = G(data)
                output1 = F1(output)
                output2 = F2(output)
                output_s1 = output1[:batch_size, :]
                output_s2 = output2[:batch_size, :]
                output_t1 = output1[batch_size:, :]
                output_t2 = output2[batch_size:, :]
                output_t1 = F.softmax(output_t1)
                output_t2 = F.softmax(output_t2)
                loss1 = criterion(output_s1, target1)
                loss2 = criterion(output_s2, target1)
                entropy_loss = -torch.mean(
                    torch.log(torch.mean(output_t1, 0) + 1e-6))
                entropy_loss -= torch.mean(
                    torch.log(torch.mean(output_t2, 0) + 1e-6))
                loss_dis = torch.mean(torch.abs(output_t1 - output_t2))
                F_loss = loss1 + loss2 - eta * loss_dis + 0.01 * entropy_loss
                F_loss.backward()
                optimizer_f.step()
                # Step C train genrator to minimize discrepancy
                for i in range(num_k):
                    optimizer_g.zero_grad()
                    output = G(data)
                    output1 = F1(output)
                    output2 = F2(output)

                    output_s1 = output1[:batch_size, :]
                    output_s2 = output2[:batch_size, :]
                    output_t1 = output1[batch_size:, :]
                    output_t2 = output2[batch_size:, :]

                    loss1 = criterion(output_s1, target1)
                    loss2 = criterion(output_s2, target1)
                    output_t1 = F.softmax(output_t1)
                    output_t2 = F.softmax(output_t2)
                    loss_dis = torch.mean(torch.abs(output_t1 - output_t2))
                    entropy_loss = -torch.mean(
                        torch.log(torch.mean(output_t1, 0) + 1e-6))
                    entropy_loss -= torch.mean(
                        torch.log(torch.mean(output_t2, 0) + 1e-6))

                    loss_dis.backward()
                    optimizer_g.step()
                if batch_idx % args.log_interval == 0:
                    print(
                        'Train Ep: {} [{}/{} ({:.0f}%)]\tLoss1: {:.6f}\tLoss2: {:.6f}\t Dis: {:.6f} Entropy: {:.6f}'
                        .format(ep, batch_idx * len(data), 70000,
                                100. * batch_idx / 70000, loss1.data[0],
                                loss2.data[0], loss_dis.data[0],
                                entropy_loss.data[0]))
                if batch_idx == 1 and ep > 1:
                    test(ep)
                    G.train()
                    F1.train()
                    F2.train()
    else:
        G_load = ResBase(option)
        F1_load = ResClassifier(num_layer=num_layer)
        F2_load = ResClassifier(num_layer=num_layer)

        F1_load.apply(weights_init)
        F2_load.apply(weights_init)
        G_path = args.load_network_path + 'G.pth'
        F1_path = args.load_network_path + 'F1.pth'
        F2_path = args.load_network_path + 'F2.pth'
        G_load.load_state_dict(torch.load(G_path))
        F1_load.load_state_dict(torch.load(F1_path))
        F2_load.load_state_dict(torch.load(F2_path))
        #
        # G_load = torch.load('whole_model_G.pth')
        # F1_load = torch.load('whole_model_F1.pth')
        # F2_load = torch.load('whole_model_F2.pth')
        if cuda:
            G_load.cuda()
            F1_load.cuda()
            F2_load.cuda()
        G_load.eval()
        F1_load.eval()
        F2_load.eval()
        test_loss = 0
        correct = 0
        correct2 = 0
        size = 0

        val = False
        for batch_idx, data in enumerate(dataset_test):
            if batch_idx * batch_size > 5000:
                break
            if args.cuda:
                data2 = data['T']
                target2 = data['T_label']
                if val:
                    data2 = data['S']
                    target2 = data['S_label']
                data2, target2 = data2.cuda(), target2.cuda()
            data1, target1 = Variable(data2, volatile=True), Variable(target2)
            output = G_load(data1)
            output1 = F1_load(output)
            output2 = F2_load(output)
            # print("Feature: {}\n Predict_value: {}".format(output, output2))
            test_loss += F.nll_loss(output1, target1).item()
            pred = output1.data.max(1)[
                1]  # get the index of the max log-probability
            correct += pred.eq(target1.data).cpu().sum()
            pred = output2.data.max(1)[
                1]  # get the index of the max log-probability
            k = target1.data.size()[0]
            correct2 += pred.eq(target1.data).cpu().sum()

            size += k
        test_loss = test_loss
        test_loss /= len(
            test_loader)  # loss function already averages over batch size
        print(
            '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%) ({:.0f}%)\n'
            .format(test_loss, correct, size, 100. * correct / size,
                    100. * correct2 / size))
        acc = 100. * correct / size
        f.write('Accuracy is:' + str(acc) + '%' + '\n')
        value = max(100. * correct / size, 100. * correct2 / size)
        print("Value: {}".format(value))
        if args.cuda:
            use_gpu = True
        else:
            use_gpu = False
        # configure network path

        if (100. * correct / size) > (100. * correct2 / size):
            predict_network_path = F1_path
        else:
            predict_network_path = F2_path

        feature_network_path = G_path

        # configure the datapath for target and test
        source_path = 'chn_training_list.txt'
        target_path = 'chn_validation_list.txt'
        cls_source_list, cls_validation_list = sep.split_set(
            source_path, class_num)
        source_list = sep.dimension_rd(cls_source_list)

        if args.validation_method == 'Source_Risk':
            cv_loss = source_risk.cross_validation_loss(
                args, feature_network_path, predict_network_path, num_layer,
                cls_source_list, target_path, cls_validation_list, class_num,
                256, 224, batch_size, use_gpu)
        elif args.validation_method == 'Dev_icml':
            cv_loss = dev_icml.cross_validation_loss(
                args, feature_network_path, predict_network_path, num_layer,
                source_list, target_path, cls_validation_list, class_num, 256,
                224, batch_size, use_gpu)
        else:
            cv_loss = dev.cross_validation_loss(args, feature_network_path,
                                                predict_network_path,
                                                num_layer, source_list,
                                                target_path,
                                                cls_validation_list, class_num,
                                                256, 224, batch_size, use_gpu)
        print(cv_loss)
        f.write(args.validation_method + ' Validation is:' + str(cv_loss) +
                '\n')